深度学习 — 基于ResNet50的野外可食用鲜花分类项目
1.项目简介
- 基本意义
在野外,有多种可食用植物可以采集,常见的包括荠菜、蒲公英、野黄花菜等,这些植物不仅美味,还有一定的药用价值。但是怎样分辨他们成为了一个巨大的难题,人工分类耗时且容易误判,所以我们使用人工智能进行学习,本项目使用了卷积神经网络进行50野外可食用植物识别分类。
1.1 项目名称
基于resnet的50种野外可食用植物分类
1.2 项目简介
在野外生存、生态监测及生物多样性调查等场景中,快速、准确地识别可食用植物具有重要价值。然而,许多可食用植物与其近缘有毒物种形态极其相似,传统人工鉴定耗时且易误判。本项目利用卷积神经网络CNN,以RenSet做迁移学习,构建一个 50 类野外可食用植物的视觉识别系统,实现高精度分类
1.3 项目结构
首先是使用了pytorch框架下的resnet50模型,区别于项目的50分类,resnet50模型可支持1000分类,首先我们对resnet50模型添加了CBAM自注意力机制,然后进行迁移学习,迁移设定我们想分类的种类数的网络结构。
在训练阶段,我们的步骤是,首先进行数据获取,数据预处理,划分数据集以及验证集,设定模型加载器;然后引用模型,定义损失函数,优化器,然后进行模型训练,设置继续训练部分;在验证部分,生成了softmax表,以及打印准确度,精确度,召回率,f1分数的指标报表,以及混淆矩阵。
在模型应用方面,进行了图片预处理,模型推理以及类别显示,并做了模型移植,最后将我们的模型部署到网页
1.4 项目目录一览
2.数据
2.1数据来源
- 数据获取于kaggle,下面为数据集路径链接
Edible wild plants:
https://www.kaggle.com/datasets/gverzea/edible-wild-plants
- 数据为62种野外可使用植物识别,由于部分植物数据较少,故我们使用50种进行分类识别
2.2 数据增强
我们使用了随机翻转,随机旋转,转张量以及归一化处理
data_transform = {
"train": transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
"val": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
2.3 数据分割
我们划分了训练集以及验证集,使用了8 : 2的比例
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, validate_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
3.神经网络
网络介绍
- 诞生背景
• 2015 年微软亚洲研究院提出,解决“随着网络加深反而训练误差增大”的退化问题(degradation)。
• 核心创新:残差连接(Residual / Skip Connection) 让信息“跳过”若干层,直接加到后面,从而训练出超过 1000 层的网络而不发散。 - 结构
Input 224×224×3
↓ 7×7 conv, 64, s=2
↓ 3×3 maxpool, s=2
↓ Stage 1 (3× Residual Bottleneck blocks) → 256-d
↓ Stage 2 (4×) → 512-d
↓ Stage 3 (6×) → 1024-d
↓ Stage 4 (3×) → 2048-d
↓ Global Average Pooling
↓ FC-2048 → FC-50 (本项目改为 50 类)
模型权重参数
模型使用的layers
4.模型训练
4.1 训练参数
轮次:ecpochs = 100
批次:batch_size=32
学习率:lr=0.001
4.2 损失函数
交叉熵损失
4.3 优化器
使用动量优化器
optim.Adam()
4.4 训练过程可视化
训练过程acc变化
可视化ACC
可视化loss
5.模型验证
5.1 验证过程数据化
5.2指标报表
5.3 混淆矩阵
6.模型优化
继续训练
# -------------- 继续训练:加载已有权重 --------------
start_epoch = 0
best_acc = 0.0
if os.path.isfile(best_model_path):
print("✅ 发现已有权重,继续训练...")
net.load_state_dict(torch.load(best_model_path, map_location=device))
else:
print("⏩ 从头训练...")
7.模型应用
7.1图片预处理
简单的图片预处理
MEAN = (0.5, 0.5, 0.5)
STD = (0.5, 0.5, 0.5)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
])
7.2 模型推理
def image_predict(image_path: str):
start_time = time.time()
if not os.path.exists(image_path):
predicted_class = torch.tensor([-1])
else:
img = Image.open(image_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
predicted_class = predict(img_tensor)
end_time = time.time()
print(f'植物识别时间: {(end_time - start_time) * 1000:.2f} ms')
return predicted_class
7.3 类别显示
模型识别分类类别
7.4 使用网页进行测试
8.模型移植
onnx
使用onnx,网址为https://netron.app/
展示部分,完整长截图地址为/卷积神经网络项目文档.assets/onnx.png