《PyTorch深度学习实践》--3梯度下降算法

本文介绍了如何在PyTorch中使用线性模型求解w的最优值,通过梯度下降和随机梯度下降算法,并展示了具体实现过程。讲解了梯度下降的原理、代码示例,以及随机梯度下降在避免局部最优上的优势。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

视频教程:《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

一、.在第二节中的线性模型中,求解w的最优值(使得MSE最小的w)问题。

从图中可以看出:w=2时,MSE最小。(即最优)

二、求解最优w问题的方法

2.1梯度下降(Gradient Descent)算法:

w按梯度下降方向移动,这样一定次数的移动后就会移动到最优解。

(a为学习因子,影响每次移动的步长,越小越精确但时间复杂度也会变高)

通过求导,可以求出具体的表达式,根据表达式就可以写出代码。

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1.0

def forward(x):
    return x * w

#mse
def cost(xs,ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost/ len(xs)

#梯度
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2*x*(x*w - y)
    return grad / len(xs)

print('Predict (before training)',4,forward(4))
for epoch in range(100):
    cost_val = cost(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= 0.01 * grad_val //更新w
    print('Epoch:',epoch, 'w=', w, 'loss=', cost_val)
print('Predict (after traning)', 4, forward(4))

 

 

(结果应该是收敛的,如果不收敛可能是a值过大。)

 

 

2.2 随机梯度下降(Stochastic Gradient Descent )

类似梯度下降,但是这里用的是随机某个样本(而不是整体)的梯度。

这样的好处是由于单个样本一般有噪声,具有随机性,可能帮助走出鞍点从而进入最优解。

坏处是计算依赖上次结果,多个样本x无法并行,时间复杂度高。因此会有一个中间的方法,Mini-Batch(或称Batch)。将若干个样本点分成一组,每次用一组来更新w。

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1.0

def forward(x):
    return x * w

#mse,单个样本
def loss(x,y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

#梯度,单个样本
def gradient(x, y):
    return 2*x* (x*w - y)

print('Predict (before training)',4,forward(4))
for epoch in range(100):
    for x,y in zip(x_data, y_data):
        grad_val = gradient(x, y)
        w -= 0.01 * grad_val
        print('\tgrad:',x,y,grad_val)
        loss_val = loss(x,y)
    print("progress:", epoch, 'w=', w, 'loss=', loss_val)

print('Predict (after traning)', 4, forward(4))

 

 

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只大鸽子

如有帮助,欢迎关注同名公众号

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值