知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例

flyfish

知识蒸馏 - 蒸的什么

知识蒸馏 - 通过引入温度参数T调整 Softmax 的输出

知识蒸馏 - 对数函数的单调性

知识蒸馏 - 信息量的公式为什么是对数

知识蒸馏 - 根据真实事件的真实概率分布对其进行编码

知识蒸馏 - 信息熵中的平均为什么是按概率加权的平均

知识蒸馏 - 自信息量是单个事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率为权重的加权平均值

知识蒸馏 - 最小化KL散度与最小化交叉熵是完全等价的

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例

知识蒸馏的步骤如下:

  1. 训练教师模型:使用常规交叉熵损失在训练集上训练深层教师模型,使其在任务上达到较好性能(作为知识的"来源")。

  2. 固定教师模型:将训练好的教师模型设为评估模式(不更新参数),仅用于提供"软标签"知识。

  3. 初始化学生模型:创建轻量级学生模型(与基线学生模型初始化相同,保证公平对比)。

  4. 学生模型蒸馏训练:
    输入数据同时传入教师模型和学生模型;
    教师模型输出logits(不计算梯度),经温度T软化后得到软概率分布(教师的"软标签");
    学生模型输出logits,经相同温度T软化后得到对数概率分布;
    计算学生与教师软分布的KL散度损失(衡量两者差异,即蒸馏损失);
    计算学生与原始硬标签(真实类别)的交叉熵损失;
    总损失为KL散度损失与交叉熵损失的加权和;
    基于总损失更新学生模型参数,教师模型参数保持不变。

  5. 重复训练:迭代多轮,直至学生模型收敛,最终得到通过蒸馏学习了教师知识的轻量级模型。

使用的数据集

CIFAR-10 数据集:

  1. 介绍
    CIFAR-10(Canadian Institute for Advanced Research 10)是由加拿大高级研究所发布的小型图像数据集,广泛用于计算机视觉领域的入门级模型训练和测试。

  2. 数据组成
    包含 10个类别 的彩色图像,类别分别为:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。
    图像尺寸统一为 32×32像素,且为 RGB三通道(每个图像包含 3×32×32=3072 个像素值,范围 0-255)。
    数据集分为两部分:

    • 训练集:50,000 张图像(每个类别 5,000 张);
    • 测试集:10,000 张图像(每个类别 1,000 张)。
  3. 用途
    主要用于图像分类任务的基准测试,适合验证轻量级模型(如简单CNN)的性能,这里用于验证知识蒸馏对轻量级学生模型的性能提升)。

  4. 使用
    通过 datasets.CIFAR10 加载数据集时,指定 root(数据保存路径)、train(是否为训练集)、download(是否自动下载)和 transform(数据预处理方式),即可便捷地获取预处理后的图像数据和对应标签。

模型设计

维度TeacherNet(教师)StudentNet(学生)设计目的
卷积层数量4层(更深)2层(更浅)教师通过更多层提取更复杂的特征,学生通过精简结构降低部署成本。
卷积核数量初始128,后续64/32(更多)固定16(更少)教师用更多卷积核捕捉细节,学生用少量卷积核减少参数和计算量。
参数规模教师拟合能力强(作为知识源),学生轻量(适合边缘设备部署)。
输出维度10类(与学生一致)10类(与教师一致)确保两者输出分布可通过KL散度比较,实现知识迁移。

知识蒸馏的核心前提——用“强模型”指导“弱模型”学习

教师网络(TeacherNet)

TeacherNet被设计为“性能更强”的网络,通过更多的卷积层和卷积核提取更丰富的图像特征,作为知识的“来源”。其结构可分为特征提取部分(self.features)分类部分(self.classifier)

1. 特征提取部分(self.features)

由4个卷积层(Conv2d)、2个池化层(MaxPool2d)和ReLU激活函数组成,逐步提取图像的层级特征:

  • 第一层nn.Conv2d(3, 128, kernel_size=3, padding=1)
    输入:3通道(RGB图像),输出:128个特征图(卷积核),卷积核3x3,边缘填充1(保证输出尺寸与输入一致,32x32)。
    作用:提取最基础的图像特征(如边缘、纹理)。

  • 第二层nn.Conv2d(128, 64, kernel_size=3, padding=1)
    输入128个特征图,输出64个,进一步压缩特征并提炼细节。

  • 第一次池化nn.MaxPool2d(kernel_size=2, stride=2)
    将特征图尺寸从32x32压缩到16x16(减少计算量,保留关键特征)。

  • 第三、四层卷积nn.Conv2d(64, 64, ...)nn.Conv2d(64, 32, ...)
    继续深化特征提取,最终输出32个特征图。

  • 第二次池化nn.MaxPool2d(...)
    特征图尺寸从16x16压缩到8x8。

