Created
March 25, 2021 11:53
-
-
Save yrpang/5d3f38344fda46e6772d93a4debc52ea to your computer and use it in GitHub Desktop.
Revisions
-
yrpang revised this gist
Mar 25, 2021 . 1 changed file with 12 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -7,7 +7,8 @@ "provenance": [], "collapsed_sections": [], "toc_visible": true, "authorship_tag": "ABX9TyM/sXLkvwCDyjfvdXADq8eS", "include_colab_link": true }, "kernelspec": { "name": "python3", @@ -18,6 +19,16 @@ } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/gist/yrpang/5d3f38344fda46e6772d93a4debc52ea/pytorch-memo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "code", "metadata": { -
yrpang created this gist
Mar 25, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,872 @@ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "PyTorch Memo.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true, "authorship_tag": "ABX9TyM/sXLkvwCDyjfvdXADq8eS" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "metadata": { "id": "camxQRqXBrz_" }, "source": [ "import torch\n", "import torch.nn as nn" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "8i2wGC59hOsH" }, "source": [ "## 基本Tensor操作" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jy2ztjWuhXAq", "outputId": "1f51ff8e-3270-479e-bf5f-da939f0922e8" }, "source": [ "# Tensor属性\n", "x = torch.arange(12) # 从0开始的向量\n", "x.shape # 形状 x.size()\n", "x.numel() # 元素总数\n", "x.reshape(3, 4) # 返回 contiguous 的reshape后的tensor" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "12" ] }, "metadata": { "tags": [] }, "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { "id": "-YUH-Vh-hzLh" }, "source": [ "# 创建Tensor\n", "torch.zeros((2, 3, 4))\n", "torch.ones((2, 3, 4))\n", "torch.randn(3, 4) # 正态分布\n", "torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "SwBsQY5miBof" }, "source": [ "# element-wised运算\n", "# + - * / **\n", "torch.exp(x)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S2zwzTuuipJ2", "outputId": "e645eeb7-ff4e-4b68-9204-fc1c546ed300" }, "source": [ "# 矩阵reshape和拼接\n", "X = torch.arange(12, dtype=torch.float32).reshape((3, 4))\n", "Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n", "\n", "torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1) # 拼接\n", "X == Y # 逻辑判断\n", "X.sum() # 求和 可以指定维度" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(66.)" ] }, "metadata": { "tags": [] }, "execution_count": 16 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tQPbca0sjcLk", "outputId": "7466a998-2fd5-48d2-fdc6-5296cb635a8a" }, "source": [ "# 原地操作\n", "Z = torch.zeros_like(Y)\n", "print('id(Z):', id(Z))\n", "Z[:] = X + Y # 避免复制,减少内存使用\n", "print('id(Z):', id(Z))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "id(Z): 140076925416544\n", "id(Z): 140076925416544\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "echMuuUqrpLC", "outputId": "239acd36-1edb-4763-afad-8b76d59c181f" }, "source": [ "# 类型转换\n", "A = X.numpy()\n", "B = torch.tensor(A)\n", "type(A), type(B)\n", "\n", "a = torch.tensor([3.5])\n", "a, a.item(), float(a), int(a) # 标量转化为其它类型" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(tensor([3.5000]), 3.5, 3.5, 3)" ] }, "metadata": { "tags": [] }, "execution_count": 19 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PWbUDSi7tuMI", "outputId": "4d4b0a12-c9cb-425d-d180-0260a3ee93e0" }, "source": [ "# 线性代数运算\n", "X = torch.rand((3,4))\n", "Y = torch.rand((3,4))\n", "Z = torch.rand((4,5))\n", "\n", "X * Y # Hardmard product 哈达玛积 i.e.按元素相乘\n", "X @ Z # 矩阵乘法\n", "\n", "X.mean(), X.sum() / X.numel() # 求所有元素的均值\n", "X.mean(axis=0), X.sum(axis=0) / X.shape[0] # 按行求均值\n", "\n", "sum_X = X.sum(axis=1, keepdims=True) # keepdims参数保持维度不变\n", "X.cumsum(axis=0) # 累加求和\n", "\n", "# 点积\n", "x = torch.rand(4, dtype=torch.float32)\n", "y = torch.ones(4, dtype=torch.float32)\n", "x, y, torch.dot(x, y)\n", "\n", "torch.mv(X, x) # 矩阵向量乘法(Ax x当作列向量参与运算)\n", "torch.mm(X, Z) # alias X@Z\n", "\n", "# 范数\n", "torch.norm(X) # L2范数\n", "torch.abs(X).sum() # L1范数" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(4.6720)" ] }, "metadata": { "tags": [] }, "execution_count": 29 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 173 }, "id": "y5_0TtCEybX8", "outputId": "f355aa61-9241-4fd9-b37a-f1adbfa6fd62" }, "source": [ "### pandas的get_dummies\n", "### 本质上是在做one-hot编码 把某一列中的元素拆分成 set(sny_column)这么多列 每一个元素用一个1表示\n", "import pandas as pd\n", "s = pd.Series(list('abca'))\n", "pd.get_dummies(s)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>a</th>\n", " <th>b</th>\n", " <th>c</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " a b c\n", "0 1 0 0\n", "1 0 1 0\n", "2 0 0 1\n", "3 1 0 0" ] }, "metadata": { "tags": [] }, "execution_count": 32 } ] }, { "cell_type": "markdown", "metadata": { "id": "mz-AHyfSBwxJ" }, "source": [ "\n", "## 常用操作\n", "\n", "1. Tensor维度的理解\n", "\n", " 在理解上把 Tensor 的维度`(a,b,c)`理解为`[[[i for i in range(c)]for _ in range(b)] for _ in range(a)]`\n", "\n", " * 相当于用最后一维存储数据,前面的维度作索引\n", " * 对于`nn.Embedding()`这样的函数即可\n", "\n", "2. `a.permute()` 可以用来改变维度顺序" ] }, { "cell_type": "code", "metadata": { "id": "0TCwWsAqBwEg" }, "source": [ "a = 10\n", "b = 3\n", "c = 3\n", "[[[i for i in range(c)]for _ in range(b)] for _ in range(a)]" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qD21_H9HX4jG" }, "source": [ "3. `torch.split(a, size)` 把tensor按照指定的size分组,相当于在最前面增加了一个维度。\n" ] }, { "cell_type": "code", "metadata": { "id": "XTd1hLB4X6zq" }, "source": [ "a = torch.randint(20, (2, 3, 3))\n", "print(a[0].size())\n", "torch.split(a, 1)[1].size()\n", "\n", "for i in a:\n", " print(i)\n", "\n", "for i in torch.split(a, 1):\n", " print(torch.squeeze(i, dim=0))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "rIRsoU3OYJi3" }, "source": [ "4. 维度操作\n", " * `torch.unsqueeze(input, dim)` 增加维度(取消维度压缩)\n", " * `torch.squeeze()` 缩减维度\n", "\n", " 具体来说`unsqueeze(input, dim)`是把input在指定维度扩展一维度,把最外层看作dim=0,最内层的元素看作一个新维度,那么该操作就相当于把指定的维度的元素用`[]`包起来形成一个新维度。\n", "\n", " e.g. 比如下面的例子中`x`原来的维度为[0,1],把最内层的单个元素看作第2维,那么如果指定dim=2则相当于把原来的每个元素都包起来形成一个新的维度。\n", "\n", " `squeeze(input, dim)`为相反的操作,去除长度为1的维度,不指定的话去除所有,指定的话去除指定的。\n", "\n", " 该操作可以用于`torch.bmm()`的数据预处理" ] }, { "cell_type": "code", "metadata": { "id": "nzXoIlOqYJ1P" }, "source": [ "x = torch.tensor([[1, 2], [3, 4]])\n", "print(x.size())\n", "print(x)\n", "\n", "x_t = torch.unsqueeze(x, 1)\n", "print(x_t.size())\n", "print(x_t)\n", "\n", "x_tt = torch.squeeze(x_t)\n", "print(x_tt)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "wZtcDNmabj1k" }, "source": [ "5. `torch.bmm()`\n", "\n", " 把第一维当作batch_index,把后面的每个矩阵相乘得到结果。\n", "\n", "6. `torch.gather(input:Tensor, dim:int, index:LongTensor)`\n", "\n", " 这个函数的作用在于,把index的内容作为dim维度,再根据index的形状去input中选取内容组合成新的tensor。\n", "\n", " e.g. 下面的例子中指定维度为1,也就是列,那么index的数字就是列的意思,保留index除去这个维度以外其它维度的正常含义(i.e. 对应行依然是行),按照它的形状去生成新的结果。\n", " ```\n", " >>> t = torch.Tensor([[1,2],[3,4]])\n", " >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))\n", " 1 1\n", " 4 3\n", " ```\n", "\n", " 技巧,在NLP的输出任务中若softmax的结果为y_pre = (length, batch, vocab_size);y = (length, batch), 则可以用这个函数巧妙的计算交叉熵的结果: \n", " \n", " `torch.gather(y_pre, -1, y.unsqueeze(-1)).squeeze(-1).sum(dim=0)`" ] }, { "cell_type": "code", "metadata": { "id": "MwV_rQASekcP" }, "source": [ "t = torch.randint(20,(2,3))\n", "print(t)\n", "print(t.sum(dim=0))\n", "\n", "torch.gather(t, 1, torch.LongTensor([[1, 0], [0, 1]]))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "0QKAsZsxCkLG" }, "source": [ "## 工具模块\n", "\n", "1. `nn.Embedding()` [doc](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding)\n", "\n", " * `embedding = nn.Embedding(10, 2)` (词表长度,单词表示维度)\n", " * `nn.Embedding.from_pretrained()` 加载与训练词表\n", " * `embedding(input)` 得到$(*, e)$维度的矩阵 (\\*是指input_shape, e是指embedding_size)\n" ] }, { "cell_type": "code", "metadata": { "id": "L85FHscXCKaF" }, "source": [ "embedding = nn.Embedding(10, 2)\n", "embedding = nn.Embedding.from_pretrained() # 加载预训练词表 该方法是 @classmethod\n", "\n", "input = torch.LongTensor([[1, 2, 3, 4], [7, 7, 7, 7]])\n", "embedding(input) " ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "cFaTsm9hAjgq" }, "source": [ "2. 模型参数初始化 `nn.init.*` [doc](https://pytorch.org/docs/stable/nn.init.html)\n" ] }, { "cell_type": "code", "metadata": { "id": "OAnnc2boA1AP" }, "source": [ "net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n", "\n", "def init_weights(m):\n", " if type(m) == nn.Linear:\n", " nn.init.normal_(m.weight, std=0.01)\n", "\n", "net.apply(init_weights); # 对Sequential的每一个model进行应用" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "9v3dTZ7pIJ4P" }, "source": [ "3. 设置添加正则化项和使用优化器\n", "```python\n", "trainer = torch.optim.SGD(\n", " [{\"params\": net[0].weight,'weight_decay': wd}, \n", " {\"params\": net[0].bias}], lr=lr)\n", "\n", " trainer.step() # 更新模型参数\n", "```\n", "\n", " ```python\n", " # 查看L2正则化的值, 这里net为Sequential,net[0]表示序列的第一个模型\n", " print('w的L2范数:', net[0].weight.norm().item())\n", " ```\n", " " ] }, { "cell_type": "markdown", "metadata": { "id": "yd_YFQZTnyFA" }, "source": [ "## 手动更新参数构建一个最简单的$y=\\omega x$拟合$(2, 8)$\n", "\n", "手动更新参数的方法\n", "```python\n", " ## Update the weights using gradient descent. Each parameter is a Tensor, so\n", " ## we can access its gradients like we did before.\n", " ## from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html\n", " with torch.no_grad():\n", " for param in model.parameters():\n", " param -= learning_rate * param.grad\n", "```" ] }, { "cell_type": "code", "metadata": { "id": "H_QY_eULhCWN" }, "source": [ "loss_fn = nn.MSELoss()\n", "\n", "class Model(nn.Module):\n", " def __init__(self):\n", " super(Model, self).__init__()\n", " self.layer = nn.Linear(1, 1, False)\n", "\n", " def forward(self, x):\n", " return self.layer(x)\n", "\n", "# target point is (2, 8)\n", "x = torch.tensor([2.], requires_grad=True)\n", "model = Model()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BBtbuQdOfvyF", "outputId": "2c8aa509-f32c-4816-f5e2-f60bf8ad3f39" }, "source": [ "# 手动的\n", "learning_rate = 0.1\n", "\n", "for i in range(10):\n", " y = model(x)\n", " print(f\"y={y.item()}\")\n", "\n", " loss = loss_fn(y, torch.tensor([8.]))\n", " print(f\"loss = {l.item():.3f}\")\n", "\n", " model.zero_grad()\n", " loss.backward()\n", " \n", " ## Update the weights using gradient descent. Each parameter is a Tensor, so\n", " ## we can access its gradients like we did before.\n", " ## from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html\n", " with torch.no_grad():\n", " for param in model.parameters():\n", " param -= learning_rate * param.grad\n", " # param[:] = param - learning_rate * param.grad\n", "\n", " ## Waring: 这种做法不正确!!!\n", " ## for param in model.parameters():\n", " ## param[:] = param - learning_rate * param.grad\n", "\n", " print(list(model.parameters()))\n", " print(\"\")" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "y=0.6354889869689941\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.2635]], requires_grad=True)]\n", "\n", "y=6.527097702026367\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.8527]], requires_grad=True)]\n", "\n", "y=7.705419540405273\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.9705]], requires_grad=True)]\n", "\n", "y=7.941083908081055\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.9941]], requires_grad=True)]\n", "\n", "y=7.988216876983643\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.9988]], requires_grad=True)]\n", "\n", "y=7.99764347076416\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[3.9998]], requires_grad=True)]\n", "\n", "y=7.999528884887695\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[4.0000]], requires_grad=True)]\n", "\n", "y=7.999905586242676\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[4.0000]], requires_grad=True)]\n", "\n", "y=7.999980926513672\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[4.0000]], requires_grad=True)]\n", "\n", "y=7.999996185302734\n", "loss = 0.000\n", "[Parameter containing:\n", "tensor([[4.0000]], requires_grad=True)]\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "AGTAjQIiwXwb" }, "source": [ "## RNN模块的使用\n", "\n", "1. 输入数据处理`nn.utils.rnn.pack_padded_sequence(input, lengths)`\n", "\n", " RNN 模块的输入必须为`nn.utils.rnn.PackedSequence`类型,所以得用这个模块进行处理。\n", "\n", " 他的作用是将padded后的数据去除`<pad>`进行压紧操作,并提供每个step的数据个数,供RNN模块使用。\n", "\n", " **_所谓`packed sequence`是指压紧的序列,pack本身有压紧挤满的意思_**\n", "\n", "* 该方法默认`batch_first==False`期待的输入`input`维度为$(T, B, *)$,T为句子长度,B为batch中每个句子去除`<pad>`后的真实长度。(i.e. lengths.shape == B)\n", "\n", "* `enforce_sorted`默认为`True`,意味着它接收一个已经按照句子长度降序排列过的数据,如果数据并不是这样,将该选项设置为`False`可以让它自动排序。\n", "\n", "* 这里面的`batch_sizes`是指每个step有多少个元素,i.e. 句子的每个位置的词的batch。\n", "\n", "\n", "2. `pad_packed_sequence()`上面操作的反向操作,把pad重新填充回去返回 (padded_sequence, lengths)\n", "\n", "* 如果是它自动排序过的,会自动恢复原顺序\n", "\n", "* `d.data[d.unsorted_indices]`也可以手动恢复\n" ] }, { "cell_type": "code", "metadata": { "id": "EORpXavmJZAR" }, "source": [ "input = torch.randint(20, (2, 3, 3))\n", "print(input)\n", "lengths = [2 for _ in range(input.shape[1])]\n", "lengths[1]=1\n", "d = nn.utils.rnn.pack_padded_sequence(input, lengths, enforce_sorted=False)\n", "print(d)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Y1ydi8IjizMR" }, "source": [ "nn.utils.rnn.pad_packed_sequence(d) # 注意返回的是tuple(padded_sequence, lengths)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Q_pZPXatk34F" }, "source": [ "d.data[d.unsorted_indices]" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "zBYtEtLd8Ih8" }, "source": [ "### 3. RNN模块的输入输出\n", "\n", "1. 输入为上面构建的`PackedSequence`类型\n", "2. 输出为`output, (h_n, c_n)`\n", " * `output` **_最后一层_** 每个step的输出 `shape(seq_len, batch, num_directions * hidden_size)`\n", " * `h_n`和`c_n` **_所有层_** 最后一个step后的hidden状态和cell状态 `shape(num\\_layers * num\\_directions, batch, hidden\\_size)`\n", " * 对于单层、单向LSTM而言, `output`的最后一个输出,等于`h_n`的输出。对于双向来说`h_n`是后传最后一个和前传第一个拼接的结果。\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "24hhSZP7phSo" }, "source": [ "# 单向、单层一致演示\n", "model = nn.LSTM(input_size = 2, hidden_size=4, bidirectional=False)\n", "data = torch.rand((5, 3, 2))\n", "lengths = [data.size()[0] for _ in range(data.size()[1])]\n", "input = nn.utils.rnn.pack_padded_sequence(data, lengths)\n", "\n", "hidden, (h_last, c_last) = model(input)\n", "print(nn.utils.rnn.pad_packed_sequence(hidden)[0][-1])\n", "print(h_last)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "BdCvgmxRJ3tt" }, "source": [ "## `nn.Dense()` 全连接层" ] }, { "cell_type": "code", "metadata": { "id": "SQmRb2mRsL4a", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ee66b3fd-ae81-4af7-c655-9d90b8d05c87" }, "source": [ "# 测试contiguous\n", "a = torch.rand((2,3))\n", "a.is_contiguous()\n", "a.view((3, -1)).is_contiguous()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZkjnWEvHgXVM", "outputId": "f1dbac6c-3bcb-4e55-c658-fd230937bc55" }, "source": [ "a.shape" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([2, 3])" ] }, "metadata": { "tags": [] }, "execution_count": 11 } ] } ] }