Python的numpy库是一个非常有用的数学计算库,其broadcasting机制给我们的矩阵运算带来了极大地方便。
我们先看下面的一个例子:
>>> import numpy as np
>>> a = np.array([1,2,3])
>>> a
array([1, 2, 3])
>>> b = np.array([6,6,6])
>>> b
array([6, 6, 6])
>>> c = a + b
>>> c
array([7, 8, 9])
上面的代码其实就是把数组a和数组b中同样位置的每对元素相加。这里a和b是相同长度的数组。
如果两个数组的长度不一致,这时候broadcasting就可以发挥作用了。
比如下面的代码:
>>> d = a + 5
>>> d
array([6, 7, 8])
broadcasting会把5扩展成[5,5,5],然后上面的代码就变成了对两个同样长度的数组相加。示意图如下(broadcasting不会分配额外的内存来存取被复制的数据,这里只是方面描述):
我们接下来看看多维数组的情况:
>>> e
array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
>>> e + a
array([[2., 3., 4.],
[2., 3., 4.],
[2., 3., 4.]])
在这里一维数组被扩展成了二维数组,和e的尺寸相同。示意图如下所示:
我们再来看一个需要对两个数组都做broadcasting的例子:
>>> b = np.arange(3).reshape((3,1))
>>> b
array([[0],
[1],
[2]])
>>> b + a
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 5]])
在这里a和b都被扩展成相同的尺寸的二维数组。示意图如下所示:
总结
broadcasting的一些规则:
- 如果两个数组维数不相等,维数较低的数组的shape进行填充,直到和高维数组的维数匹配。
- 如果两个数组维数相同,但某些维度的长度不同,那么长度为1的维度会被扩展,和另一数组的同维度的长度匹配。
- 如果两个数组维数相同,但有任一维度的长度不同且不为1,则报错。
>>> a = np.arange(3)
>>> a
array([0, 1, 2])
>>> b = np.ones((2,3))
>>> b
array([[1., 1., 1.],
[1., 1., 1.]])
>>> a.shape
(3,)
>>> a + b
array([[1., 2., 3.],
[1., 2., 3.]])
接下来我们看看报错的例子:
>>> a = np.arange(3)
>>> a
array([0, 1, 2])
>>> b = np.ones((3,2))
>>> b
array([[1., 1.],
[1., 1.],
[1., 1.]])
>>> a + b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (3,) (3,2)