Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
**__pycache__**
**build**
**egg-info**
**dist**
data/
*.pyc
venv/
*.idea/
data
*.log

.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so
*.yaml
*.sh
*.pth
*.pkl
*.zip
*.bin
.vscode/

# Virtual Environments
venv/
env/
.env

# Distribution / packaging
build/
dist/
work_dirs/

*.egg-info/
*.log.json
nohup.out
34 changes: 34 additions & 0 deletions Changes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Changes: Knowledge Distillation (KD)

Knowledge distillation was added on top of the existing CenterPoint training stack. Two variants exist; only one is used per config:

- **Response-based KD** (`kd.type = "heatmap_mse"`): MSE between student and teacher **heatmap** outputs, weighted by `lambda_kd`.
- **Feature-based KD** (`kd.type = "feature_mse"`): MSE on **`CenterHead.shared_conv`** features, weighted by `lambda_feat`.

The original detection loss (focal heatmap + L1 regression from `centernet_loss.py`) is unchanged. KD terms are added in `CenterHead.loss`. At test time only the student checkpoint is used.

---

## Code changes

| File | Change |
|------|--------|
| `det3d/torchie/apis/train.py` | If `cfg.kd.enabled`, build teacher from `teacher_config`, load `teacher_checkpoint`, freeze it, pass `teacher_model` and `kd_cfg` to the trainer. |
| `det3d/torchie/trainer/trainer.py` | Each training step: teacher forward under `no_grad` (`return_preds` and/or `return_feats`), then student forward with teacher outputs and `kd_cfg`. |
| `det3d/models/detectors/point_pillars.py` | Return `head_shared` from the head; support `return_preds` / `return_feats` for the teacher; pass KD arguments into `bbox_head.loss`. |
| `det3d/models/detectors/voxelnet.py` | Same KD-related forward/loss wiring as PointPillars. |
| `det3d/models/bbox_heads/center_head.py` | In `loss()`: compute `hm_kd_loss` (heatmap MSE) or `feat_kd_loss` (shared-feature MSE) and add to the per-task loss; log both metrics. |
| `det3d/torchie/trainer/hooks/logger/text.py` | Log `hm_kd_loss` and `feat_kd_loss` with 6 decimal places. |
| `det3d/torchie/apis/env.py` | Device selection limited to CUDA or CPU (MPS removed). |

**Unchanged:** `det3d/models/losses/centernet_loss.py` (baseline `FastFocalLoss`, `RegLoss`).

---

## New configs (`configs/nusc/pp/`)

Each file sets a slimmer student (reduced reader/neck/head channels) and a `kd` block (`enabled`, `type`, `lambda_kd` or `lambda_feat`, `teacher_config`, `teacher_checkpoint`).

**Response-based:** `response_based_kd.py`, `response_based_kd_05.py`, `response_based_kd_08.py`, `response_based_kd_smoke.py`, `response_based_kd_resnet.py`, `response_based_kd_resnet_smoke.py`

**Feature-based:** `feature_based_kd.py`, `feature_based_kd_05.py`, `feature_based_kd_08.py`, `feature_based_kd_smoke.py`
9 changes: 5 additions & 4 deletions configs/mvp/nusc_centerpoint_pp_fix_bn_z_scale.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -87,12 +88,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -165,8 +166,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
9 changes: 5 additions & 4 deletions configs/mvp/nusc_centerpoint_pp_fix_bn_z_scale_virtual.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -88,12 +89,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo_virtual.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo_virtual.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -166,8 +167,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -87,12 +88,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -156,8 +157,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -87,12 +88,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -156,8 +157,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -88,12 +89,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo_virtual.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo_virtual.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -157,8 +158,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
9 changes: 5 additions & 4 deletions configs/mvp/nusc_two_stage_base_with_virtual.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import itertools
import logging

Expand Down Expand Up @@ -138,12 +139,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = os.environ.get("NUSCENES_DATA_ROOT", "data/nuScenes")

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo_virtual.pkl",
db_info_path=f"{data_root}/dbinfos_train_10sweeps_withvelo_virtual.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -213,8 +214,8 @@
dict(type="Reformat"),
]

train_anno = "data/nuScenes/infos_train_10sweeps_withvelo_filter_painted_True.pkl"
val_anno = "data/nuScenes/infos_val_10sweeps_withvelo_filter_painted_True.pkl"
train_anno = f"{data_root}/infos_train_10sweeps_withvelo_filter_painted_True.pkl"
val_anno = f"{data_root}/infos_val_10sweeps_withvelo_filter_painted_True.pkl"
test_anno = None

data = dict(
Expand Down
Loading