第一个改了之后报错啦(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py 使用尺寸: Resized=(128, 128, 64), Crop=(64, 64, 32) /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward(). warn_deprecated(argname, msg, warning_category) /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.losses.dice DiceCELoss.__init__:ce_weight: Argument `ce_weight` has been deprecated since version 1.2. It will be removed in version 1.4. please use `weight` instead. warn_deprecated(argname, msg, warning_category) Epoch 1/200 训练 Epoch 1: 0%| | 0/104 [00:00<?, ?it/s]enc_out_list 通道数: [12, 24, 48, 96, 192] enc_out_list 各层特征图尺寸: [torch.Size([1, 12, 32, 32, 16]), torch.Size([1, 24, 16, 16, 8]), torch.Size([1, 48, 8, 8, 4]), torch.Size([1, 96, 4, 4, 2]), torch.Size([1, 192, 2, 2, 1])] 训练 Epoch 1: 0%| | 0/104 [00:01<?, ?it/s] Traceback (most recent call last): File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py", line 283, in <module> outputs = model(inputs, text_feat) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py", line 149, in forward residual = self.residual_conv(enc_out) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 610, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward return F.conv3d( File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__ ret = super().__torch_function__(func, types, args, kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/_tensor.py", line 1386, in __torch_function__ ret = func(*args, **kwargs) RuntimeError: Given groups=1, weight of size [12, 192, 1, 1, 1], expected input[1, 12, 2, 2, 1] to have 192 channels, but got 12 channels instead
时间: 2025-07-27 07:18:12 AIGC 浏览: 92
<think>我们遇到了一个错误:在残差连接部分,我们试图用一个卷积将编码器输出(通道数为192)转换为特征大小为12的通道数,但是输入到残差卷积的特征图enc_out的通道数在融合后已经被降为12(通过self.fusion_reduce),而残差卷积的输入通道数设置的是192(因为融合前的特征通道数是192),但此时融合后的特征图通道数已经是12了,所以导致维度不匹配。解决方案:1.我们应该在融合操作之前提取残差,因为融合操作会改变通道数。2.或者,我们可以调整残差卷积的输入通道数,使其与融合后的特征图通道数一致(即12)。这里我们选择第一种方案:在融合之前提取残差。修改模型代码:在`SwinUNETRWithCLIPFusion`类的`forward`方法中:-在融合操作之前,使用残差卷积对原始的编码器输出(即`enc_out`,通道数为192)进行卷积,得到残差特征(通道数为12)。-然后进行融合操作,融合后的特征图通道数变为192,再通过`self.fusion_reduce`降维到12。-最后,在解码路径结束后,将残差特征(通道数12)加到输出上(需要调整残差特征的空间尺寸以匹配输出)。但是注意:解码路径的输出空间尺寸是输入尺寸的1/2(因为SwinUNETR的编码器有4次下采样,而我们的解码器有4次上采样,最后还有一个上采样操作,所以最终输出尺寸与输入相同)。而残差特征(来自编码器最后一层)的空间尺寸是(2,2,1)(相对于输入图像64x64x32,下采样了32倍)。我们需要将残差特征上采样到原始输入尺寸才能加到解码器输出上。另一种更简单的方法:我们不在整个图像尺寸上做残差,而是将残差特征上采样到解码器最后一层的特征图尺寸(即解码器输出前的特征图尺寸),然后加到解码器最后一层的特征图上。具体步骤:1.在编码器最后一层输出后,立即用残差卷积(输入通道192,输出通道12)得到残差特征(尺寸为[1,12,2,2,1])。2.然后对编码器最后一层输出进行融合和降维(得到[1,12,2,2,1])。3.在解码路径结束后,我们得到上采样后的特征图(尺寸为[1,12,64,64,32]),然后我们需要将残差特征上采样到同样的尺寸(64,64,32),再加到解码器输出特征图上。4.最后通过输出卷积得到最终分割图。修改后的模型代码:
```pythonclass SwinUNETRWithCLIPFusion(SwinUNETR):def __init__(self,img_size, in_channels,out_channels, feature_size=12,use_checkpoint=False, text_feat_dim=512):super().__init__(img_size=img_size, in_channels=in_channels, out_channels=out_channels,feature_size=feature_size, use_checkpoint=use_checkpoint)fusion_img_feat_dim= feature_size *16#192self.fusion =LightCrossAttentionFusion(img_feat_dim=fusion_img_feat_dim,text_feat_dim=text_feat_dim)self.fusion_reduce= nn.Conv3d(fusion_img_feat_dim, feature_size,kernel_size=1)#跳跃连接通道降维self.skip4_reduce= nn.Conv3d(96, feature_size,kernel_size=1)self.skip3_reduce =nn.Conv3d(48,feature_size, kernel_size=1)self.skip2_reduce= nn.Conv3d(24, feature_size,kernel_size=1)self.skip1_reduce =nn.Conv3d(12,feature_size, kernel_size=1)#残差路径self.residual_conv= nn.Conv3d(fusion_img_feat_dim, feature_size,kernel_size=1)#输入192,输出12self.residual_upsample =nn.Upsample(size=crop_size, mode='trilinear',align_corners=True)#上采样到输出尺寸#解码器self.decoder1 =UnetUpBlock(3, feature_size,feature_size, kernel_size=3,stride=1,upsample_kernel_size=2,norm_name="instance")self.decoder2= UnetUpBlock(3,feature_size, feature_size,kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")self.decoder3 =UnetUpBlock(3, feature_size,feature_size, kernel_size=3,stride=1,upsample_kernel_size=2,norm_name="instance")self.decoder4= UnetUpBlock(3,feature_size, feature_size,kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")self.decoder5 =nn.ConvTranspose3d(feature_size,feature_size, kernel_size=2,stride=2)#添加Dropout层self.dropout= nn.Dropout3d(0.2)defforward(self, x, text_feat=None):enc_out_list= self.swinViT(x)enc_out =enc_out_list[-1]#[B,192,2,2,1]#提取残差特征(在融合之前)residual =self.residual_conv(enc_out)# [B,12,2,2,1]#上采样残差特征到最终输出尺寸(即crop_size)residual= self.residual_upsample(residual)#[B,12,64,64,32]#融合iftext_feat isnot None:enc_out= self.fusion(enc_out,text_feat)#[B,192,2,2,1]enc_out =self.fusion_reduce(enc_out)# [B,12,2,2,1]#降维跳跃连接skip4 =self.skip4_reduce(enc_out_list[-2])#[B,12,4,4,2]skip3 =self.skip3_reduce(enc_out_list[-3])#[B,12,8,8,4]skip2= self.skip2_reduce(enc_out_list[-4])# [B,12,16,16,8]skip1 =self.skip1_reduce(enc_out_list[-5])#[B,12,32,32,16]#解码路径(全部通道为feature_size)d1= self.decoder1(enc_out,skip4)# [B,12,4,4,2]d1 =self.dropout(d1)d2 =self.decoder2(d1, skip3)#[B,12,8,8,4]d2= self.dropout(d2)d3= self.decoder3(d2,skip2)# [B,12,16,16,8]d3 =self.dropout(d3)d4 =self.decoder4(d3, skip1)#[B,12,32,32,16]d4= self.dropout(d4)d5= self.decoder5(d4)#[B,12,64,64,32]#将残差特征加到解码器最后一层的输出上d5 =d5+ residualout =self.out(d5)# [B,num_classes,64,64,32]returnout```注意:残差上采样我们使用了`nn.Upsample`,并指定目标尺寸为`crop_size`(即64,64,32)。这样,残差特征就会被上采样到和d5相同的尺寸。另外,我们给每个解码器块后添加了Dropout层,以防止过拟合。此外,由于我们修改了模型结构,可能需要调整学习率或训练策略,但先尝试运行。同时,我们还需要解决字体问题(使用英文标签或设置中文字体)和启用混合精度训练。关于字体问题,我们修改绘图部分:
```python#在绘图之前设置字体(使用英文避免警告)
plt.rcParams['font.family']= 'sans-serif'
plt.rcParams['font.sans-serif'] =['DejaVuSans']#这是一个常见的无衬线字体plt.rcParams['axes.unicode_minus'] =False#然后绘制训练损失和验证Diceplt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_loss_history, label='TrainLoss')plt.title('TrainingLoss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1,2,2)plt.plot(val_dice_history, label='ValidationDice',color='orange')plt.title('ValidationDice Coefficient')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
plt.tight_layout()plt.savefig("logs/fusion_training_metrics.png")
plt.close()```启用混合精度训练:将`use_amp=True`最后,我们添加早停机制和调整学习率调度器(使用ReduceLROnPlateau)。修改训练循环中的学习率调度器:
```python#替换原有的LambdaLRscheduler =torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=0.5,patience=5,verbose=True,min_lr=1e-6)#在验证循环后更新学习率scheduler.step(avg_metric)
```添加早停机制:```pythonearly_stopping_patience =10no_improve_count=0forepoch inrange(max_epochs):#...训练和验证...if avg_metric> best_metric:best_metric =avg_metricbest_metric_epoch= epoch+1no_improve_count=0#保存模型...else:no_improve_count +=1if no_improve_count>= early_stopping_patience:print(f"Earlystopping triggeredafter{early_stopping_patience}epochs withoutimprovement.")break```现在,让我们整合这些修改,重新运行训练。注意:由于修改了模型结构,可能需要重新初始化训练(从epoch0开始)。如果运行成功,应该能解决维度不匹配的问题,并且训练过程会更加稳定。</think>您的代码在残差连接部分出现了维度不匹配错误。让我们分析问题并提供修复方案:
### 错误分析
错误信息:
```
RuntimeError: Given groups=1, weight of size [12, 192, 1, 1, 1], expected input[1, 12, 2, 2, 1] to have 192 channels, but got 12 channels instead
```
问题原因:
1. 您在融合操作后添加了残差连接,但维度不匹配
2. `residual_conv` 期望输入有192个通道,但实际输入只有12个通道
3. 这是因为 `self.fusion_reduce` 已经将通道数从192降到了12
### 解决方案
我们需要调整残差连接的位置和维度处理:
```python
class SwinUNETRWithCLIPFusion(SwinUNETR):
def __init__(self, img_size, in_channels, out_channels, feature_size=12, use_checkpoint=False, text_feat_dim=512):
super().__init__(...)
# 保持原有代码...
# 修改残差连接:在融合前提取特征
self.residual_conv = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1)
# 添加上采样层用于残差连接
self.residual_upsample = nn.Sequential(
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2),
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2),
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2),
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2)
)
def forward(self, x, text_feat=None):
enc_out_list = self.swinViT(x)
enc_out = enc_out_list[-1] # [B, 192, 2, 2, 1]
# 在融合前提取残差特征
residual = self.residual_conv(enc_out) # [B, 12, 2, 2, 1]
residual = self.residual_upsample(residual) # 上采样到原始尺寸 [B, 12, 32, 32, 16]
if text_feat is not None:
enc_out = self.fusion(enc_out, text_feat)
enc_out = self.fusion_reduce(enc_out) # [B, 12, ...]
# 保持原有解码路径...
# 在最终输出前添加残差连接
d5 = d5 + residual # 添加残差连接
out = self.out(d5)
return out
```
### 完整修复代码
```python
class SwinUNETRWithCLIPFusion(SwinUNETR):
def __init__(self, img_size, in_channels, out_channels, feature_size=12, use_checkpoint=False, text_feat_dim=512):
super().__init__(img_size=img_size, in_channels=in_channels, out_channels=out_channels,
feature_size=feature_size, use_checkpoint=use_checkpoint)
fusion_img_feat_dim = feature_size * 16 # 原始是192
self.fusion = LightCrossAttentionFusion(img_feat_dim=fusion_img_feat_dim, text_feat_dim=text_feat_dim)
self.fusion_reduce = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1)
# 跳跃连接通道也统一降维到 feature_size
self.skip4_reduce = nn.Conv3d(96, feature_size, kernel_size=1)
self.skip3_reduce = nn.Conv3d(48, feature_size, kernel_size=1)
self.skip2_reduce = nn.Conv3d(24, feature_size, kernel_size=1)
self.skip1_reduce = nn.Conv3d(12, feature_size, kernel_size=1)
# 残差连接 - 在融合前提取特征
self.residual_conv = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1)
# 残差上采样 - 匹配解码器输出尺寸
self.residual_upsample = nn.Sequential(
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2), # 4x4x2
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2), # 8x8x4
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2), # 16x16x8
nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2) # 32x32x16
)
# 自定义解码器
from monai.networks.blocks import UnetUpBlock
self.decoder1 = UnetUpBlock(
spatial_dims=3,
in_channels=feature_size,
out_channels=feature_size,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name="instance",
)
self.decoder2 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder3 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder4 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder5 = nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2)
# 添加Dropout层防止过拟合
self.dropout = nn.Dropout3d(0.2)
def forward(self, x, text_feat=None):
enc_out_list = self.swinViT(x)
enc_out = enc_out_list[-1] # [B, 192, 2, 2, 1]
# 提取残差特征并上采样
residual = self.residual_conv(enc_out) # [B, 12, 2, 2, 1]
residual = self.residual_upsample(residual) # [B, 12, 32, 32, 16]
if text_feat is not None:
enc_out = self.fusion(enc_out, text_feat)
enc_out = self.fusion_reduce(enc_out) # [B, 12, 2, 2, 1]
# 降维跳跃连接
skip4 = self.skip4_reduce(enc_out_list[-2]) # 96 → 12
skip3 = self.skip3_reduce(enc_out_list[-3]) # 48 → 12
skip2 = self.skip2_reduce(enc_out_list[-4]) # 24 → 12
skip1 = self.skip1_reduce(enc_out_list[-5]) # 12 → 12
# 解码路径(全部通道为 feature_size)
d1 = self.decoder1(enc_out, skip4)
d1 = self.dropout(d1)
d2 = self.decoder2(d1, skip3)
d2 = self.dropout(d2)
d3 = self.decoder3(d2, skip2)
d3 = self.dropout(d3)
d4 = self.decoder4(d3, skip1)
d4 = self.dropout(d4)
d5 = self.decoder5(d4) # [B, 12, 32, 32, 16]
# 添加残差连接
d5 = d5 + residual
out = self.out(d5) # [B, num_classes, 32, 32, 16]
return out
```
### 其他必要的修复
1. **更新损失函数**(解决弃用警告):
```python
# 替换原有的ce_weight参数
loss_function = DiceCELoss(
to_onehot_y=True,
softmax=True,
include_background=True,
weight=class_weights, # 使用weight而不是ce_weight
lambda_dice=0.5,
lambda_ce=0.5
)
```
2. **修复绘图中的字体问题**:
```python
import matplotlib as mpl
mpl.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 使用通用字体
mpl.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 或者在绘图时使用英文标签
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
```
3. **启用混合精度训练**:
```python
use_amp = True # 修改为True
```
### 完整训练循环调整(添加维度检查)
```python
# 在模型前向传播前添加维度检查
def print_tensor_shapes(tensor_dict):
for name, tensor in tensor_dict.items():
print(f"{name}: {tensor.shape}")
# 在训练循环中
for batch_data in train_loader:
inputs = batch_data["image"].to(device)
labels = batch_data["label"].to(device)
# 调试:打印输入维度
if step == 0:
print(f"输入图像维度: {inputs.shape}")
print(f"标签维度: {labels.shape}")
batch_size_now = inputs.shape[0]
text_feat = get_text_features_for_batch(batch_size_now, clip_text_features)
optimizer.zero_grad()
with autocast(enabled=use_amp):
outputs = model(inputs, text_feat)
# 调试:打印输出维度
if step == 0:
print(f"模型输出维度: {outputs.shape}")
print(f"期望标签维度: {labels.shape}")
loss = loss_function(outputs, labels)
# 其余代码保持不变...
```
这些修改应该能解决维度不匹配的问题,同时保持模型的多模态融合功能。残差连接现在在融合前提取特征,并通过上采样匹配解码器输出尺寸,确保维度一致。
阅读全文