Skip to content

Commit 2b6da28

Browse files
szagoruykofmassa
authored andcommitted
Add pretrained Wide ResNet (#912)
* add wide resnet * add docstring for wide resnet * update WRN-50-2 model * add docs * extend WRN docstring * use pytorch storage for WRN * fix rebase * fix typo in docs
1 parent 15e24bd commit 2b6da28

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

docs/source/models.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ architectures for image classification:
2424
- `ShuffleNet`_ v2
2525
- `MobileNet`_ v2
2626
- `ResNeXt`_
27+
- `Wide ResNet`_
2728
- `MNASNet`_
2829

2930
You can construct a model with random weights by calling its constructor:
@@ -41,6 +42,7 @@ You can construct a model with random weights by calling its constructor:
4142
shufflenet = models.shufflenet_v2_x1_0()
4243
mobilenet = models.mobilenet_v2()
4344
resnext50_32x4d = models.resnext50_32x4d()
45+
wide_resnet50_2 = models.wide_resnet50_2()
4446
mnasnet = models.mnasnet1_0()
4547
4648
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
@@ -59,6 +61,7 @@ These can be constructed by passing ``pretrained=True``:
5961
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
6062
mobilenet = models.mobilenet_v2(pretrained=True)
6163
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
64+
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
6265
mnasnet = models.mnasnet1_0(pretrained=True)
6366
6467
Instancing a pre-trained model will download its weights to a cache directory.
@@ -114,6 +117,8 @@ ShuffleNet V2 30.64 11.68
114117
MobileNet V2 28.12 9.71
115118
ResNeXt-50-32x4d 22.38 6.30
116119
ResNeXt-101-32x8d 20.69 5.47
120+
Wide ResNet-50-2 21.49 5.91
121+
Wide ResNet-101-2 21.16 5.72
117122
MNASNet 1.0 26.49 8.456
118123
================================ ============= =============
119124

@@ -202,6 +207,12 @@ ResNext
202207
.. autofunction:: resnext50_32x4d
203208
.. autofunction:: resnext101_32x8d
204209

210+
Wide ResNet
211+
-----------
212+
213+
.. autofunction:: wide_resnet50_2
214+
.. autofunction:: wide_resnet101_2
215+
205216
MNASNet
206217
--------
207218

hubconf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
66
from torchvision.models.inception import inception_v3
77
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
8-
resnext50_32x4d, resnext101_32x8d
8+
resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
99
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
1010
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
1111
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101

torchvision/models/resnet.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44

55
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6-
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
6+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
7+
'wide_resnet50_2', 'wide_resnet101_2']
78

89

910
model_urls = {
@@ -14,6 +15,8 @@
1415
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
1516
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
1617
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
19+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
1720
}
1821

1922

@@ -294,3 +297,37 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
294297
kwargs['width_per_group'] = 8
295298
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
296299
pretrained, progress, **kwargs)
300+
301+
302+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
303+
"""Constructs a Wide ResNet-50-2 model.
304+
305+
The model is the same as ResNet except for the bottleneck number of channels
306+
which is twice larger in every block. The number of channels in outer 1x1
307+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
308+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
309+
310+
Args:
311+
pretrained (bool): If True, returns a model pre-trained on ImageNet
312+
progress (bool): If True, displays a progress bar of the download to stderr
313+
"""
314+
kwargs['width_per_group'] = 64 * 2
315+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
316+
pretrained, progress, **kwargs)
317+
318+
319+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
320+
"""Constructs a Wide ResNet-101-2 model.
321+
322+
The model is the same as ResNet except for the bottleneck number of channels
323+
which is twice larger in every block. The number of channels in outer 1x1
324+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
325+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
326+
327+
Args:
328+
pretrained (bool): If True, returns a model pre-trained on ImageNet
329+
progress (bool): If True, displays a progress bar of the download to stderr
330+
"""
331+
kwargs['width_per_group'] = 64 * 2
332+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
333+
pretrained, progress, **kwargs)

0 commit comments

Comments
 (0)