深度强化学习框架DI-engine

深度强化学习框架DI-engine

一、DI-engine概述:决策智能的通用引擎

DI-engine是由OpenDILab开源的决策智能引擎,基于PyTorch和JAX构建,旨在为强化学习(RL)、模仿学习(IL)、离线学习等场景提供标准化解决方案。其核心目标是统一不同决策智能环境与应用,支持从学术研究到工业原型的全流程开发。

核心特性

  1. 算法多样性
    支持60+算法,覆盖传统DRL(如DQN、PPO、SAC)、多智能体RL(QMIX、COMA)、模仿学习(GAIL、SQIL)、离线学习(CQL、Decision Transformer)、基于模型的RL(MBPO、DreamerV3)等方向。

  2. 环境兼容性
    内置80+环境,包括Atari、MuJoCo、SMAC(多智能体)、D4RL(离线学习)、Gfootball、Metadrive(自动驾驶)等,支持离散/连续/混合动作空间,兼容单智能体与多智能体场景。

  3. 系统级优化

    • TreeTensor:树状嵌套数据结构,统一数据处理流程,支持自动微分和批量操作。
    • 分布式训练:基于DDP和Kubernetes的分布式框架(DI-orchestrator),支持大规模并行训练。
    • 中间件机制:模块化设计(如Replay Buffer、Reward Model),支持自定义流程扩展。
  4. 工程化工具链
    提供配置系统(YAML/JSON)、模型加载/恢复、随机种子管理、单元测试框架(pytest)及Docker镜像(含预构建环境)。

二、功能模块与技术架构

1. 算法分类与支持列表

算法类型典型算法举例文档链接Demo命令示例
基础DRLDQN、PPO、SAC、TD3DQN文档ding -m serial -c cartpole_dqn_config.py -s 0
多智能体RLQMIX、WQMIX、MADDPGQMIX文档ding -m serial -c smac_3s5z_qmix_config.py
模仿学习GAIL、SQIL、BCQGAIL文档ding -m serial_gail -c cartpole_gail_config.py
离线学习CQL、TD3-BC、Decision TransformerCQL文档python3 -u d4rl_cql_main.py
基于模型的RLMBPO、DreamerV3、STEVEMBPO文档python3 -u pendulum_mbpo_config.py
探索机制HER、RND、ICMHER文档python3 -u bitflip_her_dqn.py

2. 环境生态

DI-engine通过dizoo模块集成丰富环境,部分示例如下:

  • 经典控制:CartPole、Pendulum(单智能体离散/连续控制)。
  • Atari:Pong、Asterix(视觉密集型任务)。
  • 多智能体:SMAC(星际争霸微型战役)、PettingZoo(合作/竞争场景)。
  • 离线学习:D4RL(MuJoCo专家轨迹数据集)。
  • 真实应用:Metadrive(自动驾驶)、DI-star(星际争霸II决策AI)。

3. 核心组件

  • Policy:封装算法逻辑,支持自定义网络结构(如CNN、Transformer)。
  • Model:定义神经网络架构,支持PyTorch原生模块扩展。
  • Collector/Learner:分离数据采集与模型训练,支持异步并行。
  • Replay Buffer:支持优先经验回放(PER)、分层回放(PLR)等高级机制。
  • TreeTensor:树状数据结构,示例代码:
    import treetensor.torch as ttorch
    data = ttorch.randn({'obs': (3, 32, 32), 'action': (1,)})  # 嵌套张量
    stacked_data = ttorch.stack([data, data], dim=0)  # 批量操作
    

三、快速入门:安装与实战

1. 安装指南

方式1:Pip安装(推荐)
# 稳定版
pip install di-engine

# 开发版(含额外依赖)
pip install "di-engine[all]"
方式2:Docker镜像

DI-engine提供预构建镜像,包含常见环境:

# CPU基础镜像
docker pull opendilab/ding:nightly

# Atari环境镜像
docker pull opendilab/ding:nightly-atari

# 运行示例(以CartPole为例)
docker run -it opendilab/ding:nightly \
  ding -m serial -c cartpole_dqn_config.py -s 0

2. 入门Demo:DQN玩转CartPole

