-
Notifications
You must be signed in to change notification settings - Fork 1
/
arg_parser.py
129 lines (84 loc) · 5.61 KB
/
arg_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--hdf5_path", type=str, default="",
help="path to a dataset .hdf5 file")
parser.add_argument("--object_path", type=str, default="",
help="Path to an object .off file")
parser.add_argument("--resume_path", type=str, default="",
help="Path to a checkpoint .pth file for resuming the training from it).")
parser.add_argument("--checkpoint_path", type=str, default="",
help="Path to a checkpoint .pth file for loading a model for evaluation or inference.")
parser.add_argument("--id_to_class_path", type=str, default="",
help="Path to a id_to_class .pkl file")
parser.add_argument("--checkpoint_save_dir", type=str, default="",
help="Dir for saving model checkpoints during training.")
parser.add_argument("--log_dir", type=str, default="",
help="Dir for logging.")
parser.add_argument("--plots_dir", type=str, default="",
help="dir for saving plots.")
parser.add_argument("--start_epoch", type=int, default=1,
help="Start epoch number (useful for resuming training) [default: 1]")
parser.add_argument("--num_epochs", type=int,
help="Number of epochs to train [default: None]")
parser.add_argument("--batch_size_train", type=int, default=32,
help="Batch size for training [default: 32].")
parser.add_argument("--batch_size_valid", type=int, default=32,
help="Batch size for validation [default: 32].")
parser.add_argument("--batch_size_test", type=int, default=32,
help="batch size for testing [default: 32].")
parser.add_argument("--balance", type=int, default=0, choices=[0, 1],
help="Flag for weighted sampling for balancing the training dataset [default: 0]")
parser.add_argument("--data_augment", type=int, default=0, choices=[0, 1],
help="Flag data augmentation [default: 0]")
parser.add_argument("--lr_init", type=float, default=0.001,
help="Initial learning rate [default: 0.001]")
parser.add_argument("--lr_step_size", type=int, default=10,
help="Period of learning rate decay [default: 10]")
parser.add_argument("--lr_gamma", type=float, default=0.1,
help="Multiplicative factor of learning rate decay [default: 0.1]")
parser.add_argument("--reg_weights", type=float, nargs="*", default=[0.0, 0.001],
help="Weights for the TNets regularization terms [default: [0.0, 0.001]]")
parser.add_argument("--save_checkpoint_every", type=int, default=1,
help="Save checkpoint every given epoch [default: 1]")
parser.add_argument("--print_every_batch", type=int, default=0,
help="Print training stats every given batch [default: 0]")
parser.add_argument("--plot_confusion_mat", type=int, default=0, choices=[0, 1],
help="flag for plotting confusion matrix [default: 0]")
parser.add_argument("--plot_losses", type=int, default=0, choices=[0, 1],
help="Flag for plotting losses w.r.t epochs [default: 0]")
parser.add_argument("--device", type=str, default="", choices=["cpu", "cuda"],
help="Device to use (cpu or cuda).")
parser.add_argument("--num_classes", type=int, default=10,
help="Number of object classes [default: 10]")
parser.add_argument("--top_k", type=int, default=5,
help="Number of top k classes to predict [default: 5]")
parser.add_argument("--seed", type=int, default=1235976,
help="random seed [default: 42]")
args = parser.parse_args()
# Check some arguments
if args.device == "":
raise ValueError("The device argument must be provided ('cpu' or 'cuda').")
if args.hdf5_path and not Path(args.hdf5_path).is_file():
raise FileNotFoundError(f"The hdf5_path = {args.hdf5_path} not found.")
if args.object_path and not Path(args.object_path).is_file():
raise FileNotFoundError(f"The object_path = {args.object_path} not found.")
if args.checkpoint_path and not Path(args.checkpoint_path).is_file():
raise FileNotFoundError(f"The checkpoint_path = {args.checkpoint_path} not found.")
if args.id_to_class_path and not Path(args.id_to_class_path).is_file():
raise FileNotFoundError(f"The id_to_class_path = {args.id_to_class_path} not found.")
if args.resume_path and not Path(args.resume_path).is_file():
raise FileNotFoundError(f"The resume_path = {args.resume_path} file not found.")
if args.checkpoint_save_dir and not Path(args.checkpoint_save_dir).is_dir():
Path(args.checkpoint_save_dir).mkdir()
if args.log_dir and not Path(args.log_dir).is_dir():
Path(args.log_dir).mkdir()
if args.plots_dir and not Path(args.plots_dir).is_dir():
Path(args.plots_dir).mkdir()
# Convert int flags to boolean
args.plot_confusion_mat = bool(args.plot_confusion_mat)
args.data_augment = bool(args.data_augment)
args.plot_losses = bool(args.plot_losses)
args.balance = bool(args.balance)
return args