# resnet.py
# Modified from
# https://github.com/pytorch/vision/blob/release/0.8.0/torchvision/models/resnet.py
import torch
from torch import Tensor
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2',
'wide_resnet101_2'
]
model_urls = {
'resnet18':
'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34':
'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50':
'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101':
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152':
'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d':
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d':
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2':
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2':
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
groups: int = 1,
dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
# Rename relu to relu1
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
self.skip_add = nn.quantized.FloatFunctional()
# Remember to use two independent ReLU for layer fusion.
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
# Use FloatFunctional for addition for quantization compatibility
# out += identity
# out = torch.add(identity, out)
out = self.skip_add.add(identity, out)
out = self.relu2(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu1 = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.skip_add = nn.quantized.FloatFunctional()
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
# out += identity
# out = torch.add(identity, out)
out = self.skip_add.add(identity, out)
out = self.relu2(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,

踟蹰横渡口,彳亍上滩舟。
- 粉丝: 2124
最新资源
- 基于机器学习的商品评论情感分析-毕业设计项目
- 【C语言编程】字符串初始化与操作:字符数组定义、打印及指针访问方法解析
- 【C语言编程】字符串拷贝函数的多种实现方式及其应用场景分析:基础算法学习与实践
- 【C语言编程】基于while和do-while循环的strstr函数实现:字符串匹配与计数算法分析
- ensp软件安装包(包含virtualbox、wireshark、winpcap)
- 【C语言编程】指针与数组操作示例:内存管理及字符串处理函数应用详解
- 【C语言编程】两头堵模型实现:去除字符串首尾空格及长度计算功能开发
- 基于机器学习技术的商品评论情感分析毕业设计项目
- 5-分析式AI基础 6-不同领域的AI算法 7-机器学习神器
- 8-时间序列模型 9-时间序列AI大赛 10-神经网络基础与Tensorflow实战
- Java并发编程的设计原则与模式
- 机器学习基础算法模型实现
- 人工智能与机器学习课程群
- 毕业论文答辩发言稿.docx
- 本科学位论文答辩的技巧与应变能力.docx
- 本科毕业论文答辩范文.docx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈


