This repository has been archived by the owner on May 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
run_experiment.py
66 lines (48 loc) · 1.37 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Use this to easily combine complex configurations.
Add your args file to `args/` then specify the files you want to combine.
"""
import sys
import os
from unittest.mock import patch
from transformer_vae.train import main
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args_with_content = {}
def format_args(txt):
lines = txt.strip().split("\n")
return [f"--{l}" for l in lines]
def shorten(arg):
name = ""
words = arg.split("_")
for wd in words:
name += wd[0].upper() + wd[1:5]
return name
for file in os.listdir("args"):
if file.endswith(".args"):
path = os.path.join("args", file)
name = file[:-5]
assert name not in args_with_content
args_with_content[name] = format_args(open(path, "r").read())
args = sys.argv
made_args_lines = []
short_names = []
for i, arg in enumerate(["base"] + args[1:]):
if arg[:2] == "--":
made_args_lines += args[i:]
break
made_args_lines += args_with_content[arg]
if arg != "base":
short_names.append(shorten(arg))
short_names = sorted(short_names)
args_str = "train.py\n" + "\n".join(made_args_lines)
if "--run_name=" not in args_str:
args_str += f'\n--run_name={"_".join(short_names)}'
print(
f"""
Running With Arguments:
-----------------------
{args_str}
"""
)
with patch.object(sys, "argv", args_str.split()):
main()