-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When all weights are frozen, no gradients are produced during the model evaluation and the loss has no mechanism for doing backprop.
To Reproduce
A clear set of commands and resources needed to reproduce the issue:
Run any training config with weights_frozen: true set for all pieces of the chain
# Base configuration
base:
world_size: 1
iterations: 10000 # 200k/256 -> ~1000/epoch -> ~ 10 epochs
seed: 0
unwrap: false
log_dir: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/logs/full_chain/default
log_step: 1
train:
weight_prefix: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/default/snapshot
save_step: 1000 # ~1 epoch
optimizer:
name: Adam
lr: 0.001
# IO configuration
io:
loader:
batch_size: 256
shuffle: false
num_workers: 8
collate_fn: all
sampler: random_sequence
dataset:
name: larcv
file_keys: /sdf/data/neutrino/2x2/sim/mpvmpr_v2/train_file_list.txt
schema:
data:
parser: sparse3d
sparse_event: sparse3d_pcluster
seg_label:
parser: sparse3d
sparse_event: sparse3d_pcluster_semantics
ppn_label:
parser: particle_points
sparse_event: sparse3d_pcluster
particle_event: particle_pcluster
include_point_tagging: false
clust_label:
parser: cluster3d
cluster_event: cluster3d_pcluster
particle_event: particle_pcluster
sparse_semantics_event: sparse3d_pcluster_semantics
add_particle_info: true
clean_data: true
type_include_secondary: false
type_include_mpr: false
primary_include_mpr: false
coord_label:
parser: particle_coords
particle_event: particle_pcluster
cluster_event: cluster3d_pcluster
# Model configuration
model:
name: full_chain
weight_path: null
network_input:
data: data
seg_label: seg_label
clust_label: clust_label
loss_input:
seg_label: seg_label
ppn_label: ppn_label
clust_label: clust_label
coord_label: coord_label
modules:
# General chain configuration
chain:
deghosting: null
charge_rescaling: null
segmentation: uresnet
point_proposal: ppn
fragmentation: graph_spice
shower_aggregation: grappa
shower_primary: grappa
track_aggregation: grappa
particle_aggregation: null
inter_aggregation: grappa
particle_identification: grappa
primary_identification: grappa
orientation_identification: grappa
calibration: null
# Semantic segmentation + point proposal
uresnet_ppn:
weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/graph_spice/snapshot-999.ckpt
freeze_weights: true
uresnet:
num_input: 1
num_classes: 5
filters: 32
depth: 5
reps: 2
allow_bias: false
activation:
name: lrelu
negative_slope: 0.33
norm_layer:
name: batch_norm
eps: 0.0001
momentum: 0.01
ppn:
classify_endpoints: false
uresnet_ppn_loss:
uresnet_loss:
balance_loss: false
ppn_loss:
mask_loss: CE
resolution: 5.0
# Dense clustering
graph_spice:
weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/graph_spice/snapshot-999.ckpt
freeze_weights: true
shapes: [shower, track, michel, delta]
use_raw_features: true
invert: true
make_clusters: true
embedder:
spatial_embedding_dim: 3
feature_embedding_dim: 16
occupancy_mode: softplus
covariance_mode: softplus
uresnet:
num_input: 4 # 1 feature + 3 normalized coords
filters: 32
input_kernel: 5
depth: 5
reps: 2
spatial_size: 320
allow_bias: false
activation:
name: lrelu
negative_slope: 0.33
norm_layer:
name: batch_norm
eps: 0.0001
momentum: 0.01
kernel:
name: bilinear
num_features: 32
constructor:
edge_threshold: 0.1
min_size: 3
label_edges: true
graph:
name: radius
r: 1.9
orphan:
mode: radius
radius: 1.9
iterate: true
assign_all: true
graph_spice_loss:
name: edge
loss: binary_log_dice_ce
# Shower fragment aggregation + shower primary identification
grappa_shower:
weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/grappa_shower/snapshot-2999.ckpt
freeze_weights: true
nodes:
source: cluster
shapes: [shower, michel, delta]
min_size: -1
make_groups: true
grouping_method: score
graph:
name: complete
max_length: [500, 0, 500, 500, 0, 0, 0, 25, 0, 25]
dist_algorithm: recursive
node_encoder:
name: geo
use_numpy: true
add_value: true
add_shape: true
add_points: true
add_local_dirs: true
dir_max_dist: 5
add_local_dedxs: true
dedx_max_dist: 5
edge_encoder:
name: geo
use_numpy: true
gnn_model:
name: meta
node_feats: 33 # 16 (geo) + 3 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
edge_feats: 19
node_pred: 2
edge_pred: 2
edge_layer:
name: mlp
mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
node_layer:
name: mlp
reduction: max
attention: false
message_mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
aggr_mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
grappa_shower_loss:
node_loss:
name: shower_primary
high_purity: true
use_group_pred: true
edge_loss:
name: channel
target: group
high_purity: true
# Track aggregation
grappa_track:
weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/full_chain/grappa_track/snapshot-2999.ckpt
freeze_weights: true
nodes:
source: cluster
shapes: [track]
min_size: -1
make_groups: true
grouping_method: score
graph:
name: complete
max_length: 100
dist_algorithm: recursive
node_encoder:
name: geo
use_numpy: true
add_value: true
add_shape: false
add_points: true
add_local_dirs: true
dir_max_dist: 5
add_local_dedxs: true
dedx_max_dist: 5
edge_encoder:
name: geo
use_numpy: true
gnn_model:
name: meta
node_feats: 32 # 16 (geo) + 2 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
edge_feats: 19
node_pred: 2
edge_pred: 2
edge_layer:
name: mlp
mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
node_layer:
name: mlp
reduction: max
attention: false
message_mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
aggr_mlp:
depth: 3
width: 64
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
grappa_track_loss:
edge_loss:
name: channel
target: group
# Interaction aggregation, PID, primary, orientation
grappa_inter:
weight_path: /sdf/data/neutrino/2x2/spine/train/mpvmpr_v2/weights/grappa_inter/balanced/snapshot-19999.ckpt
freeze_weights: true
model_name: ''
nodes:
source: group
shapes: [shower, track, michel, delta]
min_size: -1
make_groups: true
graph:
name: complete
max_length: [500, 500, 0, 0, 25, 25, 25, 0, 0, 0]
dist_algorithm: recursive
node_encoder:
name: geo
use_numpy: true
add_value: true
add_shape: true
add_points: true
add_local_dirs: true
dir_max_dist: 5
add_local_dedxs: true
dedx_max_dist: 5
edge_encoder:
name: geo
use_numpy: true
gnn_model:
name: meta
node_feats: 33 # 16 (geo) + 3 (extra) + 6 (points) + 6 (directions) + 2 (local dedxs)
edge_feats: 19
node_pred:
type: 6
primary: 2
orient: 2
#momentum: 1
#vertex: 5
edge_pred: 2
edge_layer:
name: mlp
mlp:
depth: 3
width: 128
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
node_layer:
name: mlp
reduction: max
attention: false
message_mlp:
depth: 3
width: 128
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
aggr_mlp:
depth: 3
width: 128
activation:
name: lrelu
negative_slope: 0.1
normalization: batch_norm
grappa_inter_loss:
node_loss:
type:
name: class
target: pid
loss: ce
balance_loss: true
primary:
name: class
target: inter_primary
loss: ce
balance_loss: true
orient:
name: orient
loss: ce
#momentum:
# name: reg
# target: 16
# loss: berhu
#vertex:
# name: vertex
# primary_loss: ce
# balance_primary_loss: true
# regression_loss: mse
# only_contained: true
# normalize_positions: true
# use_anchor_points: true
# return_vertex_labels: true
# detector: icarus
edge_loss:
name: channel
target: interaction
Expected behavior
An error such as the following will appear:
Code base
Provide the following:
- SPINE release version: https://github.com/DeepLearnPhysics/spine/releases/tag/v0.1.3
- Singularity/apptainer/docker image: deeplearnphysics/larcv2:ub20.04-cuda11.6-pytorch1.13-larndsim
Possible solution (optional)
Obviously we can remove freeze_weights from our config, but it would be ideal for SPINE to check that at least some weights are being trained during a training script, so as not to produce the error above (which may be difficult to parse).
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
