-
Notifications
You must be signed in to change notification settings - Fork 405
Expand file tree
/
Copy pathgolden_config.bzl
More file actions
105 lines (90 loc) · 3.43 KB
/
Copy pathgolden_config.bzl
File metadata and controls
105 lines (90 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Bazel rules for AXLearn."""
load("@rules_python//python:defs.bzl", "py_test")
def golden_config_test(name, module, data = [], match = None, checks = None, deps = [], **kwargs):
"""Creates a golden config test for an experiment module.
Args:
name: Name of the test target.
module: Label of the py_library containing named_trainer_configs().
data: Data files (golden config files).
match: Optional regex to filter config names.
checks: Optional list of check names (e.g., ["check", "check_init"]).
Defaults to ["check", "check_init"].
deps: Additional deps beyond module and golden_config.
**kwargs: Additional arguments passed to py_test.
"""
script_name = name + "_main.py"
_generate_golden_test_script(
name = name + "_gen",
module = module,
output = script_name,
match = match or "",
checks = checks or [],
)
py_test(
name = name,
srcs = [script_name],
data = data,
deps = [module, "//axlearn/common:golden_config"] + deps,
main = script_name,
**kwargs
)
def _generate_golden_test_script_impl(ctx):
module = ctx.attr.module
module_path = module.label.package.replace("/", ".")
if module.label.name != module.label.package.split("/")[-1]:
module_path = module_path + "." + module.label.name
# Testdata subdir name: full dotted module path
testdata_subdir = module_path
# Relative path from generated script to experiments/testdata/
# Script is at {package}/{name}_main.py, testdata is at axlearn/experiments/testdata/
package_parts = module.label.package.split("/")
# Count dirs from package to axlearn/experiments/
# e.g., axlearn/experiments/text/gpt -> need ../../testdata/
experiments_idx = -1
for i, part in enumerate(package_parts):
if part == "experiments":
experiments_idx = i
break
if experiments_idx >= 0:
depth = len(package_parts) - experiments_idx - 1
rel_testdata = "/".join([".."] * depth) + "/testdata/" + testdata_subdir
else:
rel_testdata = "testdata/" + testdata_subdir
match_arg = ""
if ctx.attr.match:
match_arg = ', match="{}"'.format(ctx.attr.match)
checks_arg = ""
if ctx.attr.checks:
checks_list = ", ".join(["golden_config." + c for c in ctx.attr.checks])
checks_arg = ", checks=({},)".format(checks_list)
module_name = module_path.split(".")[-1]
parent_module = ".".join(module_path.split(".")[:-1])
content = """\
import os
from axlearn.common import golden_config
from {parent_module} import {module_name}
if __name__ == "__main__":
golden_config.test_main(
{module_name},
os.path.join(os.path.dirname(__file__), "{rel_testdata}"){match_arg}{checks_arg},
)
""".format(
parent_module = parent_module,
module_name = module_name,
rel_testdata = rel_testdata,
match_arg = match_arg,
checks_arg = checks_arg,
)
ctx.actions.write(output = ctx.outputs.output, content = content)
_generate_golden_test_script = rule(
implementation = _generate_golden_test_script_impl,
attrs = {
"module": attr.label(
mandatory = True,
providers = [PyInfo],
),
"output": attr.output(mandatory = True),
"match": attr.string(default = ""),
"checks": attr.string_list(default = []),
},
)