Skip to content

Instantly share code, notes, and snippets.

@yrpang
Created March 25, 2021 11:53
Show Gist options
  • Save yrpang/5d3f38344fda46e6772d93a4debc52ea to your computer and use it in GitHub Desktop.
Save yrpang/5d3f38344fda46e6772d93a4debc52ea to your computer and use it in GitHub Desktop.

Revisions

  1. yrpang revised this gist Mar 25, 2021. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion pytorch-memo.ipynb
    Original file line number Diff line number Diff line change
    @@ -7,7 +7,8 @@
    "provenance": [],
    "collapsed_sections": [],
    "toc_visible": true,
    "authorship_tag": "ABX9TyM/sXLkvwCDyjfvdXADq8eS"
    "authorship_tag": "ABX9TyM/sXLkvwCDyjfvdXADq8eS",
    "include_colab_link": true
    },
    "kernelspec": {
    "name": "python3",
    @@ -18,6 +19,16 @@
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/yrpang/5d3f38344fda46e6772d93a4debc52ea/pytorch-memo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
  2. yrpang created this gist Mar 25, 2021.
    872 changes: 872 additions & 0 deletions pytorch-memo.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,872 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "colab": {
    "name": "PyTorch Memo.ipynb",
    "provenance": [],
    "collapsed_sections": [],
    "toc_visible": true,
    "authorship_tag": "ABX9TyM/sXLkvwCDyjfvdXADq8eS"
    },
    "kernelspec": {
    "name": "python3",
    "display_name": "Python 3"
    },
    "language_info": {
    "name": "python"
    }
    },
    "cells": [
    {
    "cell_type": "code",
    "metadata": {
    "id": "camxQRqXBrz_"
    },
    "source": [
    "import torch\n",
    "import torch.nn as nn"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "8i2wGC59hOsH"
    },
    "source": [
    "## 基本Tensor操作"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "Jy2ztjWuhXAq",
    "outputId": "1f51ff8e-3270-479e-bf5f-da939f0922e8"
    },
    "source": [
    "# Tensor属性\n",
    "x = torch.arange(12) # 从0开始的向量\n",
    "x.shape # 形状 x.size()\n",
    "x.numel() # 元素总数\n",
    "x.reshape(3, 4) # 返回 contiguous 的reshape后的tensor"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "12"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 13
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "-YUH-Vh-hzLh"
    },
    "source": [
    "# 创建Tensor\n",
    "torch.zeros((2, 3, 4))\n",
    "torch.ones((2, 3, 4))\n",
    "torch.randn(3, 4) # 正态分布\n",
    "torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "SwBsQY5miBof"
    },
    "source": [
    "# element-wised运算\n",
    "# + - * / **\n",
    "torch.exp(x)"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "S2zwzTuuipJ2",
    "outputId": "e645eeb7-ff4e-4b68-9204-fc1c546ed300"
    },
    "source": [
    "# 矩阵reshape和拼接\n",
    "X = torch.arange(12, dtype=torch.float32).reshape((3, 4))\n",
    "Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])\n",
    "\n",
    "torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1) # 拼接\n",
    "X == Y # 逻辑判断\n",
    "X.sum() # 求和 可以指定维度"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "tensor(66.)"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 16
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "tQPbca0sjcLk",
    "outputId": "7466a998-2fd5-48d2-fdc6-5296cb635a8a"
    },
    "source": [
    "# 原地操作\n",
    "Z = torch.zeros_like(Y)\n",
    "print('id(Z):', id(Z))\n",
    "Z[:] = X + Y # 避免复制,减少内存使用\n",
    "print('id(Z):', id(Z))"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "id(Z): 140076925416544\n",
    "id(Z): 140076925416544\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "echMuuUqrpLC",
    "outputId": "239acd36-1edb-4763-afad-8b76d59c181f"
    },
    "source": [
    "# 类型转换\n",
    "A = X.numpy()\n",
    "B = torch.tensor(A)\n",
    "type(A), type(B)\n",
    "\n",
    "a = torch.tensor([3.5])\n",
    "a, a.item(), float(a), int(a) # 标量转化为其它类型"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "(tensor([3.5000]), 3.5, 3.5, 3)"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 19
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "PWbUDSi7tuMI",
    "outputId": "4d4b0a12-c9cb-425d-d180-0260a3ee93e0"
    },
    "source": [
    "# 线性代数运算\n",
    "X = torch.rand((3,4))\n",
    "Y = torch.rand((3,4))\n",
    "Z = torch.rand((4,5))\n",
    "\n",
    "X * Y # Hardmard product 哈达玛积 i.e.按元素相乘\n",
    "X @ Z # 矩阵乘法\n",
    "\n",
    "X.mean(), X.sum() / X.numel() # 求所有元素的均值\n",
    "X.mean(axis=0), X.sum(axis=0) / X.shape[0] # 按行求均值\n",
    "\n",
    "sum_X = X.sum(axis=1, keepdims=True) # keepdims参数保持维度不变\n",
    "X.cumsum(axis=0) # 累加求和\n",
    "\n",
    "# 点积\n",
    "x = torch.rand(4, dtype=torch.float32)\n",
    "y = torch.ones(4, dtype=torch.float32)\n",
    "x, y, torch.dot(x, y)\n",
    "\n",
    "torch.mv(X, x) # 矩阵向量乘法(Ax x当作列向量参与运算)\n",
    "torch.mm(X, Z) # alias X@Z\n",
    "\n",
    "# 范数\n",
    "torch.norm(X) # L2范数\n",
    "torch.abs(X).sum() # L1范数"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "tensor(4.6720)"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 29
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 173
    },
    "id": "y5_0TtCEybX8",
    "outputId": "f355aa61-9241-4fd9-b37a-f1adbfa6fd62"
    },
    "source": [
    "### pandas的get_dummies\n",
    "### 本质上是在做one-hot编码 把某一列中的元素拆分成 set(sny_column)这么多列 每一个元素用一个1表示\n",
    "import pandas as pd\n",
    "s = pd.Series(list('abca'))\n",
    "pd.get_dummies(s)"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/html": [
    "<div>\n",
    "<style scoped>\n",
    " .dataframe tbody tr th:only-of-type {\n",
    " vertical-align: middle;\n",
    " }\n",
    "\n",
    " .dataframe tbody tr th {\n",
    " vertical-align: top;\n",
    " }\n",
    "\n",
    " .dataframe thead th {\n",
    " text-align: right;\n",
    " }\n",
    "</style>\n",
    "<table border=\"1\" class=\"dataframe\">\n",
    " <thead>\n",
    " <tr style=\"text-align: right;\">\n",
    " <th></th>\n",
    " <th>a</th>\n",
    " <th>b</th>\n",
    " <th>c</th>\n",
    " </tr>\n",
    " </thead>\n",
    " <tbody>\n",
    " <tr>\n",
    " <th>0</th>\n",
    " <td>1</td>\n",
    " <td>0</td>\n",
    " <td>0</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>1</th>\n",
    " <td>0</td>\n",
    " <td>1</td>\n",
    " <td>0</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>2</th>\n",
    " <td>0</td>\n",
    " <td>0</td>\n",
    " <td>1</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>3</th>\n",
    " <td>1</td>\n",
    " <td>0</td>\n",
    " <td>0</td>\n",
    " </tr>\n",
    " </tbody>\n",
    "</table>\n",
    "</div>"
    ],
    "text/plain": [
    " a b c\n",
    "0 1 0 0\n",
    "1 0 1 0\n",
    "2 0 0 1\n",
    "3 1 0 0"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 32
    }
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "mz-AHyfSBwxJ"
    },
    "source": [
    "\n",
    "## 常用操作\n",
    "\n",
    "1. Tensor维度的理解\n",
    "\n",
    " 在理解上把 Tensor 的维度`(a,b,c)`理解为`[[[i for i in range(c)]for _ in range(b)] for _ in range(a)]`\n",
    "\n",
    " * 相当于用最后一维存储数据,前面的维度作索引\n",
    " * 对于`nn.Embedding()`这样的函数即可\n",
    "\n",
    "2. `a.permute()` 可以用来改变维度顺序"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "0TCwWsAqBwEg"
    },
    "source": [
    "a = 10\n",
    "b = 3\n",
    "c = 3\n",
    "[[[i for i in range(c)]for _ in range(b)] for _ in range(a)]"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "qD21_H9HX4jG"
    },
    "source": [
    "3. `torch.split(a, size)` 把tensor按照指定的size分组,相当于在最前面增加了一个维度。\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "XTd1hLB4X6zq"
    },
    "source": [
    "a = torch.randint(20, (2, 3, 3))\n",
    "print(a[0].size())\n",
    "torch.split(a, 1)[1].size()\n",
    "\n",
    "for i in a:\n",
    " print(i)\n",
    "\n",
    "for i in torch.split(a, 1):\n",
    " print(torch.squeeze(i, dim=0))"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "rIRsoU3OYJi3"
    },
    "source": [
    "4. 维度操作\n",
    " * `torch.unsqueeze(input, dim)` 增加维度(取消维度压缩)\n",
    " * `torch.squeeze()` 缩减维度\n",
    "\n",
    " 具体来说`unsqueeze(input, dim)`是把input在指定维度扩展一维度,把最外层看作dim=0,最内层的元素看作一个新维度,那么该操作就相当于把指定的维度的元素用`[]`包起来形成一个新维度。\n",
    "\n",
    " e.g. 比如下面的例子中`x`原来的维度为[0,1],把最内层的单个元素看作第2维,那么如果指定dim=2则相当于把原来的每个元素都包起来形成一个新的维度。\n",
    "\n",
    " `squeeze(input, dim)`为相反的操作,去除长度为1的维度,不指定的话去除所有,指定的话去除指定的。\n",
    "\n",
    " 该操作可以用于`torch.bmm()`的数据预处理"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "nzXoIlOqYJ1P"
    },
    "source": [
    "x = torch.tensor([[1, 2], [3, 4]])\n",
    "print(x.size())\n",
    "print(x)\n",
    "\n",
    "x_t = torch.unsqueeze(x, 1)\n",
    "print(x_t.size())\n",
    "print(x_t)\n",
    "\n",
    "x_tt = torch.squeeze(x_t)\n",
    "print(x_tt)"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "wZtcDNmabj1k"
    },
    "source": [
    "5. `torch.bmm()`\n",
    "\n",
    " 把第一维当作batch_index,把后面的每个矩阵相乘得到结果。\n",
    "\n",
    "6. `torch.gather(input:Tensor, dim:int, index:LongTensor)`\n",
    "\n",
    " 这个函数的作用在于,把index的内容作为dim维度,再根据index的形状去input中选取内容组合成新的tensor。\n",
    "\n",
    " e.g. 下面的例子中指定维度为1,也就是列,那么index的数字就是列的意思,保留index除去这个维度以外其它维度的正常含义(i.e. 对应行依然是行),按照它的形状去生成新的结果。\n",
    " ```\n",
    " >>> t = torch.Tensor([[1,2],[3,4]])\n",
    " >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))\n",
    " 1 1\n",
    " 4 3\n",
    " ```\n",
    "\n",
    " 技巧,在NLP的输出任务中若softmax的结果为y_pre = (length, batch, vocab_size);y = (length, batch), 则可以用这个函数巧妙的计算交叉熵的结果: \n",
    " \n",
    " `torch.gather(y_pre, -1, y.unsqueeze(-1)).squeeze(-1).sum(dim=0)`"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "MwV_rQASekcP"
    },
    "source": [
    "t = torch.randint(20,(2,3))\n",
    "print(t)\n",
    "print(t.sum(dim=0))\n",
    "\n",
    "torch.gather(t, 1, torch.LongTensor([[1, 0], [0, 1]]))"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "0QKAsZsxCkLG"
    },
    "source": [
    "## 工具模块\n",
    "\n",
    "1. `nn.Embedding()` [doc](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding)\n",
    "\n",
    " * `embedding = nn.Embedding(10, 2)` (词表长度,单词表示维度)\n",
    " * `nn.Embedding.from_pretrained()` 加载与训练词表\n",
    " * `embedding(input)` 得到$(*, e)$维度的矩阵 (\\*是指input_shape, e是指embedding_size)\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "L85FHscXCKaF"
    },
    "source": [
    "embedding = nn.Embedding(10, 2)\n",
    "embedding = nn.Embedding.from_pretrained() # 加载预训练词表 该方法是 @classmethod\n",
    "\n",
    "input = torch.LongTensor([[1, 2, 3, 4], [7, 7, 7, 7]])\n",
    "embedding(input) "
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "cFaTsm9hAjgq"
    },
    "source": [
    "2. 模型参数初始化 `nn.init.*` [doc](https://pytorch.org/docs/stable/nn.init.html)\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "OAnnc2boA1AP"
    },
    "source": [
    "net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))\n",
    "\n",
    "def init_weights(m):\n",
    " if type(m) == nn.Linear:\n",
    " nn.init.normal_(m.weight, std=0.01)\n",
    "\n",
    "net.apply(init_weights); # 对Sequential的每一个model进行应用"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "9v3dTZ7pIJ4P"
    },
    "source": [
    "3. 设置添加正则化项和使用优化器\n",
    "```python\n",
    "trainer = torch.optim.SGD(\n",
    " [{\"params\": net[0].weight,'weight_decay': wd}, \n",
    " {\"params\": net[0].bias}], lr=lr)\n",
    "\n",
    " trainer.step() # 更新模型参数\n",
    "```\n",
    "\n",
    " ```python\n",
    " # 查看L2正则化的值, 这里net为Sequential,net[0]表示序列的第一个模型\n",
    " print('w的L2范数:', net[0].weight.norm().item())\n",
    " ```\n",
    " "
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "yd_YFQZTnyFA"
    },
    "source": [
    "## 手动更新参数构建一个最简单的$y=\\omega x$拟合$(2, 8)$\n",
    "\n",
    "手动更新参数的方法\n",
    "```python\n",
    " ## Update the weights using gradient descent. Each parameter is a Tensor, so\n",
    " ## we can access its gradients like we did before.\n",
    " ## from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html\n",
    " with torch.no_grad():\n",
    " for param in model.parameters():\n",
    " param -= learning_rate * param.grad\n",
    "```"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "H_QY_eULhCWN"
    },
    "source": [
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "class Model(nn.Module):\n",
    " def __init__(self):\n",
    " super(Model, self).__init__()\n",
    " self.layer = nn.Linear(1, 1, False)\n",
    "\n",
    " def forward(self, x):\n",
    " return self.layer(x)\n",
    "\n",
    "# target point is (2, 8)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "model = Model()"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "BBtbuQdOfvyF",
    "outputId": "2c8aa509-f32c-4816-f5e2-f60bf8ad3f39"
    },
    "source": [
    "# 手动的\n",
    "learning_rate = 0.1\n",
    "\n",
    "for i in range(10):\n",
    " y = model(x)\n",
    " print(f\"y={y.item()}\")\n",
    "\n",
    " loss = loss_fn(y, torch.tensor([8.]))\n",
    " print(f\"loss = {l.item():.3f}\")\n",
    "\n",
    " model.zero_grad()\n",
    " loss.backward()\n",
    " \n",
    " ## Update the weights using gradient descent. Each parameter is a Tensor, so\n",
    " ## we can access its gradients like we did before.\n",
    " ## from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html\n",
    " with torch.no_grad():\n",
    " for param in model.parameters():\n",
    " param -= learning_rate * param.grad\n",
    " # param[:] = param - learning_rate * param.grad\n",
    "\n",
    " ## Waring: 这种做法不正确!!!\n",
    " ## for param in model.parameters():\n",
    " ## param[:] = param - learning_rate * param.grad\n",
    "\n",
    " print(list(model.parameters()))\n",
    " print(\"\")"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "y=0.6354889869689941\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.2635]], requires_grad=True)]\n",
    "\n",
    "y=6.527097702026367\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.8527]], requires_grad=True)]\n",
    "\n",
    "y=7.705419540405273\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.9705]], requires_grad=True)]\n",
    "\n",
    "y=7.941083908081055\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.9941]], requires_grad=True)]\n",
    "\n",
    "y=7.988216876983643\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.9988]], requires_grad=True)]\n",
    "\n",
    "y=7.99764347076416\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[3.9998]], requires_grad=True)]\n",
    "\n",
    "y=7.999528884887695\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[4.0000]], requires_grad=True)]\n",
    "\n",
    "y=7.999905586242676\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[4.0000]], requires_grad=True)]\n",
    "\n",
    "y=7.999980926513672\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[4.0000]], requires_grad=True)]\n",
    "\n",
    "y=7.999996185302734\n",
    "loss = 0.000\n",
    "[Parameter containing:\n",
    "tensor([[4.0000]], requires_grad=True)]\n",
    "\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "AGTAjQIiwXwb"
    },
    "source": [
    "## RNN模块的使用\n",
    "\n",
    "1. 输入数据处理`nn.utils.rnn.pack_padded_sequence(input, lengths)`\n",
    "\n",
    " RNN 模块的输入必须为`nn.utils.rnn.PackedSequence`类型,所以得用这个模块进行处理。\n",
    "\n",
    " 他的作用是将padded后的数据去除`<pad>`进行压紧操作,并提供每个step的数据个数,供RNN模块使用。\n",
    "\n",
    " **_所谓`packed sequence`是指压紧的序列,pack本身有压紧挤满的意思_**\n",
    "\n",
    "* 该方法默认`batch_first==False`期待的输入`input`维度为$(T, B, *)$,T为句子长度,B为batch中每个句子去除`<pad>`后的真实长度。(i.e. lengths.shape == B)\n",
    "\n",
    "* `enforce_sorted`默认为`True`,意味着它接收一个已经按照句子长度降序排列过的数据,如果数据并不是这样,将该选项设置为`False`可以让它自动排序。\n",
    "\n",
    "* 这里面的`batch_sizes`是指每个step有多少个元素,i.e. 句子的每个位置的词的batch。\n",
    "\n",
    "\n",
    "2. `pad_packed_sequence()`上面操作的反向操作,把pad重新填充回去返回 (padded_sequence, lengths)\n",
    "\n",
    "* 如果是它自动排序过的,会自动恢复原顺序\n",
    "\n",
    "* `d.data[d.unsorted_indices]`也可以手动恢复\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "EORpXavmJZAR"
    },
    "source": [
    "input = torch.randint(20, (2, 3, 3))\n",
    "print(input)\n",
    "lengths = [2 for _ in range(input.shape[1])]\n",
    "lengths[1]=1\n",
    "d = nn.utils.rnn.pack_padded_sequence(input, lengths, enforce_sorted=False)\n",
    "print(d)"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "Y1ydi8IjizMR"
    },
    "source": [
    "nn.utils.rnn.pad_packed_sequence(d) # 注意返回的是tuple(padded_sequence, lengths)"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "Q_pZPXatk34F"
    },
    "source": [
    "d.data[d.unsorted_indices]"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "zBYtEtLd8Ih8"
    },
    "source": [
    "### 3. RNN模块的输入输出\n",
    "\n",
    "1. 输入为上面构建的`PackedSequence`类型\n",
    "2. 输出为`output, (h_n, c_n)`\n",
    " * `output` **_最后一层_** 每个step的输出 `shape(seq_len, batch, num_directions * hidden_size)`\n",
    " * `h_n`和`c_n` **_所有层_** 最后一个step后的hidden状态和cell状态 `shape(num\\_layers * num\\_directions, batch, hidden\\_size)`\n",
    " * 对于单层、单向LSTM而言, `output`的最后一个输出,等于`h_n`的输出。对于双向来说`h_n`是后传最后一个和前传第一个拼接的结果。\n",
    "\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "24hhSZP7phSo"
    },
    "source": [
    "# 单向、单层一致演示\n",
    "model = nn.LSTM(input_size = 2, hidden_size=4, bidirectional=False)\n",
    "data = torch.rand((5, 3, 2))\n",
    "lengths = [data.size()[0] for _ in range(data.size()[1])]\n",
    "input = nn.utils.rnn.pack_padded_sequence(data, lengths)\n",
    "\n",
    "hidden, (h_last, c_last) = model(input)\n",
    "print(nn.utils.rnn.pad_packed_sequence(hidden)[0][-1])\n",
    "print(h_last)"
    ],
    "execution_count": null,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "BdCvgmxRJ3tt"
    },
    "source": [
    "## `nn.Dense()` 全连接层"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "SQmRb2mRsL4a",
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "outputId": "ee66b3fd-ae81-4af7-c655-9d90b8d05c87"
    },
    "source": [
    "# 测试contiguous\n",
    "a = torch.rand((2,3))\n",
    "a.is_contiguous()\n",
    "a.view((3, -1)).is_contiguous()"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "True"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 8
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "id": "ZkjnWEvHgXVM",
    "outputId": "f1dbac6c-3bcb-4e55-c658-fd230937bc55"
    },
    "source": [
    "a.shape"
    ],
    "execution_count": null,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "torch.Size([2, 3])"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 11
    }
    ]
    }
    ]
    }