-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathselect_best_models.py
More file actions
138 lines (117 loc) · 2.78 KB
/
select_best_models.py
File metadata and controls
138 lines (117 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# pylint: disable=no-member
import argparse
import os
import os.path as osp
# Local imports
from lambdaml.deploy import select_best_models
from lambdaml.log import set_global_log_level
from lambdaml.util import load_yaml
argparser = argparse.ArgumentParser(
description="Select model from Optuna database with SQL query"
)
argparser.add_argument(
"--log_level",
type=str,
default="INFO",
choices=[
"debug",
"info",
"warning",
"error",
"critical",
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
],
help="Log level",
)
argparser.add_argument(
"--config",
type=str,
default=None,
help="Path to config yaml",
)
argparser.add_argument(
"--n_best_trials",
type=int,
default=1,
help="Storage URL for Optuna",
)
argparser.add_argument(
"--optuna_storage_url",
type=str,
default="sqlite:///optuna_study.db",
help="Storage URL for Optuna",
)
argparser.add_argument(
"--optuna_study_name",
type=str,
default="model_hpo",
help="Optuna study name",
)
argparser.add_argument(
"--optuna_study_direction",
type=str,
default="maximize",
choices=["minimize", "maximize"],
help="Optuna study direction",
)
argparser.add_argument(
"--registry",
type=str,
default=(
os.environ["LAMBDAML_REGISTRY"]
if "LAMBDAML_REGISTRY" in os.environ
else "app/registry"
),
help="Flask app directory where model state dictionaries and parameters will be copied",
)
argparser.add_argument(
"--codename_separator",
type=str,
default="-",
help="Codename separator",
)
argparser.add_argument(
"--encoder_fname",
type=str,
default="encoder.pt",
help="Encoder model state dictionary file name",
)
argparser.add_argument(
"--encoder_params_fname",
type=str,
default="encoder_params.json",
help="Encoder parameters json file name",
)
argparser.add_argument(
"--clf_fname",
type=str,
default="clf.pt",
help="Classifier model state dictionary file name",
)
argparser.add_argument(
"--clf_params_fname",
type=str,
default="clf_params.json",
help="Classifier parameters json file name",
)
# Parse arguments and initialize argument dictionary
args_raw = argparser.parse_args()
args = {}
# Check if arguments have been supplied by yaml
if args_raw.config is not None and osp.exists(args_raw.config):
args = load_yaml(args_raw.config)
# Otherwise take them from command line
else:
args = vars(args_raw)
# Set log level
set_global_log_level(args["log_level"])
args.pop("log_level")
# Remove config argument
args.pop("config")
# Select top n models and copy model states
# and parameter definitions to flask app directory
select_best_models(**args)