2. 分类部分(self.classifier)

将特征提取得到的特征图转换为类别概率:

  • 先通过torch.flatten(x, 1)将8x8的32个特征图展平为一维向量:32 * 8 * 8 = 2048(维度)。
  • 再通过全连接层处理:2048 → 512 → 10(10为CIFAR-10的类别数),中间用ReLU激活和Dropout(0.1)防止过拟合。

学生网络(StudentNet)

StudentNet被设计为“轻量级”网络,参数更少、结构更简单(便于部署),但其输出维度与教师网络一致(均为10类),确保能通过KL散度学习教师的知识。结构同样分为特征提取和分类两部分:

1. 特征提取部分(self.features)

仅含2个卷积层和2个池化层,参数远少于教师网络:

  • 第一层nn.Conv2d(3, 16, kernel_size=3, padding=1)

    • 输入3通道,输出仅16个特征图(远少于教师的128),同样保留32x32尺寸。
  • 第一次池化MaxPool2d将尺寸压缩到16x16。

  • 第二层卷积nn.Conv2d(16, 16, ...)

    • 保持16个特征图,进一步提取特征。
  • 第二次池化:尺寸压缩到8x8。

2. 分类部分(self.classifier)
  • 展平后特征维度:16 * 8 * 8 = 1024(远小于教师的2048)。
  • 全连接层:1024 → 256 → 10(隐藏层维度也远小于教师的512)。

完整的 HelloWorld 示例


"""

知识蒸馏(KL散度实现)
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用{device}设备")

######################################################################
# 数据加载
######################################################################
# 数据预处理
transforms_cifar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transforms_cifar
)
test_dataset = datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transforms_cifar
)

# 数据加载器
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=2
)

######################################################################
# 模型定义
######################################################################
# 教师模型(较深网络)
class TeacherNet(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# 学生模型(轻量级网络)
class StudentNet(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

######################################################################
# 核心函数(KL散度蒸馏)
######################################################################
def train_baseline(model, train_loader, epochs, learning_rate, device):
    """常规交叉熵训练(用于训练教师模型)"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        print(f"轮次 {epoch+1}/{epochs}, 损失: {running_loss/len(train_loader):.4f}")

def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, 
                      kl_weight, ce_weight, device):
    """基于KL散度的知识蒸馏训练"""
    ce_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    
    teacher.eval()  # 教师模型固定
    student.train() # 学生模型训练
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            # 教师输出(不计算梯度)
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            
            # 学生输出
            student_logits = student(inputs)
            
            # 计算KL散度损失(知识蒸馏核心)
            # 教师分布:softmax(teacher_logits / T)
            # 学生分布:log_softmax(student_logits / T)
            teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
            student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)
            kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
            kl_loss *= T ** 2  # 温度缩放补偿
            
            # 计算硬标签损失
            ce_loss = ce_criterion(student_logits, labels)
            
            # 总损失:KL散度损失 + 交叉熵损失
            total_loss = kl_weight * kl_loss + ce_weight * ce_loss
            
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
        
        print(f"蒸馏轮次 {epoch+1}/{epochs}, 总损失: {running_loss/len(train_loader):.4f}")

