【SAM2代码解析】training部分代码详解-训练流程

2.1 初始化

在这里插入图片描述

  • initialize_config_module(“sam2”, version_base=“1.2”) ----> 初始化Hydra配置模块,并指定配置根目录为sam2(sam2目录应包含配置文件来定义模型、训练参数等),声明配置兼容Hydra1.2版本
  • 参数列表
    在这里插入图片描述

2.2 训练流程(main函数)

  • 1、读取sam2.1_hiera_b+_MOSE_finetune.yaml的内容,compose用于加载配置文件
    在这里插入图片描述
  • 2、创建日志文件夹,如果在配置文件中,experiment_log_dir为None则自动创建sam2_logs的文件夹,否则则按你设置的文件夹名创建。(注:这里自带的配置文件里,experiment_log_dir设置的null,如果你不改,那么就会报错,具体请看我主页的报错记录)在这里插入图片描述
  • 2.1、创建日志文件config.yaml并保存当前配置信息
    在这里插入图片描述
  • 3、输出配置细节,注意OmegaConf.to_yaml(cfg)的作用是将配置对象的内容转换成字符串
    在这里插入图片描述
  • 4、处理配置的解析版本
    • 将 OmegaConf 配置对象cfg转换为一个普通的 Python 容器(如字典或列表)。这使得配置可以被其他不支持 OmegaConf 的代码使用。然后又立马转换成配置对象,这样就得到了一个复制的cfg配置对象—cfg_resolved
      在这里插入图片描述
    • 将解析后的配置保存为 config_resolved.yaml 文件,(解析指的是将动态变量变为确定值)
      在这里插入图片描述
      在这里插入图片描述
  • 5、submitit日志记录设置,Submitit 日志​​是使用 Submitit 库提交集群任务时自动生成的记录文件,主要用于跟踪作业状态、调试错误和管理任务。
  • 6、优先使用命令行参数,如果命令行参数中指定了 num_gpus、num_nodes 或 use_cluster,则优先使用这些值,否则使用配置文件中的默认值。
    在这里插入图片描述
  • 7、判断是否使用集群,若是使用则要设置SLURM参数,并打印相关配置,提交任务等
  • 8、判断是否使用集群,若不使用则设置节点数量为1,随机生成一个主节点端口号,调用single_node_runner函数在本地运行任务
    在这里插入图片描述

总结:

  • 加载配置文件并设置实验日志目录。
  • 打印和保存配置文件。
  • 检查是否使用集群。
  • 如果使用集群,配置 SLURM 参数并提交任务。
  • 如果不使用集群,则在本地运行任务。

2.3 single_node_runner

  • 确保配置中指定的节点数量为 1。这是因为这个函数是为单节点训练设计的,如果节点数量不为 1,则会抛出断言错误。

assert cfg.launcher.num_nodes == 1

  • 获取GPU数量,这个值将决定需要启动多少个进程

num_proc = cfg.launcher.gpus_per_node

  • 设置多进程启动方法,设置pytorch多进程的启动方法为spawn

torch.multiprocessing.set_start_method(
“spawn”
) # CUDA runtime does not support fork
spawn 方法会重新启动一个 Python 解释器来运行子进程,这可以避免 CUDA 运行时的兼容性问题。

  • 单GPU情况下启动程序:single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)

2.4 single_proc_run

def single_proc_run(local_rank, main_port, cfg, world_size):
    # 单GOU进程的入口点,用于初始化分布式训练环境并运行训练任务
    # local_rank--当前进程的本地排名,表示当前GPU的索引(从0开始)
    # main_port--主节点端口号,用于进程间通信
    # cfg--配置对象,包含训练任务的配置信息
    # world_size--全局进程总数,表示所有节点上的GPU数量
    """Single GPU process,PyTorch分布式训练所必需的,用于初始化分布式通信的后端"""
    os.environ["MASTER_ADDR"] = "localhost" #主节点的地址,用于分布式训练的通信
    os.environ["MASTER_PORT"] = str(main_port) #主节点的端口号,用于分布式训练的通信
    os.environ["RANK"] = str(local_rank) #当前进程的全局排名,表示当前进程在整个分布式环境中的索引。
    os.environ["LOCAL_RANK"] = str(local_rank) #当前进程的本地排名,表示当前GPU在当前节点上的索引
    os.environ["WORLD_SIZE"] = str(world_size) #全局进程总数,表示所有节点上的GPU数量
    try:
        register_omegaconf_resolvers()
        # 扩展 OmegaConf 的功能,使其能够处理更复杂的动态表达式。
        # 这些解析器允许在配置文件中直接进行数学运算、类型转换、类和方法的动态引用等操作,从而使配置文件更加灵活和强大
    except Exception as e:
        logging.info(e) #将异常对象 e 的信息记录到日志中,通常在发生异常时使用,以便于后续的问题排查和调试

    #instantiate用于根据配置实例化对象。参数 _recursive_=False 表示不递归实例化配置中的嵌套对象。
    trainer = instantiate(cfg.trainer, _recursive_=False)
    trainer.run()#调用训练器的 run 方法,启动训练任务
### SAM2 模型训练方法概述 SAM2 是基于 Segment Anything Model (SAM) 的改进版本,其核心目标仍然是通过提示机制完成高质量的分割任务。以下是有关 SAM2 训练的一些关键信息: #### 数据准备 在训练 SAM2 前,需准备好数据集并对其进行适当处理。通常情况下,可以参考以下流程- **样本制作**:根据实际需求选择变化监测或典型要素提取功能,并调整相关参数[^1]。随后点击“样本制作”,系统将自动裁剪原始数据至适合模型输入的尺寸。 - **数据划分**:为了确保模型性能稳定,在训练前应合理划分数据集为训练集和验证集。可以通过手动方式实现,例如利用 `train_test_split` 函数指定测试集比例(如 20%)作为验证集[^3];或者采用更复杂的策略如 K 折交叉验证。 #### 预训练权重加载 对于 CV 方向的任务而言,初始化阶段引入高质量预训练模型往往能够显著提升收敛速度与最终效果。具体到 SAM 系列产品上,则可直接调用官方发布的 SA-1B 数据集中得到的预训练权重文件来启动项目开发工作[^2]。 #### 超参调节及优化器配置 当上述准备工作完成后便进入正式建模环节——即定义网络结构以及设定各类超参数的过程当中去。这里列举几个较为重要的选项供参考: - Epochs 数量决定了整个过程持续多久; - Batch Size 影响每次迭代所使用的样例数目多少从而间接作用于梯度估计精度高低之间权衡关系之中; - Validation Split 参数用于控制内部保留一部分数据专门用来监控泛化能力情况的变化趋势如何发展等等细节之处均不可忽视掉任何一个方面才行哦! 下面给出一段简单的 Python 实现代码示例展示这部分逻辑框架的大致模样: ```python from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from sklearn.model_selection import train_test_split import numpy as np # 假设已有的特征矩阵 X 和标签向量 Y X = np.random.rand(100, 4) Y = np.random.randint(0, 3, size=(100,)) # 划分训练集和验证集 x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=0.2, random_state=7) model = Sequential() model.add(Dense(units=8, activation='relu', input_dim=4)) model.add(Dense(units=3, activation='softmax')) model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) history = model.fit(x=x_train, y=y_train, epochs=150, batch_size=10, validation_data=(x_val, y_val), verbose=2) ``` 此段脚本主要演示了如何构建一个多层感知机(Multi-Layer Perceptron),并通过 Iris 数据集进行了多分类问题的学习实践操作步骤说明文档链接如下所示。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值