PyTorch实现CNN

本文详细介绍了PyTorch中数据处理的过程,包括Dataset、Sampler和Dataloader的使用,以及如何自定义数据集和采样规则。同时阐述了SequentialSampler、RandomSampler、SubsetRandomSampler和WeightedRandomSampler的功能,并解释了BatchSampler和collate_fn在构建batch时的作用。此外,还提到了PyTorch中的基本对象Tensor和Variable的区别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原理及步骤

数据处理

       PyTorch中对于数据集的处理有三个非常重要的类:DatasetDataloaderSampler,它们均是 torch.utils.data 包下的模块(类)。它们的关系可以这样理解:

       Dataset是数据集的类,主要用于定义数据集

       Sampler是采样器的类,用于定义从数据集中选出数据的规则,比如是随机取数据还是按照顺序取等等

       Dataloader是数据的加载类,它是对于Dataset和Sampler的进一步包装,即其实Dataset和Sampler会作为参数传递给Dataloader,用于实际读取数据,可以理解为它是这个工作的真正实践者,而Dataset和Sampler则负责定义。我们训练、测试所获得的数据也是Dataloader直接给我们的。

Dataset

       Dataset 位于 torch.utils.data 下,我们通过定义继承自这个类的子类来自定义数据集。它有两个最重要的方法需要重写,实际上它们都是类的特殊方法:

       __getitem__(self, index):传入参数index为下标,返回数据集中对应下标的数据组(数据和标签)
       __len__(self):返回数据集的大小

       Dataset的目标是根据你输入的索引输出对应的image和label,而且这个功能是要在__getitem__()函数中完成的,所以当你自定义数据集的时候,首先要继承Dataset类,还要复写__getitem__()函数。

Dataloader

       Dataloader对Dataset(和Sampler等)打包,完成最后对数据的读取的执行工作,一般不需要自己定义或者重写一个Dataloader的类(或子类),直接使用即可,通过传入参数定制Dataloader,定制化的功能应该在Dataset(和Sampler等)中完成了。

Sampler

       Sampler类是一个很抽象的父类,其主要用于设置从一个序列中返回样本的规则,即采样的规则。Sampler是一个可迭代对象,使用step方法可以返回下一个迭代后的结果,因此其主要的类方法就是 __iter__ 方法,定义了迭代后返回的内容。

SequentialSampler

       SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样。

RandmSampler

       RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:

       replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
       num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值