学习背景和目标:
背景
作者刚开始接触机器学习,完全不懂怎么进行搭建模型、训练数据、测试效果,刚好找到教程一气呵成教完以上内容。
目标
1.使用pytorch搭建U-Net网络进行图像分割
2.熟悉模型训练流程
学习资料:
1.了解U-Net
https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/
2.观看此视频搭建用于图像分割的U-Net
【使用 U-NET 的 PyTorch 图像分割教程】 https://www.bilibili.com/video/BV1Cz4y1J7hP/?share_source=copy_web&vd_source=38574f6963ec994e868c8d4db8fcba20
3.数据集点击此处获取
https://www.kaggle.com/c/carvana-image-masking-challenge
也可以通过百度网盘直接下载
链接:https://pan.baidu.com/s/1uV9sPTu5p_-f1v95FJoTEg
提取码:1234
学习内容:
1.搭建U-Net网络
完整的U-Net网络如下图
1.1我们观察到每一步都有双卷积,也就是深蓝色箭头部分。先实现双卷积。创建model.py,完成以下代码
class DoubleConv(nn.Module):
def __init__(self,in_channels,out_channels):
super(DoubleConv,self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),#kernel_size=3,stride=1,padding=1
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.conv(x)
1.2双卷积实现完毕,接着实现下采样,也就是红色箭头部分
先定义一个features数组,用于存储通道数
features=[64,128,256,512]
实现下采样
#Down part of UNET
for feature in features:
self.downs.append(DoubleConv(in_channels,feature))
in_channels=feature
1.3 实现上采样及绿色箭头部分
for feature in reversed(features):
self.ups.append(
nn.ConvTranspose2d(
feature*2,feature,kernel_size=2,stride=2,#feature*2是因为进行了跳连接
)
)
self.ups.append(DoubleConv(feature*2,feature))
1.4实现跳连接,也就是灰色箭头部分
def forward(self,x):
skip_connnections=[]
for down in self.downs:
x=down(x)
skip_connnections.append(x)
x=self.pool(x)
x=self.bottleneck(x)
skip_connnections=skip_connnections[::-1]
for idx in range(0,len(self.ups),2):
x=self.ups[idx](x)
skip_connnection=skip_connnections[idx//2]
if x.shape !=skip_connnection.shape:
x=TF.resize(x,size=skip_connnections.shape[2:])
concat_skip=torch.cat((skip_connnection,x),dim=1)
x=self.ups[idx+1](concat_skip)
return self.final_conv(x)
1.5 完善网络
注意到,以上代码还不完善,没有实现网络最下面一层,也就是这部分:
添加代码
self.bottleneck=DoubleConv(features[-1],features[-1]*2)#features[-1]=512
也没有输出部分,如下图
添加代码
self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)
1.6 完整的U-Net代码
完成以上步骤后,将以上代码进行组合,完整的代码如下
import torch
import torch.nn as nn
import torchvision.