项目采用完整六阶段训练端结构:
base_model -> pruning -> qat -> onnx -> amct -> atc
本文档按代码实际结构说明各模块职责。
补充约定:
- 模型家族为 2D ResNet,但输入样本支持 2D
(H, W)与 3D(C, H, W).npy - 类别名与类别数运行时从
Data/<class>/一级子目录动态推断 - 阶段入口统一通过包内
__main__.py暴露,用户侧命令推荐使用uv run python -m <package>;ATC 例外为pixi run python -m atc
src/
├── base_model/
│ ├── __main__.py
│ ├── args.py
│ ├── confusionMatrix.py
│ ├── dataset.py
│ ├── lr_scheduler.py
│ ├── plotting.py
│ ├── resnet_lightweight.py
│ ├── resnet_standard.py
│ ├── tester.py
│ ├── trainer.py
│ ├── utils.py
│ └── visualizer.py
├── pruning/
│ ├── __main__.py
│ ├── args.py
│ ├── checkpoint.py
│ ├── evaluator.py
│ ├── output.py
│ ├── pruner.py
│ ├── topology.py
│ ├── trainer.py
│ ├── utils.py
│ └── README.md
├── qat/
│ ├── __main__.py
│ ├── args.py
│ ├── checkpoint.py
│ ├── evaluator.py
│ ├── output.py
│ ├── quantization.py
│ ├── trainer.py
│ ├── utils.py
│ └── README.md
├── onnx_export/
│ ├── __main__.py
│ ├── args.py
│ ├── evaluator.py
│ ├── exporter.py
│ ├── output.py
│ ├── rewrite.py
│ ├── utils.py
│ └── validate.py
├── amct/
│ ├── __main__.py
│ ├── __init__.py
│ ├── args.py
│ ├── converter.py
│ ├── output.py
│ └── utils.py
├── atc/
│ ├── __main__.py
│ ├── __init__.py
│ ├── args.py
│ ├── converter.py
│ ├── output.py
│ └── utils.py
└── thesis_figures/
├── __main__.py
├── __init__.py
├── args.py
├── contracts.py
├── output.py
├── plots.py
└── scanner.py
负责:
- 解析基座训练参数
- 调用数据切分与 DataLoader
- 扫描数据集类别并推断输入 shape / 通道数
- 初始化模型
- 执行训练 / 测试 / UMAP
负责:
- 解析 pruning 参数
- 按
--model自动选择最佳基座 checkpoint - 恢复基座 checkpoint
- 执行 iterative pruning
- 每轮评估与可选微调
- 导出 pruning checkpoint 与 summary
负责:
- 解析 QAT 参数
- 读取 pruning checkpoint
- 恢复剪枝后的浮点模型
prepare_qat_fx- 执行保守 QAT 微调
- 导出 QAT prepare checkpoint 与 summary
负责:
- 解析 ONNX 导出参数
- 调度
pruning_fp16/qat_convert两条导出分支 - 构建测试 DataLoader
- 做 Torch / ORT 精度对照
- 输出
onnx_summary.json
负责:
- 解析 AMCT 参数
- 调用 AMCT 转换逻辑
- 保存
amct_summary.json
负责:
- 解析 ATC 参数
- 调用 ATC 编译逻辑
- 保存
atc_summary.json
- 定义基座训练 CLI 参数
- 管理训练、测试、UMAP、性能选项
.npy数据集加载- 自然排序扫描
- 从
Data/<class>/动态推断类别映射 - 从首个可读样本推断输入 shape 与通道数
- 稳定 train / val / test 划分
output/splitsmanifest 落盘与复用
提供跨阶段复用的公共函数:
load_model_map()create_optimized_dataloader()load_state_dict_safely()get_raw_model()build_architecture_signature()- 设备、显存、
torch.compile相关辅助函数
- 基座训练主循环
- AMP、优化器、学习率调度
- TensorBoard
- 最优模型判定与结构化 checkpoint 保存
- 加载基座权重
- 测试集评估
- 生成混淆矩阵
- 提取特征
- 执行 PCA + UMAP 可视化
- 轻量级 ResNet 定义
- 默认工厂:
resnet6_2dresnet10_2dresnet14_2d
- 输入通道数与分类头维度由调用方按数据集推断结果传入
*_from_cfg()恢复入口
- 标准 ResNet 定义
- 默认工厂:
resnet18_2dresnet34_2d
- 输入通道数与分类头维度由调用方按数据集推断结果传入
*_from_cfg()恢复入口
- 定义 pruning CLI 参数
--pruning_ratio在入口统一规范到 2 位小数
- 扫描
output/base_model/<model>/下实验目录并选择最佳best_model.pth - 严格恢复基座 checkpoint
- 对 base checkpoint 执行
architecture_signature强校验
- 封装
torch-pruning - 执行单轮结构化通道剪枝
- 计算 step ratio 与剪枝统计
- 从剪枝后的真实模型提取
channel_cfg - 生成
architecture_signature - 保证 pruning 产物与
*_from_cfg()恢复入口对齐
- 剪枝后微调
- 每轮仅保留内存中的最佳权重
- 最终轮保存 pruning checkpoint
- 验证 / 测试评估
- 参数量与 MACs 统计
- 最终测试阶段生成混淆矩阵
- pruning 输出目录命名
- 保存
pruning_summary.json
- 复用
base_model中稳定公共函数 - 提供 pruning 阶段的路径与元数据辅助函数
- 定义 QAT CLI 参数
- 读取 pruning / QAT checkpoint
- 用
*_from_cfg()恢复剪枝结构 - 对 pruning / QAT checkpoint 执行
architecture_signature强校验 - 严格加载权重
- 构建 canonical QAT 方案
- 执行
prepare_qat_fx - 管理
quantization_meta - 管理 observer / BN 冻结策略
- QAT 微调主循环
- 最优模型判定与 prepare checkpoint 保存
- 复用 pruning 阶段的验证 / 测试逻辑
- 明确在 QAT 阶段禁用 AMP
- QAT 输出目录命名
- 保存
qat_summary.json
- 复用公共工具
- 提供 QAT 阶段路径与相对路径辅助函数
- 定义 ONNX 导出相关参数
- 按分支构建导出产物
- 管理导出设备、输入 dtype、动态 batch
- 创建 ORT session
- 执行 Torch / ORT 评估
- 生成导出后混淆矩阵
- 对
qat_convertONNX 做 graph rewrite - 调整量化节点连接模式
- 收敛到 CANN/AMCT/ATC 兼容图形态
- 校验 ONNX opset、输入输出 dtype、量化节点模式
- 验证 QAT 图是否满足当前约束
- ONNX 输出目录命名
- 保存
onnx_summary.json - 在 ONNX 无法可靠嵌入
architecture_signature时,通过 summary 传递上游签名引用
- ONNX 阶段通用路径与 DataLoader 工具
- 定义 AMCT CLI 参数
- 校验输入
model_quant.onnx - 校验
onnx_summary.json - 在 ONNX 实体无法可靠承载签名时,消费 summary 中的
architecture_signature引用 - 调用
amct_onnx.convert_qat_model(...) - 校验 deploy / fakequant ONNX 输出
- 在
amct_summary.json中继续传递上游签名引用
- AMCT 输出目录命名
- 保存
amct_summary.json
- 仓库根路径解析
- JSON 读取
- 路径解析与文件存在性检查
- 定义 ATC CLI 参数
- 按
pruning_fp16/amct_deploy加载输入契约 - 在 pixi 环境下消费 summary 契约与上游
architecture_signature引用,不再在 ATC 阶段导入 ONNX 做 deploy 实体复核 - 构建
atc子进程命令 - 收集
.om与工具产物 - 生成
atc_summary.json
- ATC 输出目录命名
- 保存
atc_summary.json
- 构建
atc子进程环境 - 路径解析与文件检查
- 汇总工具链环境变量
- 定义论文插图后处理 CLI 参数
- 支持
--output_root、--figure_dir、--formats、--dry_run等只读扫描选项
- 扫描
output/下已有 summary - 跳过自身生成的
output/thesis_figures/ - 按模型名、实验名、阶段和分支排序记录
- 将 pruning / QAT / ONNX / AMCT / ATC summary 归一化为统一图表记录
- 将
from_ratio...实验名对齐为ratio... - 只读取 JSON 字段,不加载 checkpoint、ONNX 或 OM 实体
- 生成剪枝折中、复杂度压缩、阶段错误率流转、ONNX 差异和接口矩阵图
- 输出
png/svg图片以及 CSV 表格
- 创建固定输出目录
output/thesis_figures/ - 保存
figures_manifest.json与tables/*.csv
base_model:产出结构化基座 checkpointpruning:消费基座 checkpoint,产出 pruning checkpointqat:消费 pruning checkpoint,产出 QAT prepare checkpointonnx:消费 pruning / QAT checkpoint,产出 ONNX 与评估摘要amct:消费qat_convertONNX,产出 deploy / fakequant ONNXatc:消费pruning_fp16或amct_deployONNX,产出.omthesis_figures:只读消费output/summary,产出论文插图和表格