PyTorch的torch.cat、squeeze()、unsqueeze()和size()函数

本文介绍了PyTorch中用于深度学习的几个关键函数:squeeze()用于压缩维度,unsqueeze()用于增加维度,size()用于获取张量的尺寸信息,以及torch.cat()用于在指定维度上连接多个张量。每个函数通过实例演示了其工作原理和使用方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

一、sequeeze()函数

二、unsequeeze()函数 

三、size()函数

四、torch.cat函数


在Pytorch做深度学习过程中,CNN的卷积和池化过程中会用到torch.cat、squeeze()、unsqueeze()和size()函数,下面分别做讲解:

一、sequeeze()函数

x.squeeze(dim)

用途:进行维度压缩,去掉tensor中维数为1的维度

参数设置:如果设置dim=a,就是去掉指定维度中维数为1的

示例: 

import torch
x = torch.tensor([[[1],[2]],[[3],[4]]])
print('x:',x)
print(x.shape)
x1 = x.squeeze()
print('x1:',x1)
print(x1.shape)
x2 = x.squeeze(2)
print('x2:',x2)
print(x2.shape)

输出:

x: tensor([[[1],
         [2]],
 
        [[3],
         [4]]])
torch.Size([2, 2, 1]
<think>嗯,用户想了解PyTorch中张量形状操作的方法,特别是view、unsqueezesqueezecatstack这几个函数。首先,我需要回忆这些方法的具体用法区别。根据引用[3][4],catstack的主要区别在于是否新增维度。cat是在现有维度上拼接,不增加维度,而stack会创建新维度。比如,两个2x3的张量用cat在dim=0拼接会变成4x3,而用stack在dim=0会变成2x2x3。 接下来是view,它类似于numpy的reshape,用来改变张量的形状,但必须保证元素总数不变。比如,一个4x3的张量可以用view(3,4)变成3x4。不过要注意连续性,可能需要先contiguous()再view。 unsqueezesqueeze是处理维度的,unsqueeze增加一个大小为1的维度,而squeeze去掉所有大小为1的维度。例如,一个2x3的张量,unsqueeze(0)变成1x2x3,squeeze()如果原本有1的维度就去掉。 用户可能还会关心这些方法的应用场景,比如数据预处理时需要调整维度,或者在神经网络层之间调整形状。比如,全连接层之前通常需要将多维数据展平,这时候可以用view。而处理图像数据时可能需要增加通道维度,用unsqueeze。 需要确保回答结构清晰,分点解释每个方法,给出示例代码,并对比catstack的区别。同时,根据引用中的信息,正确引用来源。最后生成相关问题,帮助用户深入理解。</think>### PyTorch张量形状变换方法详解 #### 1. `view()`:维度重塑 - **作用**:调整张量形状,元素总数必须保持不变 - **示例**: ```python a = torch.rand(2, 3) b = a.view(3, 2) # 形状变为3×2 c = a.view(-1) # 展平为6元素向量 ``` - **注意**:当张量内存不连续时,需先调用`.contiguous()`[^3] #### 2. `unsqueeze()`:增加维度 - **作用**:在指定位置插入大小为1的新维度 - **示例**: ```python a = torch.ones(2,3) b = a.unsqueeze(0) # 形状变为1×2×3 c = a.unsqueeze(2) # 形状变为2×3×1 ``` #### 3. `squeeze()`:压缩维度 - **作用**:移除所有大小为1的维度 - **示例**: ```python a = torch.ones(1,2,3,1) b = a.squeeze() # 形状变为2×3 c = a.squeeze(0) # 指定移除第0维 ``` #### 4. `cat()`:拼接张量 - **特点**:沿现有维度拼接,**不增加新维度**[^4] - **示例**: ```python a = torch.eye(2,3) b = torch.cat([a,a], dim=0) # 4×3张量 c = torch.cat([a,a], dim=1) # 2×6张量 ``` #### 5. `stack()`:堆叠张量 - **特点**:创建新维度进行拼接,**增加维度数量**[^1] - **示例**: ```python a = torch.eye(2,3) b = torch.stack([a,a], dim=0) # 2×2×3张量 c = torch.stack([a,a], dim=2) # 2×3×2张量 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dr.Petrichor

作者逐个题目分析的噢

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值