【Masked Video Distillation蒸馏损失函数】

在这里插入图片描述

Masked Video Distillation蒸馏损失函数伪码实现:

# f: student encoder  # encoder
# g_img: decoder for reconstructing spatial features  # decoder for image
# g_vid: decoder for reconstructing spatial-temporal  # decoder for video
features
# t_m: learnable mask tokens # 掩码token
# h_img: image teacher model  #image teacher
# h_vid: video teacher model  #video teacher
for x, m in loader: # x: video data, m: mask
x_pe = patch_emb(x) # patch embedding of input #patch embedding
x_vis = mask_select(x_pe, 1 - m) # masking tokens # 可见token
q_vis = f(x_vis) # visible local patch features  # 编码结果
# reconstruction of target features
p_img = g_img(concat(q_vis, t_m)) # 重建image结果
p_vid = g_vid(concat(q_vis, t_m)) # 重建video结果
# compute target features with teacher models
k_img = h_img(x) # target spatial features  # image teacher 预测结果
k_vid = h_vid(x) # target spatial-temporal features # video teacher 预测结果
# compute reconstruction loss
loss_img = smooth_L1_loss(p_img ? m, k_img ? m) # image loss
loss_vid = smooth_L1_loss(p_vid ? m, k_vid ? m) # video loss
loss = λ1 * loss_img + λ2 * loss_vid #总体损失函数
loss.backward()
optimizer.step() # optimizer update```

Masked Video Distillation蒸馏损失函数源码实现:

        with torch.cuda.amp.autocast():
            output_features, output_video_features = model(videos, bool_masked_pos)
            with torch.no_grad():
                image_teacher_model.eval()  #训练image teacher
                if time_stride_loss:
                    teacher_features = image_teacher_model(  #得到image teacher的预测结果
                        rearrange(videos_for_teacher[:, :, ::tubelet_size, :, :], 'b c t h w -> (b t) c h w'),
                    )
                    teacher_features = rearrange(teacher_features, '(b t) l c -> b (t l) c', t=T//tubelet_size)
                else:
                    teacher_features = image_teacher_model(
                        rearrange(videos_for_teacher, 'b c t h w -> (b t) c h w'),
                    )
                    teacher_features = rearrange(teacher_features, '(b t d) l c -> b (t l) (d c)', t=T//tubelet_size, d=tubelet_size)
                if norm_feature:
                    teacher_features = LN_img(teacher_features)

                video_teacher_model.eval()  # 训练video teacher
                videos_for_video_teacher = videos if args.video_teacher_input_size == args.input_size \
                    else videos_for_teacher

                video_teacher_features = video_teacher_model(videos_for_video_teacher)#得到video teacher的预测结果
                if norm_feature:
                    video_teacher_features = LN_vid(video_teacher_features)

            B, _, D = output_features.shape
            loss_img_feat = loss_func_img_feat(#image teacher 的 loss
                input=output_features,
                target=teacher_features[bool_masked_pos].reshape(B, -1, D)
            )
            loss_value_img_feat = loss_img_feat.item()

            B, _, D = output_video_features.shape
            loss_vid_feat = loss_func_vid_feat(#video teacher 的 loss
                input=output_video_features,
                target=video_teacher_features[bool_masked_pos].reshape(B, -1, D)
            )
            loss_value_vid_feat = loss_vid_feat.item()

            loss = image_loss_weight * loss_img_feat + video_loss_weight * loss_vid_feat#总的损失函数
            ```

### Masked Language Model 损失函数详解 在掩码语言模型(MLM)中,损失函数的设计至关重要。该类模型的目标是在给定一部分被遮蔽的词的情况下,尽可能准确地预测这些词。为了实现这一目标,通常采用交叉熵作为损失函数。 对于每一个被遮蔽的位置 \(i\) ,假设词汇表大小为 \(V\) 。那么,在位置 \(i\) 的真实标签记作 \(y_i\) ,而模型对该位置上各个可能单词的概率分布预测则表示为 \(\hat{p}_i = (p_{i,1}, p_{i,2}, ..., p_{i,V})\) 。 因此,针对单个样本而言,其对应的交叉熵损失可以表达如下: \[ L_i = -\log(p_{i,y_i}) \] 当考虑整个批次的数据时,则总的损失可以通过求平均值得到: \[ L = \frac{1}{N}\sum^{N}_{n=1} L_n \] 其中 \(N\) 表示批处理数量。这种形式化的描述有助于理解如何衡量模型预测的质量以及指导参数更新过程[^2]。 值得注意的是,在实际应用过程中,并不是所有的token都会参与计算损失;只有那些真正被mask掉的部分才会贡献于最终的梯度下降优化方向。这不仅提高了训练效率,也使得模型能够更专注于学习上下文之间的关系而不是简单记忆输入序列[^1]。 ```python import torch.nn.functional as F def masked_language_model_loss(logits, labels, mask): """ logits: Tensor of shape [batch_size, seq_len, vocab_size], output from the model. labels: Tensor of shape [batch_size, seq_len], ground truth token ids. mask: Tensor of shape [batch_size, seq_len], binary tensor indicating where tokens are masked. Returns: loss: Scalar value representing average cross entropy loss over all masked positions. """ # Flatten everything to work with standard CrossEntropyLoss function which expects input and target tensors of size (minibatch,C). active_logits = logits.view(-1, logits.size(-1)) active_labels = labels.view(-1) active_mask = mask.view(-1) # Select only those elements that correspond to actual masks. filtered_logits = active_logits[active_mask == 1] filtered_labels = active_labels[active_mask == 1] # Compute mean cross-entropy loss across all non-zero entries in `filtered_labels`. loss = F.cross_entropy(filtered_logits, filtered_labels) return loss ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值