目录
一、图像数据处理规范化封装
1. 彩色与灰度图像处理函数
标准化预处理函数
import torchvision.transforms as T
from torchvision.io import read_image
def get_transform(mode='rgb'):
"""获取标准化的图像转换管道
Args:
mode: 'rgb'或'gray',指定彩色或灰度处理
Returns:
transform: 预处理管道
"""
if mode.lower() == 'gray':
return T.Compose([
T.Grayscale(num_output_channels=1),
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5])
])
elif mode.lower() == 'rgb':
return T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
else:
raise ValueError("mode必须为'rgb'或'gray'")
# 使用示例
rgb_transform = get_transform('rgb')
gray_transform = get_transform('gray')
数据加载封装
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, file_list, mode='rgb'):
self.files = file_list
self.transform = get_transform(mode)
def __getitem__(self, idx):
img = read_image(self.files[idx]) # 自动保持通道顺序
return self.transform(img)
def __len__(self):
return len(self.files)
二、张量展平操作规范
1. 智能展平函数实现
def flatten_tensor(x):
"""保留batch维度展平张量
Args:
x: 输入张量 (batch, ...)
Returns:
展平后的张量 (batch, features)
"""
return x.view(x.size(0), -1) # 保持第0维(batch),其余展平
# 使用示例
tensor_4d = torch.randn(32, 3, 28, 28) # batch, channel, height, width
flattened = flatten_tensor(tensor_4d) # 形状变为(32, 3*28*28)
2. 展平操作在模型中的应用
class CNNWithFlatten(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3)
self.fc = nn.Linear(16*26*26, 10) # 假设输入为3x28x28
def forward(self, x):
x = F.relu(self.conv(x)) # 输出(bs,16,26,26)
x = flatten_tensor(x) # 变为(bs,16*26*26)
return self.fc(x)
三、Dropout规范化实现
1. Dropout工作原理
2. 正确使用Dropout的模型示例
class DropoutNet(nn.Module):
def __init__(self, dropout_rate=0.5):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.dropout = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x) # 仅在训练时生效
return self.fc2(x)
# 使用示例
model = DropoutNet()
# 训练阶段
model.train() # 启用dropout
output_train = model(inputs)
# 测试阶段
model.eval() # 关闭dropout
with torch.no_grad():
output_test = model(inputs)
3. Dropout效果验证实验
def verify_dropout(model, input_tensor):
"""验证dropout在不同模式下的行为"""
model.train()
train_outputs = torch.stack([model(input_tensor) for _ in range(100)])
model.eval()
eval_outputs = torch.stack([model(input_tensor) for _ in range(100)])
print("训练模式输出方差:", train_outputs.std(dim=0).mean().item())
print("评估模式输出方差:", eval_outputs.std(dim=0).mean().item())
# 测试
dummy_input = torch.randn(1, 784)
verify_dropout(DropoutNet(), dummy_input)
四、综合应用示例
完整图像分类管道
class ImageClassifier(nn.Module):
def __init__(self, in_channels=3, num_classes=10, dropout_rate=0.2):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(dropout_rate) # 空间dropout
)
self.classifier = nn.Sequential(
nn.Linear(32*7*7, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = flatten_tensor(x)
return self.classifier(x)
def train_model(model, dataloader, epochs=10):
model.train()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for images, labels in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
def evaluate(model, dataloader):
model.eval()
correct = 0
with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
accuracy = 100 * correct / len(dataloader.dataset)
print(f"准确率: {accuracy:.2f}%")
五、关键注意事项
1. 图像处理规范
-
通道顺序一致性:确保训练和测试时使用相同的通道顺序(RGB或BGR)
-
归一化参数统一:使用相同的均值和标准差进行归一化
-
灰度转换时机:在数据加载时而非模型内部进行灰度化
2. 展平操作陷阱
错误做法 | 正确做法 |
---|---|
x.view(-1) | x.view(x.size(0), -1) |
手动计算特征维度 | 使用x.shape 动态获取 |
在卷积层前展平 | 在全连接层前展平 |
3. Dropout最佳实践
-
概率选择:通常0.2-0.5之间,深层网络可适当增加
-
位置选择:通常在激活函数之后
-
模式切换:验证/测试前务必调用
model.eval()
-
特定变体:
-
Dropout2d
:适合卷积层,按通道丢弃 -
AlphaDropout
:适合SELU激活函数
-
六、调试技巧
1. 形状调试工具
def print_shapes(model, input_shape):
"""打印模型各层输出形状"""
dummy = torch.randn(input_shape)
print(f"{'Layer':<20} {'Output Shape':<20}")
print("-"*40)
for name, layer in model.named_children():
dummy = layer(dummy)
print(f"{name:<20} {str(list(dummy.shape)):<20}")
if 'fc' in name: # 全连接层后停止
break
2. Dropout激活检查
def is_dropout_active(m):
"""检查dropout层是否处于激活状态"""
if isinstance(m, nn.Dropout):
return m.training
return False
# 使用示例
model = DropoutNet()
print("训练模式:", any(is_dropout_active(m) for m in model.modules()))
model.eval()
print("评估模式:", any(is_dropout_active(m) for m in model.modules()))