MMEngine训练模式详解:EpochBased与IterBased的转换与配置

MMEngine训练模式详解:EpochBased与IterBased的转换与配置

前言

在深度学习模型训练过程中,训练循环(Training Loop)的组织方式主要有两种:基于轮次(EpochBased)和基于迭代(IterBased)。MMEngine作为深度学习训练框架,为这两种训练模式提供了完善的支持。本文将深入解析这两种训练模式的区别,并详细介绍如何在MMEngine中进行配置和转换。

两种训练模式的概念

EpochBased训练模式

EpochBased训练模式是指以"轮次"为基本单位组织训练过程。一个epoch表示模型完整遍历整个训练数据集一次。这种模式的特点是:

  • 训练周期直观,易于理解
  • 验证通常在每个epoch结束时进行
  • 学习率调整等操作通常以epoch为间隔

IterBased训练模式

IterBased训练模式是指以"迭代次数"为基本单位组织训练过程。一次iteration表示模型处理一个batch的数据并完成一次参数更新。这种模式的特点是:

  • 更适合大规模数据集训练
  • 验证可以更灵活地设置间隔
  • 学习率调整等操作以iteration为间隔

MMEngine中的默认配置

MMEngine中大多数模块默认采用EpochBased模式,包括:

  • 参数调度器(ParamScheduler)
  • 日志钩子(LoggerHook)
  • 检查点钩子(CheckpointHook)
  • 日志处理器(log_processor)

配置转换详解

下面我们通过一个完整的示例,展示如何将EpochBased配置转换为IterBased配置。

1. 训练配置(train_cfg)转换

原始EpochBased配置:

train_cfg = dict(
    by_epoch=True,
    max_epochs=10,
    val_interval=2
)

转换为IterBased配置:

train_cfg = dict(
    by_epoch=False,
    max_iters=10000,  # 总迭代次数
    val_interval=2000  # 每2000次迭代验证一次
)

2. 钩子(Hooks)配置转换

原始EpochBased配置:

default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=True),
    checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)

转换为IterBased配置:

default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=False),
    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)

3. 参数调度器(ParamScheduler)配置转换

原始EpochBased配置:

param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8],
    by_epoch=True
)

转换为IterBased配置有两种方式:

方式一:直接指定迭代次数

param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6000, 8000],
    by_epoch=False,
)

方式二:使用自动转换

param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8],
    convert_to_iter_based=True
)

4. 日志处理器(log_processor)配置转换

原始EpochBased配置:

log_processor = dict(
    by_epoch=True
)

转换为IterBased配置:

log_processor = dict(
    by_epoch=False
)

完整示例对比

下面我们以CIFAR10分类任务为例,展示两种训练模式的完整配置对比:

模型定义

两种模式下模型定义相同:

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels

数据加载器

数据加载器配置也相同,但需要注意采样器的选择:

train_dataloader = DataLoader(
    batch_size=32,
    shuffle=True,  # 当sampler为None时,MMEngine会根据by_epoch自动选择采样器
    dataset=torchvision.datasets.CIFAR10(...)

评估指标

评估指标实现相同:

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)

注意事项

  1. 采样器选择:如果基础配置中已经为train_dataloader配置了采样器,需要根据当前训练模式进行调整或设置为None。当sampler为None时:

    • by_epoch=True:使用DefaultSampler
    • by_epoch=False:使用InfiniteSampler
  2. 显式指定类型:如果基础配置中已经为train_cfg指定了type(EpochBasedTrainLoop或IterBasedTrainLoop),必须显式覆盖类型,而不仅仅是设置by_epoch参数。

  3. 参数转换:当从EpochBased转为IterBased时,需要确保总迭代次数与原始配置的训练量相匹配,避免训练不足或过拟合。

总结

MMEngine提供了灵活的配置方式,支持在EpochBased和IterBased训练模式之间轻松切换。理解这两种模式的区别和转换方法,可以帮助开发者根据实际需求选择最合适的训练策略。对于小规模数据集,EpochBased模式更为直观;而对于大规模数据集,IterBased模式通常能提供更好的训练效率和灵活性。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

晏惠娣Elijah

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值