首先要知道线性回归什么,并且线性回归可以解决什么问题。
在统计学中,线性回归(LinearRegression)是利用称为线性回归方程的最小平方函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。只有一个自变量的情况称为简单回归,大于一个自变量情况的叫做多元回归。
回归分析常用于分析几个变量之间的关系,用来做连续值的预测,预测的结果也是一个连续值。
我们就假设蛋糕直径和价格之间存在如下的数据:
训练样本 | 直径 | 价格 |
1 | 6 | 15 |
2 | 8 | 20 |
3 | 10 | 32 |
4 | 14 | 41.5 |
5 | 18 | 48 |
现将这个数据绘制下来。
import numpy
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname=r'C:\windows\fonts\simsun.ttc',size=10)
#设置汉字格式
fig = plt.figure()
ax = fig.add_subplot(111)
x = [[6],[8],[10],[14],[18]]
y = [[15],[20],[32],[41.5],[48]]
ax.scatter(x,y,c = 'r')
plt.xlabel('直径',FontProperties = font)
plt.ylabel('价格',FontProperties = font)
plt.title('蛋糕直径与价格之间的关系',FontProperties = font)
plt.show()
我们如果想要预测其他尺寸的蛋糕价格,我们就可以使用上面的数据,拟合出一条直线,,然后输入x,得到我们想要得到的价格。
这样我们就引入一般线性回归的表达式:
表示为预测出的结果。
代价函数
上面我们使用的预测函数是:。我们预测出的结果和真实的结果肯定是会有一定的差距的,那我们就可以将这个差距表示出来,这里我们选用误差的平方,因为误差的平方代价函数对于大对数问题,特别是回归问题,都是一个合理的选择。还会有其他的代价函数也能很好的发挥作用,但是平方误差代价函数可能是解决回归问题最常用的手段了。
在这里我们需要预测与真实的结果之间的误差越小越好,那么就是代价函数越小越好。那么我们就需要一种算法能够找到使代价函数取最小值的参数和来。
梯度下降算法
梯度下降算法步骤:
1、随机选择一组
2、不断变化,使最小化。
梯度下降算法公式:
我们先要计算函数的偏导数:
那么我们梯度下降的公式就变成
通过不断的迭代梯度下降的公式,我们就可以得到比较合理的值,然后我们就能够拟合出合适的直线。
例子
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
def getDataSet():
dataSet = []
label = []
with open('data.txt','r') as f:
for line in f.readlines():
line = line.strip().split()
dataSet.append([1.0,float(line[0])])
label.append(float(line[1]))
return dataSet,label
def get_line():
dataSet,label = getDataSet()
fig = plt.figure()
ax = fig.add_subplot(111)
dataSet = np.array(dataSet)
ax.scatter(dataSet[:,1], label, c='r')
label = np.mat(label).transpose()
dataSet = np.mat(dataSet)
m,n = np.shape(dataSet)
weights = np.ones((n,1))
alpha = 0.01
x = np.arange(6,18,0.1) #shape(120,0)
for i in range(10):
error = (dataSet * weights - label)
num = dataSet.transpose()*error
weights = weights - alpha * num / m #shape(2,1)
y = weights[0,0] + weights[1,0] * x
plt.plot(x,y,'--')
plt.plot(x,y)
plt.show()
if __name__ == '__main__':
get_line()
我们最终得到的拟合直线,图中实线是我们最终得到的结果: