transforms.normalize如何对特定数据集设定标准化参数

该博客介绍了如何在图像分类实验中为本地特定数据集计算合适的归一化参数。通过提供一个修正后的Python代码示例,展示了如何利用PyTorch的DataLoader计算训练数据的均值和标准差。在训练和验证阶段使用这些参数可以优化模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在图像分类实验中,经常能看到对数据集进行数据增强操作,其中包括transforms.Normalize(),这个函数的定义如下:

torchvision.transforms.Normalize(mean, std, inplace=False)

功能:针对RGB3个 channel 分别对图像进行标准化

output = ( input - mean ) / std

  • mean: 各通道的均值
  • std: 各通道的标准差
  • inplace: 是否原地操作

通常ImageNet有自己的标准化参数,是通过抽样统计图像的均值方差得到的,那么针对本地特定数据集,如何获取到适合的参数呢?我参考了PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差_紫芝的博客-CSDN博客_pytorch 数据归一化

原文代码有一处错误,需要先把transform设置为transforms.ToTensor(),而不是None,否则会运行错误。以下是改正后的代码:

def getStat(train_data):
    '''
    Compute mean and variance for training data
    :param train_data: 自定义类Dataset(或ImageFolder即可)
    :return: (mean, std)
    '''
    print('Compute mean and variance for training data.')
    print(le
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值