def test(model, test_loader, device):
    """测试模型准确率"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"测试准确率: {accuracy:.2f}%")
    return accuracy

######################################################################
# 执行流程
######################################################################
if __name__ == "__main__":
    # 1. 训练教师模型
    torch.manual_seed(42)
    teacher = TeacherNet().to(device)
    print("===== 训练教师模型 =====")
    train_baseline(teacher, train_loader, epochs=10, learning_rate=0.001, device=device)
    teacher_acc = test(teacher, test_loader, device)
    
    # 2. 初始化学生模型(与基线对比)
    torch.manual_seed(42)
    student_baseline = StudentNet().to(device)
    print("\n===== 训练基线学生模型(无蒸馏) =====")
    train_baseline(student_baseline, train_loader, epochs=10, learning_rate=0.001, device=device)
    student_baseline_acc = test(student_baseline, test_loader, device)
    
    # 3. 用KL散度蒸馏训练学生模型
    torch.manual_seed(42)
    student_distill = StudentNet().to(device)
    print("\n===== 用KL散度蒸馏训练学生模型 =====")
    train_distillation(
        teacher=teacher,
        student=student_distill,
        train_loader=train_loader,
        epochs=10,
        learning_rate=0.001,
        T=2,  # 温度参数
        kl_weight=0.25,  # KL散度损失权重
        ce_weight=0.75,  # 交叉熵损失权重
        device=device
    )
    student_distill_acc = test(student_distill, test_loader, device)
    
    # 4. 结果对比
    print("\n===== 最终结果 =====")
    print(f"教师模型准确率: {teacher_acc:.2f}%")
    print(f"基线学生模型准确率: {student_baseline_acc:.2f}%")
    print(f"KL蒸馏学生模型准确率: {student_distill_acc:.2f}%")

KL散度的数学定义

对于两个概率分布 P P P(教师模型的输出分布)和 Q Q Q(学生模型的输出分布),KL散度的公式为:
KL ( P ∥ Q ) = E P [ log ⁡ P − log ⁡ Q ] = ∑ x P ( x ) ⋅ ( log ⁡ P ( x ) − log ⁡ Q ( x ) ) \text{KL}(P \parallel Q) = \mathbb{E}_P \left[ \log P - \log Q \right] = \sum_x P(x) \cdot \left( \log P(x) - \log Q(x) \right) KL(PQ)=EP[logPlogQ]=xP(x)(logP(x)logQ(x))
其中:

  • P ( x ) P(x) P(x) 是教师模型输出的概率分布(经温度软化后);
  • Q ( x ) Q(x) Q(x) 是学生模型输出的概率分布(经温度软化后);
  • E P \mathbb{E}_P EP 表示对分布 P P P 取期望(即对所有样本平均)。

代码中KL散度公式的体现

train_distillation函数中,kl_loss的计算完全对应上述公式,具体如下:

  1. 定义分布 P P P Q Q Q

教师模型的输出(logits)经温度 T T T 软化后,通过softmax得到概率分布 P P P

teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)  # P(x)

学生模型的输出(logits)经相同温度 T T T 软化后,通过log_softmax得到 log ⁡ Q ( x ) \log Q(x) logQ(x)(因为log_softmax等价于对softmax的结果取对数):

student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)  # log Q(x)
  1. 计算KL散度的核心部分 (重点看这里)

代码中通过以下一行实现KL散度的求和与期望计算:

   kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)

teacher_soft.log() 对应 log ⁡ P ( x ) \log P(x) logP(x)
student_soft 对应 log ⁡ Q ( x ) \log Q(x) logQ(x)
teacher_soft * (teacher_soft.log() - student_soft) 对应 P ( x ) ⋅ ( log ⁡ P ( x ) − log ⁡ Q ( x ) ) P(x) \cdot (\log P(x) - \log Q(x)) P(x)(logP(x)logQ(x))
torch.sum(...) 对应公式中的求和 ∑ x \sum_x x
除以 inputs.size(0)(批量大小)对应取期望 E P \mathbb{E}_P EP(对批量内样本平均)。

  1. 温度补偿

最后乘以 T 2 T^2 T2 是为了抵消温度对梯度的影响(Hinton原始论文中的标准处理),但不改变KL散度的核心公式:

   kl_loss *= T **2  # 温度缩放补偿

重点说完了,说总体,让学生同时学习“教师的软标签知识”和“真实硬标签任务”

1. 函数参数:控制蒸馏过程的关键变量

def train_distillation(teacher, student, train_loader, epochs, learning_rate, T, 
                      kl_weight, ce_weight, device):
  • teacher:已训练好的教师模型(提供知识的“强模型”);
  • student:需要训练的学生模型(需要学习知识的“弱模型”);
  • train_loader:训练数据加载器(提供输入和真实标签);
  • epochs:训练轮次(遍历数据集的次数);
  • learning_rate:学生模型的学习率;
  • T:温度参数(控制教师/学生输出分布的“平滑度”,值越大分布越平缓);
  • kl_weight:KL散度损失(蒸馏损失)的权重;
  • ce_weight:交叉熵损失(硬标签损失)的权重;
  • device:训练设备(GPU/CPU)。

2. 初始化:损失函数与优化器

ce_criterion = nn.CrossEntropyLoss()  # 硬标签损失函数(真实标签任务)
optimizer = optim.Adam(student.parameters(), lr=learning_rate)  # 只优化学生模型参数
  • CrossEntropyLoss计算学生模型对“真实硬标签”的损失(保证学生不偏离基础任务);
  • 优化器仅针对student.parameters(),因为教师模型参数固定,不需要更新。

3. 模型模式设置:固定教师,训练学生

teacher.eval()  # 教师设为评估模式(关闭dropout等训练特有的层,输出稳定)
student.train()  # 学生设为训练模式(启用参数更新)
  • 教师模型处于eval模式:确保其输出稳定(不受训练时随机因素影响),且不计算梯度(节省资源);
  • 学生模型处于train模式:允许其参数通过反向传播更新。

4. 核心训练循环:每轮迭代训练

for epoch in range(epochs):  # 按轮次遍历
    running_loss = 0.0  # 累计每轮总损失
    for inputs, labels in train_loader:  # 按批量遍历数据
        # 数据转移到设备
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()  # 清零梯度(避免上一批次梯度残留)
  • 外层循环控制训练轮次(epochs),确保模型多次学习数据;
  • 内层循环按批量处理数据,每次处理一个batch的输入(图像)和标签(真实类别);
  • optimizer.zero_grad():清除上一批次计算的梯度,避免干扰当前批次。

5. 教师模型输出:提供“软标签知识”

# 教师输出(不计算梯度,固定参数)
with torch.no_grad():
    teacher_logits = teacher(inputs)
  • with torch.no_grad():禁用教师模型的梯度计算(教师参数不更新,节省内存和计算资源);
  • teacher_logits:教师模型的原始输出(未经过softmax的“对数几率”),后续将用于生成“软标签”。

6. 学生模型输出:学习并生成待优化的预测

# 学生输出(需要计算梯度,更新参数)
student_logits = student(inputs)
  • student_logits:学生模型的原始输出(对数几率),后续用于计算与教师的差异(蒸馏损失)和与真实标签的差异(硬标签损失)。

7. 计算KL散度损失:蒸馏的核心(学习教师知识)

# 教师分布:softmax(teacher_logits / T) → 软化的概率分布(软标签)
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
# 学生分布:log_softmax(student_logits / T) → 软化的对数概率分布
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)

# 计算KL散度:衡量学生分布与教师分布的差异
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
kl_loss *= T ** 2  # 温度缩放补偿(Hinton论文标准操作)
  • 为什么用温度T?
    温度越高(如T=10),softmax输出越平滑(小概率值被放大),教师的“知识细节”(如类别间的相似性)更明显;温度为1时,等价于普通softmax(硬标签)。
  • KL散度公式对应
    严格遵循KL散度定义 KL ( P ∥ Q ) = ∑ P ⋅ ( log ⁡ P − log ⁡ Q ) \text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q) KL(PQ)=P(logPlogQ),其中 P P P 是教师软分布(teacher_soft), Q Q Q 是学生软分布(student_soft对应的概率)。
  • 除以inputs.size(0):对批量内样本取平均,得到单样本的KL损失;
  • 乘以:补偿温度对梯度的影响(温度升高会使梯度变小,乘以T²可抵消这一效应)。

8. 计算硬标签损失:保证学生不偏离真实任务

# 学生对真实标签的交叉熵损失
ce_loss = ce_criterion(student_logits, labels)
  • 用学生的原始输出(student_logits)与真实标签(labels)计算交叉熵损失,确保学生在学习教师知识的同时,不忘记“正确答案”。

9. 总损失:平衡教师知识与真实任务

# 加权求和:KL损失(教师知识) + 交叉熵损失(真实标签)
total_loss = kl_weight * kl_loss + ce_weight * ce_loss
  • 通过kl_weightce_weight控制两者的重要性(例如代码中kl_weight=0.25ce_weight=0.75,表示更侧重真实标签)。

10. 反向传播与参数更新:只优化学生

total_loss.backward()  # 计算总损失对学生参数的梯度
optimizer.step()  # 根据梯度更新学生参数
  • 梯度仅从学生模型计算(教师模型无梯度),最终只有学生模型的参数被优化。

11. 损失监控:跟踪训练过程

running_loss += total_loss.item()  # 累计批量损失
# 每轮结束后打印平均损失
print(f"蒸馏轮次 {epoch+1}/{epochs}, 总损失: {running_loss/len(train_loader):.4f}")
  • 实时监控损失变化,判断模型是否在有效学习(通常损失应逐渐下降并趋于稳定)。
类型本质代码对应作用
教师的软标签知识教师输出的平滑概率分布teacher_soft(经T软化的softmax)传递类别间的相似性知识,帮助学生学习更鲁棒的特征
真实硬标签任务数据集中的真实类别标签labelsce_loss保证学生不偏离基础分类任务,避免完全“模仿教师错误”

还可以继续的学习

“瞎折腾”参数和模型,看看知识蒸馏到底在什么情况下最管用,以及它的“能力边界”在哪儿

1. 调“温度”试试 温度参数 T T T 对蒸馏效果的影响

做法:把代码里的 T 改成1、5、10这些不同的数,其他别动,看看学生模型准确率变不变。 例如固定其他参数(如kl_weight=0.25),测试不同温度值(如 T = 1 , 2 , 5 , 10 , 20 T=1, 2, 5, 10, 20 T=1,2,5,10,20)下学生模型的准确率。
目的:温度就像“老师说话的委婉程度”——温度低,老师只说“这个一定对”(硬邦邦);温度高,老师会说“这个可能对,那个也有点像”(更细腻)。试试哪种“说话方式”能让学生学得更好。 理解温度如何调节教师“软标签”的平滑度。

2. 调“听老师”和“听标准答案”的比例 损失权重( k l _ w e i g h t kl\_weight kl_weight c e _ w e i g h t ce\_weight ce_weight)的平衡实验

做法:把 kl_weight(听老师的权重)和 ce_weight(听标准答案的权重)换成不同组合,看学生成绩。 例如固定 T = 2 T=2 T=2,测试不同权重组合(如 (0.1, 0.9), (0.5, 0.5), (0.9, 0.1))对学生性能的影响
目的:学生既要学老师的“经验”,又不能完全忘了“标准答案”。试试多听老师点好,还是多信标准答案点好。 分析“教师知识”与“真实标签”的权重如何影响学生学习。

3. 老师自己学得好不好,影响学生吗? 教师模型性能对蒸馏的影响

做法:先让老师少学几轮(比如只学5轮,学得差点),或者多学几轮(比如学20轮,学得好点),再让它教学生,看学生成绩差多少。
目的:想知道“老师越厉害,学生是不是一定越厉害”。比如老师自己考80分,能不能教出考75分的学生?老师考90分,学生能到85分吗?

4. 学生太“笨”了,老师还能教好吗? 学生模型复杂度的极限测试

做法:把学生模型改得更简单(比如少一层卷积,少点参数),再用蒸馏训练,看它还能不能比自己学(不蒸馏)强。
目的:测试蒸馏的“底线”——如果学生太简单(比如只有一层神经网络),老师再厉害,是不是也教不会?

5. 换种“衡量学生和老师差异的方式”行不行? 与其他蒸馏损失函数的对比

做法:不用现在的KL散度,换成“均方误差”(MSE)来算学生和老师的差异,其他不变,看学生成绩变不变。
目的:现在用的KL散度就像“比较两个概率分布像不像”,换种方式(比如直接比数值差),会不会更好用?

6. 给数据“加戏”,学生学得更好吗? 数据增强对蒸馏的影响

做法:训练时给图片加些变化(比如随机裁剪一块、左右翻转),再做蒸馏,看学生在测试集上表现会不会更好。
目的:就像学生平时做难题练习,考试时更从容。给训练数据加变化,是不是能让学生和老师都学更扎实?

7. 学生学几轮效果最好? 蒸馏轮次的影响

做法:把蒸馏的 epochs 改成5、20这些数,看看学5轮、10轮、20轮,学生成绩是不是一直涨,还是学太久反而变差。
目的:避免“学过头”——就像人做题,做10套题可能进步快,做100套可能记住答案了,但换题就不会了。

8. 学生是不是真的“又小又能打”? 模型轻量化指标的验证

做法:算一算老师、普通学生、蒸馏学生的“参数数量”(模型大小)和“做题速度”(每秒处理多少张图)。
目的:蒸馏的核心是“让小模型有大模型的本事”。得验证一下:蒸馏后的学生是不是确实比老师小很多,但成绩接近;同时跑起来比老师快。

9. 学生能“举一反三”吗? 跨数据集泛化测试

做法:用CIFAR-10训练好的蒸馏学生,去做类似的题(比如CIFAR-100里的部分类别),看它比普通学生表现好多少。
目的:好的学生不仅会做学过的题,还能应付新题。蒸馏是不是能让学生学到更通用的“解题思路”?

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二分掌柜的

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值