一个使用简单 CNN 架构实现 CIFAR-10 图像分类的 PyTorch 项目。
SimpleTorch/
├── README.md # 项目文档
├── requirements.txt # 项目依赖
├── main.py # 主训练脚本
├── model.py # CNN 模型定义
├── configs/ # 配置文件
│ ├── __init__.py
│ └── config.py # 配置类
├── utils/ # 工具函数
│ ├── __init__.py
│ ├── data_utils.py # 数据加载和处理
│ ├── train_utils.py # 训练工具
│ └── visualization.py # 可视化工具
├── tests/ # 测试文件
│ └── __init__.py
├── checkpoints/ # 模型检查点
└── data/ # 数据集目录
- 创建新的 conda 环境:
conda create -n simpletorch python=3.9
conda activate simpletorch- 安装依赖:
pip install -r requirements.txtCNN 模型结构:
- 输入:3x32x32 RGB 图像
- 卷积层1:3->16 通道,3x3 卷积核
- 卷积层2:16->32 通道,3x3 卷积核
- 全连接层:3288 -> 10 个类别
项目使用模块化的配置系统:
TrainingConfig:训练参数(批次大小、学习率等)ModelConfig:模型架构参数DataConfig:数据集和数据加载参数
配置示例:
from configs.config import Config
config = Config()
config.training.batch_size = 32
config.training.learning_rate = 0.01- 数据集大小:50,000 张图像
- 图像尺寸:3x32x32(RGB 图像)
- 类别数:10(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车)
from utils.data_utils import get_data_transforms, load_dataset
transform = get_data_transforms(train=True)
dataset = load_dataset(config.data)运行以下命令查看数据集样本和分布:
python show_dataset.py这将生成两个可视化文件:
cifar10_samples.png:数据集样本图像class_distribution.png:类别分布可视化
- 运行训练:
python main.py- 脚本将执行以下操作:
- 下载 CIFAR-10 数据集
- 训练模型
- 显示训练进度
- 保存模型检查点
- 生成训练可视化
项目包含多个可视化工具:
- 数据集样本可视化
- 类别分布绘图
- 训练历史绘图
- 模型预测可视化
- Python 3.9
- PyTorch 2.6.0
- torchvision
- numpy
- matplotlib
- 添加验证集评估
- 实现模型检查点保存
- 添加训练可视化
- 优化训练策略
- 添加测试集评估
- 添加单元测试
- 添加 CI/CD 流程
- 添加性能分析工具
MIT 许可证