PyTorch Tensor 形状变化操作详解
在深度学习中,Tensor 的形状变换是非常常见的操作。PyTorch 提供了丰富的 API 来帮助我们调整 Tensor 的形状,以满足模型输入、计算或数据处理的需求。本文将详细介绍 PyTorch 中常见的 Tensor 形状变换操作,并通过示例代码进行说明。
1. 基础形状操作
1.1 view
和 reshape
- 功能:改变 Tensor 的形状而不改变其数据。
- 区别:
view
要求新形状的总元素数与原形状一致,否则会报错。reshape
更灵活,如果无法直接改变形状,会尝试创建一个新的 Tensor。
- 示例:
tensor = torch.randn(2, 3, 4) # 原形状为 (2, 3, 4)
reshaped_tensor = tensor.view(2, 12) # 改变形状为 (2, 12)
print(reshaped_tensor.shape) # 输出: torch.Size([2, 12])
1.2 squeeze
和 unsqueeze
- 功能:
squeeze
:移除大小为 1 的维度。unsqueeze
:在指定位置插入大小为 1 的维度。
- 示例:
tensor = torch.randn(1, 3, 1, 4) # 原形状为 (1, 3, 1, 4)
squeezed_tensor = tensor.squeeze() # 移除所有大小为 1 的维度
print(squeezed_tensor.shape) # 输出: torch.Size([3, 4])
unsqueezed_tensor = squeezed_tensor.unsqueeze(0) # 在第 0 维插入大小为 1 的维度
print(unsqueezed_tensor.shape) # 输出: torch.Size([1, 3, 4])