用pytorch搭建U-Net并进行图像分割

学习背景和目标:

背景

作者刚开始接触机器学习,完全不懂怎么进行搭建模型、训练数据、测试效果,刚好找到教程一气呵成教完以上内容。

目标

1.使用pytorch搭建U-Net网络进行图像分割
2.熟悉模型训练流程


学习资料:

1.了解U-Net
https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

2.观看此视频搭建用于图像分割的U-Net
【使用 U-NET 的 PyTorch 图像分割教程】 https://www.bilibili.com/video/BV1Cz4y1J7hP/?share_source=copy_web&vd_source=38574f6963ec994e868c8d4db8fcba20

3.数据集点击此处获取
https://www.kaggle.com/c/carvana-image-masking-challenge
也可以通过百度网盘直接下载
链接:https://pan.baidu.com/s/1uV9sPTu5p_-f1v95FJoTEg
提取码:1234

4.源码点击此处获取
https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet


学习内容:

1.搭建U-Net网络

完整的U-Net网络如下图
在这里插入图片描述

1.1我们观察到每一步都有双卷积,也就是深蓝色箭头部分。先实现双卷积。创建model.py,完成以下代码

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),#kernel_size=3,stride=1,padding=1
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.conv(x)

1.2双卷积实现完毕,接着实现下采样,也就是红色箭头部分

先定义一个features数组,用于存储通道数

features=[64,128,256,512]

实现下采样

#Down part of UNET
for feature in features:
    self.downs.append(DoubleConv(in_channels,feature))
    in_channels=feature

1.3 实现上采样及绿色箭头部分

for feature in reversed(features):
    self.ups.append(
        nn.ConvTranspose2d(
            feature*2,feature,kernel_size=2,stride=2,#feature*2是因为进行了跳连接
        )
    )
    self.ups.append(DoubleConv(feature*2,feature))

1.4实现跳连接,也就是灰色箭头部分

def forward(self,x):
    skip_connnections=[]

    for down in self.downs:
        x=down(x)
        skip_connnections.append(x)
        x=self.pool(x)

    x=self.bottleneck(x)
    skip_connnections=skip_connnections[::-1]

    for idx in range(0,len(self.ups),2):
        x=self.ups[idx](x)
        skip_connnection=skip_connnections[idx//2]

        if x.shape !=skip_connnection.shape:
            x=TF.resize(x,size=skip_connnections.shape[2:])

        concat_skip=torch.cat((skip_connnection,x),dim=1)
        x=self.ups[idx+1](concat_skip)

    return self.final_conv(x)

1.5 完善网络

注意到,以上代码还不完善,没有实现网络最下面一层,也就是这部分:
在这里插入图片描述
添加代码

self.bottleneck=DoubleConv(features[-1],features[-1]*2)#features[-1]=512

也没有输出部分,如下图
在这里插入图片描述
添加代码

self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)

1.6 完整的U-Net代码

完成以上步骤后,将以上代码进行组合,完整的代码如下

import torch
import torch.nn as nn
import torchvision.
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值