Skip to content

Commit 23ef648

Browse files
committed
Add mzn-bench plot-cactus to create plots to compare solved instances
1 parent fa39a48 commit 23ef648

File tree

5 files changed

+339
-4
lines changed

5 files changed

+339
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ mzn-bench = 'mzn_bench.cli:main'
1919

2020
[project.optional-dependencies]
2121
scripts = ["pandas>=2.2.3", "pytest>=8.3.4", "tabulate>=0.9.0"]
22-
plotting = ["bokeh>=3.6.2"]
22+
plotting = ["bokeh>=3.6.2", "matplotlib>=3.10.1", "seaborn>=0.13.2"]

src/mzn_bench/analysis/plot.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,57 @@
99
from bokeh.palettes import Palette, Spectral5
1010
from bokeh.plotting import figure, gridplot
1111
from bokeh.transform import factor_cmap
12+
import matplotlib.pyplot as plt
13+
import seaborn as sns
14+
15+
16+
def plot_cactus(stats: pd.DataFrame):
17+
configurations = stats["configuration"].unique()
18+
19+
frames = []
20+
for conf in configurations:
21+
# Filter statistics to find completed instances
22+
conf_stats = stats[
23+
(stats["configuration"] == conf)
24+
& (
25+
(stats["status"] == "OPTIMAL_SOLUTION")
26+
| ((stats["status"] == "SATISFIED") & (stats["method"] == "satisfy"))
27+
)
28+
]
29+
30+
# Extract solving time and sort in ascending order
31+
t = pd.DataFrame({"time": sorted(conf_stats["time"])})
32+
33+
# Add the position in the column (i.e. 1..n) as the number of instances
34+
# solved in up to the time in that row
35+
t["n_solved"] = list(range(1, 1 + len(t)))
36+
37+
# Label with the associated configuration
38+
t["configuration"] = conf
39+
frames.append(t)
40+
41+
data = pd.concat(frames, ignore_index=True)
42+
43+
fig, ax = plt.subplots(figsize=(12, 5))
44+
sns.lineplot(
45+
ax=ax,
46+
data=data,
47+
y="time",
48+
x="n_solved",
49+
hue="configuration",
50+
style="configuration",
51+
markers=True,
52+
dashes=False,
53+
)
54+
ax.set(
55+
title="Comparison of solved instances between different configurations",
56+
ylabel="CPU time(seconds)",
57+
xlabel="# of instances solved",
58+
)
59+
sns.move_legend(ax, "upper left", bbox_to_anchor=(1.01, 1), borderaxespad=0)
60+
fig.tight_layout()
61+
62+
return fig
1263

1364

1465
def plot_all_instances(

src/mzn_bench/cli.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def collect_instances(shared_data: Optional[str], benchmarks_location: str):
7474
"--param",
7575
"-p",
7676
help="Additional solution parameters to add to each row of the CSV file",
77-
default=["configuration"],
77+
default=[],
7878
multiple=True,
7979
)
8080
@click.argument("dirs", nargs=-1, type=click.Path(exists=True, dir_okay=True))
@@ -102,9 +102,12 @@ def collect_objectives_(
102102
count = 0
103103
additional_params = list(additional_params)
104104
with Path(out_file).open(mode="w") as file:
105+
labels = STANDARD_KEYS.copy()
106+
labels.remove("status") # No need to output SAT every time
107+
labels = labels + ["run", "objective"] + additional_params
105108
writer = csv.DictWriter(
106109
file,
107-
STANDARD_KEYS + ["run", "objective"] + additional_params,
110+
labels,
108111
dialect="unix",
109112
extrasaction="ignore",
110113
)
@@ -419,5 +422,66 @@ def compare_configurations(
419422
exit(1)
420423

421424

425+
@main.command()
426+
@click.argument(
427+
"objectives", metavar="objs_file", type=click.Path(exists=True, file_okay=True)
428+
)
429+
@click.argument(
430+
"statistics", metavar="stats_file", type=click.Path(exists=True, file_okay=True)
431+
)
432+
@click.argument("out_file", type=click.Path(file_okay=True))
433+
def plot_all_instances(
434+
objectives: str,
435+
statistics: str,
436+
out_file: str,
437+
):
438+
"""Plot all instances in a grid
439+
440+
STATS_FILE is the CSV file containing aggregated statistics data
441+
OBJS_FILE is the CSV file containing aggregated solutions data
442+
OUT_FILE is the file to write the plot to
443+
"""
444+
try:
445+
from .analysis.collect import read_csv
446+
from .analysis.plot import plot_all_instances as fn
447+
from bokeh.plotting import save
448+
449+
objs, stats = read_csv(objectives, statistics)
450+
figure = fn(objs, stats)
451+
452+
save(figure, filename=out_file)
453+
except ImportError:
454+
click.echo(IMPORT_ERROR, err=True)
455+
exit(1)
456+
457+
458+
@main.command()
459+
@click.argument(
460+
"statistics", metavar="stats_file", type=click.Path(exists=True, file_okay=True)
461+
)
462+
@click.argument("out_file", type=click.Path(file_okay=True))
463+
def plot_cactus(
464+
statistics: str,
465+
out_file: str,
466+
):
467+
"""Plots all configurations in a cactus plot of solved instances
468+
469+
STATS_FILE is the CSV file containing aggregated statistics data
470+
OUT_FILE is the file to write the plot to
471+
"""
472+
try:
473+
import pandas as pd
474+
from .analysis.plot import plot_cactus as fn
475+
476+
stats = pd.read_csv(statistics)
477+
stats.data_file = stats.data_file.fillna("")
478+
fig = fn(stats)
479+
fig.savefig(out_file)
480+
481+
except ImportError:
482+
click.echo(IMPORT_ERROR, err=True)
483+
exit(1)
484+
485+
422486
if __name__ == "__main__":
423487
main()

src/mzn_bench/mzn_slurm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
import time
88
import traceback
9-
from dataclasses import asdict, dataclass, field, fields
9+
from dataclasses import dataclass, field, fields
1010
from datetime import timedelta
1111
from pathlib import Path
1212
from typing import Any, Dict, Iterable, List, NoReturn, Optional

0 commit comments

Comments
 (0)