Skip to content

Commit 8d98b3f

Browse files
authored
Fix deployment to Spaces (#146)
* Fix deployment to Spaces * revert changes * revert more * revert * simplify * final
1 parent a9db2d5 commit 8d98b3f

5 files changed

Lines changed: 95 additions & 67 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ tests/__pycache__/
44
.trackio/
55
trackio.db
66
*.pyc
7+
.venv/

examples/deploy-on-spaces.py

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,87 @@
1+
import math
12
import random
23
import time
34

4-
from tqdm import tqdm
5-
65
import trackio as wandb
76

8-
project_id = random.randint(10000, 99999)
9-
10-
wandb.init(
11-
project=f"fake-training-{project_id}",
12-
name="test-run",
13-
config=dict(
14-
epochs=5,
15-
learning_rate=0.001,
16-
batch_size=32,
17-
),
18-
space_id=f"trackio-{project_id}",
19-
)
20-
21-
EPOCHS = 5
22-
NUM_TRAIN_BATCHES = 100
23-
NUM_VAL_BATCHES = 20
24-
25-
for epoch in range(EPOCHS):
26-
train_loss = 0
27-
train_accuracy = 0
28-
val_loss = 0
29-
val_accuracy = 0
30-
31-
for _ in tqdm(range(NUM_TRAIN_BATCHES), desc=f"Epoch {epoch + 1} - Training"):
32-
loss = random.uniform(0.2, 1.0)
33-
accuracy = random.uniform(0.6, 0.95)
34-
train_loss += loss
35-
train_accuracy += accuracy
36-
37-
for _ in tqdm(range(NUM_VAL_BATCHES), desc=f"Epoch {epoch + 1} - Validation"):
38-
loss = random.uniform(0.2, 0.9)
39-
accuracy = random.uniform(0.65, 0.98)
40-
val_loss += loss
41-
val_accuracy += accuracy
42-
43-
train_loss /= NUM_TRAIN_BATCHES
44-
train_accuracy /= NUM_TRAIN_BATCHES
45-
val_loss /= NUM_VAL_BATCHES
46-
val_accuracy /= NUM_VAL_BATCHES
47-
48-
wandb.log(
49-
{
50-
"train_loss": train_loss,
51-
"train_accuracy": train_accuracy,
52-
"val_loss": val_loss,
53-
"val_accuracy": val_accuracy,
54-
}
7+
EPOCHS = 20
8+
PROJECT_ID = random.randint(100000, 999999)
9+
10+
11+
def generate_loss_curve(epoch, max_epochs, base_loss=2.5, min_loss=0.1):
12+
"""Generate a realistic loss curve that decreases over time with noise"""
13+
progress = epoch / max_epochs
14+
base_curve = base_loss * math.exp(-3 * progress) + min_loss
15+
16+
noise_scale = 0.3 * (1 - progress * 0.7)
17+
noise = random.gauss(0, noise_scale)
18+
19+
return max(min_loss * 0.5, base_curve + noise)
20+
21+
22+
def generate_accuracy_curve(epoch, max_epochs, max_acc=0.95, min_acc=0.1):
23+
"""Generate a realistic accuracy curve that increases over time with noise"""
24+
progress = epoch / max_epochs
25+
base_curve = max_acc / (1 + math.exp(-6 * (progress - 0.5))) + min_acc
26+
27+
noise_scale = 0.08 * (1 - progress * 0.5)
28+
noise = random.gauss(0, noise_scale)
29+
30+
return max(0, min(max_acc, base_curve + noise))
31+
32+
33+
for run in range(3):
34+
wandb.init(
35+
project=f"deploy-on-spaces-{PROJECT_ID}",
36+
name=f"test-run-{run}",
37+
config=dict(
38+
epochs=EPOCHS,
39+
learning_rate=0.001,
40+
batch_size=32,
41+
),
42+
space_id=f"trackio-on-spaces-{PROJECT_ID}",
5543
)
56-
time.sleep(1)
44+
45+
for epoch in range(EPOCHS):
46+
train_loss = generate_loss_curve(
47+
epoch,
48+
EPOCHS,
49+
base_loss=random.uniform(2.5, 3.5),
50+
min_loss=random.uniform(0.05, 0.15),
51+
)
52+
val_loss = generate_loss_curve(
53+
epoch,
54+
EPOCHS,
55+
base_loss=random.uniform(2.5, 3.5),
56+
min_loss=random.uniform(0.05, 0.15),
57+
)
58+
59+
train_accuracy = generate_accuracy_curve(
60+
epoch,
61+
EPOCHS,
62+
max_acc=random.uniform(0.7, 0.9),
63+
min_acc=random.uniform(0.1, 0.3),
64+
)
65+
val_accuracy = generate_accuracy_curve(
66+
epoch,
67+
EPOCHS,
68+
max_acc=random.uniform(0.7, 0.9),
69+
min_acc=random.uniform(0.1, 0.3),
70+
)
71+
72+
if epoch > 2 and random.random() < 0.3:
73+
val_loss *= 1.1
74+
val_accuracy *= 0.95
75+
76+
wandb.log(
77+
{
78+
"train_loss": round(train_loss, 4),
79+
"train_accuracy": round(train_accuracy, 4),
80+
"val_loss": round(val_loss, 4),
81+
"val_accuracy": round(val_accuracy, 4),
82+
}
83+
)
84+
85+
time.sleep(0.2)
5786

5887
wandb.finish()

examples/fake-training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,6 @@ def generate_accuracy_curve(epoch, max_epochs, max_acc=0.95, min_acc=0.1):
8181
}
8282
)
8383

