项目包含 6 套流水线 CLI 与 1 套论文插图后处理 CLI:
- 基座训练:
uv run python -m base_model ... - 剪枝:
uv run python -m pruning ... - QAT:
uv run python -m qat ... - ONNX 导出:
uv run python -m onnx_export ... - AMCT 转换:
uv run python -m amct ... - ATC 编译:
pixi run python -m atc ... - 论文插图后处理:
uv run python -m thesis_figures ...
各入口参数并不完全相同,阅读时需要严格区分阶段。
补充说明:
- 所有训练相关 CLI 都以
Data/<class>/一级子目录动态推断类别名与类别数,无需显式传入类别数 resnet*_2d是模型家族命名;数据样本支持 2D(H, W)与 3D(C, H, W)
gitpixiuvdirenv(推荐)
pixi install- Python 3.12 运行时
- GCC / G++ / Make / CMake
cuda-runtime、cudnnascend-cann-toolkit、ascend-cann-310b-ops
uv synctorchonnxonnxruntime-gputorch-pruning- 以及
pyproject.toml中声明的其余 Python 依赖
- 若要使用 CUDA 加速训练,宿主机需要可用的 NVIDIA GPU 与驱动。
- 若要做真实的 Ascend 编译或部署验证,宿主机需要对应的 Ascend 设备/驱动环境。
推荐先在项目根目录执行:
pixi install
uv sync
direnv allow其中 .envrc 提供仓库级公共变量:
REPO_ROOTPYTHONPATH=$REPO_ROOT/src
说明:
direnv为推荐方案;若不使用direnv自动激活,也必须手动提供与.envrc等价的环境变量- 所有脚本统一通过
.envrc提供的REPO_ROOT识别仓库根
各阶段如需额外环境变量,再按需 source scripts/load_*_env.sh。
AMCT CLI 额外依赖仓库自带组件:
amct_onnx/amct_onnx-0.23.2-py3-none-linux_x86_64.whlamct_onnx/amct_onnx_op.tar.gz
这两项不在 uv sync / pixi install 自动安装范围内,运行 AMCT 前需按目标环境自行安装或部署。
uv run python -m base_model --help| 参数 | 默认值 | 说明 |
|---|---|---|
--epochs |
60 |
训练轮数 |
--lr |
0.0003 |
学习率 |
--batch_size |
64 |
批次大小 |
--model_path |
best_model.pth |
模型保存文件名 |
--model |
resnet6_2d |
模型名 |
--data_dir |
Data |
数据集路径 |
--data_dtype |
fp16 |
数据集输出 tensor 精度 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--full_load |
False |
是否全量加载数据集 |
--num_workers |
None |
DataLoader 工作线程数 |
--prefetch_factor |
2 |
DataLoader 预取因子 |
--persistent_workers |
True |
是否保持工作线程 |
--pin_memory |
True |
是否启用 pin_memory |
--cudnn_benchmark |
True |
是否启用 cuDNN benchmark |
--cudnn_deterministic |
False |
是否启用确定性算法 |
--compile_model |
True |
是否启用 torch.compile |
--compile_mode |
default |
编译模式 |
| 参数 | 默认值 | 说明 |
|---|---|---|
--Train |
True |
是否执行训练 |
--Test |
True |
是否执行测试 |
--UMAP |
False |
是否执行 UMAP |
--dropout_p |
0.3 |
Dropout 概率 |
--weight_decay |
1e-4 |
权重衰减 |
--warmup_ratio |
0.05 |
Warmup 占总步数比例 |
--warmup_steps |
0 |
Warmup 步数 |
--min_lr |
1e-6 |
最小学习率 |
--plot_lr_schedule |
True |
是否绘制学习率曲线 |
uv run python -m base_model --epochs 100 --batch_size 64 --model resnet18_2d说明:
--model只决定使用哪一种 2D ResNet 拓扑,不限制原始样本必须是 2D 数组- 输入通道数与分类头维度由入口根据
Data/的样本 shape 与类别目录自动推断
uv run python -m pruning --help完整实现见 src/pruning/args.py。
| 参数 | 默认值 | 说明 |
|---|---|---|
--model |
必填 | 基座模型名,自动扫描 output/base_model/<model>/ 下实验目录并选择最佳 best_model.pth |
--model_path |
best_pruned_model.pth |
最终剪枝模型文件名 |
--data_dir |
Data |
数据集路径 |
--data_dtype |
fp16 |
数据集输出 tensor 精度 |
--full_load |
False |
是否全量加载数据集 |
--num_workers |
None |
DataLoader 工作线程数 |
--prefetch_factor |
2 |
DataLoader 预取因子 |
--persistent_workers |
True |
是否保持 DataLoader 工作线程 |
--pin_memory |
True |
是否启用 pin_memory |
--pruning_ratio |
0.30 |
最终总剪枝率,统一规范到 2 位小数 |
--pruning_steps |
5 |
iterative pruning 剪枝轮数 |
--global_pruning |
True |
是否启用全局剪枝 |
--ignore_fc |
True |
是否忽略分类头 |
--finetune_epochs |
10 |
每轮剪枝后微调轮数 |
--batch_size |
64 |
批次大小 |
--lr |
1e-4 |
微调学习率 |
--weight_decay |
1e-4 |
权重衰减 |
--warmup_ratio |
0.05 |
Warmup 占总步数比例 |
--warmup_steps |
0 |
Warmup 步数 |
--min_lr |
1e-7 |
最小学习率 |
--cudnn_benchmark |
True |
是否启用 cuDNN benchmark |
--cudnn_deterministic |
False |
是否启用 cuDNN 确定性算法 |
--evaluate_test |
True |
是否在最终阶段评估测试集 |
- pruning 通过扫描 base_model 实验目录自动选择最佳
best_model.pth - pruning 会重新扫描
--data_dir,用当前数据集类别数校验被选中的基座模型分类头 - 产物用于后续 QAT / ONNX 恢复
uv run python -m pruning --model resnet34_2d --pruning_ratio 0.80 --pruning_steps 8uv run python -m qat --help完整实现见 src/qat/args.py。
| 参数 | 默认值 | 说明 |
|---|---|---|
--pruning_checkpoint |
必填 | 输入 pruning checkpoint 路径 |
--model_path |
best_qat_prepare_model.pth |
QAT prepare 模型文件名 |
--data_dir |
Data |
数据集路径 |
--full_load |
False |
是否全量加载数据集 |
--num_workers |
None |
DataLoader 工作线程数 |
--prefetch_factor |
2 |
DataLoader 预取因子 |
--persistent_workers |
True |
是否保持 DataLoader 工作线程 |
--pin_memory |
True |
是否启用 pin_memory |
--qat_epochs |
10 |
QAT 微调轮数 |
--batch_size |
64 |
批次大小 |
--lr |
1e-5 |
QAT 微调学习率 |
--weight_decay |
1e-4 |
权重衰减 |
--warmup_ratio |
0.05 |
Warmup 占总步数比例 |
--warmup_steps |
0 |
Warmup 步数 |
--min_lr |
1e-7 |
最小学习率 |
--cudnn_benchmark |
True |
是否启用 cuDNN benchmark |
--cudnn_deterministic |
False |
是否启用 cuDNN 确定性算法 |
--evaluate_test |
True |
是否在最终阶段评估测试集 |
- QAT 固定纯
fp32 - QAT 仍通过
--data_dir重新获取类别名与类别数,用于恢复与校验 pruning checkpoint - QAT 阶段输出 prepare checkpoint,不在本阶段做
torch.convert - QAT checkpoint 使用最小
quantization_meta契约,供后续 ONNX 恢复
uv run python -m qat \
--pruning_checkpoint output/pruning/resnet18_2d/ratio0.60_steps8_global_ft10_bs64/best_pruned_model.pthuv run python -m onnx_export --help| 参数 | 默认值 | 说明 |
|---|---|---|
--branch |
必填 | pruning_fp16 或 qat_convert |
--checkpoint |
必填 | pruning 分支传 pruning checkpoint;QAT 分支传 QAT checkpoint |
--data_dir |
Data |
数据集路径 |
--full_load |
False |
是否全量加载数据集 |
--num_workers |
None |
数据加载工作线程数 |
--evaluate_test |
True |
是否在导出后执行测试集精度评估 |
--eval_batch_size |
64 |
Torch / ORT 评估 batch size,不影响导出图结构 |
--opset_version |
16 |
固定 ONNX opset 16 |
pruning_fp16- pruning checkpoint -> FP16 ONNX
qat_convert- QAT checkpoint ->
convert_fx-> quantized ONNX
- QAT checkpoint ->
- 导出使用动态 batch
--eval_batch_size仅影响评估qat_convert导出后还会执行 graph rewrite 与 validate
uv run python -m onnx_export \
--branch pruning_fp16 \
--checkpoint output/pruning/resnet10_2d/ratio0.40_steps5_global_ft10_bs64/best_pruned_model.pth \
--eval_batch_size 64
uv run python -m onnx_export \
--branch qat_convert \
--checkpoint output/qat/resnet10_2d/from_ratio0.40_steps5_global_ft10_bs64/best_qat_prepare_model.pth \
--eval_batch_size 64uv run python -m amct --help| 参数 | 默认值 | 说明 |
|---|---|---|
--onnx_model |
必填 | 输入的 qat_convert ONNX 路径,固定为仓库导出的 model_quant.onnx |
运行前请先按目标环境自行安装或部署仓库附带的:
amct_onnx/amct_onnx-0.23.2-py3-none-linux_x86_64.whlamct_onnx/amct_onnx_op.tar.gz
- 只接受仓库
uv run python -m onnx_export --branch qat_convert导出的model_quant.onnx - 同目录必须存在
onnx_summary.json onnx_summary.json.branch必须为qat_convert
deploy_model.onnxfake_quant_model.onnxscale_offset_record.txtamct_summary.json
uv run python -m amct \
--onnx_model output/onnx/qat_convert/resnet6_2d/from_ratio0.60_steps8_global_ft10_bs64/model_quant.onnxpixi run python -m atc --help| 参数 | 默认值 | 说明 |
|---|---|---|
--branch |
必填 | pruning_fp16 或 amct_deploy |
--onnx_model |
必填 | pruning 分支传 model_fp16.onnx,AMCT 分支传 deploy_model.onnx |
--soc_version |
Ascend310B4 |
目标芯片版本 |
--input_shape |
None |
可选显式输入形状;默认从上游摘要中的输入接口派生并将 batch 固定为 1,显式传入时必须与自动派生结果完全一致 |
--input_format |
NCHW |
输入格式 |
pruning_fp16:- 输入
model_fp16.onnx - 同目录必须存在
onnx_summary.json - 读取并校验
onnx_summary.json.source_architecture_signature,作为上游签名引用
- 输入
amct_deploy:- 输入
deploy_model.onnx - 同目录必须存在
amct_summary.json amct_summary.json.deploy_interface必须与source_interface一致- 同时回读
amct_summary.json.source_onnx_summary_path指向的onnx_summary.json - 要求
onnx_path、source_architecture_signature、interface与amct_summary.json桥接一致
- 输入
说明:
- 所有直接消费
.pth checkpoint的链路都执行architecture_signature强校验 - ONNX / deploy ONNX 阶段通过同目录 summary 读取并校验上游签名与来源信息
pruning_fp16 -> atc的校验内容包括上游签名引用与摘要契约校验amct_deploy -> atc的校验内容包括摘要桥接闭环与 summary 内部一致性校验,ATC 在 pixi 环境下不依赖 ONNX 实体复核
.omatc_summary.jsoncheck_result.json/fusion_result.json(若 ATC 生成)
pixi run python -m atc \
--branch pruning_fp16 \
--onnx_model output/onnx/pruning_fp16/resnet10_2d/from_ratio0.40_steps5_global_ft10_bs64/model_fp16.onnx
pixi run python -m atc \
--branch amct_deploy \
--onnx_model output/amct/resnet6_2d/from_ratio0.60_steps8_global_ft10_bs64/deploy_model.onnxuv run python -m thesis_figures --help| 参数 | 默认值 | 说明 |
|---|---|---|
--output_root |
output |
只读扫描的训练端产物根目录 |
--figure_dir |
output/thesis_figures |
论文插图输出根目录 |
--formats |
png,svg |
逗号分隔的图片格式,当前支持 png、svg |
--model |
all |
筛选单个模型,或扫描全部模型 |
--experiment |
all |
按实验名子串筛选;from_ratio... 会按 ratio... 对齐 |
--dry_run |
False |
只扫描并打印记录数,不创建输出目录 |
--strict |
False |
遇到坏 JSON、缺关键字段或无记录时直接失败 |
- 只消费
output/下已有 summary:pruning_summary.jsonqat_summary.jsononnx_summary.jsonamct_summary.jsonatc_summary.json
- 不读取
Data/,不加载.pth、.onnx、.om。 - 不生成或模拟 ResNet_Acl 的真实推理延迟、吞吐、能耗等指标。
fig1_pruning_accuracy_complexity.<png|svg>fig2_compression_by_model.<png|svg>fig3_stage_accuracy_flow.<png|svg>fig4_onnx_metric_delta.<png|svg>fig5_atc_amct_interface_matrix.<png|svg>figures_manifest.jsontables/*.csv
图表使用错误率口径,错误率按 1 - acc 派生;错误率、参数量与 MACs 相关图使用对数坐标。文件名保持兼容,不再额外改名。
uv run python -m thesis_figures --output_root output --dry_run
uv run python -m thesis_figures --output_root output --formats png,svg
uv run python -m thesis_figures --model resnet6_2d --experiment ratio0.60