Skip to content

Commit 54775a7

Browse files
committed
feat: move seed into stats to support multi-seed runs
1 parent f5db04b commit 54775a7

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

src/balatrollm/benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def analyze_strategy_runs(self, strategy_dir: Path) -> None:
8989
output_dir / vendor_dir.name / f"{model_dir.name}.json"
9090
)
9191
model_stats_path.parent.mkdir(exist_ok=True, parents=True)
92+
stats_dict = asdict(model_stats)
93+
stats_dict["config"].pop("seed")
9294
with open(model_stats_path, "w") as f:
93-
json.dump(asdict(model_stats), f, indent=2)
95+
json.dump(stats_dict, f, indent=2)
9496

9597
# Create detailed run directories
9698
detailed_output_dir = output_dir / vendor_dir.name

src/balatrollm/data_collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Stats:
4141
completed: bool
4242
ante_reached: int
4343
final_round: int
44+
seed: str
4445
providers: list[str]
4546
calls: CallStats
4647
total: AggregatedStats
@@ -59,6 +60,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Stats":
5960
completed=data["completed"],
6061
ante_reached=data["ante_reached"],
6162
final_round=data["final_round"],
63+
seed=data["seed"],
6264
providers=data["providers"],
6365
calls=calls,
6466
total=total,
@@ -276,6 +278,7 @@ def calculate_stats(self) -> Stats:
276278
completed=state["state"] == 4, # 4 is GAME_OVER gamestate
277279
ante_reached=state["game"]["round_resets"]["ante"],
278280
final_round=state["game"]["round"],
281+
seed=self.config.seed,
279282
providers=stats["providers"],
280283
calls=call_stats,
281284
total=total,

0 commit comments

Comments
 (0)