-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtrain.sh
More file actions
73 lines (66 loc) · 2.39 KB
/
train.sh
File metadata and controls
73 lines (66 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
seed=1
lr='2e-4'
batch_size=8
epoch=5
input_steps=1
predict_steps=1
max_t=6
input_var_list='so thetao tos uo vo zos'
save_eval_steps=800
dist_port=$[12345+$[$RANDOM%12345]]
output_dir=./output/train/exp1 # configure your output directory
data_dir=YOUR_CMIP_DATA_DIR # replace with your CMIP data directory, e.g., ./download/train_data/
soda_dir=YOUR_SODA_DATA_DIR # replace with your SODA data directory, e.g., ./download/valid_test_data/SODA2
oras5_dir=YOUR_ORAS5_DATA_DIR # replace with your ORAS5 data directory e.g., ./download/valid_test_data/ORAS5
### If you use SLURM to launch the training script, you can use the following command:
# node_num=1
# gpu_per_node=4
# srun -p YOUR_PARTITION_NAME --ntasks-per-node=$gpu_per_node -N $node_num --gres=gpu:$gpu_per_node --async \
# python -u train.py
### Otherwise, you can use torchrun to launch the training script
torchrun --nproc_per_node=4 \
train.py \
--in_chans 16 16 1 16 16 1 \
--out_chans 16 16 1 16 16 1 \
--max_t $max_t \
--atmo_var_list tauu tauv \
--atmo_dims 2 \
--ignore_mismatched_sizes True \
--do_train \
--dist_port $dist_port \
--data_dir $data_dir \
--input_var_list $input_var_list \
--input_steps $input_steps \
--predict_steps $predict_steps \
--output_dir $output_dir \
--seed $seed \
--report_to tensorboard \
--log_level info \
--logging_dir $output_dir/log \
--logging_steps 5 \
--log_on_each_node False \
--save_strategy steps \
--save_steps $save_eval_steps \
--save_total_limit 3 \
--ddp_find_unused_parameters False \
--num_train_epochs $epoch \
--per_device_train_batch_size $batch_size \
--per_device_eval_batch_size $batch_size \
--gradient_accumulation_steps 1 \
--dataloader_num_workers 8 \
--gradient_checkpointing False \
--fsdp "full_shard auto_wrap" \
--learning_rate $lr \
--weight_decay 0.1 \
--max_grad_norm 0.0 \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--adam_epsilon 1e-6 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--do_eval \
--valid_data_dir $soda_dir $oras5_dir \
--end_year 1980 \
--evaluation_strategy steps \
--eval_steps $save_eval_steps \
--load_best_model_at_end True