随机初始化种子(torch.manual_seed(params.seed) )的作用

本文探讨了在神经网络训练中设置初始化种子的重要性,确保结果的一致性和可重复性。通过介绍如何在PyTorch中使用固定种子,如`set_seed`函数,以及其在多GPU环境中的应用,揭示了种子设置对于算法稳定性的关键作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在神经网络中,参数默认是进行随机初始化的。如果不设置的话每次训练时的初始化都是随机的,导致结果不确定。如果设置初始化,则每次初始化都是固定的。
# 实验初始化
if getattr(params, 'seed', -1) >= 0:
    np.random.seed(params.seed)
    torch.manual_seed(params.seed)       # 为CPU设置种子用于生成随机数,以使得结果是确定的
    if params.cuda:
        torch.cuda.manual_seed(params.seed)     # 为当前GPU设置随机种子
#如果使用多个GPU,应该使用torch.cuda.manual_seed_all()为所有的GPU设置种子。

或者

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        # 设置在所有 gpu 上生成随机数的种子。如果 CUDA 不可用,可以安全地调用这个函数; 在这种情况下,它会被默默地忽略
        torch.cuda.manual_seed_all(args.seed)
    #  一个 bool,如果为真,导致 cunn 只使用确定性卷积算法
    torch.backends.cudnn.deterministic = True

SEED = 123#随机种子设置部分以下 torch.manual_seed(SEED)# 固定PyTorch随机种子 torch.backends.cudnn.deterministic = False# 允许CUDA优化 torch.backends.cudnn.benchmark = False np.random.seed(SEED)# 固定Numpy随机种子 def weights_init_normal(m):#权重初始化函数部分 if type(m) == nn.Conv2d:# 对2D卷积层正态初始化 torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif type(m) == nn.Conv1d:#对2D卷积(nn.Conv2d)和1D卷积(nn.Conv1d)使用相同的初始化策略 采用小标准差(0.02)的正态分布: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif type(m) == nn.BatchNorm1d:# 对1D批归一化层初始化 torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) def main(config, fold_id): batch_size = config["data_loader"]["args"]["batch_size"] logger = config.get_logger('train') # build model architecture, initialize weights, then print to console model = config.init_obj('arch', module_arch) model.apply(weights_init_normal) logger.info(model) # get function handles of loss and metrics criterion = getattr(module_loss, config['loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.init_obj('optimizer', torch.optim, trainable_params) data_loader, valid_data_loader, data_count = data_generator_np(folds_data[fold_id][0], folds_data[fold_id][1], batch_size) weights_for_each_class = calc_class_weight(data_count) trainer = Trainer(model, criterion, metrics, optimizer, config=config, data_loader=data_loader, fold_id=fold_id, valid_data_loader=valid_data_loader, class_weights=weights_for_each_class) trainer.train() if __name__ == '__main__': args = argparse.ArgumentParser(description='PyTorch Template') args.add_argument('-c', '--config', default="config.json", type=str, help='config file path (default: None)') args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') args.add_argument('-d', '--device', default="0", type=str, help='indices of GPUs to enable (default: all)') args.add_argument('-f', '--fold_id', type=str, help='fold_id') args.add_argument('-da', '--np_data_dir', type=str, help='Directory containing numpy files') CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') options = [] args2 = args.parse_args() fold_id = int(args2.fold_id) config = ConfigParser.from_args(args, fold_id, options) if "shhs" in args2.np_data_dir: folds_data = load_folds_data_shhs(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) else: folds_data = load_folds_data(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) main(config, fold_id) 请逐行解释该部分代码,并标出调用其他文件的指令部分,训练过程有没有调用其他部分的代码指令
03-08
import argparse import collections import numpy as np from data_loader.data_loaders import * import model.loss as module_loss import model.metric as module_metric import model.model as module_arch from parse_config import ConfigParser from trainer import Trainer from utils.util import * import torch import torch.nn as nn # fix random seeds for reproducibility SEED = 123 torch.manual_seed(SEED) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False np.random.seed(SEED) def weights_init_normal(m): if type(m) == nn.Conv2d: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif type(m) == nn.Conv1d: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif type(m) == nn.BatchNorm1d: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) def main(config, fold_id): batch_size = config["data_loader"]["args"]["batch_size"] logger = config.get_logger('train') # build model architecture, initialize weights, then print to console model = config.init_obj('arch', module_arch) model.apply(weights_init_normal) logger.info(model) # get function handles of loss and metrics criterion = getattr(module_loss, config['loss']) metrics = [getattr(module_metric, met) for met in config['metrics']] # build optimizer trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = config.init_obj('optimizer', torch.optim, trainable_params) data_loader, valid_data_loader, data_count = data_generator_np(folds_data[fold_id][0], folds_data[fold_id][1], batch_size) weights_for_each_class = calc_class_weight(data_count) trainer = Trainer(model, criterion, metrics, optimizer, config=config, data_loader=data_loader, fold_id=fold_id, valid_data_loader=valid_data_loader, class_weights=weights_for_each_class) trainer.train() if __name__ == '__main__': args = argparse.ArgumentParser(description='PyTorch Template') args.add_argument('-c', '--config', default="config.json", type=str, help='config file path (default: None)') args.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') args.add_argument('-d', '--device', default="0", type=str, help='indices of GPUs to enable (default: all)') args.add_argument('-f', '--fold_id', type=str, help='fold_id') args.add_argument('-da', '--np_data_dir', type=str, help='Directory containing numpy files') CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') options = [] args2 = args.parse_args() fold_id = int(args2.fold_id) config = ConfigParser.from_args(args, fold_id, options) if "shhs" in args2.np_data_dir: folds_data = load_folds_data_shhs(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) else: folds_data = load_folds_data(args2.np_data_dir, config["data_loader"]["args"]["num_folds"]) main(config, fold_id)逐行代码都是什么意思?请给我讲解
03-08
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值