-
Notifications
You must be signed in to change notification settings - Fork 0
/
configs.py
110 lines (91 loc) · 3.07 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
import os.path
import ast
from typing import List
from pathlib import Path
class Config:
intents_list: List
tags_list: List
max_sentence_length: int
device:str
batch_size:int
num_workers:int
dropout_rate: float
learning_rate: float
weight_decay: float
start_epoch: int
epochs: int
ignore_index:int
slot_loss_coef:float
output_dir:str
version:str
bert_model_name:str
def __init__(self,
device:str ='cpu',
batch_size:int = 2,
num_workers:int = 1,
max_sentence_length=50,
dropout: int = 0.20,
learning_rate: float = 0.0001,
weight_decay: float = 0.0005,
start_epoch: int = 0,
epochs: int = 10,
ignore_index:int = -100,
slot_loss_coef:float = 1.0,
output_dir:str = './pretrained/',
version:str = 'v.0.1',
bert_model_name:str='bert-base-uncased',
):
self.device = device
self.batch_size = batch_size
self.num_workers = num_workers
self.max_sentence_length = max_sentence_length
self.dropout_rate = dropout
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.start_epoch = start_epoch
self.epochs = epochs
self.ignore_index = ignore_index
self.slot_loss_coef = slot_loss_coef
self.output_dir = output_dir
self.version = version
self.bert_model_name = bert_model_name
@classmethod
def from_pretrained(self, config_path: str):
config_path = Path(config_path)
if os.path.isdir(config_path):
config_path = config_path / 'config.json'
default_conf = self()
default_conf.tags_list = []
default_conf.intents_list = []
# Opening JSON file
f = open(config_path)
configs = json.load(f)
f.close()
for item in configs:
val = getattr(default_conf, item)
config_val = configs[item]
if isinstance(val, bool):
config_val = True if configs[item].capitalize() == 'True' else False
elif isinstance(val, int):
config_val = int(configs[item])
elif isinstance(val, float):
config_val= float(configs[item])
elif isinstance(val, List):
config_val = ast.literal_eval(configs[item])
setattr(default_conf, item, config_val)
return default_conf
def save_config(self, path:str=None):
config = {}
exclude = []
for attr, value in self.__dict__.items():
if attr not in exclude:
config[attr] = str(value)
if not path:
path = self.output_dir
if self.version != "":
path = path + "/" + self.version
path = Path(path)
target_file = path / 'config.json'
with open(target_file, 'w') as f:
json.dump(config, f, indent=2)