PyTorch Metric Learning中的采样器(Samplers)详解
引言
在深度度量学习领域,如何有效地组织训练数据批次对模型性能有着至关重要的影响。PyTorch Metric Learning项目提供了一系列精心设计的采样器(Samplers),专门用于优化度量学习任务的训练过程。本文将深入解析这些采样器的工作原理、适用场景及使用方法。
采样器基础概念
采样器是PyTorch数据加载流程中的关键组件,继承自torch.utils.data.Sampler
类。在度量学习中,采样器的主要职责是:
- 控制批次数据的组织形式
- 实现离线样本对或三元组挖掘
- 优化训练过程中的样本分布
这些采样器可以直接传递给PyTorch的DataLoader,通过sampler
或batch_sampler
参数进行配置。
MPerClassSampler:类别均衡采样器
核心思想
MPerClassSampler
确保每个批次中包含每个类别的固定数量样本,这对于度量学习特别重要,因为它保证了每个批次中都有足够的同类和不同类样本进行比较。
技术细节
samplers.MPerClassSampler(labels, m, batch_size=None, length_before_new_iter=100000)
参数解析:
labels
:数据集的标签列表,labels[x]
对应数据集中第x个元素的标签m
:每次迭代每个类别采样的样本数。如果某类样本数不足m,返回的批次中会有重复样本batch_size
:可选参数。指定后,每个批次保证有m个样本/类。需满足:batch_size
必须是m的倍数length_before_new_iter >= batch_size
m * (唯一标签数) >= batch_size
length_before_new_iter
:创建新可迭代对象前的迭代次数
使用场景
适用于需要严格控制每个类别样本数量的场景,特别是在类别不平衡的数据集上训练时。
HierarchicalSampler:层次化采样器
核心思想
基于《Deep Metric Learning to Rank》论文实现,该采样器引入层次化采样策略,模拟现实世界中的类别层次结构。
技术细节
samplers.HierarchicalSampler(
labels,
batch_size,
samples_per_class,
batches_per_super_tuple=4,
super_classes_per_batch=2,
inner_label=0,
outer_label=1,
)
参数解析:
labels
:二维数组,每行对应一个样本,列对应层次化标签batch_size
:必须指定,且需是super_classes_per_batch
和samples_per_class
的倍数samples_per_class
:每类每批次的样本数。设为"all"可使用类的所有元素(适合少样本学习)batches_per_super_tuple
:为每个超类元组创建的批次数,影响采样器返回的迭代器长度super_classes_per_batch
:每批次的超类数inner_label
:labels
中对应类的列索引outer_label
:labels
中对应超类的列索引
使用场景
特别适合具有自然层次结构的数据,如生物分类、产品目录等。
TuplesToWeightsSampler:基于元组的加权采样器
核心思想
这是一种离线挖掘器,通过分析样本在困难元组中出现的频率来调整采样权重。
技术细节
samplers.TuplesToWeightsSampler(
model,
miner,
dataset,
subset_size=None,
**tester_kwargs
)
参数解析:
model
:用于计算嵌入向量的模型miner
:用于从计算出的嵌入中寻找困难元组dataset
:采样源数据集subset_size
:如果不为None,则使用数据集的随机子集进行挖掘,避免内存不足tester_kwargs
:传递给BaseTester的其他参数
工作流程
- 随机选择数据集子集(如果指定了subset_size)
- 使用指定挖掘器从子集中挖掘元组
- 根据每个元素在元组中出现的频率计算权重
- 使用权重作为概率进行随机采样
使用场景
适用于需要重点关注困难样本的场景,可有效提升模型对边界样本的学习能力。
FixedSetOfTriplets:固定三元组采样器
核心思想
创建固定的三元组集合,特别适用于只有三元组作为真实标签的场景。
技术细节
samplers.FixedSetOfTriplets(labels, num_triplets)
参数解析:
labels
:数据集的标签列表num_triplets
:要创建的三元组数量
使用场景
当评估算法性能且只有三元组作为真实标签时特别有用。
总结
PyTorch Metric Learning提供的采样器为度量学习任务提供了强大的数据组织能力。选择合适的采样器可以:
- 改善模型对类别不平衡的鲁棒性
- 更好地捕捉数据中的层次结构
- 专注于困难样本的学习
- 适应不同的评估场景
开发者应根据具体任务需求和数据特性选择合适的采样策略,以获得最佳的训练效果。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考