-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_prime.sh
More file actions
97 lines (81 loc) · 2.4 KB
/
train_prime.sh
File metadata and controls
97 lines (81 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/bin/bash
# ==========================================
# PRIME Training Script
# ==========================================
# -----------------------------
# Config
# -----------------------------
DATA_CONFIG="./config/data_config.yaml"
MODEL_CONFIG="./config/model_config.yaml"
TASK="GeneOntology" # FoldClassification | ECReaction | GeneOntology | BindingSite
BATCH_SIZE=32
EPOCHS=100
LR=1e-3
# -----------------------------
# GPU Selection
# -----------------------------
DEVICE_ID=1
# -----------------------------
# Optional (for GeneOntology)
# -----------------------------
GO_BRANCH="CC" # MF | BP | CC
# -----------------------------
# Hierarchy Ablation
# -----------------------------
ACTIVE_LEVELS=("surface" "atom" "residue" "sse" "protein")
READOUT_LEVEL="residue"
# -----------------------------
# Cross-Attention Option
# -----------------------------
CROSS_ATTENTION="true" # true or false
# -----------------------------
# Resume Option
# Set to checkpoint path to resume, or empty to train from scratch
# -----------------------------
RESUME=""
echo "===================================="
echo "Training PRIME"
echo "Task: $TASK"
echo "Batch Size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "LR: $LR"
echo "GPU: $DEVICE_ID"
echo "Active Levels: ${ACTIVE_LEVELS[@]}"
echo "Readout Level: $READOUT_LEVEL"
echo "Cross Attention: $CROSS_ATTENTION"
echo "Resume: ${RESUME:-none}"
echo "===================================="
# -----------------------------
# Set CUDA visibility
# -----------------------------
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
# -----------------------------
# Build base command
# -----------------------------
CMD="python train_prime.py \
--data_config $DATA_CONFIG \
--model_config $MODEL_CONFIG \
--task $TASK \
--batch_size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--active_levels ${ACTIVE_LEVELS[@]} \
--readout_level $READOUT_LEVEL"
# add cross_attention flag if enabled
if [ "$CROSS_ATTENTION" == "true" ]; then
CMD="$CMD --cross_attention"
fi
# add go_branch if GeneOntology
if [ "$TASK" == "GeneOntology" ]; then
echo "GO Branch: $GO_BRANCH"
CMD="$CMD --go_branch $GO_BRANCH"
fi
# add resume if set
if [ -n "$RESUME" ]; then
echo "Resuming from: $RESUME"
CMD="$CMD --resume \"$RESUME\""
fi
# -----------------------------
# Run
# -----------------------------
eval $CMD