广播机制
广播机制是为了解决两个不同尺寸之间的数组(张量)之间的计算问题而设计的一种算法机制,在numpy中就存在这种机制,而Pytorch和Tensorflow为了减少学习成本,也按照numpy的这种机制规则设计了广播机制,深入理解有助于更好地对张量进行操作
广播的基本流程
本质上两个形状(shape)不同的张量之间是不能直接运算的,需要相同的形状才行。而广播就是在运算之前,将两个张量进行匹配扩充的过程。第一步,两个张量扩充成相同形状,本质上是复制扩充;第二步,两个扩充后的形状相同的向量进行运算。下图可以直观地表明这两个过程:
广播的规则
规则1
如果两个数组的维度不相同,那么小维度数组的形状将会在最左边补1.
这个规则的目的是让维度数相同,同时注意是最左边补1
例如
a = np.ones([3,4,6]), b = np.ones([4,1])
这是a有三个维度,b只有两个维度,故b会在最左边扩充1个维度,直到b和a的维度数是相同的
这时 _b = np.ones([1,4,1])
规则2
如果两个数组的形状在任何一个维度上都不匹配,那么数组的形状会沿着维度为1的维度拓展以匹配另外一个数组形状。
这个规则的目标是让每一个维度上的数字相等,同时注意数字是1才能拓展
接着规则1的a和b,b进行扩充一个维度后,a的shape是[3,4,6],b的shape是[1,4,1],从右往左开始一位位地判断。
首先a.shape[2] = 6, b.shape[2] = 1 ,满足了1的拓展条件,直接复制最后一维到6,