在图像分类实验中,经常能看到对数据集进行数据增强操作,其中包括transforms.Normalize(),这个函数的定义如下:
torchvision.transforms.Normalize(mean, std, inplace=False)
功能:针对RGB3个 channel 分别对图像进行标准化
output = ( input - mean ) / std
- mean: 各通道的均值
- std: 各通道的标准差
- inplace: 是否原地操作
通常ImageNet有自己的标准化参数,是通过抽样统计图像的均值方差得到的,那么针对本地特定数据集,如何获取到适合的参数呢?我参考了PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差_紫芝的博客-CSDN博客_pytorch 数据归一化
原文代码有一处错误,需要先把transform设置为transforms.ToTensor(),而不是None,否则会运行错误。以下是改正后的代码:
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)
'''
print('Compute mean and variance for training data.')
print(le