self [ index[i,j] , j ] += src [ i , j ] # if dim == 0
self [ i , index[i,j] ] += src[ i, j ] # if dim == 1
理解scatter_add()函数,看index就行了,index有多少个,self坐标就会变多少次。
self是一个二维的数组,self[第一维,第二维],dim==0,就是将src对应坐标,对应到 index 坐标里面的值,放置到self的第一维中。
例如src[i,j]对应到index[i,j],假设 index[i,j] ==0,则self[第一维,第二维] 为self[0,j],只改变第一维,第二维的值和src第二维一样。然后self[0,j]的值就会变为 self[0,j]=self[0,j]+src[i,j]
代码中 self=torch.zeros(3,5), dim=0, index=[0,1,2,0,0], src=torch.ones(2,5)
我们只看 src,当 src[0,0]=1, index[0,0]=0, self[0,0]=self[0,0]+src[0,0]=1。
当src[0,1]=1, index[0,1]=1, self[1,1]=self[1,1]+src[0,1]=1, self的第一维是index的值决定的为1,第二维是src的第二维坐标决定也为1。当index的值没有时,就停止变换,self没有变换过的坐标值就保持不变。