文章目录
- 不同输入类型的损失
- 输入类型:[(anchor, positive/negative, label 1/0)...],label为1距离小、为0距离大
- 输入类型:[(sentence1, label1), (sentence2, label2)...],label相同则距离小
- 输入类型:[(sentence1, sentence2, score), ...], 拟合sentence pair的score(大于0小于1)
- 输入类型:[(sentence1, sentence2, label), ...], 多分类sentence pair
- 输入类型:[(anchor, positive, negative), ...], 三元组样本对输入
- 输入类型:[(anchor, positive), ...], 仅正样本对输入
- 输入类型:[sentence1, sentence2, ...],无标签输入
不同输入类型的损失
根据任务、数据类型选择合适的损失,详见这里。
输入类型:[(anchor, positive/negative, label 1/0)…],label为1距离小、为0距离大
ContrastiveLoss(对比损失)
对于样本对A和B:
- 正样本对(类别为1),它们之间的距离应尽可能近;
- 负样本对(类别为0),它们之间的距离应尽可能远,只惩罚距离小于margin的负样本对,距离超过阈值时不再惩罚;
distance_metric
默认为余弦距离,margin
默认为0.5,loss
为d^2(a,p) + max(margin - d^2(a,n), 0)
。
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
assert len(reps) == 2
rep_anchor, rep_other = reps
distances = self.distance_metric(rep_anchor, rep_other)
losses = 0.5 * (
labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)
)
return losses.mean() if self.size_average else losses.sum()
OnlineContrastiveLoss
与ContrastiveLoss
基本相同,该loss仅选择批次内困难样本计算损失,通常效果比对比损失更优。
损失:选择距离小于最大正样本对距离的负样本,选择距离大于最小负样本对距离的正样本。忽略负样本对最小距离与正样本对最大距离的差超过阈值的easy
实例。
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor, size_average=False) -> Tensor:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
distance_matrix = self.distance_metric(embeddings[0], embeddings[1])
negs = distance_matrix[labels == 0]
poss = distance_matrix[labels == 1]
# select hard positive and hard negative pairs
negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]
positive_loss = positive_pairs.pow(2).