forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun.py
More file actions
135 lines (112 loc) · 4 KB
/
run.py
File metadata and controls
135 lines (112 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import sys
from dotenv import load_dotenv
# Load the .env file if it exists
load_dotenv()
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = os.getenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
seed = None
if "SEED" in os.environ:
try:
seed = int(os.environ["SEED"])
except ValueError:
print(f"Invalid SEED value: {os.environ['SEED']}. SEED must be an integer.")
sys.path.insert(0, os.getcwd())
# must come before ANY torch or fastai imports
# import toolkit.cuda_malloc
# turn off diffusers telemetry until I can figure out how to make it opt-in
os.environ['DISABLE_TELEMETRY'] = 'YES'
# set torch to trace mode
import torch
# check if we have DEBUG_TOOLKIT in env
if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
torch.autograd.set_detect_anomaly(True)
if seed is not None:
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
import argparse
from toolkit.job import get_job
from toolkit.accelerator import get_accelerator
from toolkit.print import print_acc, setup_log_to_file
accelerator = get_accelerator()
def print_end_message(jobs_completed, jobs_failed):
if not accelerator.is_main_process:
return
failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
print_acc("")
print_acc("========================================")
print_acc("Result:")
if len(completed_string) > 0:
print_acc(f" - {completed_string}")
if len(failure_string) > 0:
print_acc(f" - {failure_string}")
print_acc("========================================")
def main():
parser = argparse.ArgumentParser()
# require at lease one config file
parser.add_argument(
'config_file_list',
nargs='+',
type=str,
help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
)
# flag to continue if failed job
parser.add_argument(
'-r', '--recover',
action='store_true',
help='Continue running additional jobs even if a job fails'
)
# flag to continue if failed job
parser.add_argument(
'-n', '--name',
type=str,
default=None,
help='Name to replace [name] tag in config file, useful for shared config file'
)
parser.add_argument(
'-l', '--log',
type=str,
default=None,
help='Log file to write output to'
)
args = parser.parse_args()
if args.log is not None:
setup_log_to_file(args.log)
config_file_list = args.config_file_list
if len(config_file_list) == 0:
raise Exception("You must provide at least one config file")
jobs_completed = 0
jobs_failed = 0
if accelerator.is_main_process:
print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
for config_file in config_file_list:
try:
job = get_job(config_file, args.name)
job.run()
job.cleanup()
jobs_completed += 1
except Exception as e:
print_acc(f"Error running job: {e}")
jobs_failed += 1
try:
job.process[0].on_error(e)
except Exception as e2:
print_acc(f"Error running on_error: {e2}")
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
except KeyboardInterrupt as e:
try:
job.process[0].on_error(e)
except Exception as e2:
print_acc(f"Error running on_error: {e2}")
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
if __name__ == '__main__':
main()