将lits2017数据集nii格式转换为png格式以及简单的预处理操作(具体操作代码有注释,代码能运行)

import os
import nibabel as nib
import numpy as np
import cv2

# 原始数据文件夹路径
data_folder = 'H:/LITS2017/lits_train/'

# 转换后的图像文件夹路径和标签文件夹路径
output_image_folder = 'D:/python/workplace/segment/experiment/datasets/data/train/images/'
output_label_folder = 'D:/python/workplace/segment/experiment/datasets/data/train/labels/'

# 创建保存转换后图像和标签的文件夹
os.makedirs(output_image_folder, exist_ok=True)
os.makedirs(output_label_folder, exist_ok=True)

# 调窗函数,将像素值映射到指定范围内
def windowing(img, window_center, window_width):
    img_min = window_center - (window_width / 2.0)
    img_max = window_center + (window_width / 2.0)
    img = np.clip(img, img_min, img_max)
    img = (img - img_min) / (img_max - img_min)
    return img
# 直方图均衡化
def histogram_equalization(img):
    hist, bins = np.histogram(img.flatten(), 256, [0, 1])
    cdf = hist.cumsum()
    cdf_normalized = cdf * hist.max() / cdf.max()
    img_equalized = np.interp(img.flatten(), bins[:-1], cdf_normalized)
    return img_equalized.reshape(img.
import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import monai from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Activations, AsDiscrete, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd ) from monai.data import list_data_collate, decollate_batch from monai.networks.nets import SwinUNETR from monai.losses import DiceCELoss from monai.metrics import DiceMetric from glob import glob from sklearn.model_selection import train_test_split from monai.data import PersistentDataset from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from torch.cuda.amp import GradScaler, autocast import matplotlib.pyplot as plt # ======= 配置参数 ======= root_dir = "datasets/LiTS/processed" images_dir = os.path.join(root_dir, "images") labels_dir = os.path.join(root_dir, "labels") max_epochs = 200 batch_size = 1 learning_rate = 1e-4 num_classes = 3 # 背景、肝脏、肿瘤 warmup_epochs = 10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_amp = False # 自动混合精度训练 def get_valid_size(size): return tuple([max(32, (s // 32) * 32) for s in size]) base_size = (128, 128, 64) resized_size = get_valid_size(base_size) crop_size = get_valid_size((64, 64, 32)) print(f"使用尺寸: Resized={resized_size}, Crop={crop_size}") # ======= 跨模态注意力融合模块 ======= class LightCrossAttentionFusion(nn.Module): def __init__(self, img_feat_dim=192, text_feat_dim=512, num_heads=2): super().__init__() self.num_heads = num_heads self.img_feat_dim = img_feat_dim self.text_feat_dim = text_feat_dim self.query_proj = nn.Linear(img_feat_dim, img_feat_dim) self.key_proj = nn.Linear(text_feat_dim, img_feat_dim) self.value_proj = nn.Linear(text_feat_dim, img_feat_dim) self.out_proj = nn.Linear(img_feat_dim, img_feat_dim) def forward(self, img_feat, text_feat): B, C, D, H, W = img_feat.shape N = D * H * W img_feat_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C) Q = self.query_proj(img_feat_flat) # (B, N, C) K = self.key_proj(text_feat).unsqueeze(1) # (B, 1, C) V = self.value_proj(text_feat).unsqueeze(1) # (B, 1, C) head_dim = C // self.num_heads Q = Q.view(B, N, self.num_heads, head_dim).permute(0, 2, 1, 3) K = K.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3) V = V.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3) scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5) attn = torch.softmax(scores, dim=-2) context = torch.matmul(attn, V) context = context.permute(0, 2, 1, 3).contiguous().view(B, N, C) out = self.out_proj(context) out = out.permute(0, 2, 1).view(B, C, D, H, W) fused = img_feat + out return fused # ======= 继承SwinUNETR加入融合 ======= class SwinUNETRWithCLIPFusion(SwinUNETR): def __init__(self, img_size, in_channels, out_channels, feature_size=12, use_checkpoint=False, text_feat_dim=512): super().__init__(img_size=img_size, in_channels=in_channels, out_channels=out_channels, feature_size=feature_size, use_checkpoint=use_checkpoint) fusion_img_feat_dim = feature_size * 16 # 原始是192 self.fusion = LightCrossAttentionFusion(img_feat_dim=fusion_img_feat_dim, text_feat_dim=text_feat_dim) self.fusion_reduce = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1) # 增加残差连接 self.residual_conv = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1) # 添加Dropout层防止过拟合 self.dropout = nn.Dropout3d(0.2) # 跳跃连接通道也统一降维到 feature_size self.skip4_reduce = nn.Conv3d(96, feature_size, kernel_size=1) self.skip3_reduce = nn.Conv3d(48, feature_size, kernel_size=1) self.skip2_reduce = nn.Conv3d(24, feature_size, kernel_size=1) self.skip1_reduce = nn.Conv3d(12, feature_size, kernel_size=1) # 自定义解码器,每个 block 都是 in=feature_size, skip=feature_size from monai.networks.blocks import UnetUpBlock self.decoder1 = UnetUpBlock( spatial_dims=3, in_channels=feature_size, out_channels=feature_size, kernel_size=3, # 卷积核大小 stride=1, # 主路径不下采样 upsample_kernel_size=2, # 上采样倍数 norm_name="instance", ) self.decoder2 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder3 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder4 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder5 = nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2) def forward(self, x, text_feat=None): enc_out_list = self.swinViT(x) if not hasattr(self, "printed_shapes"): print("enc_out_list 通道数:", [feat.shape[1] for feat in enc_out_list]) print("enc_out_list 各层特征图尺寸:", [feat.shape for feat in enc_out_list]) self.printed_shapes = True enc_out = enc_out_list[-1] # [B, 192, 2, 2, 1] if text_feat is not None: enc_out = self.fusion(enc_out, text_feat) enc_out = self.fusion_reduce(enc_out) # [B, 12, ...] # 降维跳跃连接 skip4 = self.skip4_reduce(enc_out_list[-2]) # 96 → 12 skip3 = self.skip3_reduce(enc_out_list[-3]) # 48 → 12 skip2 = self.skip2_reduce(enc_out_list[-4]) # 24 → 12 skip1 = self.skip1_reduce(enc_out_list[-5]) # 12 → 12 # 添加残差连接 residual = self.residual_conv(enc_out) # 解码路径添加Dropout d1 = self.dropout(self.decoder1(enc_out, skip4)) d2 = self.dropout(self.decoder2(d1, skip3)) d3 = self.dropout(self.decoder3(d2, skip2)) d4 = self.dropout(self.decoder4(d3, skip1)) d5 = self.decoder5(d4) # 残差连接 out = self.out(d5) + residual return out # ======= 数据预处理 ======= train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1, image_key="image", image_threshold=0, ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")), RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.1), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), EnsureTyped(keys=["image", "label"]), ]) images = sorted(glob(os.path.join(images_dir, "*.nii.gz"))) labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.2, random_state=42) train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir="./cache/train") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir="./cache/val") train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=True ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=list_data_collate, pin_memory=True ) # ======= 加载预提取文本特征 ======= clip_text_features = np.load("./clip_text_features.npy") # shape (num_prompts, 512) clip_text_features = torch.from_numpy(clip_text_features).float() def get_text_features_for_batch(batch_size, clip_text_features): if clip_text_features.shape[0] >= batch_size: return clip_text_features[:batch_size].to(device) else: return clip_text_features.repeat(batch_size, 1).to(device) # ======= 模型、损失、优化器、调度器 ======= model = SwinUNETRWithCLIPFusion( img_size=crop_size, in_channels=1, out_channels=num_classes, feature_size=12, use_checkpoint=True, text_feat_dim=512, ).to(device) class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, include_background=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5 ) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes)]) post_label = Compose([AsDiscrete(to_onehot=num_classes)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes) scaler = torch.cuda.amp.GradScaler(enabled=use_amp) best_metric = -1 best_metric_epoch = -1 train_loss_history = [] val_dice_history = [] os.makedirs("fusion_checkpoints", exist_ok=True) os.makedirs("logs", exist_ok=True) for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 with tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") as pbar: for batch_data in train_loader: step += 1 inputs = batch_data["image"].to(device) labels = batch_data["label"].to(device) batch_size_now = inputs.shape[0] text_feat = get_text_features_for_batch(batch_size_now, clip_text_features) optimizer.zero_grad() with autocast(enabled=use_amp): outputs = model(inputs, text_feat) loss = loss_function(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() pbar.update(1) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) epoch_loss /= step train_loss_history.append(epoch_loss) print(f"训练平均损失: {epoch_loss:.4f}") scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f"当前学习率: {current_lr:.7f}") model.eval() dice_values = [] with torch.no_grad(), tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") as pbar: for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) batch_size_val = val_images.shape[0] text_feat_val = get_text_features_for_batch(batch_size_val, clip_text_features) with autocast(enabled=use_amp): val_outputs = model(val_images, text_feat_val) val_outputs_list = decollate_batch(val_outputs.detach().cpu()) val_labels_list = decollate_batch(val_labels.detach().cpu()) val_output_convert = [post_pred(x) for x in val_outputs_list] val_label_convert = [post_label(x) for x in val_labels_list] dice_metric(y_pred=val_output_convert, y=val_label_convert) metric = dice_metric.aggregate().item() dice_values.append(metric) dice_metric.reset() pbar.update(1) pbar.set_postfix({"dice": f"{metric:.4f}"}) avg_metric = np.mean(dice_values) val_dice_history.append(avg_metric) print(f"验证平均Dice: {avg_metric:.4f}") if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), f"fusion_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth") print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}") if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'dice': avg_metric }, f"fusion_checkpoints/checkpoint_epoch_{epoch+1}.pth") plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='训练损失') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_dice_history, label='验证Dice', color='orange') plt.title('验证Dice系数') plt.xlabel('Epoch') plt.ylabel('Dice') plt.legend() plt.tight_layout() plt.savefig("logs/fusion_training_metrics.png") plt.close() print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}") 你帮我直接把它改好,给一份完整的代码
最新发布
07-01
以下是将LiTS2017数据集处理为2D png格式,并将标签转换为灰度图的完整代码。需要注意的是,本代码中使用的是Python3.6及以上版本。 ```python import os import numpy as np import SimpleITK as sitk from PIL import Image # 数据集路径 data_path = '/path/to/LiTS2017' # 保存路径 save_path = '/path/to/save' # 处理函数 def process_data(file_path, save_path): # 读取nii文件 itk_img = sitk.ReadImage(file_path) img_array = sitk.GetArrayFromImage(itk_img) # 获取每个切片的像素大小、间距和方向 spacing = itk_img.GetSpacing() origin = itk_img.GetOrigin() direction = itk_img.GetDirection() # 将像素值缩放到0-255之间 img_array = (img_array / np.max(img_array)) * 255 img_array = img_array.astype(np.uint8) # 将标签转换为灰度图 label_array = np.zeros_like(img_array) label_array[img_array == 0] = 0 # 背景 label_array[img_array == 1] = 1 # 肝脏 label_array[img_array == 2] = 2 # 肝脏肿瘤 # 保存2D png格式的图像和标签 for i in range(img_array.shape[0]): img = Image.fromarray(img_array[i]) img.save(os.path.join(save_path, f'{i+1:03d}.png')) label = Image.fromarray(label_array[i]) label.save(os.path.join(save_path, f'label_{i+1:03d}.png')) print(f'{file_path} processed successfully.') # 遍历数据集并处理数据 for file_name in os.listdir(data_path): if file_name.endswith('.nii'): file_path = os.path.join(data_path, file_name) save_folder = os.path.join(save_path, file_name.split('.')[0]) if not os.path.exists(save_folder): os.makedirs(save_folder) process_data(file_path, save_folder) ``` 在运行代码前需要先安装SimpleITK和Pillow库,可使用以下命令进行安装: ``` pip install SimpleITK Pillow ``` 执行完毕后,数据集将被处理成2D的png格式,并且标签将被转换为灰度图,保存在指定的路径下。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值