pytorch基础(六)-优化器

本文详细介绍了PyTorch中的优化器概念,包括SGD优化器的使用,如设置超参数、管理参数组、step()方法、梯度清零、添加新参数组以及保存和加载优化器状态。重点讲解了梯度下降过程中的关键变量如动量和学习率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

基本属性

default:

优化器超参数

state:

优化器缓存

params_groups:

管理的参数组

_step_count:

记录更新参数(学习率调整中使用)

基本方法

step()

更新一次参数

weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))

optimizer = optim.SGD([weight], lr=0.1)
print("weight before step:{}".format(weight.data))
optimizer.step()
print("weight after step:{}".format(weight.data))

zero_grad()

清空所有梯度

optimizer.step() 

 add_param_group()
    x_1 = optimizer.param_groups
    w2 = torch.randn((3, 3), requires_grad=True)

    optimizer.add_param_group({"params": w2, 'lr': 0.0001})
    x_2 = optimizer.param_groups

 

state_dict()

获取优化器当前的字典信息

optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()

#  ...梯度下降

torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))

 

load_state_dict()

加载优化器信息

state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
optimizer.load_state_dict(state_dict)

state_dict before load state:
 {'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0]}]}
state_dict after load state:
 {'state': {0: {'momentum_buffer': tensor([[6.5132, 6.5132],
        [6.5132, 6.5132]])}}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0]}]}


梯度下降

v_i = m*v_{i-1}+g(w_i)\\w_{i+1}=w_i-lr*v_i

v_i:当前梯度下降的更新量

m:动量

v_i-1:上一次梯度下降的更新量

g(w_i):当前梯度下降的梯度

lr:学习率

w_i:当前时刻权重

w_i+1:更新后权重

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值