原理及步骤
数据处理
PyTorch中对于数据集的处理有三个非常重要的类:Dataset
、Dataloader
、Sampler
,它们均是 torch.utils.data
包下的模块(类)。它们的关系可以这样理解:
Dataset是数据集的类,主要用于定义数据集
Sampler是采样器的类,用于定义从数据集中选出数据的规则,比如是随机取数据还是按照顺序取等等
Dataloader是数据的加载类,它是对于Dataset和Sampler的进一步包装,即其实Dataset和Sampler会作为参数传递给Dataloader,用于实际读取数据,可以理解为它是这个工作的真正实践者,而Dataset和Sampler则负责定义。我们训练、测试所获得的数据也是Dataloader直接给我们的。
Dataset
Dataset 位于 torch.utils.data 下,我们通过定义继承自这个类的子类来自定义数据集。它有两个最重要的方法需要重写,实际上它们都是类的特殊方法:
__getitem__(self, index):传入参数index为下标,返回数据集中对应下标的数据组(数据和标签)
__len__(self):返回数据集的大小
Dataset的目标是根据你输入的索引输出对应的image和label,而且这个功能是要在__getitem__()函数中完成的,所以当你自定义数据集的时候,首先要继承Dataset类,还要复写__getitem__()函数。
Dataloader
Dataloader对Dataset(和Sampler等)打包,完成最后对数据的读取的执行工作,一般不需要自己定义或者重写一个Dataloader的类(或子类),直接使用即可,通过传入参数定制Dataloader,定制化的功能应该在Dataset(和Sampler等)中完成了。
Sampler
Sampler类是一个很抽象的父类,其主要用于设置从一个序列中返回样本的规则,即采样的规则。Sampler是一个可迭代对象,使用step方法可以返回下一个迭代后的结果,因此其主要的类方法就是 __iter__ 方法,定义了迭代后返回的内容。
SequentialSampler
SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样。
RandmSampler
RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:
replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导