【Python 语义分割中Dice Loss】Dice Loss 图像分割损失

 

实现

import torch
import torch.nn as nn
import numpy as np

# PyTorch实现Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth  # 平滑因子,用于避免除零错误

    def forward(self, pred, target):
        # pred 是模型的预测输出,target 是真实的分割标签
        # 假设 pred 和 target 的形状是 (batch_size, num_classes, height, width)
        # 这里简化为二分类情况,直接对张量进行扁平化处理
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)

        # 计算交集
        intersection = (pred * target).sum()

        # 计算Dice系数
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)

        # 返回Dice Loss,即1减去Dice系数
        return 1 - dice

# NumPy实现Dice Loss
def dice_loss_numpy(pred, target, smooth=1.0):
    # pred 和 target 是 NumPy 数组
    # 假设 pred 和 target 的形状是 (num_samples, height, width)
    # 这里简化为二分类情况,将数组扁平化处理
    pred = pred.reshape(-1)
    target = target.reshape(-1)

    # 计算交集
    intersection = np.sum(pred * target)

    # 计算Dice系数
    dice = (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

    # 返回Dice Loss,即1减去Dice系数
    return 1 - dice

# 测试PyTorch实现的Dice Loss
if __name__ == "__main__":
    # 假设 batch_size = 2,高度和宽度都是 3(简单示例)
    pred_torch = torch.tensor([[[[0.7, 0.3, 0.4],
                                 [0.2, 0.8, 0.5],
                                 [0.6, 0.9, 0.1]]]], dtype=torch.float32)
    target_torch = torch.tensor([[[[1, 0, 0],
                                  [0, 1, 1],
                                  [1, 1, 0]]]], dtype=torch.float32)

    dice_loss_torch = DiceLoss()
    loss_torch = dice_loss_torch(pred_torch, target_torch)
    print(f"PyTorch Dice Loss: {loss_torch.item()}")

    # 测试NumPy实现的Dice Loss
    pred_numpy = np.array([[[0.7, 0.3, 0.4],
                            [0.2, 0.8, 0.5],
                            [0.6, 0.9, 0.1]]], dtype=np.float32)
    target_numpy = np.array([[[1, 0, 0],
                             [0, 1, 1],
                             [1, 1, 0]]], dtype=np.float32)

    loss_numpy = dice_loss_numpy(pred_numpy, target_numpy)
    print(f"NumPy Dice Loss: {loss_numpy}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值