项目采用按阶段拆分的结构,主线为:
base_model -> pruning -> qat -> onnx -> amct -> atc -> deploy
各阶段职责如下:
base_model- 训练 2D ResNet 模型族
- 输入样本可为 2D
(H, W)或 3D(C, H, W) - 输出结构化基座 checkpoint
pruning- 按
--model自动选择最佳基座 checkpoint - 执行 iterative structured pruning + 微调
- 输出 pruning checkpoint 与剪枝拓扑
- 按
qat- 读取 pruning checkpoint
- 按剪枝拓扑重建浮点模型
- 执行 FX graph mode QAT
- 输出 prepare 后 QAT checkpoint
onnx- 读取 pruning 或 QAT checkpoint
- 导出
pruning_fp16或qat_convertONNX - 用 ONNX Runtime 做一致性评估
amct- 读取
qat_convertONNX - 生成
deploy_model.onnx与fake_quant_model.onnx
- 读取
atc- 读取
pruning_fp16ONNX 或amct_deployONNX - 生成
Ascend310B4目标.om
- 读取
- 数据集采用
Data/<class>/*.npy目录组织 - 一级子目录名即类别名,类别数运行时动态推断
- 训练期输入 shape 由数据集首个可读样本推断
data_set_split()负责:- 自然排序扫描
- 推断原始样本 shape,并补齐
CHW / NCHW - 分层切分 train / val / test
- 将切分结果落盘到
output/splits/ - 后续优先复用 manifest
- 轻量模型定义位于
base_model/resnet_lightweight.py - 标准模型定义位于
base_model/resnet_standard.py resnet*_2d表示模型为 2D 卷积结构,不限制原始样本必须是 2D 数组- 仅支持 5 个模型:
resnet6_2dresnet10_2dresnet14_2dresnet18_2dresnet34_2d
- 两类模型都支持:
- 默认构造函数
*_from_cfg()恢复入口- 基于
channel_cfg的逐层重建
各阶段入口位于对应包内的 __main__.py:
src/base_model/__main__.pysrc/pruning/__main__.pysrc/qat/__main__.pysrc/onnx_export/__main__.pysrc/amct/__main__.pysrc/atc/__main__.py
这 6 个入口分别承担单阶段编排,不直接混写彼此逻辑;用户侧命令统一推荐使用 uv run python -m <package>,ATC 阶段使用 pixi run python -m atc。
环境采用“公共层 + 阶段增量层”:
- 公共层:
.envrcREPO_ROOTPYTHONPATH=$REPO_ROOT/srcautorun/autorun_*.sh与阶段环境脚本统一直接依赖这里提供的REPO_ROOT- 所有脚本统一通过
.envrc提供的REPO_ROOT识别仓库根
- 阶段增量层:
load_base_model_env.shload_onnx_env.shload_amct_env.shload_atc_env.sh
说明:
pixi install负责系统工具链、Python 运行时、CUDA runtime、cuDNN、CANN toolkit 等自动安装内容uv sync负责 Python 包依赖amct_onnx相关 wheel 与算子包不在uv sync管理范围内,由 AMCT 阶段按目标环境单独准备
已实现:
- 基座训练 / 验证 / 测试
- AMP +
torch.compile - Warmup + Cosine Annealing 学习率调度
- TensorBoard、混淆矩阵、UMAP
- 结构化基座 checkpoint
关键输出字段:
model_state_dicttrain_contextmodel_structureinput_tensor_metaarchitecture_signature
已实现:
- 通过
--model自动解析best_model.pth符号链接 - iterative structured pruning
- 每轮评估与可选微调
- 仅最终轮保存 pruning checkpoint
pruning_summary.json- 最终混淆矩阵
关键输出字段:
model_structure.channel_cfgmodel_structure.architecture_signaturepruning_meta
已实现:
- 读取 pruning checkpoint
- 用
*_from_cfg()重建剪枝后的浮点模型 prepare_qat_fx- 保守单路径 QAT 微调
best_qat_prepare_model.pthqat_summary.json
约束:
- 数据链固定
fp32 quantization_meta采用最小恢复契约- QAT 阶段输出 prepare checkpoint,由后续 ONNX 导出阶段继续消费
已实现:
pruning_fp16:- pruning checkpoint -> FP16 ONNX
qat_convert:- QAT checkpoint ->
convert_fx-> quantized ONNX
- QAT checkpoint ->
- ONNX Runtime 精度评估
- 动态 batch 导出
rewrite + validate用于 CANN/AMCT/ATC 兼容约束
已实现:
- 只接受仓库
qat_convert导出的model_quant.onnx - 自动读取同目录
onnx_summary.json - 调用
amct_onnx.convert_qat_model(...) - 输出:
deploy_model.onnxfake_quant_model.onnxscale_offset_record.txtamct_summary.json
说明:
- AMCT 代码已接入主线
- 运行该阶段前需要额外准备仓库附带的
amct_onnxwheel 与算子包
已实现:
- 支持两条输入分支:
pruning_fp16amct_deploy
- 调用
atc编译 - 输出:
.omatc_summary.jsoncheck_result.json/fusion_result.json(若工具链生成)
默认:
soc_version=Ascend310B4input_format=NCHWinput_shape默认从上游摘要中的输入接口派生,并将 batch 固定为1- 若用户显式传入
--input_shape,其输入名与各维度必须与自动派生结果完全一致,否则直接报错
pruning 依赖:
model_structure.model_namemodel_structure.model_kwargsmodel_state_dictinput_tensor_meta
QAT 依赖:
model_structure.model_namemodel_structure.model_kwargsmodel_structure.channel_cfgmodel_structure.architecture_signaturemodel_state_dict
说明:
- 所有
.pthcheckpoint 消费步骤统一对architecture_signature执行强校验
ONNX 依赖:
model_structure.model_namemodel_structure.model_kwargsmodel_structure.channel_cfgmodel_structure.architecture_signaturequantization_meta- prepare 后 graph 的
model_state_dict
说明:
qat_convert与pruning_fp16两条 ONNX 导出路径都属于.pth消费链,因此同样执行architecture_signature强校验
AMCT 依赖:
- ONNX 实体接口与图事实
onnx_summary.json.branch == "qat_convert"onnx_summary.json.onnx_pathonnx_summary.json.model_nameonnx_summary.json.source_checkpoint_pathonnx_summary.json.source_architecture_signatureonnx_summary.json.example_input_shapeonnx_summary.json.opset_version
说明:
- AMCT 统一通过同目录
onnx_summary.json读取并校验上游签名、来源路径与输入接口信息
ATC 依赖:
pruning_fp16分支:model_fp16.onnx- 同目录
onnx_summary.json onnx_summary.json.source_architecture_signature
amct_deploy分支:deploy_model.onnx- 同目录
amct_summary.json amct_summary.json.source_architecture_signatureamct_summary.json.source_onnx_summary_path指向的onnx_summary.json
说明:
pruning_fp16分支通过onnx_summary.json.source_architecture_signature引用并校验上游签名,同时结合onnx_path与interface完成摘要契约校验amct_deploy -> atc在 pixi 环境下不导入 ONNX,而是基于amct_summary.json与上游onnx_summary.json做摘要契约校验amct_summary.json.deploy_interface必须与source_interface一致,且source_onnx_summary_path指向的onnx_summary.json必须在onnx_path、source_architecture_signature、interface三者上与amct_summary.json桥接一致- 因而当前
amct_deploy -> atc形成的是摘要桥接闭环与 summary 内部一致性校验,而不再在 ATC 阶段做 deploy ONNX 实体复核
基座、pruning、QAT 三类 checkpoint 都不只是 state_dict,而是带有:
- 模型结构描述
- 输入信息
- 结构签名
- 上下文元数据
这保证了跨阶段恢复是明确契约,而不是隐式猜测。
- pruning 从真实模型中提取
channel_cfg - QAT 用
*_from_cfg()按channel_cfg重建剪枝结构 - ONNX 导出和后续部署链都建立在这条恢复链上
当前 ONNX 阶段不只是“导出一个文件”,还承担:
convert_fx- 量化图重写
- 结构校验
- ORT 精度对照
因此它是训练端与 Ascend 部署链之间最关键的衔接层。
项目的重点目标不是继续扩展模型种类,而是收敛部署前主线:
- 保持基座 / pruning / QAT 的恢复契约稳定
- 让
qat_convertONNX 的 rewrite / validate 规则更稳 - 让 AMCT / ATC 在当前环境分层下可重复运行
- 保持文档、环境脚本和代码实现三者一致