import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import monai
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Activations,
AsDiscrete, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd
)
from monai.data import list_data_collate, decollate_batch
from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from glob import glob
from sklearn.model_selection import train_test_split
from monai.data import PersistentDataset
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
# ======= 配置参数 =======
root_dir = "datasets/LiTS/processed"
images_dir = os.path.join(root_dir, "images")
labels_dir = os.path.join(root_dir, "labels")
max_epochs = 200
batch_size = 1
learning_rate = 1e-4
num_classes = 3 # 背景、肝脏、肿瘤
warmup_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = False # 自动混合精度训练
def get_valid_size(size):
return tuple([max(32, (s // 32) * 32) for s in size])
base_size = (128, 128, 64)
resized_size = get_valid_size(base_size)
crop_size = get_valid_size((64, 64, 32))
print(f"使用尺寸: Resized={resized_size}, Crop={crop_size}")
# ======= 跨模态注意力融合模块 =======
class LightCrossAttentionFusion(nn.Module):
def __init__(self, img_feat_dim=192, text_feat_dim=512, num_heads=2):
super().__init__()
self.num_heads = num_heads
self.img_feat_dim = img_feat_dim
self.text_feat_dim = text_feat_dim
self.query_proj = nn.Linear(img_feat_dim, img_feat_dim)
self.key_proj = nn.Linear(text_feat_dim, img_feat_dim)
self.value_proj = nn.Linear(text_feat_dim, img_feat_dim)
self.out_proj = nn.Linear(img_feat_dim, img_feat_dim)
def forward(self, img_feat, text_feat):
B, C, D, H, W = img_feat.shape
N = D * H * W
img_feat_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C)
Q = self.query_proj(img_feat_flat) # (B, N, C)
K = self.key_proj(text_feat).unsqueeze(1) # (B, 1, C)
V = self.value_proj(text_feat).unsqueeze(1) # (B, 1, C)
head_dim = C // self.num_heads
Q = Q.view(B, N, self.num_heads, head_dim).permute(0, 2, 1, 3)
K = K.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3)
V = V.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)
attn = torch.softmax(scores, dim=-2)
context = torch.matmul(attn, V)
context = context.permute(0, 2, 1, 3).contiguous().view(B, N, C)
out = self.out_proj(context)
out = out.permute(0, 2, 1).view(B, C, D, H, W)
fused = img_feat + out
return fused
# ======= 继承SwinUNETR加入融合 =======
class SwinUNETRWithCLIPFusion(SwinUNETR):
def __init__(self, img_size, in_channels, out_channels, feature_size=12, use_checkpoint=False, text_feat_dim=512):
super().__init__(img_size=img_size, in_channels=in_channels, out_channels=out_channels,
feature_size=feature_size, use_checkpoint=use_checkpoint)
fusion_img_feat_dim = feature_size * 16 # 原始是192
self.fusion = LightCrossAttentionFusion(img_feat_dim=fusion_img_feat_dim, text_feat_dim=text_feat_dim)
self.fusion_reduce = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1)
# 增加残差连接
self.residual_conv = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1)
# 添加Dropout层防止过拟合
self.dropout = nn.Dropout3d(0.2)
# 跳跃连接通道也统一降维到 feature_size
self.skip4_reduce = nn.Conv3d(96, feature_size, kernel_size=1)
self.skip3_reduce = nn.Conv3d(48, feature_size, kernel_size=1)
self.skip2_reduce = nn.Conv3d(24, feature_size, kernel_size=1)
self.skip1_reduce = nn.Conv3d(12, feature_size, kernel_size=1)
# 自定义解码器,每个 block 都是 in=feature_size, skip=feature_size
from monai.networks.blocks import UnetUpBlock
self.decoder1 = UnetUpBlock(
spatial_dims=3,
in_channels=feature_size,
out_channels=feature_size,
kernel_size=3, # 卷积核大小
stride=1, # 主路径不下采样
upsample_kernel_size=2, # 上采样倍数
norm_name="instance",
)
self.decoder2 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder3 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder4 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance")
self.decoder5 = nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2)
def forward(self, x, text_feat=None):
enc_out_list = self.swinViT(x)
if not hasattr(self, "printed_shapes"):
print("enc_out_list 通道数:", [feat.shape[1] for feat in enc_out_list])
print("enc_out_list 各层特征图尺寸:", [feat.shape for feat in enc_out_list])
self.printed_shapes = True
enc_out = enc_out_list[-1] # [B, 192, 2, 2, 1]
if text_feat is not None:
enc_out = self.fusion(enc_out, text_feat)
enc_out = self.fusion_reduce(enc_out) # [B, 12, ...]
# 降维跳跃连接
skip4 = self.skip4_reduce(enc_out_list[-2]) # 96 → 12
skip3 = self.skip3_reduce(enc_out_list[-3]) # 48 → 12
skip2 = self.skip2_reduce(enc_out_list[-4]) # 24 → 12
skip1 = self.skip1_reduce(enc_out_list[-5]) # 12 → 12
# 添加残差连接
residual = self.residual_conv(enc_out)
# 解码路径添加Dropout
d1 = self.dropout(self.decoder1(enc_out, skip4))
d2 = self.dropout(self.decoder2(d1, skip3))
d3 = self.dropout(self.decoder3(d2, skip2))
d4 = self.dropout(self.decoder4(d3, skip1))
d5 = self.decoder5(d4)
# 残差连接
out = self.out(d5) + residual
return out
# ======= 数据预处理 =======
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")),
RandCropByPosNegLabeld(
keys=["image", "label"], label_key="label", spatial_size=crop_size,
pos=1.0, neg=1.0, num_samples=1, image_key="image", image_threshold=0,
),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.1),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")),
CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size),
EnsureTyped(keys=["image", "label"]),
])
images = sorted(glob(os.path.join(images_dir, "*.nii.gz")))
labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz")))
data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)]
train_files, val_files = train_test_split(data, test_size=0.2, random_state=42)
train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir="./cache/train")
val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir="./cache/val")
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=True
)
val_loader = DataLoader(
val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=list_data_collate, pin_memory=True
)
# ======= 加载预提取文本特征 =======
clip_text_features = np.load("./clip_text_features.npy") # shape (num_prompts, 512)
clip_text_features = torch.from_numpy(clip_text_features).float()
def get_text_features_for_batch(batch_size, clip_text_features):
if clip_text_features.shape[0] >= batch_size:
return clip_text_features[:batch_size].to(device)
else:
return clip_text_features.repeat(batch_size, 1).to(device)
# ======= 模型、损失、优化器、调度器 =======
model = SwinUNETRWithCLIPFusion(
img_size=crop_size,
in_channels=1,
out_channels=num_classes,
feature_size=12,
use_checkpoint=True,
text_feat_dim=512,
).to(device)
class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device)
loss_function = DiceCELoss(
to_onehot_y=True,
softmax=True,
include_background=True,
ce_weight=class_weights,
lambda_dice=0.5,
lambda_ce=0.5
)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
def lr_lambda(epoch):
if epoch < warmup_epochs:
return (epoch + 1) / warmup_epochs
progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs)
return 0.5 * (1 + np.cos(np.pi * progress))
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes)])
post_label = Compose([AsDiscrete(to_onehot=num_classes)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
best_metric = -1
best_metric_epoch = -1
train_loss_history = []
val_dice_history = []
os.makedirs("fusion_checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)
for epoch in range(max_epochs):
print(f"\nEpoch {epoch+1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
with tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") as pbar:
for batch_data in train_loader:
step += 1
inputs = batch_data["image"].to(device)
labels = batch_data["label"].to(device)
batch_size_now = inputs.shape[0]
text_feat = get_text_features_for_batch(batch_size_now, clip_text_features)
optimizer.zero_grad()
with autocast(enabled=use_amp):
outputs = model(inputs, text_feat)
loss = loss_function(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
pbar.update(1)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
epoch_loss /= step
train_loss_history.append(epoch_loss)
print(f"训练平均损失: {epoch_loss:.4f}")
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
print(f"当前学习率: {current_lr:.7f}")
model.eval()
dice_values = []
with torch.no_grad(), tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") as pbar:
for val_data in val_loader:
val_images = val_data["image"].to(device)
val_labels = val_data["label"].to(device)
batch_size_val = val_images.shape[0]
text_feat_val = get_text_features_for_batch(batch_size_val, clip_text_features)
with autocast(enabled=use_amp):
val_outputs = model(val_images, text_feat_val)
val_outputs_list = decollate_batch(val_outputs.detach().cpu())
val_labels_list = decollate_batch(val_labels.detach().cpu())
val_output_convert = [post_pred(x) for x in val_outputs_list]
val_label_convert = [post_label(x) for x in val_labels_list]
dice_metric(y_pred=val_output_convert, y=val_label_convert)
metric = dice_metric.aggregate().item()
dice_values.append(metric)
dice_metric.reset()
pbar.update(1)
pbar.set_postfix({"dice": f"{metric:.4f}"})
avg_metric = np.mean(dice_values)
val_dice_history.append(avg_metric)
print(f"验证平均Dice: {avg_metric:.4f}")
if avg_metric > best_metric:
best_metric = avg_metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), f"fusion_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth")
print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}")
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
'dice': avg_metric
}, f"fusion_checkpoints/checkpoint_epoch_{epoch+1}.pth")
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='训练损失')
plt.title('训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(val_dice_history, label='验证Dice', color='orange')
plt.title('验证Dice系数')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
plt.tight_layout()
plt.savefig("logs/fusion_training_metrics.png")
plt.close()
print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}")
你帮我直接把它改好,给一份完整的代码我
最新发布