自定义模型
import torch
from torch import nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.forward1 = nn.Sequential(
nn.Conv3d(1, 8, kernel_size=3, padding=1),
nn.BatchNorm3d(8),
nn.ReLU(inplace=True),
nn.Conv3d(8, 8, kernel_size=3, padding=1),
nn.BatchNorm3d(8)<