-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli_parser.py
More file actions
146 lines (134 loc) · 4.63 KB
/
cli_parser.py
File metadata and controls
146 lines (134 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import re
from argparse import ArgumentParser, ArgumentTypeError
from pathlib import Path
from typing import List
import torchvision.models
class ArgTypes:
"""Class to organize different types of arguments"""
@staticmethod
def int_list(in_string: str) -> List[int]:
"""Verify in_string is comma-separated integers"""
try:
return [int(x) for x in in_string.split(",")]
except ValueError as exc:
raise ArgumentTypeError(
f"{in_string} must be a comma-separated list of integers"
) from exc
@staticmethod
def torchvision_model_name(in_string: str) -> str:
"""Verify that in_string is the name of a supported model in torchvision.models"""
if not re.search("^dense|alex|res|vgg", in_string):
raise ArgumentTypeError(
f"This tool only supports DenseNet, ResNet, VGG, or AlexNet models"
)
if hasattr(torchvision.models, in_string):
return in_string
raise ArgumentTypeError(f"{in_string} is not a torchvision model")
@staticmethod
def valid_dir(in_string: str) -> Path:
"""Verify that in_string is a valid directory path and convert it to an absolute Path"""
p: Path = Path(in_string).absolute()
if not (p.exists() and p.is_dir()):
raise ArgumentTypeError(f"{in_string} is not a valid directory path.")
return p
@staticmethod
def valid_file(in_string: str) -> Path:
"""Verify that in_string is a valid file and convert it to an absolute Path"""
p: Path = Path(in_string).absolute()
if not (p.exists() and p.is_file()):
raise ArgumentTypeError(f"{in_string} is not a valid file path.")
return p
@staticmethod
def abs_path(in_string: str) -> Path:
"""Turn a string into an absolute path"""
return Path(in_string).absolute()
def create_training_parser():
parser = ArgumentParser(
prog="train",
description="Train a neural net classifier for a pretrained model from PyTorch's torchvision module",
)
parser.add_argument(
"data_dir",
type=ArgTypes.valid_dir,
help="Directory of the structure specified by PyTorch's ImageFolder class(https://pytorch.org"
"/docs/0.4.0/torchvision/datasets.html?highlight=imagefolder#imagefolder",
)
parser.add_argument(
"-s",
"--save-dir",
type=ArgTypes.abs_path,
default="checkpoints",
help="Directory to save checkpoints of trained neural nets",
)
parser.add_argument(
"-a",
"--arch",
type=ArgTypes.torchvision_model_name,
default="vgg16",
help="Name of pretrained neural net to use. Must be one of the models in PyTorch's torchvision.models("
"https://pytorch.org/docs/0.4.0/torchvision/models.html) and must be one of the AlexNet, DenseNet, "
"ResNet, or VGG models",
)
parser.add_argument(
"-lr",
"--learning-rate",
type=float,
default=0.003,
help="Learning rate for the model being trained",
)
parser.add_argument(
"-hu",
"--hidden-units",
type=ArgTypes.int_list,
default="500,400",
help="Comma-separated list of ints representing the size of each hidden layer. A single layer "
"(i.e. 400) is allowed.",
)
parser.add_argument(
"-e",
"--epochs",
type=int,
default=10,
help="Number of epochs to train the model for",
)
parser.add_argument(
"-g",
"--gpu",
action="store_true",
help="Whether to use a GPU. Will only work if PyTorch has access to a GPU",
)
return parser
def create_prediction_parser():
parser = ArgumentParser(
prog="predict",
description="Predict an image's classification using a saved PyTorch model",
)
parser.add_argument(
"input",
type=ArgTypes.valid_file,
help="Path to an image file"
)
parser.add_argument(
"checkpoint",
type=ArgTypes.valid_file,
help="Path to a model checkpoint"
)
parser.add_argument(
"-top_k",
type=int,
default=5,
help="Return the top K best predictions"
)
parser.add_argument(
"-cn",
"--category_names",
type=ArgTypes.valid_file,
help="JSON file that maps category numbers of the image directories to human-readable names"
)
parser.add_argument(
"-g",
"--gpu",
action="store_true",
help="Whether to use a GPU. Will only work if PyTorch has access to a GPU",
)
return parser