步骤1:编写配置文件(cartpole_dqn_config.py
from ding.config import compile_config
from ding.policy import DQNPolicy
from ding.envs import GymEnv

cfg = dict(
    env=dict(
        env_id='CartPole-v1',
        collector_env_num=8,
        evaluator_env_num=4,
    ),
    policy=dict(
        cuda=False,
        model=dict(
            obs_shape=4,
            action_shape=2,
        ),
    ),
)
cfg = compile_config(cfg, seed=0, env=GymEnv, policy=DQNPolicy)
步骤2:运行训练
ding -m serial -c cartpole_dqn_config.py -s 0
步骤3:评估与可视化
from ding.utils import VideoRecorder

env = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')

obs = env.reset()
while True:
    action = policy.predict(obs)
    obs, reward, done, info = env.step(action)
    video_recorder.record(obs)
    if done:
        break
video_recorder.close()

3. 多智能体示例:QMIX在SMAC环境

# SMAC 3s5z场景(3陆战队vs5 zealots)
ding -m serial -c smac_3s5z_qmix_config.py -s 0

四、进阶功能与最佳实践

1. 自定义环境迁移

若需接入自定义环境,需实现以下接口:

from ding.envs import BaseEnv

class MyEnv(BaseEnv):
    def __init__(self, env_id):
        super().__init__()
        # 初始化环境逻辑

    def reset(self):
        # 返回初始观测值
        return obs

    def step(self, action):
        # 执行动作,返回obs, reward, done, info
        return obs, reward, done, info

    @property
    def observation_space(self):
        # 定义观测空间(gym.Space格式)
        return Box(low=0, high=255, shape=(3, 32, 32))

    @property
    def action_space(self):
        # 定义动作空间
        return Discrete(10)

2. 模型自定义与网络设计

通过继承ding.model.BaseModel实现自定义网络:

import torch.nn as nn
from ding.model import BaseModel

class CustomModel(BaseModel):
    def __init__(self, obs_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(obs_shape[0], 32, 3),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(32*14*14, action_shape)

    def forward(self, x):
        x = self.cnn(x)
        return self.fc(x)

3. 分布式训练

使用DDP模式启动分布式训练:

ding -m distributed_ddp -c ppo_lunarlander_config.py -n 4

五、生态与社区支持

1. 工具与资源

2. 社区参与

  • 问题反馈:在GitHub Issues提交BUG或功能请求。
  • 贡献代码:参考CONTRIBUTING.md,参与算法实现或文档完善。
  • 交流渠道
    • 微信社群:扫码添加“DI小助手”(见GitHub README)。
    • Discord/ Slack:链接

3. 引用与许可

若在研究中使用DI-engine,请引用:

@misc{ding,
    title={DI-engine: A Universal AI System/Engine for Decision Intelligence},
    author={Niu, Yazhe et al.},
    howpublished={\url{https://github.com/opendilab/DI-engine}},
    year={2021}
}

DI-engine采用Apache 2.0许可证,允许商业使用与修改。

DI-engine核心算法实战指南

一、DQN(离散动作空间经典算法)

算法特性

  • 适用场景:离散动作空间、单智能体、稀疏奖励场景
  • 核心思想:基于Q-Learning,使用神经网络近似Q值函数,经验回放+目标网络稳定训练
  • DI-engine实现:支持Double DQN、Dueling DQN等变体,集成PER(优先经验回放)

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制CartPole-v1、CliffWalking离散CartPole文档
Atari游戏Pong、Breakout、Asterix离散Atari环境指南
文本决策TabMWP(数学文字题解答)离散TabMWP文档

快速上手:CartPole场景

1. 配置文件(cartpole_dqn_config.py
from ding.config import compile_config
from ding.envs import GymEnv
from ding.policy import DQNPolicy

cfg = dict(
    env=dict(
        env_id='CartPole-v1',
        collector_env_num=8,  # 并行采集环境数
        evaluator_env_num=4,   # 并行评估环境数
        n_evaluator_episode=20,  # 评估 episodes 数
    ),
    policy=dict(
        cuda=False,
        model=dict(
            obs_shape=4,         # 观测空间维度
            action_shape=2,      # 动作空间大小(离散)
            encoder_hidden_size_list=[128, 128],  # 网络结构
        ),
        learn=dict(
            update_per_collect=10,  # 每次采集后更新次数
            batch_size=32,         # 批量大小
            learning_rate=0.001,   # 学习率
        ),
        collect=dict(
            n_sample=100,        # 每次采集样本数
            random_collect_size=1000,  # 随机初始化样本数
        ),
        eval=dict(evaluator=dict(eval_freq=500, ))  # 评估频率
    ),
)
cfg = compile_config(cfg, env=GymEnv, policy=DQNPolicy, seed=0)
2. 运行训练
# 单机训练
ding -m serial -c cartpole_dqn_config.py -s 0

# 分布式训练(4进程)
ding -m distributed_ddp -c cartpole_dqn_config.py -n 4
3. 评估与可视化
from ding.utils import VideoRecorder
env = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('output/ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')

obs = env.reset()
while True:
    action = policy.predict(obs)  # 推理动作
    obs, reward, done, info = env.step(action)
    video_recorder.record(obs)
    if done:
        break
video_recorder.close()

二、PPO(连续/离散通用策略梯度算法)

算法特性

  • 适用场景:离散/连续动作空间、单智能体/多智能体(MAPPO变体)
  • 核心思想:近端策略优化,通过clip机制平衡策略更新步长
  • DI-engine扩展:支持GAE(广义优势估计)、分层回放(PLR)

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制LunarLanderContinuous-v2连续LunarLander文档
MuJoCoHalfCheetah-v4、Ant-v4连续MuJoCo环境指南
多智能体SMAC(3s5z场景)、PettingZoo离散/MARLSMAC文档

实战案例:LunarLander连续控制

1. 配置要点(连续动作特化)
cfg.policy.model = dict(
    obs_shape=8,
    action_shape=4,
    action_space='continuous',  # 显式声明连续空间
    actor_head_hidden_size=256,
    critic_head_hidden_size=256,
)
cfg.policy.learn = dict(
    epoch_per_collect=10,  # 每个采集周期训练轮数
    batch_size=2048,
    clip_ratio=0.2,        # PPO核心超参数
)
2. 运行命令
# 连续动作版示例(需指定环境ID)
ding -m serial_onpolicy -c lunarlander_ppo_continuous_config.py -s 0
3. 多智能体扩展(MAPPO)
# SMAC 3s5z多智能体场景
ding -m serial -c smac_3s5z_mappo_config.py -s 0

三、SAC(连续动作空间最大熵强化学习)

算法特性

  • 适用场景:连续动作空间、探索性强的环境(如机器人控制)
  • 核心思想:结合最大熵原理,同时优化策略熵与累积奖励
  • DI-engine优化:支持自动熵调整、多Q网络正则化

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制Pendulum-v1连续Pendulum文档
MuJoCoSwimmer-v4、Walker2d-v4连续MuJoCo环境列表
离线学习D4RL(HalfCheetah-medium)连续(离线)D4RL文档

代码示例:Pendulum摆锤控制

1. 关键配置
cfg.policy.model = dict(
    obs_shape=3,
    action_shape=1,
    twin_critic=True,       # 使用双Q网络
    action_space='continuous',
)
cfg.policy.learn = dict(
    target_entropy='auto',  # 自动计算目标熵
    discount_factor=0.99,
)
2. 运行与评估
# 训练命令
ding -m serial -c pendulum_sac_config.py -s 0

# 离线学习(D4RL数据集)
python3 -u d4rl_sac_main.py --env_id halfcheetah-medium-v0

四、TD3(连续动作空间延迟策略更新算法)

算法特性

  • 适用场景:连续动作空间、高维动作空间(如机械臂控制)
  • 核心改进:双Q网络+策略延迟更新+动作噪声,缓解过估计问题
  • DI-engine集成:支持与HER(indsight experience replay)结合

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制MountainCarContinuous-v0连续MountainCar文档
PyBulletHumanoidBulletEnv-v0连续PyBullet环境指南
自定义环境机械臂控制(需迁移)连续环境迁移教程

快速运行:MountainCarContinuous

# 配置文件路径:ding/example/td3_mountaincar_continuous_config.py
ding -m serial -c td3_mountaincar_continuous_config.py -s 0
高级用法:结合HER探索
# 在配置中启用HER
cfg.policy.other = dict(
    replay_buffer=dict(
        type='her',
        her_type='future',  # HER类型(future/episode等)
    )
)

五、环境支持总表(DQN/PPO/SAC/TD3适用场景)

算法离散动作空间连续动作空间多智能体场景离线学习场景典型环境示例
DQN✅(全支持)CartPole、Breakout
PPO✅(离散版)✅(连续版)✅(MAPPO)LunarLander、SMAC
SAC✅(全支持)✅(D4RL)Pendulum、D4RL-HalfCheetah
TD3✅(高维优先)MountainCarContinuous、Ant

六、进阶技巧:算法调优与环境适配

1. 离散vs连续动作空间关键差异

  • 动作处理
    • 离散:输出logits或Q值,通过argmax选择动作
    • 连续:输出均值+标准差(如SAC)或直接映射(如TD3),需限制动作范围
  • 网络结构
    • 离散:单分支输出(动作维度)
    • 连续:Actor-Critic双分支(如PPO/SAC)

2. 环境迁移三步法

  1. 继承BaseEnv:实现reset()step(action)observation_space等接口
  2. 适配动作格式
    • 离散:动作需为0~n-1的整数(np.int64
    • 连续:动作需为numpy数组,范围与action_space一致
  3. 测试兼容性:使用ding.utils.env_test验证环境接口
from ding.utils import env_test
env = MyCustomEnv()
env_test(env, coll_num=100)  # 测试采集兼容性

3. 超参数调优建议

算法关键超参数调优方向
DQNlearning_rategamma小学习率(1e-4~1e-3),gamma=0.99
PPOclip_ratiogae_lambdaclip_ratio=0.1~0.3,lambda=0.95
SACtarget_entropytau自动熵或手动设置(如-env_dim)
TD3policy_noisenoise_clip噪声=0.2,clip=0.5

七、总结:选择合适算法的决策树

离散
连续
需要探索
高维动作
多智能体
离线数据
问题类型
动作空间类型
DQN/PPO
SAC/PPO/TD3
SAC
TD3
MAPPO
CQL/TD3-BC

通过以上指南,可快速在DI-engine中落地主流RL算法。更多细节可参考官方文档:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LetsonH

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值