MMEngine训练模式详解:EpochBased与IterBased的转换与配置
前言
在深度学习模型训练过程中,训练循环(Training Loop)的组织方式主要有两种:基于轮次(EpochBased)和基于迭代(IterBased)。MMEngine作为深度学习训练框架,为这两种训练模式提供了完善的支持。本文将深入解析这两种训练模式的区别,并详细介绍如何在MMEngine中进行配置和转换。
两种训练模式的概念
EpochBased训练模式
EpochBased训练模式是指以"轮次"为基本单位组织训练过程。一个epoch表示模型完整遍历整个训练数据集一次。这种模式的特点是:
- 训练周期直观,易于理解
- 验证通常在每个epoch结束时进行
- 学习率调整等操作通常以epoch为间隔
IterBased训练模式
IterBased训练模式是指以"迭代次数"为基本单位组织训练过程。一次iteration表示模型处理一个batch的数据并完成一次参数更新。这种模式的特点是:
- 更适合大规模数据集训练
- 验证可以更灵活地设置间隔
- 学习率调整等操作以iteration为间隔
MMEngine中的默认配置
MMEngine中大多数模块默认采用EpochBased模式,包括:
- 参数调度器(ParamScheduler)
- 日志钩子(LoggerHook)
- 检查点钩子(CheckpointHook)
- 日志处理器(log_processor)
配置转换详解
下面我们通过一个完整的示例,展示如何将EpochBased配置转换为IterBased配置。
1. 训练配置(train_cfg)转换
原始EpochBased配置:
train_cfg = dict(
by_epoch=True,
max_epochs=10,
val_interval=2
)
转换为IterBased配置:
train_cfg = dict(
by_epoch=False,
max_iters=10000, # 总迭代次数
val_interval=2000 # 每2000次迭代验证一次
)
2. 钩子(Hooks)配置转换
原始EpochBased配置:
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
转换为IterBased配置:
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
3. 参数调度器(ParamScheduler)配置转换
原始EpochBased配置:
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8],
by_epoch=True
)
转换为IterBased配置有两种方式:
方式一:直接指定迭代次数
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
方式二:使用自动转换
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8],
convert_to_iter_based=True
)
4. 日志处理器(log_processor)配置转换
原始EpochBased配置:
log_processor = dict(
by_epoch=True
)
转换为IterBased配置:
log_processor = dict(
by_epoch=False
)
完整示例对比
下面我们以CIFAR10分类任务为例,展示两种训练模式的完整配置对比:
模型定义
两种模式下模型定义相同:
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
数据加载器
数据加载器配置也相同,但需要注意采样器的选择:
train_dataloader = DataLoader(
batch_size=32,
shuffle=True, # 当sampler为None时,MMEngine会根据by_epoch自动选择采样器
dataset=torchvision.datasets.CIFAR10(...)
评估指标
评估指标实现相同:
class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append({
'batch_size': len(gt),
'correct': (score.argmax(dim=1) == gt).sum().cpu(),
})
def compute_metrics(self, results):
total_correct = sum(item['correct'] for item in results)
total_size = sum(item['batch_size'] for item in results)
return dict(accuracy=100 * total_correct / total_size)
注意事项
-
采样器选择:如果基础配置中已经为train_dataloader配置了采样器,需要根据当前训练模式进行调整或设置为None。当sampler为None时:
- by_epoch=True:使用DefaultSampler
- by_epoch=False:使用InfiniteSampler
-
显式指定类型:如果基础配置中已经为train_cfg指定了type(EpochBasedTrainLoop或IterBasedTrainLoop),必须显式覆盖类型,而不仅仅是设置by_epoch参数。
-
参数转换:当从EpochBased转为IterBased时,需要确保总迭代次数与原始配置的训练量相匹配,避免训练不足或过拟合。
总结
MMEngine提供了灵活的配置方式,支持在EpochBased和IterBased训练模式之间轻松切换。理解这两种模式的区别和转换方法,可以帮助开发者根据实际需求选择最合适的训练策略。对于小规模数据集,EpochBased模式更为直观;而对于大规模数据集,IterBased模式通常能提供更好的训练效率和灵活性。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考