pytorch transform函数
时间: 2025-02-04 07:58:34 浏览: 44
### PyTorch 中 `transform` 函数的用法
在 PyTorch 中,`transform` 是用于数据预处理的重要组件之一。通过组合不同的变换方法可以构建复杂的预处理流水线。
#### 使用 Compose 组合多种变换
为了实现一系列连续的数据变换操作,通常会使用 `transforms.Compose()` 方法将多个变换串联起来。例如,在 CIFAR-10 数据集上应用如下预处理:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(), # 将 PIL 图像转为 Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化图像张量
])
```
这段代码定义了一个包含两个步骤的转换流程:首先是将输入图片转化为张量形式;其次是对其进行标准化处理[^1]。
#### 自定义 Dataset 类并集成 Transform
当自定义数据集时,可以在初始化过程中传入所需的 `transform` 参数以便于后续调用。下面是一个基于面部特征点检测任务的例子:
```python
from torch.utils.data import Dataset
import pandas as pd
from skimage import io, transform
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].values.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
```
在这个例子中,如果提供了 `transform` 参数,则会在每次访问特定样本 (`__getitem__`) 的时候自动对该样本施加指定的一系列变换[^2]。
#### DataLoader 加载已处理过的数据
最后一步是利用 `DataLoader` 来高效地批量加载已经过 `Transform` 处理后的数据。这不仅简化了批量化过程还支持多线程加速读取速度:
```python
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2
)
```
这里展示了如何设置一个带有随机打乱顺序(`shuffle=True`)以及并发工作进程数(`num_workers=2`)的训练集加载器。
阅读全文
相关推荐




















