知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例
flyfish
知识蒸馏 - 通过引入温度参数T调整 Softmax 的输出
知识蒸馏 - 自信息量是单个事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率为权重的加权平均值
知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例
知识蒸馏的步骤如下:
-
训练教师模型:使用常规交叉熵损失在训练集上训练深层教师模型,使其在任务上达到较好性能(作为知识的"来源")。
-
固定教师模型:将训练好的教师模型设为评估模式(不更新参数),仅用于提供"软标签"知识。
-
初始化学生模型:创建轻量级学生模型(与基线学生模型初始化相同,保证公平对比)。
-
学生模型蒸馏训练:
输入数据同时传入教师模型和学生模型;
教师模型输出logits(不计算梯度),经温度T软化后得到软概率分布(教师的"软标签");
学生模型输出logits,经相同温度T软化后得到对数概率分布;
计算学生与教师软分布的KL散度损失(衡量两者差异,即蒸馏损失);
计算学生与原始硬标签(真实类别)的交叉熵损失;
总损失为KL散度损失与交叉熵损失的加权和;
基于总损失更新学生模型参数,教师模型参数保持不变。 -
重复训练:迭代多轮,直至学生模型收敛,最终得到通过蒸馏学习了教师知识的轻量级模型。
使用的数据集
CIFAR-10 数据集:
-
介绍
CIFAR-10(Canadian Institute for Advanced Research 10)是由加拿大高级研究所发布的小型图像数据集,广泛用于计算机视觉领域的入门级模型训练和测试。 -
数据组成
包含 10个类别 的彩色图像,类别分别为:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)。
图像尺寸统一为 32×32像素,且为 RGB三通道(每个图像包含 3×32×32=3072 个像素值,范围 0-255)。
数据集分为两部分:- 训练集:50,000 张图像(每个类别 5,000 张);
- 测试集:10,000 张图像(每个类别 1,000 张)。
-
用途
主要用于图像分类任务的基准测试,适合验证轻量级模型(如简单CNN)的性能,这里用于验证知识蒸馏对轻量级学生模型的性能提升)。 -
使用
通过datasets.CIFAR10
加载数据集时,指定root
(数据保存路径)、train
(是否为训练集)、download
(是否自动下载)和transform
(数据预处理方式),即可便捷地获取预处理后的图像数据和对应标签。
模型设计
维度 | TeacherNet(教师) | StudentNet(学生) | 设计目的 |
---|---|---|---|
卷积层数量 | 4层(更深) | 2层(更浅) | 教师通过更多层提取更复杂的特征,学生通过精简结构降低部署成本。 |
卷积核数量 | 初始128,后续64/32(更多) | 固定16(更少) | 教师用更多卷积核捕捉细节,学生用少量卷积核减少参数和计算量。 |
参数规模 | 大 | 小 | 教师拟合能力强(作为知识源),学生轻量(适合边缘设备部署)。 |
输出维度 | 10类(与学生一致) | 10类(与教师一致) | 确保两者输出分布可通过KL散度比较,实现知识迁移。 |
知识蒸馏的核心前提——用“强模型”指导“弱模型”学习
教师网络(TeacherNet)
TeacherNet
被设计为“性能更强”的网络,通过更多的卷积层和卷积核提取更丰富的图像特征,作为知识的“来源”。其结构可分为特征提取部分(self.features) 和分类部分(self.classifier):
1. 特征提取部分(self.features)
由4个卷积层(Conv2d
)、2个池化层(MaxPool2d
)和ReLU激活函数组成,逐步提取图像的层级特征:
-
第一层:
nn.Conv2d(3, 128, kernel_size=3, padding=1)
输入:3通道(RGB图像),输出:128个特征图(卷积核),卷积核3x3,边缘填充1(保证输出尺寸与输入一致,32x32)。
作用:提取最基础的图像特征(如边缘、纹理)。 -
第二层:
nn.Conv2d(128, 64, kernel_size=3, padding=1)
输入128个特征图,输出64个,进一步压缩特征并提炼细节。 -
第一次池化:
nn.MaxPool2d(kernel_size=2, stride=2)
将特征图尺寸从32x32压缩到16x16(减少计算量,保留关键特征)。 -
第三、四层卷积:
nn.Conv2d(64, 64, ...)
和nn.Conv2d(64, 32, ...)
继续深化特征提取,最终输出32个特征图。 -
第二次池化:
nn.MaxPool2d(...)
特征图尺寸从16x16压缩到8x8。
2. 分类部分(self.classifier)
将特征提取得到的特征图转换为类别概率:
- 先通过
torch.flatten(x, 1)
将8x8的32个特征图展平为一维向量:32 * 8 * 8 = 2048
(维度)。 - 再通过全连接层处理:
2048 → 512 → 10
(10为CIFAR-10的类别数),中间用ReLU激活和Dropout(0.1)防止过拟合。
学生网络(StudentNet)
StudentNet
被设计为“轻量级”网络,参数更少、结构更简单(便于部署),但其输出维度与教师网络一致(均为10类),确保能通过KL散度学习教师的知识。结构同样分为特征提取和分类两部分:
1. 特征提取部分(self.features)
仅含2个卷积层和2个池化层,参数远少于教师网络:
-
第一层:
nn.Conv2d(3, 16, kernel_size=3, padding=1)
- 输入3通道,输出仅16个特征图(远少于教师的128),同样保留32x32尺寸。
-
第一次池化:
MaxPool2d
将尺寸压缩到16x16。 -
第二层卷积:
nn.Conv2d(16, 16, ...)
- 保持16个特征图,进一步提取特征。
-
第二次池化:尺寸压缩到8x8。
2. 分类部分(self.classifier)
- 展平后特征维度:
16 * 8 * 8 = 1024
(远小于教师的2048)。 - 全连接层:
1024 → 256 → 10
(隐藏层维度也远小于教师的512)。
完整的 HelloWorld 示例
"""
知识蒸馏(KL散度实现)
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用{device}设备")
######################################################################
# 数据加载
######################################################################
# 数据预处理
transforms_cifar = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transforms_cifar
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transforms_cifar
)
# 数据加载器
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=128, shuffle=False, num_workers=2
)
######################################################################
# 模型定义
######################################################################
# 教师模型(较深网络)
class TeacherNet(nn.Module):
def __init__(self, num_classes=10):
super(TeacherNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# 学生模型(轻量级网络)
class StudentNet(nn.Module):
def __init__(self, num_classes=10):
super(StudentNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
######################################################################
# 核心函数(KL散度蒸馏)
######################################################################
def train_baseline(model, train_loader, epochs, learning_rate, device):
"""常规交叉熵训练(用于训练教师模型)"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"轮次 {epoch+1}/{epochs}, 损失: {running_loss/len(train_loader):.4f}")
def train_distillation(teacher, student, train_loader, epochs, learning_rate, T,
kl_weight, ce_weight, device):
"""基于KL散度的知识蒸馏训练"""
ce_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(student.parameters(), lr=learning_rate)
teacher.eval() # 教师模型固定
student.train() # 学生模型训练
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# 教师输出(不计算梯度)
with torch.no_grad():
teacher_logits = teacher(inputs)
# 学生输出
student_logits = student(inputs)
# 计算KL散度损失(知识蒸馏核心)
# 教师分布:softmax(teacher_logits / T)
# 学生分布:log_softmax(student_logits / T)
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
kl_loss *= T ** 2 # 温度缩放补偿
# 计算硬标签损失
ce_loss = ce_criterion(student_logits, labels)
# 总损失:KL散度损失 + 交叉熵损失
total_loss = kl_weight * kl_loss + ce_weight * ce_loss
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
print(f"蒸馏轮次 {epoch+1}/{epochs}, 总损失: {running_loss/len(train_loader):.4f}")
def test(model, test_loader, device):
"""测试模型准确率"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"测试准确率: {accuracy:.2f}%")
return accuracy
######################################################################
# 执行流程
######################################################################
if __name__ == "__main__":
# 1. 训练教师模型
torch.manual_seed(42)
teacher = TeacherNet().to(device)
print("===== 训练教师模型 =====")
train_baseline(teacher, train_loader, epochs=10, learning_rate=0.001, device=device)
teacher_acc = test(teacher, test_loader, device)
# 2. 初始化学生模型(与基线对比)
torch.manual_seed(42)
student_baseline = StudentNet().to(device)
print("\n===== 训练基线学生模型(无蒸馏) =====")
train_baseline(student_baseline, train_loader, epochs=10, learning_rate=0.001, device=device)
student_baseline_acc = test(student_baseline, test_loader, device)
# 3. 用KL散度蒸馏训练学生模型
torch.manual_seed(42)
student_distill = StudentNet().to(device)
print("\n===== 用KL散度蒸馏训练学生模型 =====")
train_distillation(
teacher=teacher,
student=student_distill,
train_loader=train_loader,
epochs=10,
learning_rate=0.001,
T=2, # 温度参数
kl_weight=0.25, # KL散度损失权重
ce_weight=0.75, # 交叉熵损失权重
device=device
)
student_distill_acc = test(student_distill, test_loader, device)
# 4. 结果对比
print("\n===== 最终结果 =====")
print(f"教师模型准确率: {teacher_acc:.2f}%")
print(f"基线学生模型准确率: {student_baseline_acc:.2f}%")
print(f"KL蒸馏学生模型准确率: {student_distill_acc:.2f}%")
KL散度的数学定义
对于两个概率分布
P
P
P(教师模型的输出分布)和
Q
Q
Q(学生模型的输出分布),KL散度的公式为:
KL
(
P
∥
Q
)
=
E
P
[
log
P
−
log
Q
]
=
∑
x
P
(
x
)
⋅
(
log
P
(
x
)
−
log
Q
(
x
)
)
\text{KL}(P \parallel Q) = \mathbb{E}_P \left[ \log P - \log Q \right] = \sum_x P(x) \cdot \left( \log P(x) - \log Q(x) \right)
KL(P∥Q)=EP[logP−logQ]=x∑P(x)⋅(logP(x)−logQ(x))
其中:
- P ( x ) P(x) P(x) 是教师模型输出的概率分布(经温度软化后);
- Q ( x ) Q(x) Q(x) 是学生模型输出的概率分布(经温度软化后);
- E P \mathbb{E}_P EP 表示对分布 P P P 取期望(即对所有样本平均)。
代码中KL散度公式的体现
在train_distillation
函数中,kl_loss
的计算完全对应上述公式,具体如下:
- 定义分布 P P P 和 Q Q Q
教师模型的输出(logits)经温度
T
T
T 软化后,通过softmax
得到概率分布
P
P
P:
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1) # P(x)
学生模型的输出(logits)经相同温度
T
T
T 软化后,通过log_softmax
得到
log
Q
(
x
)
\log Q(x)
logQ(x)(因为log_softmax
等价于对softmax
的结果取对数):
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1) # log Q(x)
- 计算KL散度的核心部分 (重点看这里)
代码中通过以下一行实现KL散度的求和与期望计算:
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
teacher_soft.log()
对应
log
P
(
x
)
\log P(x)
logP(x);
student_soft
对应
log
Q
(
x
)
\log Q(x)
logQ(x);
teacher_soft * (teacher_soft.log() - student_soft)
对应
P
(
x
)
⋅
(
log
P
(
x
)
−
log
Q
(
x
)
)
P(x) \cdot (\log P(x) - \log Q(x))
P(x)⋅(logP(x)−logQ(x));
torch.sum(...)
对应公式中的求和
∑
x
\sum_x
∑x;
除以 inputs.size(0)
(批量大小)对应取期望
E
P
\mathbb{E}_P
EP(对批量内样本平均)。
- 温度补偿
最后乘以 T 2 T^2 T2 是为了抵消温度对梯度的影响(Hinton原始论文中的标准处理),但不改变KL散度的核心公式:
kl_loss *= T **2 # 温度缩放补偿
重点说完了,说总体,让学生同时学习“教师的软标签知识”和“真实硬标签任务”
1. 函数参数:控制蒸馏过程的关键变量
def train_distillation(teacher, student, train_loader, epochs, learning_rate, T,
kl_weight, ce_weight, device):
teacher
:已训练好的教师模型(提供知识的“强模型”);student
:需要训练的学生模型(需要学习知识的“弱模型”);train_loader
:训练数据加载器(提供输入和真实标签);epochs
:训练轮次(遍历数据集的次数);learning_rate
:学生模型的学习率;T
:温度参数(控制教师/学生输出分布的“平滑度”,值越大分布越平缓);kl_weight
:KL散度损失(蒸馏损失)的权重;ce_weight
:交叉熵损失(硬标签损失)的权重;device
:训练设备(GPU/CPU)。
2. 初始化:损失函数与优化器
ce_criterion = nn.CrossEntropyLoss() # 硬标签损失函数(真实标签任务)
optimizer = optim.Adam(student.parameters(), lr=learning_rate) # 只优化学生模型参数
- 用
CrossEntropyLoss
计算学生模型对“真实硬标签”的损失(保证学生不偏离基础任务); - 优化器仅针对
student.parameters()
,因为教师模型参数固定,不需要更新。
3. 模型模式设置:固定教师,训练学生
teacher.eval() # 教师设为评估模式(关闭dropout等训练特有的层,输出稳定)
student.train() # 学生设为训练模式(启用参数更新)
- 教师模型处于
eval
模式:确保其输出稳定(不受训练时随机因素影响),且不计算梯度(节省资源); - 学生模型处于
train
模式:允许其参数通过反向传播更新。
4. 核心训练循环:每轮迭代训练
for epoch in range(epochs): # 按轮次遍历
running_loss = 0.0 # 累计每轮总损失
for inputs, labels in train_loader: # 按批量遍历数据
# 数据转移到设备
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # 清零梯度(避免上一批次梯度残留)
- 外层循环控制训练轮次(
epochs
),确保模型多次学习数据; - 内层循环按批量处理数据,每次处理一个
batch
的输入(图像)和标签(真实类别); optimizer.zero_grad()
:清除上一批次计算的梯度,避免干扰当前批次。
5. 教师模型输出:提供“软标签知识”
# 教师输出(不计算梯度,固定参数)
with torch.no_grad():
teacher_logits = teacher(inputs)
with torch.no_grad()
:禁用教师模型的梯度计算(教师参数不更新,节省内存和计算资源);teacher_logits
:教师模型的原始输出(未经过softmax的“对数几率”),后续将用于生成“软标签”。
6. 学生模型输出:学习并生成待优化的预测
# 学生输出(需要计算梯度,更新参数)
student_logits = student(inputs)
student_logits
:学生模型的原始输出(对数几率),后续用于计算与教师的差异(蒸馏损失)和与真实标签的差异(硬标签损失)。
7. 计算KL散度损失:蒸馏的核心(学习教师知识)
# 教师分布:softmax(teacher_logits / T) → 软化的概率分布(软标签)
teacher_soft = nn.functional.softmax(teacher_logits / T, dim=-1)
# 学生分布:log_softmax(student_logits / T) → 软化的对数概率分布
student_soft = nn.functional.log_softmax(student_logits / T, dim=-1)
# 计算KL散度:衡量学生分布与教师分布的差异
kl_loss = torch.sum(teacher_soft * (teacher_soft.log() - student_soft)) / inputs.size(0)
kl_loss *= T ** 2 # 温度缩放补偿(Hinton论文标准操作)
- 为什么用温度T?
温度越高(如T=10),softmax输出越平滑(小概率值被放大),教师的“知识细节”(如类别间的相似性)更明显;温度为1时,等价于普通softmax(硬标签)。 - KL散度公式对应:
严格遵循KL散度定义 KL ( P ∥ Q ) = ∑ P ⋅ ( log P − log Q ) \text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q) KL(P∥Q)=∑P⋅(logP−logQ),其中 P P P 是教师软分布(teacher_soft
), Q Q Q 是学生软分布(student_soft
对应的概率)。 - 除以
inputs.size(0)
:对批量内样本取平均,得到单样本的KL损失; - 乘以
T²
:补偿温度对梯度的影响(温度升高会使梯度变小,乘以T²可抵消这一效应)。
8. 计算硬标签损失:保证学生不偏离真实任务
# 学生对真实标签的交叉熵损失
ce_loss = ce_criterion(student_logits, labels)
- 用学生的原始输出(
student_logits
)与真实标签(labels
)计算交叉熵损失,确保学生在学习教师知识的同时,不忘记“正确答案”。
9. 总损失:平衡教师知识与真实任务
# 加权求和:KL损失(教师知识) + 交叉熵损失(真实标签)
total_loss = kl_weight * kl_loss + ce_weight * ce_loss
- 通过
kl_weight
和ce_weight
控制两者的重要性(例如代码中kl_weight=0.25
、ce_weight=0.75
,表示更侧重真实标签)。
10. 反向传播与参数更新:只优化学生
total_loss.backward() # 计算总损失对学生参数的梯度
optimizer.step() # 根据梯度更新学生参数
- 梯度仅从学生模型计算(教师模型无梯度),最终只有学生模型的参数被优化。
11. 损失监控:跟踪训练过程
running_loss += total_loss.item() # 累计批量损失
# 每轮结束后打印平均损失
print(f"蒸馏轮次 {epoch+1}/{epochs}, 总损失: {running_loss/len(train_loader):.4f}")
- 实时监控损失变化,判断模型是否在有效学习(通常损失应逐渐下降并趋于稳定)。
类型 | 本质 | 代码对应 | 作用 |
---|---|---|---|
教师的软标签知识 | 教师输出的平滑概率分布 | teacher_soft (经T软化的softmax) | 传递类别间的相似性知识,帮助学生学习更鲁棒的特征 |
真实硬标签任务 | 数据集中的真实类别标签 | labels 和ce_loss | 保证学生不偏离基础分类任务,避免完全“模仿教师错误” |
还可以继续的学习
“瞎折腾”参数和模型,看看知识蒸馏到底在什么情况下最管用,以及它的“能力边界”在哪儿
1. 调“温度”试试 温度参数 T T T 对蒸馏效果的影响
做法:把代码里的 T
改成1、5、10这些不同的数,其他别动,看看学生模型准确率变不变。 例如固定其他参数(如kl_weight=0.25
),测试不同温度值(如
T
=
1
,
2
,
5
,
10
,
20
T=1, 2, 5, 10, 20
T=1,2,5,10,20)下学生模型的准确率。
目的:温度就像“老师说话的委婉程度”——温度低,老师只说“这个一定对”(硬邦邦);温度高,老师会说“这个可能对,那个也有点像”(更细腻)。试试哪种“说话方式”能让学生学得更好。 理解温度如何调节教师“软标签”的平滑度。
2. 调“听老师”和“听标准答案”的比例 损失权重( k l _ w e i g h t kl\_weight kl_weight 和 c e _ w e i g h t ce\_weight ce_weight)的平衡实验
做法:把 kl_weight
(听老师的权重)和 ce_weight
(听标准答案的权重)换成不同组合,看学生成绩。 例如固定
T
=
2
T=2
T=2,测试不同权重组合(如 (0.1, 0.9), (0.5, 0.5), (0.9, 0.1)
)对学生性能的影响
目的:学生既要学老师的“经验”,又不能完全忘了“标准答案”。试试多听老师点好,还是多信标准答案点好。 分析“教师知识”与“真实标签”的权重如何影响学生学习。
3. 老师自己学得好不好,影响学生吗? 教师模型性能对蒸馏的影响
做法:先让老师少学几轮(比如只学5轮,学得差点),或者多学几轮(比如学20轮,学得好点),再让它教学生,看学生成绩差多少。
目的:想知道“老师越厉害,学生是不是一定越厉害”。比如老师自己考80分,能不能教出考75分的学生?老师考90分,学生能到85分吗?
4. 学生太“笨”了,老师还能教好吗? 学生模型复杂度的极限测试
做法:把学生模型改得更简单(比如少一层卷积,少点参数),再用蒸馏训练,看它还能不能比自己学(不蒸馏)强。
目的:测试蒸馏的“底线”——如果学生太简单(比如只有一层神经网络),老师再厉害,是不是也教不会?
5. 换种“衡量学生和老师差异的方式”行不行? 与其他蒸馏损失函数的对比
做法:不用现在的KL散度,换成“均方误差”(MSE)来算学生和老师的差异,其他不变,看学生成绩变不变。
目的:现在用的KL散度就像“比较两个概率分布像不像”,换种方式(比如直接比数值差),会不会更好用?
6. 给数据“加戏”,学生学得更好吗? 数据增强对蒸馏的影响
做法:训练时给图片加些变化(比如随机裁剪一块、左右翻转),再做蒸馏,看学生在测试集上表现会不会更好。
目的:就像学生平时做难题练习,考试时更从容。给训练数据加变化,是不是能让学生和老师都学更扎实?
7. 学生学几轮效果最好? 蒸馏轮次的影响
做法:把蒸馏的 epochs
改成5、20这些数,看看学5轮、10轮、20轮,学生成绩是不是一直涨,还是学太久反而变差。
目的:避免“学过头”——就像人做题,做10套题可能进步快,做100套可能记住答案了,但换题就不会了。
8. 学生是不是真的“又小又能打”? 模型轻量化指标的验证
做法:算一算老师、普通学生、蒸馏学生的“参数数量”(模型大小)和“做题速度”(每秒处理多少张图)。
目的:蒸馏的核心是“让小模型有大模型的本事”。得验证一下:蒸馏后的学生是不是确实比老师小很多,但成绩接近;同时跑起来比老师快。
9. 学生能“举一反三”吗? 跨数据集泛化测试
做法:用CIFAR-10训练好的蒸馏学生,去做类似的题(比如CIFAR-100里的部分类别),看它比普通学生表现好多少。
目的:好的学生不仅会做学过的题,还能应付新题。蒸馏是不是能让学生学到更通用的“解题思路”?