Skip to content
This repository was archived by the owner on Aug 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import multiprocessing as mp

import logging

import pprint
import yaml

Expand All @@ -29,9 +31,15 @@ def process_main(rank, fname, world_size, devices):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])

import logging
logging.basicConfig()
logger = logging.getLogger()

# Add a log handler (e.g., StreamHandler or FileHandler)
log_handler = logging.StreamHandler() # Output log messages to console
# log_handler = logging.FileHandler('logfile.txt') # Output log messages to a file

logger.addHandler(log_handler)

if rank == 0:
logger.setLevel(logging.INFO)
else:
Expand All @@ -42,7 +50,7 @@ def process_main(rank, fname, world_size, devices):
# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
params = yaml.safe_load(y_file)
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
Expand Down
124 changes: 67 additions & 57 deletions main_distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import logging
import os
Expand All @@ -14,9 +7,12 @@

import submitit

from src.train import main as app_main

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.basicConfig(
filename='application.log', # Specify the log file name
filemode='a', # Append to the log file
format='%(asctime)s %(levelname)s - %(message)s', # Specify the log format
level=logging.INFO
)
logger = logging.getLogger()


Expand All @@ -26,7 +22,7 @@
help='location to save submitit logs')
parser.add_argument(
'--batch-launch', action='store_true',
help='whether fname points to a file to batch-lauch several config files')
help='whether fname points to a file to batch-launch several config files')
parser.add_argument(
'--fname', type=str,
help='yaml file containing config file names to launch',
Expand All @@ -39,66 +35,80 @@
help='num. nodes to request for job')
parser.add_argument(
'--tasks-per-node', type=int, default=1,
help='num. procs to per node')
help='num. procs per node')
parser.add_argument(
'--time', type=int, default=4300,
help='time in minutes to run job')


class Trainer:

def __init__(self, fname='configs.yaml', load_model=None):
def __init__(self, fname='configs.yaml', resume_training=False):
self.fname = fname
self.load_model = load_model
self.resume_training = resume_training

def __call__(self):
fname = self.fname
load_model = self.load_model
logger.info(f'called-params {fname}')

# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)

resume_preempt = False if load_model is None else load_model
app_main(args=params, resume_preempt=resume_preempt)
try:
fname = self.fname
resume_training = self.resume_training
logger.info(f'called-params {fname}')

# Load script params
with open(fname, 'r') as y_file:
params = yaml.safe_load(y_file)
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)

resume_preempt = False if resume_training is None else resume_training
app_main(args=params, resume_preempt=resume_preempt)
except Exception as e:
logger.exception(f'An error occurred: {str(e)}')
sys.exit(1)

def checkpoint(self):
fb_trainer = Trainer(self.fname, True)
return submitit.helpers.DelayedSubmission(fb_trainer,)
try:
fb_trainer = Trainer(self.fname, True)
return submitit.helpers.DelayedSubmission(fb_trainer)
except Exception as e:
logger.exception(f'An error occurred: {str(e)}')
sys.exit(1)


def launch():
executor = submitit.AutoExecutor(
folder=os.path.join(args.folder, 'job_%j'),
slurm_max_num_timeout=20)
executor.update_parameters(
slurm_partition=args.partition,
slurm_mem_per_gpu='55G',
timeout_min=args.time,
nodes=args.nodes,
tasks_per_node=args.tasks_per_node,
cpus_per_task=10,
gpus_per_node=args.tasks_per_node)

config_fnames = [args.fname]

jobs, trainers = [], []
with executor.batch():
for cf in config_fnames:
fb_trainer = Trainer(cf)
job = executor.submit(fb_trainer,)
trainers.append(fb_trainer)
jobs.append(job)

for job in jobs:
print(job.job_id)
try:
executor = submitit.AutoExecutor(
folder=os.path.join(args.folder, 'job_%j'),
slurm_max_num_timeout=20)
executor.update_parameters(
slurm_partition=args.partition,
slurm_mem_per_gpu='55G',
timeout_min=args.time,
nodes=args.nodes,
tasks_per_node=args.tasks_per_node,
cpus_per_task=10,
gpus_per_node=args.tasks_per_node)

config_fnames = [args.fname]

jobs, trainers = [], []
with executor.batch():
for cf in config_fnames:
fb_trainer = Trainer(cf)
job = executor.submit(fb_trainer)
trainers.append(fb_trainer)
jobs.append(job)

for job in jobs:
print(job.job_id)
except Exception as e:
logger.exception(f'An error occurred: {str(e)}')
sys.exit(1)


if __name__ == '__main__':
args = parser.parse_args()
launch()
try:
args = parser.parse_args()
launch()
except Exception as e:
logger.exception(f'An error occurred: {str(e)}')
sys.exit(1)