Skip to content

Commit 8a8957e

Browse files
abidlabsAbubakarcontrastive-vae
authored
Add trackio.save() (#351)
Co-authored-by: Abubakar <abubakar@Abubakars-MacBook-Pro.local> Co-authored-by: Abubakar Abid <aaabid93@gmail.com>
1 parent 8abe691 commit 8a8957e

29 files changed

Lines changed: 1032 additions & 696 deletions

.changeset/sweet-suns-know.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"trackio": minor
3+
---
4+
5+
feat:Add `trackio.save()`

docs/source/api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
[[autodoc]] finish
1818

19+
## save
20+
21+
[[autodoc]] save
22+
1923
## show
2024

2125
[[autodoc]] show

examples/files/config1.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
model:
2+
name: "simple_mlp"
3+
layers: [128, 64, 32]
4+
activation: "relu"
5+
dropout: 0.2
6+
7+
training:
8+
epochs: 10
9+
batch_size: 32
10+
learning_rate: 0.001
11+
optimizer: "adam"
12+
13+
data:
14+
dataset: "mnist"
15+
train_split: 0.8
16+
normalize: true
17+

examples/files/config2.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
model:
2+
name: "cnn"
3+
layers: [32, 64, 128]
4+
activation: "relu"
5+
dropout: 0.3
6+
7+
training:
8+
epochs: 15
9+
batch_size: 64
10+
learning_rate: 0.0005
11+
optimizer: "sgd"
12+
13+
data:
14+
dataset: "cifar10"
15+
train_split: 0.75
16+
normalize: true
17+
augmentation: true
18+

examples/files/models/model1.pth

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
FAKE_MODEL_FILE_1
2+
This is a fake PyTorch model checkpoint file.
3+
In reality, this would contain binary model weights.
4+
For testing purposes, this is just a text placeholder.
5+

examples/files/models/model2.pth

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
FAKE_MODEL_FILE_2
2+
This is a fake PyTorch model checkpoint file.
3+
In reality, this would contain binary model weights.
4+
For testing purposes, this is just a text placeholder.
5+
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import random
2+
import time
3+
from pathlib import Path
4+
5+
import trackio as wandb
6+
7+
PROJECT_ID = random.randint(100000, 999999)
8+
EPOCHS = 10
9+
10+
current_dir = Path(__file__).parent
11+
12+
13+
def main():
14+
wandb.init(
15+
project=f"test-save-{PROJECT_ID}",
16+
name="test-run",
17+
config={
18+
"epochs": EPOCHS,
19+
"learning_rate": 0.001,
20+
"batch_size": 32,
21+
},
22+
)
23+
24+
wandb.save(current_dir / "config1.yml")
25+
wandb.save(current_dir / "config2.yml")
26+
wandb.save(current_dir / "models/*.pth")
27+
28+
for epoch in range(EPOCHS):
29+
loss = 2.0 * (1 - epoch / EPOCHS) + random.uniform(-0.1, 0.1)
30+
accuracy = 0.5 + 0.4 * (epoch / EPOCHS) + random.uniform(-0.05, 0.05)
31+
32+
wandb.log(
33+
{
34+
"loss": round(loss, 4),
35+
"accuracy": round(accuracy, 4),
36+
}
37+
)
38+
39+
time.sleep(0.1)
40+
41+
wandb.finish()
42+
print(f"* Test completed. Check project 'test-save-{PROJECT_ID}' for saved files.")
43+
44+
45+
if __name__ == "__main__":
46+
main()

examples/test-save.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import random
2+
import time
3+
4+
import trackio as wandb
5+
6+
PROJECT_ID = random.randint(100000, 999999)
7+
EPOCHS = 10
8+
9+
10+
def main():
11+
wandb.init(
12+
project=f"test-save-{PROJECT_ID}",
13+
name="test-run",
14+
config={
15+
"epochs": EPOCHS,
16+
"learning_rate": 0.001,
17+
"batch_size": 32,
18+
},
19+
)
20+
21+
wandb.save("files/config1.yml")
22+
wandb.save("files/config2.yml")
23+
wandb.save("files/models/*.pth")
24+
25+
for epoch in range(EPOCHS):
26+
loss = 2.0 * (1 - epoch / EPOCHS) + random.uniform(-0.1, 0.1)
27+
accuracy = 0.5 + 0.4 * (epoch / EPOCHS) + random.uniform(-0.05, 0.05)
28+
29+
wandb.log(
30+
{
31+
"loss": round(loss, 4),
32+
"accuracy": round(accuracy, 4),
33+
}
34+
)
35+
36+
time.sleep(0.1)
37+
38+
wandb.finish()
39+
print(f"* Test completed. Check project 'test-save-{PROJECT_ID}' for saved files.")
40+
41+
42+
if __name__ == "__main__":
43+
main()

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def temp_dir(monkeypatch):
1414
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
1515
for name in ["trackio.sqlite_storage"]:
1616
monkeypatch.setattr(f"{name}.TRACKIO_DIR", Path(tmpdir))
17-
for name in ["trackio.media.media", "trackio.media.file_storage"]:
17+
for name in ["trackio.media.media", "trackio.media.utils"]:
1818
monkeypatch.setattr(f"{name}.MEDIA_DIR", Path(tmpdir) / "media")
1919
yield tmpdir
2020

tests/e2e/test_import_from_tf.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)