split函数的学习
import torch
weight = torch.randn(3,2,2)
print(weight)
a = torch.split(weight, 1, dim=0) # 沿着dim这个维度切分,然后每个tensor的切割单位都是1。
print(a)
split函数的学习
import torch
weight = torch.randn(3,2,2)
print(weight)
a = torch.split(weight, 1, dim=0) # 沿着dim这个维度切分,然后每个tensor的切割单位都是1。
print(a)