84-
time.sleep(0.5)
84+
time.sleep(0.2)
8585

86-
wandb.finish()
86+
wandb.finish()

trackio/run.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from trackio.typehints import LogEntry
99
from trackio.utils import RESERVED_KEYS, fibo, generate_readable_name
1010

11+
BATCH_SEND_INTERVAL = 0.5
12+
1113

1214
class Run:
1315
def __init__(
@@ -33,15 +35,17 @@ def __init__(
3335
self._client_thread.start()
3436

3537
def _batch_sender(self):
36-
"""Send batched logs every 500ms."""
37-
while not self._stop_flag.is_set():
38-
time.sleep(0.5)
38+
"""Send batched logs every BATCH_SEND_INTERVAL."""
39+
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
40+
# If the stop flag has been set, then just quickly send all
41+
# the logs and exit.
42+
if not self._stop_flag.is_set():
43+
time.sleep(BATCH_SEND_INTERVAL)
3944

4045
with self._client_lock:
4146
if self._queued_logs and self._client is not None:
4247
logs_to_send = self._queued_logs.copy()
4348
self._queued_logs.clear()
44-
4549
self._client.predict(
4650
api_name="/bulk_log",
4751
logs=logs_to_send,
@@ -54,6 +58,7 @@ def _init_client_background(self):
5458
for sleep_coefficient in fib:
5559
try:
5660
client = Client(self.url, verbose=False)
61+
5762
with self._client_lock:
5863
self._client = client
5964
break
@@ -85,16 +90,9 @@ def finish(self):
8590
"""Cleanup when run is finished."""
8691
self._stop_flag.set()
8792

88-
with self._client_lock:
89-
if self._queued_logs and self._client is not None:
90-
logs_to_send = self._queued_logs.copy()
91-
self._queued_logs.clear()
92-
self._client.predict(
93-
api_name="/bulk_log",
94-
logs=logs_to_send,
95-
hf_token=huggingface_hub.utils.get_token(),
96-
)
93+
# Wait for the batch sender to finish before joining the client thread.
94+
time.sleep(2 * BATCH_SEND_INTERVAL)
9795

9896
if self._client_thread is not None:
9997
print(f"* Uploading logs to Trackio Space: {self.url} (please wait...)")
100-
self._client_thread.join(timeout=30)
98+
self._client_thread.join()

trackio/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.2.5
1+
0.2.6

0 commit comments

Comments
 (0)