1- import json
21import os
32import platform
43import sqlite3
1312 fcntl = None
1413
1514import huggingface_hub as hf
15+ import orjson
1616import pandas as pd
1717
1818try : # absolute imports when installed from PyPI
@@ -166,7 +166,7 @@ def export_to_parquet():
166166 metrics = df ["metrics" ].copy ()
167167 metrics = pd .DataFrame (
168168 metrics .apply (
169- lambda x : deserialize_values (json .loads (x ))
169+ lambda x : deserialize_values (orjson .loads (x ))
170170 ).values .tolist (),
171171 index = df .index ,
172172 )
@@ -196,9 +196,9 @@ def import_from_parquet():
196196 for col in other_cols :
197197 del metrics [col ]
198198 # combine them all into a single metrics col
199- metrics = json .loads (metrics .to_json (orient = "records" ))
199+ metrics = orjson .loads (metrics .to_json (orient = "records" ))
200200 df ["metrics" ] = [
201- json .dumps (serialize_values (row )) for row in metrics
201+ orjson .dumps (serialize_values (row )) for row in metrics
202202 ]
203203 df .to_sql ("metrics" , conn , if_exists = "replace" , index = False )
204204
@@ -273,7 +273,7 @@ def log(project: str, run: str, metrics: dict, step: int | None = None):
273273 current_timestamp ,
274274 run ,
275275 current_step ,
276- json .dumps (serialize_values (metrics )),
276+ orjson .dumps (serialize_values (metrics )),
277277 ),
278278 )
279279 conn .commit ()
@@ -335,7 +335,7 @@ def bulk_log(
335335 timestamps [i ],
336336 run ,
337337 steps [i ],
338- json .dumps (serialize_values (metrics )),
338+ orjson .dumps (serialize_values (metrics )),
339339 )
340340 )
341341
@@ -356,7 +356,11 @@ def bulk_log(
356356 (run_name, config, created_at)
357357 VALUES (?, ?, ?)
358358 """ ,
359- (run , json .dumps (serialize_values (config )), current_timestamp ),
359+ (
360+ run ,
361+ orjson .dumps (serialize_values (config )),
362+ current_timestamp ,
363+ ),
360364 )
361365
362366 conn .commit ()
@@ -383,7 +387,7 @@ def get_logs(project: str, run: str) -> list[dict]:
383387 rows = cursor .fetchall ()
384388 results = []
385389 for row in rows :
386- metrics = json .loads (row ["metrics" ])
390+ metrics = orjson .loads (row ["metrics" ])
387391 metrics = deserialize_values (metrics )
388392 metrics ["timestamp" ] = row ["timestamp" ]
389393 metrics ["step" ] = row ["step" ]
@@ -490,7 +494,7 @@ def store_config(project: str, run: str, config: dict) -> None:
490494 (run_name, config, created_at)
491495 VALUES (?, ?, ?)
492496 """ ,
493- (run , json .dumps (serialize_values (config )), current_timestamp ),
497+ (run , orjson .dumps (serialize_values (config )), current_timestamp ),
494498 )
495499 conn .commit ()
496500
@@ -513,7 +517,7 @@ def get_run_config(project: str, run: str) -> dict | None:
513517
514518 row = cursor .fetchone ()
515519 if row :
516- config = json .loads (row ["config" ])
520+ config = orjson .loads (row ["config" ])
517521 return deserialize_values (config )
518522 return None
519523 except sqlite3 .OperationalError as e :
@@ -557,7 +561,7 @@ def get_all_run_configs(project: str) -> dict[str, dict]:
557561
558562 results = {}
559563 for row in cursor .fetchall ():
560- config = json .loads (row ["config" ])
564+ config = orjson .loads (row ["config" ])
561565 results [row ["run_name" ]] = deserialize_values (config )
562566 return results
563567 except sqlite3 .OperationalError as e :
0 commit comments