-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathgenerate_graphs.py
More file actions
67 lines (53 loc) · 2.44 KB
/
generate_graphs.py
File metadata and controls
67 lines (53 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import json
import random
from utils import relabel_and_name_vertices, generate_ba_graph, generate_ws_graph, generate_delaunay_triangulation
import networkx as nx
from networkx.readwrite import json_graph
import matplotlib.pyplot as plt
GRAPH_MODELS = {
"ws": generate_ws_graph,
"ba": generate_ba_graph,
"dt": generate_delaunay_triangulation,
}
OUTPUT_DIR = "graphs"
def get_graph_path(graph_model, graph_size, num_sample):
filename = f"graph_{graph_model}_{graph_size}_{num_sample}.json"
return os.path.join(OUTPUT_DIR, filename)
def get_graph(graph_model, graph_size, num_sample):
filepath = get_graph_path(graph_model, graph_size, num_sample)
print(f"Loading graph from {filepath}")
with open(filepath) as f:
graph_dict = json.load(f)
return json_graph.node_link_graph(graph_dict["graph"], edges="links")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--graph_models", type=str, nargs="+", default=["ws", "ba", "dt"])
parser.add_argument("--samples_per_graph_model", type=int, default=3)
parser.add_argument("--graph_sizes", type=int, nargs="+", default=[4, 8, 16])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
os.makedirs(OUTPUT_DIR, exist_ok=True)
random.seed(args.seed)
graph_seeds = [random.randint(1, 10000) for _ in range(args.samples_per_graph_model)]
for graph_model in args.graph_models:
for graph_size in args.graph_sizes:
for i in range(args.samples_per_graph_model):
graph = GRAPH_MODELS[graph_model](graph_size, seed=graph_seeds[i])
graph = relabel_and_name_vertices(graph)
filepath = get_graph_path(graph_model, graph_size, i)
with open(filepath, "w") as f:
json.dump({
"graph": json_graph.node_link_data(graph, edges="links"),
"num_nodes": len(graph.nodes()),
"diameter": nx.diameter(graph),
"max_degree": max(dict(graph.degree()).values()),
"graph_seed": graph_seeds[i],
"seed": args.seed,
}, f, indent=4)
plt.figure()
nx.draw(graph, with_labels=True)
plt.savefig(filepath[:-5] + ".png")
plt.close()
print(f"Graph saved to {filepath}")