Skip to content

Commit 586e65a

Browse files
AkiSakuraiabidlabs
andauthored
use orjson for faster JSON parsing (#253)
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
1 parent 73eca21 commit 586e65a

3 files changed

Lines changed: 18 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"gradio>=5.46.0,<6.0.0",
1818
"numpy<3.0.0",
1919
"pillow<12.0.0",
20+
"orjson>=3.0,<4.0.0"
2021
]
2122
classifiers = [
2223
"Programming Language :: Python :: 3",

tests/test_sqlite_storage.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tempfile
77
import time
88

9+
import orjson
910
import pytest
1011

1112
from trackio.sqlite_storage import SQLiteStorage
@@ -203,8 +204,6 @@ def test_config_storage_in_database(temp_dir):
203204

204205
def test_old_database_without_configs_table(temp_dir):
205206
# To make sure that we can continue to work with projects created with older versions of Trackio.
206-
import json
207-
208207
db_path = SQLiteStorage.get_project_db_path("test")
209208
db_path.parent.mkdir(parents=True, exist_ok=True)
210209

@@ -220,7 +219,7 @@ def test_old_database_without_configs_table(temp_dir):
220219
""")
221220
conn.execute(
222221
"INSERT INTO metrics (timestamp, run_name, step, metrics) VALUES (?, ?, ?, ?)",
223-
("2024-01-01", "test_run", 0, json.dumps({"loss": 0.5})),
222+
("2024-01-01", "test_run", 0, orjson.dumps({"loss": 0.5})),
224223
)
225224

226225
config = SQLiteStorage.get_run_config("test", "test_run")

trackio/sqlite_storage.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
import platform
43
import sqlite3
@@ -13,6 +12,7 @@
1312
fcntl = None
1413

1514
import huggingface_hub as hf
15+
import orjson
1616
import pandas as pd
1717

1818
try: # absolute imports when installed from PyPI
@@ -166,7 +166,7 @@ def export_to_parquet():
166166
metrics = df["metrics"].copy()
167167
metrics = pd.DataFrame(
168168
metrics.apply(
169-
lambda x: deserialize_values(json.loads(x))
169+
lambda x: deserialize_values(orjson.loads(x))
170170
).values.tolist(),
171171
index=df.index,
172172
)
@@ -196,9 +196,9 @@ def import_from_parquet():
196196
for col in other_cols:
197197
del metrics[col]
198198
# combine them all into a single metrics col
199-
metrics = json.loads(metrics.to_json(orient="records"))
199+
metrics = orjson.loads(metrics.to_json(orient="records"))
200200
df["metrics"] = [
201-
json.dumps(serialize_values(row)) for row in metrics
201+
orjson.dumps(serialize_values(row)) for row in metrics
202202
]
203203
df.to_sql("metrics", conn, if_exists="replace", index=False)
204204

@@ -273,7 +273,7 @@ def log(project: str, run: str, metrics: dict, step: int | None = None):
273273
current_timestamp,
274274
run,
275275
current_step,
276-
json.dumps(serialize_values(metrics)),
276+
orjson.dumps(serialize_values(metrics)),
277277
),
278278
)
279279
conn.commit()
@@ -335,7 +335,7 @@ def bulk_log(
335335
timestamps[i],
336336
run,
337337
steps[i],
338-
json.dumps(serialize_values(metrics)),
338+
orjson.dumps(serialize_values(metrics)),
339339
)
340340
)
341341

@@ -356,7 +356,11 @@ def bulk_log(
356356
(run_name, config, created_at)
357357
VALUES (?, ?, ?)
358358
""",
359-
(run, json.dumps(serialize_values(config)), current_timestamp),
359+
(
360+
run,
361+
orjson.dumps(serialize_values(config)),
362+
current_timestamp,
363+
),
360364
)
361365

362366
conn.commit()
@@ -383,7 +387,7 @@ def get_logs(project: str, run: str) -> list[dict]:
383387
rows = cursor.fetchall()
384388
results = []
385389
for row in rows:
386-
metrics = json.loads(row["metrics"])
390+
metrics = orjson.loads(row["metrics"])
387391
metrics = deserialize_values(metrics)
388392
metrics["timestamp"] = row["timestamp"]
389393
metrics["step"] = row["step"]
@@ -490,7 +494,7 @@ def store_config(project: str, run: str, config: dict) -> None:
490494
(run_name, config, created_at)
491495
VALUES (?, ?, ?)
492496
""",
493-
(run, json.dumps(serialize_values(config)), current_timestamp),
497+
(run, orjson.dumps(serialize_values(config)), current_timestamp),
494498
)
495499
conn.commit()
496500

@@ -513,7 +517,7 @@ def get_run_config(project: str, run: str) -> dict | None:
513517

514518
row = cursor.fetchone()
515519
if row:
516-
config = json.loads(row["config"])
520+
config = orjson.loads(row["config"])
517521
return deserialize_values(config)
518522
return None
519523
except sqlite3.OperationalError as e:
@@ -557,7 +561,7 @@ def get_all_run_configs(project: str) -> dict[str, dict]:
557561

558562
results = {}
559563
for row in cursor.fetchall():
560-
config = json.loads(row["config"])
564+
config = orjson.loads(row["config"])
561565
results[row["run_name"]] = deserialize_values(config)
562566
return results
563567
except sqlite3.OperationalError as e:

0 commit comments

Comments
 (0)