01 Pytorch Workflow - Ipynb
01 Pytorch Workflow - Ipynb
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-
learning/blob/main/01_pytorch_workflow.ipynb\" target=\"_parent\"><img
src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In
Colab\"/></a>\n",
"\n",
"[View Source
Code](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/
01_pytorch_workflow.ipynb) | [View Slides](https://github.com/mrdbourke/pytorch-
deep-learning/blob/main/slides/01_pytorch_workflow.pdf) | [Watch Video Walkthrough]
(https://youtu.be/Z_ikDlimN6A?t=15419) "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OgYkrRCRec0r"
},
"source": [
"# 01. PyTorch Workflow Fundamentals\n",
"\n",
"The essence of machine learning and deep learning is to take some data from
the past, build an algorithm (like a neural network) to discover patterns in it and
use the discovered patterns to predict the future.\n",
"\n",
"There are many ways to do this and many new ways are being discovered all the
time.\n",
"\n",
"But let's start small.\n",
"\n",
"How about we start with a straight line?\n",
"\n",
"And we see if we can build a PyTorch model that learns the pattern of the
straight line and matches it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "51Ug7Ug123Ip"
},
"source": [
"## What we're going to cover\n",
"\n",
"In this module we're going to cover a standard PyTorch workflow (it can be
chopped and changed as necessary but it covers the main outline of steps).\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/
main/images/01_a_pytorch_workflow.png\" width=900 alt=\"a pytorch workflow
flowchat\"/>\n",
"\n",
"For now, we'll use this workflow to predict a simple straight line but the
workflow steps can be repeated and changed depending on the problem you're working
on.\n",
"\n",
"Specifically, we're going to cover:\n",
"\n",
"| **Topic** | **Contents** |\n",
"| ----- | ----- |\n",
"| **1. Getting data ready** | Data can be almost anything but to get started
we're going to create a simple straight line |\n",
"| **2. Building a model** | Here we'll create a model to learn patterns in the
data, we'll also choose a **loss function**, **optimizer** and build a **training
loop**. | \n",
"| **3. Fitting the model to data (training)** | We've got data and a model,
now let's let the model (try to) find patterns in the (**training**) data. |\n",
"| **4. Making predictions and evaluating a model (inference)** | Our model's
found patterns in the data, let's compare its findings to the actual (**testing**)
data. |\n",
"| **5. Saving and loading a model** | You may want to use your model
elsewhere, or come back to it later, here we'll cover that. |\n",
"| **6. Putting it all together** | Let's take all of the above and combine it.
|"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kKC3ugfM25e_"
},
"source": [
"\n",
"## Where can you get help?\n",
"\n",
"All of the materials for this course are [available on
GitHub](https://github.com/mrdbourke/pytorch-deep-learning).\n",
"\n",
"And if you run into trouble, you can ask a question on the [Discussions page]
(https://github.com/mrdbourke/pytorch-deep-learning/discussions) there too.\n",
"\n",
"There's also the [PyTorch developer forums](https://discuss.pytorch.org/), a
very helpful place for all things PyTorch. \n",
"\n",
"Let's start by putting what we're covering into a dictionary to reference
later.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "nGM1dEsYec0u"
},
"outputs": [],
"source": [
"what_were_covering = {1: \"data (prepare and load)\",\n",
" 2: \"build model\",\n",
" 3: \"fitting the model to data (training)\",\n",
" 4: \"making predictions and evaluating a model (inference)\",\n",
" 5: \"saving and loading a model\",\n",
" 6: \"putting it all together\"\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L9EOt5cbod6l"
},
"source": [
"And now let's import what we'll need for this module.\n",
"\n",
"We're going to get `torch`, `torch.nn` (`nn` stands for neural network and
this package contains the building blocks for creating neural networks in PyTorch)
and `matplotlib`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "ZT_ikDC-ec0w",
"outputId": "1f0b19d0-6e96-4cc9-b8e6-7adcb3f1da27"
},
"outputs": [
{
"data": {
"text/plain": [
"'1.12.1+cu113'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from torch import nn # nn contains all of PyTorch's building blocks for neural
networks\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Check PyTorch version\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ci_-geIdec0w"
},
"source": [
"## 1. Data (preparing and loading)\n",
"\n",
"I want to stress that \"data\" in machine learning can be almost anything you
can imagine. A table of numbers (like a big Excel spreadsheet), images of any kind,
videos (YouTube has lots of data!), audio files like songs or podcasts, protein
structures, text and more.\n",
"\n",
"\n",
"\n",
"Machine learning is a game of two parts: \n",
"1. Turn your data, whatever it is, into numbers (a representation).\n",
"2. Pick or build a model to learn the representation as best as possible.\n",
"\n",
"Sometimes one and two can be done at the same time.\n",
"\n",
"But what if you don't have data?\n",
"\n",
"Well, that's where we're at now.\n",
"\n",
"No data.\n",
"\n",
"But we can create some.\n",
"\n",
"Let's create our data as a straight line.\n",
"\n",
"We'll use [linear regression](https://en.wikipedia.org/wiki/Linear_regression)
to create the data with known **parameters** (things that can be learned by a
model) and then we'll use PyTorch to see if we can build model to estimate these
parameters using [**gradient
descent**](https://en.wikipedia.org/wiki/Gradient_descent).\n",
"\n",
"Don't worry if the terms above don't mean much now, we'll see them in action
and I'll put extra resources below where you can learn more.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HmZWVNjGec0x",
"outputId": "ef7c9d50-31d6-47b6-add9-2cd51694298f"
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0.0000],\n",
" [0.0200],\n",
" [0.0400],\n",
" [0.0600],\n",
" [0.0800],\n",
" [0.1000],\n",
" [0.1200],\n",
" [0.1400],\n",
" [0.1600],\n",
" [0.1800]]),\n",
" tensor([[0.3000],\n",
" [0.3140],\n",
" [0.3280],\n",
" [0.3420],\n",
" [0.3560],\n",
" [0.3700],\n",
" [0.3840],\n",
" [0.3980],\n",
" [0.4120],\n",
" [0.4260]]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create *known* parameters\n",
"weight = 0.7\n",
"bias = 0.3\n",
"\n",
"# Create data\n",
"start = 0\n",
"end = 1\n",
"step = 0.02\n",
"X = torch.arange(start, end, step).unsqueeze(dim=1)\n",
"y = weight * X + bias\n",
"\n",
"X[:10], y[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dzNigr8dtW2Y"
},
"source": [
"Beautiful! Now we're going to move towards building a model that can learn the
relationship between `X` (**features**) and `y` (**labels**). "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YApM7diprjP0"
},
"source": [
"### Split data into training and test sets \n",
"\n",
"We've got some data.\n",
"\n",
"But before we build a model we need to split it up.\n",
"\n",
"One of most important steps in a machine learning project is creating a
training and test set (and when required, a validation set).\n",
"\n",
"Each split of the dataset serves a specific purpose:\n",
"\n",
"| Split | Purpose | Amount of total data | How often is it used? |\n",
"| ----- | ----- | ----- | ----- |\n",
"| **Training set** | The model learns from this data (like the course
materials you study during the semester). | ~60-80% | Always |\n",
"| **Validation set** | The model gets tuned on this data (like the practice
exam you take before the final exam). | ~10-20% | Often but not always |\n",
"| **Testing set** | The model gets evaluated on this data to test what it has
learned (like the final exam you take at the end of the semester). | ~10-20% |
Always |\n",
"\n",
"For now, we'll just use a training and test set, this means we'll have a
dataset for our model to learn on as well as be evaluated on.\n",
"\n",
"We can create them by splitting our `X` and `y` tensors.\n",
"\n",
"> **Note:** When dealing with real-world data, this step is typically done
right at the start of a project (the test set should always be kept separate from
all other data). We want our model to learn from training data and then evaluate it
on test data to get an indication of how well it **generalizes** to unseen
examples.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BpyB7JgHec0y",
"outputId": "a859f5c1-37ed-4a9a-b139-20a1107077ed"
},
"outputs": [
{
"data": {
"text/plain": [
"(40, 40, 10, 10)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create train/test split\n",
"train_split = int(0.8 * len(X)) # 80% of data used for training set, 20% for
testing \n",
"X_train, y_train = X[:train_split], y[:train_split]\n",
"X_test, y_test = X[train_split:], y[train_split:]\n",
"\n",
"len(X_train), len(y_train), len(X_test), len(y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ua1y5hFjtLxC"
},
"source": [
"Wonderful, we've got 40 samples for training (`X_train` & `y_train`) and 10
samples for testing (`X_test` & `y_test`).\n",
"\n",
"The model we create is going to try and learn the relationship between
`X_train` & `y_train` and then we will evaluate what it learns on `X_test` and
`y_test`.\n",
"\n",
"But right now our data is just numbers on a page.\n",
"\n",
"Let's create a function to visualize it."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "w9Ep0T-Dec0y"
},
"outputs": [],
"source": [
"def plot_predictions(train_data=X_train, \n",
" train_labels=y_train, \n",
" test_data=X_test, \n",
" test_labels=y_test, \n",
" predictions=None):\n",
" \"\"\"\n",
" Plots training data, test data and compares predictions.\n",
" \"\"\"\n",
" plt.figure(figsize=(10, 7))\n",
"\n",
" # Plot training data in blue\n",
" plt.scatter(train_data, train_labels, c=\"b\", s=4, label=\"Training
data\")\n",
" \n",
" # Plot test data in green\n",
" plt.scatter(test_data, test_labels, c=\"g\", s=4, label=\"Testing data\")\
n",
"\n",
" if predictions is not None:\n",
" # Plot the predictions in red (predictions were made on the test data)\n",
" plt.scatter(test_data, predictions, c=\"r\", s=4, label=\"Predictions\")\
n",
"\n",
" # Show the legend\n",
" plt.legend(prop={\"size\": 14});"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 428
},
"id": "xTaIwydGec0z",
"outputId": "0d02d134-f6de-4e6f-c904-b081c7d6b8b1"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAlMAAAGbCAYAAADgEhWsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAm+ElEQVR4nO3dfXRUhb3u8eeXBARBIJQAEhRQUUEECxHLWV
VA60EFFuXaLkCroFbDBc6SdXzB6hGL2q622lq95tRgD6W+VG0VWwoc0OMVQSuSgEINQU9EhCAlCb0LBVshy
e/
+MTlpEpLMhD3v8/2sNSvZLzPzIxv0yZ49z5i7CwAAACcmK9EDAAAApDLCFAAAQACEKQAAgAAIUwAAAAEQpg
AAAALISdQT9+nTxwcPHpyopwcAAIjYli1batw9r7VtCQtTgwcPVmlpaaKeHgAAIGJm9klb23iZDwAAIADCF
AAAQACEKQAAgAAIUwAAAAEQpgAAAAII+24+M1smaYqkKncf0cp2k/
SopKskfSFpjrtvDTrYZ599pqqqKh07dizoQyEDdOvWTQMHDlRWFr8fAADiK5JqhOWSHpf0VBvbr5Q0tOF2k
aRfNHw9YZ999pkOHDig/Px8de3aVaG8BrSuvr5e+/btU01Njfr27ZvocQAAGSbsr/
HuvkHSX9vZZZqkpzxkk6ReZnZqkKGqqqqUn5+vk08+mSCFsLKystSvXz8dOnQo0aMAADJQNF4TyZe0t8lyZ
cO6E3bs2DF17do10FDILJ06dVJtbW2ixwAAZKBohKnWTh15qzua3WJmpWZWWl1d3f6DckYKHcDfFwBAokQj
TFVKOq3J8kBJn7a2o7svdfcCdy/
Iy2v1420AAABSSjTC1EpJ11vI1yQdcvf9UXhcAACApBc2TJnZc5LelnSOmVWa2U1mNtfM5jbsskbSLkkVkp
6UNC9m02agOXPmaMqUKR26z4QJE7RgwYIYTdS+BQsWaMKECQl5bgAAEiFsNYK7zwqz3SXNj9pEKSrcNTuzZ
8/W8uXLO/y4jz76qEI/4sitWLFCnTp16vBzJcLu3bs1ZMgQlZSUqKCgINHjAADQYZH0TCEC+/f/
45XNVatW6eabb262ruW7E48dOxZR4OnZs2eHZ+ndu3eH7wMAAE4MddFR0r9//8Zbr169mq37+9//
rl69eum5557TpZdeqq5du6q4uFgHDx7UrFmzNHDgQHXt2lXnnXeefvWrXzV73JYv802YMEHz5s3T3XffrT5
9+qhv3766/fbbVV9f32yfpi/zDR48WA8++KAKCwvVo0cPDRw4UA899FCz5/
nwww81fvx4denSReecc47WrFmj7t27t3s2ra6uTrfffrtyc3OVm5urhQsXqq6urtk+a9eu1cUXX6zc3Fz17
t1bkyZNUnl5eeP2IUOGSJIuvPBCmVnjS4QlJSX653/+Z/
Xp00c9evTQ17/+db399tvhDwQAIKPMXz1fOffnaP7qxL1IRpiKo+9973uaN2+eduzYoW9+85v6+9//
rtGjR2vVqlUqKyvTrbfeqsLCQr322mvtPs6zzz6rnJwc/elPf9Ljjz+un//853rhhRfavc8jjzyi888/
X1u3btWiRYt05513NoaT+vp6TZ8+XTk5Odq0aZOWL1+uJUuW6Msvv2z3MX/605/qySefVHFxsd5+
+23V1dXp2WefbbbPkSNHtHDhQm3evFnr169Xz549NXXqVB09elSStHnzZkmh0LV//36tWLFCkvT555/
ruuuu08aNG7V582ZdcMEFuuqqq1RTU9PuTACAzFK8pVh1XqfiLcWJG8LdE3IbM2aMt2XHjh1tbuuoefPcs7
NDX+Pld7/7nYd+tCEff/yxS/
KHH3447H1nzJjhN910U+Py7NmzffLkyY3L48eP96997WvN7vONb3yj2X3Gjx/
v8+fPb1weNGiQz5w5s9l9zjrrLH/ggQfc3X3t2rWenZ3tlZWVjdvfeustl+S/
+tWv2pz11FNP9QcffLBxua6uzocOHerjx49v8z6HDx/2rKws37hxo7v/
42dTUlLS5n3c3evr671///7+9NNPt7lPNP/eAABSw7xV8zx7SbbPWxXb/
9FLKvU2Mk3an5kqLpbq6kJfE63lBdZ1dXX6wQ9+oJEjR+orX/
mKunfvrhUrVmjPnj3tPs7IkSObLQ8YMEBVVVUnfJ+dO3dqwIABys//R3H9hRde2O6HBh86dEj79+/
XuHHjGtdlZWXpoouafyzjRx99pGuuuUZnnnmmevTooX79+qm+vj7sn7GqqkqFhYU6+
+yz1bNnT51yyimqqqoKez8AQGYpmlyk2sW1KppclLAZ0v4C9MLCUJAqLEz0JFK3bt2aLT/88MP66U9/
qkcffVTnn3++unfvrrvvvjtsMGp54bqZNbtmqqP3cfeYNYhPnTpV+fn5Ki4uVn5+vnJycjR8+PDGl/
naMnv2bB04cECPPPKIBg8erJNOOkmXXXZZ2PsBABBvaR+miopCt2T05ptvaurUqbruuuskhULNhx9+2HgBe
7wMGzZM+/bt06effqoBAwZIkkpLS9sNaD179tSpp56qTZs26dJLL5UUmn/
z5s069dTQ51wfPHhQ5eXlKioq0sSJEyVJW7dubfYZep07d5ak4y5cf/PNN/
XYY49p8uTJkqQDBw40e3ckAADJIu1f5ktmZ599tl577TW9+eab2rlzpxYsWKCPP/
447nNcfvnlOuecczR79mxt27ZNmzZt0r/+678qJyen3TNWt956q37yk5/
oxRdf1AcffKCFCxc2Czy5ubnq06ePnnzySVVUVOiNN97Q3LlzlZPzjwzft29fde3aVevWrdOBAwd06NAhSa
GfzTPPPKMdO3aopKREM2fObAxeAAAkE8JUAv3bv/2bxo4dqyuvvFKXXHKJunXrpmuvvTbuc2RlZenll1/
Wl19+qbFjx2r27Nm65557ZGbq0qVLm/
e77bbbdMMNN+i73/2uLrroItXX1zebPysrSy+88IK2b9+uESNGaP78+XrggQd00kknNe6Tk5Ojxx57TL/
85S81YMAATZs2TZK0bNkyHT58WGPGjNHMmTN14403avDgwTH7GQAAkkcy1B10hHkH27WjpaCgwEtLS1vdVl
5ermHDhsV5IjS1bds2XXDBBSotLdWYMWMSPU5E+HsDAOkh5/4c1Xmdsi1btYtrw98hDsxsi7u3+lEdnJmCJ
Onll1/WK6+8oo8//livv/665syZo1GjRmn06NGJHg0AkGEKxxQq27JVOCYJ3j0WgbS/
AB2R+fzzz7Vo0SLt3btXubm5mjBhgh555JGYvcsPAIC2FE0uSmjVQUcRpiBJuv7663X99dcnegwAAFIOL/
MBAAAEQJgCAAAIgDAFAADiItUqDyJFmAIAAHFRvKVYdV6n4i1J8IG5UUSYAgAAcZFqlQeR4t18AAAgLlKt8
iBSnJlKYYMHD9bDDz+ckOeeMmWK5syZk5DnBgAgmRCmosTM2r0FCR7f//73NWLEiOPWl5SUaN68eQGmjp/
169fLzFRTU5PoUQAAiCpe5ouS/fv3N36/atUq3Xzzzc3Wde3aNerPmZeXF/
XHBAAAHcOZqSjp379/461Xr17HrduwYYPGjBmjLl26aMiQIbrnnnt09OjRxvuvWLFCI0eOVNeuXdW7d2+NH
z9eBw4c0PLly7VkyRKVlZU1nuVavny5pONf5jMzLV26VN/
+9rfVrVs3nXHGGXrmmWeazfnOO+9o9OjR6tKli7761a9qzZo1MjOtX7++zT/
bF198oTlz5qh79+7q16+ffvjDHx63zzPPPKMLL7xQp5xyivr27atvf/
vb2rdvnyRp9+7dmjhxoqRQAGx6pm7t2rW6+OKLlZubq969e2vSpEkqLy/
v6I8fAJBA6Vp5ECnCVBysW7dO1157rRYsWKCysjItW7ZML774ou6++25J0l/
+8hfNnDlTs2fPVnl5uTZs2KDrrrtOkjRjxgzddtttOuecc7R//37t379fM2bMaPO57r//
fk2bNk3btm3TjBkzdOONN+qTTz6RJB0+fFhTpkzRueeeqy1btugnP/
mJ7rjjjrDz33777Xr11Vf10ksv6bXXXtO7776rDRs2NNvn6NGjWrJkibZt26ZVq1appqZGs2bNkiSddtppe
umllyRJZWVl2r9/
vx599FFJ0pEjR7Rw4UJt3rxZ69evV8+ePTV16tRmQRMAkNzStfIgYu6ekNuYMWO8LTt27GhzW0fNWzXPs5d
k+7xV86L2mOH87ne/89CPNuTiiy/2+++/
v9k+L7/8snfr1s3r6+t9y5YtLsl3797d6uPdd999ft555x23ftCgQf7QQw81Lkvyu+66q3H52LFj3rVrV3/
66afd3f2JJ57w3Nxc/+KLLxr3efbZZ12Sv/
76660+9+eff+6dO3f2Z555ptm6nj17+uzZs9v8GZSXl7sk37t3r7u7v/766y7Jq6ur27yPu/
vhw4c9KyvLN27c2O5+rYnm3xsAQOQS8f/aeJNU6m1kmrQ/M5UMaXnLli36wQ9+oO7duzferrnmGh05ckR/
+ctfNGrUKH3jG9/QiBEjdPXVV+sXv/
iFqqurT+i5Ro4c2fh9Tk6O8vLyVFVVJUnauXOnRowY0ez6rYsuuqjdx/
voo4909OhRjRs3rnFd9+7ddf755zfbb+vWrZo2bZoGDRqkU045RQUFBZKkPXv2hH38a665RmeeeaZ69Oihf
v36qb6+Puz9AADJo2hykWoX16Zl7UEk0j5MJUNBWH19ve677z699957jbft27frv//
7v5WXl6fs7Gy98soreuWVVzRy5Ej9x3/8h4YOHapt27Z1+Lk6derUbNnMVF9fLyl0FtLMOvR4oTDeviNHjm
jSpEk6+eST9fTTT6ukpERr166VpLAv102dOlXV1dUqLi7WO++8o3fffVc5OTm8zAcASBlp/
26+ZCgIGz16tHbu3KmzzjqrzX3MTOPGjdO4ceO0ePFinXfeeXrhhRc0atQode7cWXV1dYHnGDZsmJ566in9
7W9/azw7tXnz5nbvc9ZZZ6lTp07atGmTzjjjDEmh8PT+++/rzDPPlBQ641VTU6Mf/
vCHGjJkiKTQBfVNde7cWZKa/TkOHjyo8vJyFRUVNV6gvnXrVtXW1gb+swIAEC9pf2YqGSxevFi/
+c1vtHjxYr3//
vvauXOnXnzxRd15552SpE2bNunBBx9USUmJ9uzZo5UrV2rv3r0aPny4pNC79j755BNt3bpVNTU1+vLLL09o
jmuvvVbZ2dm6+eabtWPHDv3Xf/1X4zvz2jpj1b17d910001atGiRXn31VZWVlenGG29sFopOP/
10nXTSSXr88ce1a9curV69Wvfee2+zxxk0aJDMTKtXr1Z1dbUOHz6s3Nxc9enTR08+
+aQqKir0xhtvaO7cucrJSfuMDwBII4SpOJg0aZJWr16t119/XWPHjtXYsWP1ox/
9SKeffrokqWfPnnrrrbc0ZcoUDR06VLfddpvuvfdefec735EkXX311brqqqt02WWXKS8vT88999wJzdG9e3
f98Y9/VFlZmb761a/qjjvu0Pe//31JUpcuXdq838MPP6yJEydq+vTpmjhxokaMGKFLLrmkcXteXp5+/
etf6/e//72GDx+uJUuW6Gc/+1mzx8jPz9eSJUt0zz33qF+/
flqwYIGysrL0wgsvaPv27RoxYoTmz5+vBx54QCeddNIJ/
fkAANGT6XUHHWGRXBMTCwUFBV5aWtrqtvLycg0bNizOE2WmP/
zhD5o+fbqqqqrUp0+fRI8TCH9vACB6cu7PUZ3XKduyVbuYyy/MbIu7F7S2jTNTGebXv/
61Nm7cqN27d2vVqlVauHChpk6dmvJBCgAQXcnwBq5UwcUpGebAgQO67777tH//fvXv31+TJ0/
Wj3/840SPBQBIMsnwBq5UQZjKMHfeeWfjhe8AACA4XuYDAAAIIGnD1P8UTQKRSNQbKQAASMow1a1bN+3bt0
9Hjx7lf5IIy9118ODBdusdAAAhVB5EX1JWI9TX16umpkaHDh2iDRsR6dKliwYOHHjcx+kAAJqj8uDEtFeNk
JQXoGdlZalv377q27dvokcBACCtFI4pVPGWYioPoigpz0wBAAAkE0o7AQAAYoQwBQAAEEBEYcrMrjCzD8ys
wszuamV7rpm9bGbbzWyzmY2I/qgAAADJJ2yYMrNsSUWSrpQ0XNIsMxveYre7Jb3n7iMlXS/
p0WgPCgAA2kblQeJEcmZqrKQKd9/
l7kclPS9pWot9hkt6TZLcfaekwWbWL6qTAgCANhVvKVad16l4S3GiR8k4kYSpfEl7myxXNqxrapuk/
yVJZjZW0iBJA1s+kJndYmalZlZaXV19YhMDAIDjFI4pVLZlU3mQAJH0TFkr61r2KfxI0qNm9p6kP0t6V9Jx
TWDuvlTSUilUjdChSQEAQJuKJhepaHJRosfISJGEqUpJpzVZHijp06Y7uPtnkm6QJDMzSR833AAAANJaJC/
zlUgaamZDzKyzpJmSVjbdwcx6NWyTpO9K2tAQsAAAANJa2DNT7l5rZgskrZOULWmZu5eZ2dyG7U9IGibpKT
Ork7RD0k0xnBkAACBpRPTZfO6+RtKaFuueaPL925KGRnc0AAAy2/zV8xs/R4/
roZIXDegAACQp6g5SA2EKAIAkRd1BajD3xDQUFBQUeGlpaUKeGwAAoCPMbIu7F7S2jTNTAAAAARCmAAAAAi
BMAQAABECYAgAgzuavnq+c+3M0f/
X8RI+CKCBMAQAQZ1QepBfCFAAAcUblQXqhGgEAACAMqhEAAABihDAFAAAQAGEKAAAgAMIUAABRQuVBZiJMA
QAQJVQeZCbCFAAAUULlQWaiGgEAACAMqhEAAABihDAFAAAQAGEKAAAgAMIUAADtmD9fyskJfQVaQ5gCAKAd
xcVSXV3oK9AawhQAAO0oLJSys0NfgdZQjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQAGEKAJCRqDxAtBCmAAA
ZicoDRAthCgCQkag8QLRQjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQAGEKAJBWqDxAvBGmAABphcoDxBthCg
CQVqg8QLxRjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQQERhysyuMLMPzKzCzO5qZXtPM/
ujmW0zszIzuyH6owIAMhV1B0hmYS9AN7NsSR9KulxSpaQSSbPcfUeTfe6W1NPdF5lZnqQPJPV396NtPS4Xo
AMAIpWTE6o7yM6WamsTPQ0yUdAL0MdKqnD3XQ3h6HlJ01rs45JOMTOT1F3SXyXx1x0AEBXUHSCZRRKm8iXt
bbJc2bCuqcclDZP0qaQ/
S7rV3etbPpCZ3WJmpWZWWl1dfYIjAwAyTVFR6IxUUVGiJwGOF0mYslbWtXxtcJKk9yQNkHSBpMfNrMdxd3J
f6u4F7l6Ql5fXwVEBAACSTyRhqlLSaU2WByp0BqqpGySt8JAKSR9LOjc6IwIAACSvSMJUiaShZjbEzDpLmi
lpZYt99ki6TJLMrJ+kcyTtiuagAAAAyShsmHL3WkkLJK2TVC7pt+5eZmZzzWxuw24PSPonM/
uzpNckLXL3mlgNDQBID1QeIB3w2XwAgISh8gCpgs/
mAwAkJSoPkA44MwUAABAGZ6YAAABihDAFAAAQAGEKAAAgAMIUACDqqDxAJiFMAQCirrg4VHlQXJzoSYDYI0
wBAKKOygNkEqoRAAAAwqAaAQAAIEYIUwAAAAEQpgAAAAIgTAEAAARAmAIARITuKKB1hCkAQETojgJaR5gCA
ESE7iigdfRMAQAAhEHPFAAAQIwQpgAAAAIgTAEAAARAmAKADEflARAMYQoAMhyVB0AwhCkAyHBUHgDBUI0A
AAAQBtUIAAAAMUKYAgAACIAwBQAAEABhCgDSEHUHQPwQpgAgDVF3AMQPYQoA0hB1B0D8UI0AAAAQBtUIAAA
AMUKYAgAACIAwBQAAEABhCgBSCJUHQPIhTAFACqHyAEg+hCkASCFUHgDJh2oEAACAMKhGAAAAiBHCFAAAQA
CEKQAAgAAIUwCQBKg8AFJXRGHKzK4wsw/
MrMLM7mpl+x1m9l7D7X0zqzOz3tEfFwDSE5UHQOoKG6bMLFtSkaQrJQ2XNMvMhjfdx90fcvcL3P0CSd+T9I
a7/zUG8wJAWqLyAEhdkZyZGiupwt13uftRSc9LmtbO/
rMkPReN4QAgUxQVSbW1oa8AUkskYSpf0t4my5UN645jZidLukLSS21sv8XMSs2stLq6uqOzAgAAJJ1IwpS1
sq6tps+pkt5q6yU+d1/
q7gXuXpCXlxfpjAAAAEkrkjBVKem0JssDJX3axr4zxUt8AAAgg0QSpkokDTWzIWbWWaHAtLLlTmbWU9J4SX
+I7ogAkJqoOwAyQ9gw5e61khZIWiepXNJv3b3MzOaa2dwmu06X9Iq7H4nNqACQWqg7ADJDTiQ7ufsaSWtar
HuixfJyScujNRgApLrCwlCQou4ASG/
m3ta15LFVUFDgpaWlCXluAACAjjCzLe5e0No2Pk4GAAAgAMIUAABAAIQpAACAAAhTANBBVB4AaIowBQAdRO
UBgKYIUwDQQYWFUnY2lQcAQqhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUADag8AHAiCFMA0IDKAwAng
jAFAA2oPABwIqhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUgrVF3ACDWCFMA0hp1BwBijTAFIK1RdwAg
1qhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUgJVF5ACBZEKYApCQqDwAkC8IUgJRE5QGAZEE1AgAAQBh
UIwAAAMQIYQoAACAAwhQAAEAAhCkASYXKAwCphjAFIKlQeQAg1RCmACQVKg8ApBqqEQAAAMKgGgEAACBGCF
MAAAABEKYAAAACIEwBiDnqDgCkM8IUgJij7gBAOosoTJnZFWb2gZlVmNldbewzwczeM7MyM3sjumMCSGXUH
QBIZ2GrEcwsW9KHki6XVCmpRNIsd9/RZJ9ekv4k6Qp332Nmfd29qr3HpRoBAACkiqDVCGMlVbj7Lnc/
Kul5SdNa7HONpBXuvkeSwgUpAACAdBFJmMqXtLfJcmXDuqbOlpRrZuvNbIuZXd/
aA5nZLWZWamal1dXVJzYxAABAEokkTFkr61q+NpgjaYykyZImSbrXzM4+7k7uS929wN0L8vLyOjwsAABAso
kkTFVKOq3J8kBJn7ayz1p3P+LuNZI2SBoVnREBJCsqDwAgsjBVImmomQ0xs86SZkpa2WKfP0i62MxyzOxkS
RdJKo/uqACSDZUHABBBmHL3WkkLJK1TKCD91t3LzGyumc1t2Kdc0lpJ2yVtlvRLd38/
dmMDSAZUHgBABNUIsUI1AgAASBVBqxEAAADQBsIUAABAAIQpAACAAAhTAI5D5QEARI4wBeA4VB4AQOQIUwC
OQ+UBAESOagQAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQIag7gAAYoMwBWQI6g4AIDYIU0CGoO4AAGKDag
QAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQIqj8gAAEoswBaQ4Kg8AILEIU0CKo/
IAABKLagQAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQJKi8gAAUgNhCkhSVB4AQGogTAFJisoDAEgNVCMAA
ACEQTUCAABAjBCmAAAAAiBMAQAABECYAgAACIAwBcQR3VEAkH4IU0Ac0R0FAOmHMAXEEd1RAJB+6JkCAAAI
g54pAACAGCFMAQAABECYAgAACIAwBUQBlQcAkLkIU0AUUHkAAJmLMAVEAZUHAJC5IgpTZnaFmX1gZhVmdlc
r2yeY2SEze6/
htjj6owLJq6hIqq0NfQUAZJaccDuYWbakIkmXS6qUVGJmK919R4tdN7r7lBjMCAAAkLQiOTM1VlKFu+9y96
OSnpc0LbZjAQAApIZIwlS+pL1Nlisb1rU0zsy
2mdl/mtl5rT2Qmd1iZqVmVlpdXX0C4wIAACSXSMKUtbKu5WfQbJU0yN1HSfo/
kn7f2gO5+1J3L3D3gry8vA4NCsQbdQcAgEhEEqYqJZ3WZHmgpE+b7uDun7n74Ybv10jqZGZ9ojYlkADUHQA
AIhFJmCqRNNTMhphZZ0kzJa1suoOZ9Tcza/
h+bMPjHoz2sEA8UXcAAIhE2HfzuXutmS2QtE5StqRl7l5mZnMbtj8h6VuS/
reZ1Ur6m6SZ7t7ypUAgpRQVUXUAAAjPEpV5CgoKvLS0NCHPDQAA0BFmtsXdC1rbRgM6AABAAIQpAACAAAhT
yDhUHgAAookwhYxD5QEAIJoIU8g4VB4AAKKJd/
MBAACEwbv5AAAAYoQwBQAAEABhCgAAIADCFNIGlQcAgEQgTCFtUHkAAEgEwhTSBpUHAIBEoBoBAAAgDKoRA
AAAYoQwBQAAEABhCgAAIADCFJIadQcAgGRHmEJSo+4AAJDsCFNIatQdAACSHdUIAAAAYVCNAAAAECOEKQAA
gAAIUwAAAAEQppAQVB4AANIFYQoJQeUBACBdEKaQEFQeAADSBdUIAAAAYVCNAAAAECOEKQAAgAAIUwAAAAE
QphBVVB4AADINYQpRReUBACDTEKYQVVQeAAAyDdUIAAAAYVCNAAAAECOEKQAAgAAIUwAAAAEQphAWdQcAAL
SNMIWwqDsAAKBthCmERd0BAABtoxoBAAAgjMDVCGZ2hZl9YGYVZnZXO/
tdaGZ1ZvatEx0WAAAglYQNU2aWLalI0pWShkuaZWbD29jvx5LWRXtIAACAZBXJmamxkircfZe7H5X0vKRpr
ez3L5JeklQVxfkAAACSWiRhKl/
S3ibLlQ3rGplZvqTpkp5o74HM7BYzKzWz0urq6o7Oiiij8gAAgOAiCVPWyrqWV63/
XNIid69r74Hcfam7F7h7QV5eXoQjIlaoPAAAILhIwlSlpNOaLA+U9GmLfQokPW9muyV9S9K/
m9k3ozEgYofKAwAAggtbjWBmOZI+lHSZpH2SSiRd4+5lbey/
XNIqd3+xvcelGgEAAKSK9qoRcsLd2d1rzWyBQu/Sy5a0zN3LzGxuw/
Z2r5MCAABIZ2HDlCS5+xpJa1qsazVEufuc4GMBAACkBj5OBgAAIADCVBqi8gAAgPghTKUhKg8AAIgfwlQao
vIAAID4CVuNECtUIwAAgFTRXjUCZ6YAAAACIEwBAAAEQJgCAAAIgDCVIqg7AAAgORGmUgR1BwAAJCfCVIqg
7gAAgORENQIAAEAYVCMAAADECGEKAAAgAMIUAABAAISpBKPyAACA1EaYSjAqDwAASG2EqQSj8gAAgNRGNQI
AAEAYVCMAAADECGEKAAAgAMIUAABAAISpGKHyAACAzECYihEqDwAAyAyEqRih8gAAgMxANQIAAEAYVCMAAA
DECGEKAAAgAMIUAABAAISpDqDuAAAAtESY6gDqDgAAQEuEqQ6g7gAAALRENQIAAEAYVCMAAADECGEKAAAgA
MIUAABAAIQpUXkAAABOHGFKVB4AAIATR5gSlQcAAODEUY0AAAAQBtUIAAAAMRJRmDKzK8zsAzOrMLO7Wtk+
zcy2m9l7ZlZqZl+P/qgAAADJJyfcDmaWLalI0uWSKiWVmNlKd9/
RZLfXJK10dzezkZJ+K+ncWAwMAACQTCI5MzVWUoW773L3o5KelzSt6Q7uftj/
cfFVN0mJuRALAAAgziIJU/
mS9jZZrmxY14yZTTeznZJWS7oxOuMFQ38UAACItUjClLWy7rgzT+7+srufK+mbkh5o9YHMbmm4pqq0urq6Q
4OeCPqjAABArEUSpiolndZkeaCkT9va2d03SDrTzPq0sm2puxe4e0FeXl6Hh+0o+qMAAECsRRKmSiQNNbMh
ZtZZ0kxJK5vuYGZnmZk1fD9aUmdJB6M9bEcVFUm1taGvAAAAsRD23XzuXmtmCyStk5QtaZm7l5nZ3IbtT0i
6WtL1ZnZM0t8kzfBEtYECAADEEQ3oAAAAYdCADgAAECOEKQAAgAAIUwAAAAEQpgAAAAIgTAEAAARAmAIAAA
iAMAUAABAAYQoAACAAwhQAAEAAhCkAAIAACFMAAAABEKYAAAACSNgHHZtZtaRP4vBUfSTVxOF50HEcm+TG8
UleHJvkxvFJXkGOzSB3z2ttQ8LCVLyYWWlbn/KMxOLYJDeOT/Li2CQ3jk/
yitWx4WU+AACAAAhTAAAAAWRCmFqa6AHQJo5NcuP4JC+OTXLj+CSvmBybtL9mCgAAIJYy4cwUAABAzBCmAA
AAAkiLMGVmV5jZB2ZWYWZ3tbLdzOyxhu3bzWx0IubMVBEcn2sbjst2M/
uTmY1KxJyZKNyxabLfhWZWZ2bfiud8mS6S42NmE8zsPTMrM7M34j1jporgv2s9zeyPZrat4djckIg5M5GZL
TOzKjN7v43t0c8E7p7SN0nZkj6SdIakzpK2SRreYp+rJP2nJJP0NUnvJHruTLlFeHz+SVJuw/
dXcnyS59g02e//
Sloj6VuJnjtTbhH+2+klaYek0xuW+yZ67ky4RXhs7pb044bv8yT9VVLnRM+eCTdJl0gaLen9NrZHPROkw5m
psZIq3H2Xux+V9LykaS32mSbpKQ/ZJKmXmZ0a70EzVNjj4+5/cvf/17C4SdLAOM+YqSL5tyNJ/
yLpJUlV8RwOER2fayStcPc9kuTuHKP4iOTYuKRTzMwkdVcoTNXGd8zM5O4bFPp5tyXqmSAdwlS+pL1Nlisb
1nV0H8RGR3/2Nyn0GwNiL+yxMbN8SdMlPRHHuRASyb+dsyXlmtl6M9tiZtfHbbrMFsmxeVzSMEmfSvqzpFv
dvT4+4yGMqGeCnEDjJAdrZV3LvodI9kFsRPyzN7OJCoWpr8d0IvyPSI7NzyUtcve60C/
YiKNIjk+OpDGSLpPUVdLbZrbJ3T+M9XAZLpJjM0nSe5IulXSmpFfNbKO7fxbj2RBe1DNBOoSpSkmnNVkeqN
BvAh3dB7ER0c/
ezEZK+qWkK939YJxmy3SRHJsCSc83BKk+kq4ys1p3/31cJsxskf63rcbdj0g6YmYbJI2SRJiKrUiOzQ2Sfu
Shi3QqzOxjSedK2hyfEdGOqGeCdHiZr0TSUDMbYmadJc2UtLLFPislXd9wBf/
XJB1y9/3xHjRDhT0+Zna6pBWSruM36rgKe2zcfYi7D3b3wZJelDSPIBU3kfy37Q+SLjazHDM7WdJFksrjPG
cmiuTY7FHojKHMrJ+kcyTtiuuUaEvUM0HKn5ly91ozWyBpnULvsFjm7mVmNrdh+xMKvQvpKkkVkr5Q6DcGx
EGEx2expK9I+veGMyC1zieux1yExwYJEsnxcfdyM1srabukekm/
dPdW3w6O6Inw384Dkpab2Z8VellpkbvXJGzoDGJmz0maIKmPmVVKuk9SJyl2mYCPkwEAAAggHV7mAwAASBj
CFAAAQACEKQAAgAAIUwAAAAEQpgAAAAIgTAEAAARAmAIAAAjg/wOpTIj28IK1hAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_predictions();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mdElzVUJuWRe"
},
"source": [
"Epic!\n",
"\n",
"Now instead of just being numbers on a page, our data is a straight line.\n",
"\n",
"> **Note:** Now's a good time to introduce you to the data explorer's motto...
\"visualize, visualize, visualize!\"\n",
"> \n",
"> Think of this whenever you're working with data and turning it into numbers,
if you can visualize something, it can do wonders for understanding.\n",
">\n",
"> Machines love numbers and we humans like numbers too but we also like to
look at things."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0eFsorRHec00"
},
"source": [
"## 2. Build model\n",
"\n",
"Now we've got some data, let's build a model to use the blue dots to predict
the green dots.\n",
"\n",
"We're going to jump right in.\n",
"\n",
"We'll write the code first and then explain everything. \n",
"\n",
"Let's replicate a standard linear regression model using pure PyTorch."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "jhcUJBFuec00"
},
"outputs": [],
"source": [
"# Create a Linear Regression model class\n",
"class LinearRegressionModel(nn.Module): # <- almost everything in PyTorch is a
nn.Module (think of this as neural network lego blocks)\n",
" def __init__(self):\n",
" super().__init__() \n",
" self.weights = nn.Parameter(torch.randn(1, # <- start with random
weights (this will get adjusted as the model learns)\n",
" dtype=torch.float), # <-
PyTorch loves float32 by default\n",
" requires_grad=True) # <- can we update this
value with gradient descent?)\n",
"\n",
" self.bias = nn.Parameter(torch.randn(1, # <- start with random bias
(this will get adjusted as the model learns)\n",
" dtype=torch.float), # <- PyTorch
loves float32 by default\n",
" requires_grad=True) # <- can we update this
value with gradient descent?))\n",
"\n",
" # Forward defines the computation in the model\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor: # <- \"x\" is the
input data (e.g. training/testing features)\n",
" return self.weights * x + self.bias # <- this is the linear regression
formula (y = m*x + b)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xhu5wxVO7s_q"
},
"source": [
"Alright there's a fair bit going on above but let's break it down bit by bit.\
n",
"\n",
"> **Resource:** We'll be using Python classes to create bits and pieces for
building neural networks. If you're unfamiliar with Python class notation, I'd
recommend reading [Real Python's Object Orientating programming in Python 3 guide]
(https://realpython.com/python3-object-oriented-programming/) a few times.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iRRq3a0Gvvnl"
},
"source": [
"### PyTorch model building essentials\n",
"\n",
"PyTorch has four (give or take) essential modules you can use to create almost
any kind of neural network you can imagine.\n",
"\n",
"They are [`torch.nn`](https://pytorch.org/docs/stable/nn.html),
[`torch.optim`](https://pytorch.org/docs/stable/optim.html),
[`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.
data.Dataset) and
[`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html). For
now, we'll focus on the first two and get to the other two later (though you may be
able to guess what they do).\n",
"\n",
"| PyTorch module | What does it do? |\n",
"| ----- | ----- |\n",
"| [`torch.nn`](https://pytorch.org/docs/stable/nn.html) | Contains all of the
building blocks for computational graphs (essentially a series of computations
executed in a particular way). |\n",
"|
[`torch.nn.Parameter`](https://pytorch.org/docs/stable/generated/torch.nn.parameter
.Parameter.html#parameter) | Stores tensors that can be used with `nn.Module`. If
`requires_grad=True` gradients (used for updating model parameters via [**gradient
descent**](https://ml-cheatsheet.readthedocs.io/en/latest/gradient_descent.html))
are calculated automatically, this is often referred to as \"autograd\". | \n",
"|
[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#
torch.nn.Module) | The base class for all neural network modules, all the building
blocks for neural networks are subclasses. If you're building a neural network in
PyTorch, your models should subclass `nn.Module`. Requires a `forward()` method be
implemented. | \n",
"| [`torch.optim`](https://pytorch.org/docs/stable/optim.html) | Contains
various optimization algorithms (these tell the model parameters stored in
`nn.Parameter` how to best change to improve gradient descent and in turn reduce
the loss). | \n",
"| `def forward()` | All `nn.Module` subclasses require a `forward()` method,
this defines the computation that will take place on the data passed to the
particular `nn.Module` (e.g. the linear regression formula above). |\n",
"\n",
"If the above sounds complex, think of like this, almost everything in a
PyTorch neural network comes from `torch.nn`,\n",
"* `nn.Module` contains the larger building blocks (layers)\n",
"* `nn.Parameter` contains the smaller parameters like weights and biases (put
these together to make `nn.Module`(s))\n",
"* `forward()` tells the larger blocks how to make calculations on inputs
(tensors full of data) within `nn.Module`(s)\n",
"* `torch.optim` contains optimization methods on how to improve the parameters
within `nn.Parameter` to better represent input data \n",
"\n",
"\n",
"*Basic building blocks of creating a PyTorch model by subclassing `nn.Module`.
For objects that subclass `nn.Module`, the `forward()` method must be defined.*\n",
"\n",
"> **Resource:** See more of these essential modules and their use cases in the
[PyTorch Cheat Sheet](https://pytorch.org/tutorials/beginner/ptcheat.html). \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HYt5sKsgufG7"
},
"source": [
"\n",
"### Checking the contents of a PyTorch model\n",
"Now we've got these out of the way, let's create a model instance with the
class we've made and check its parameters using
[`.parameters()`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#to
rch.nn.Module.parameters). "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CsEKA3A_ec01",
"outputId": "cd999f12-2efd-4fe7-e449-d51ff98e5242"
},
"outputs": [
{
"data": {
"text/plain": [
"[Parameter containing:\n",
" tensor([0.3367], requires_grad=True),\n",
" Parameter containing:\n",
" tensor([0.1288], requires_grad=True)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Set manual seed since nn.Parameter are randomly initialized\n",
"torch.manual_seed(42)\n",
"\n",
"# Create an instance of the model (this is a subclass of nn.Module that
contains nn.Parameter(s))\n",
"model_0 = LinearRegressionModel()\n",
"\n",
"# Check the nn.Parameter(s) within the nn.Module subclass we created\n",
"list(model_0.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CNOmcQdSq34e"
},
"source": [
"We can also get the state (what the model contains) of the model using
[`.state_dict()`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#to
rch.nn.Module.state_dict)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XC1N_1Qrec01",
"outputId": "7e35b61c-371e-4d28-ae02-c1981afc1bbb"
},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# List named parameters \n",
"model_0.state_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tdTEPSwSec02"
},
"source": [
"Notice how the values for `weights` and `bias` from `model_0.state_dict()`
come out as random float tensors?\n",
"\n",
"This is because we initialized them above using `torch.randn()`.\n",
"\n",
"Essentially we want to start from random parameters and get the model to
update them towards parameters that fit our data best (the hardcoded `weight` and
`bias` values we set when creating our straight line data).\n",
"\n",
"> **Exercise:** Try changing the `torch.manual_seed()` value two cells above,
see what happens to the weights and bias values. \n",
"\n",
"Because our model starts with random values, right now it'll have poor
predictive power.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BDKdLN7nuheb"
},
"source": [
"### Making predictions using `torch.inference_mode()` \n",
"To check this we can pass it the test data `X_test` to see how closely it
predicts `y_test`.\n",
"\n",
"When we pass data to our model, it'll go through the model's `forward()`
method and produce a result using the computation we've defined. \n",
"\n",
"Let's make some predictions. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "-ITlZgU5ec02"
},
"outputs": [],
"source": [
"# Make predictions with model\n",
"with torch.inference_mode(): \n",
" y_preds = model_0(X_test)\n",
"\n",
"# Note: in older PyTorch code you might also see torch.no_grad()\n",
"# with torch.no_grad():\n",
"# y_preds = model_0(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L_Bx5I1FsIS0"
},
"source": [
"Hmm?\n",
"\n",
"You probably noticed we used
[`torch.inference_mode()`](https://pytorch.org/docs/stable/generated/
torch.inference_mode.html) as a [context manager](https://realpython.com/python-
with-statement/) (that's what the `with torch.inference_mode():` is) to make the
predictions.\n",
"\n",
"As the name suggests, `torch.inference_mode()` is used when using a model for
inference (making predictions).\n",
"\n",
"`torch.inference_mode()` turns off a bunch of things (like gradient tracking,
which is necessary for training but not for inference) to make **forward-passes**
(data going through the `forward()` method) faster.\n",
"\n",
"> **Note:** In older PyTorch code, you may also see `torch.no_grad()` being
used for inference. While `torch.inference_mode()` and `torch.no_grad()` do similar
things,\n",
"`torch.inference_mode()` is newer, potentially faster and preferred. See this
[Tweet from PyTorch](https://twitter.com/PyTorch/status/1437838231505096708?s=20)
for more.\n",
"\n",
"We've made some predictions, let's see what they look like. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "k4xCScCvec02",
"outputId": "2ce37ea3-6bc4-4e50-91ef-dcf53277dde7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of testing samples: 10\n",
"Number of predictions made: 10\n",
"Predicted values:\n",
"tensor([[0.3982],\n",
" [0.4049],\n",
" [0.4116],\n",
" [0.4184],\n",
" [0.4251],\n",
" [0.4318],\n",
" [0.4386],\n",
" [0.4453],\n",
" [0.4520],\n",
" [0.4588]])\n"
]
}
],
"source": [
"# Check the predictions\n",
"print(f\"Number of testing samples: {len(X_test)}\") \n",
"print(f\"Number of predictions made: {len(y_preds)}\")\n",
"print(f\"Predicted values:\\n{y_preds}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FnSwGbQEupZs"
},
"source": [
"Notice how there's one prediction value per testing sample.\n",
"\n",
"This is because of the kind of data we're using. For our straight line, one
`X` value maps to one `y` value. \n",
"\n",
"However, machine learning models are very flexible. You could have 100 `X`
values mapping to one, two, three or 10 `y` values. It all depends on what you're
working on.\n",
"\n",
"Our predictions are still numbers on a page, let's visualize them with our
`plot_predictions()` function we created above."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 428
},
"id": "pwjxLWZTec02",
"outputId": "56bf8a4d-2365-4539-a8b7-9bfe606f5b93"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAlMAAAGbCAYAAADgEhWsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtv0lEQVR4nO3de3RU9bn/
8c9DwiUQbkoAIQgIqCCiQkSx5aLiQQUWtdWCWAW1Gn7AObqOF6iegnhtFWttTWvUKt6qVgVrkaKWHwi2Ikl
QqBDwIF4AIwT7WyhYxSTP749J0yQkmQlzn3m/1pqV7L2/
s+dJdgiffPfez5i7CwAAAIenRbwLAAAASGaEKQAAgDAQpgAAAMJAmAIAAAgDYQoAACAMmfF64S5dunifPn3
i9fIAAAAhKykp2evuOQ1ti1uY6tOnj4qLi+P18gAAACEzs48b28ZpPgAAgDAQpgAAAMJAmAIAAAgDYQoAAC
AMhCkAAIAwBL2bz8welTRB0h53H9zAdpN0v6TzJX0labq7rw+3sC++
+EJ79uzRt99+G+6ukAbatWun3NxctWjB3wcAgNgKpTXCIkkPSHqike3nSRpQ/
ThN0m+rPx62L774Qrt371bPnj2VlZWlQF4DGlZVVaVdu3Zp79696tq1a7zLAQCkmaB/xrv7akn/
aGLIJElPeMBaSZ3M7KhwitqzZ4969uyptm3bEqQQVIsWLdStWzft27cv3qUAANJQJM6J9JS0o9byzup1h+3
bb79VVlZWWEUhvbRs2VIVFRXxLgMAkIYiEaYamjryBgeaXW1mxWZWXF5e3vROmZFCM/
DzAgCIl0iEqZ2SetVazpX0aUMD3f0hd89z97ycnAbf3gYAACCpRCJMvSzpMgs4XdI+dy+LwH4BAAASXtAwZ
WbPSHpL0nFmttPMrjSzGWY2o3rIMknbJW2T9LCkmVGrNg1Nnz5dEyZMaNZzxowZo9mzZ0epoqbNnj1bY8aM
ictrAwAQD0FbI7j7xUG2u6RZEasoSQW7ZmfatGlatGhRs/d7//33K/
AtDt3ixYvVsmXLZr9WPHz00Ufq27evioqKlJeXF+9yAABotlD6TCEEZWX/
PrO5dOlSXXXVVXXW1b878dtvvw0p8HTs2LHZtRxxxBHNfg4AADg8tIuOkO7du9c8OnXqVGfd119/
rU6dOumZZ57RWWedpaysLBUWFurzzz/XxRdfrNzcXGVlZemEE07QY489Vme/9U/
zjRkzRjNnztRNN92kLl26qGvXrrr++utVVVVVZ0zt03x9+vTR7bffrvz8fHXo0EG5ubm655576rzO+++/
r9GjR6tNmzY67rjjtGzZMmVnZzc5m1ZZWanrr79enTt3VufOnXXttdeqsrKyzpjly5dr5MiR6ty5s4444gi
NGzdOpaWlNdv79u0rSTr11FNlZjWnCIuKivQf//
Ef6tKlizp06KDvfve7euutt4IfCABAWpn1yixl3pqpWa/E7yQZYSqGfvKTn2jmzJnavHmzvve97+nrr7/
W0KFDtXTpUm3atEnXXHON8vPztWLFiib38/TTTyszM1N/+9vf9MADD+iXv/ylnnvuuSafc9999+nEE0/
U+vXrNWfOHN1444014aSqqkoXXHCBMjMztXbtWi1atEgLFizQN9980+Q+7733Xj388MMqLCzUW2+9pcrKSj
399NN1xhw4cEDXXnut1q1bp1WrVqljx46aOHGiDh48KElat26dpEDoKisr0+LFiyVJX375pS699FKtWbNG6
9at08knn6zzzz9fe/fubbImAEB6KSwpVKVXqrCkMH5FuHtcHsOGDfPGbN68udFtzTVzpntGRuBjrDz//
PMe+NYGfPjhhy7JFy5cGPS5kydP9iuvvLJmedq0aT5+/Pia5dGjR/
vpp59e5zljx46t85zRo0f7rFmzapZ79+7tU6ZMqfOc/v37+2233ebu7suXL/eMjAzfuXNnzfa//
vWvLskfe+yxRms96qij/Pbbb69Zrqys9AEDBvjo0aMbfc7+/fu9RYsWvmbNGnf/9/
emqKio0ee4u1dVVXn37t39ySefbHRMJH9uAADJYebSmZ6xIMNnLo3uf/SSir2RTJPyM1OFhVJlZeBjvNW/
wLqyslJ33HGHhgwZoiOPPFLZ2dlavHixPvnkkyb3M2TIkDrLPXr00J49ew77OVu2bFGPHj3Us+e/
G9efeuqpTb5p8L59+1RWVqYRI0bUrGvRooVOO63u2zJ+8MEHmjp1qvr166cOHTqoW7duqqqqCvo17tmzR/
n5+Tr22GPVsWNHtW/fXnv27An6PABAeikYX6CKeRUqGF8QtxpS/gL0/
PxAkMrPj3clUrt27eosL1y4UPfee6/uv/
9+nXjiicrOztZNN90UNBjVv3DdzOpcM9Xc57h71DqIT5w4UT179lRhYaF69uypzMxMDRo0qOY0X2OmTZum3
bt367777lOfPn3UunVrnX322UGfBwBArKV8mCooCDwS0ZtvvqmJEyfq0ksvlRQINe+//
37NBeyxMnDgQO3atUuffvqpevToIUkqLi5uMqB17NhRRx11lNauXauzzjpLUqD+devW6aijAu9z/
fnnn6u0tFQFBQU688wzJUnr16+v8x56rVq1kqRDLlx/88039atf/Urjx4+XJO3evbvO3ZEAACSKlD/
Nl8iOPfZYrVixQm++
+aa2bNmi2bNn68MPP4x5Heecc46OO+44TZs2TRs2bNDatWv13//938rMzGxyxuqaa67R3XffrRdeeEFbt27
VtddeWyfwdO7cWV26dNHDDz+sbdu26Y033tCMGTOUmfnvDN+1a1dlZWXp1Vdf1e7du7Vv3z5Jge/
NU089pc2bN6uoqEhTpkypCV4AACQSwlQc/c///
I+GDx+u8847T6NGjVK7du10ySWXxLyOFi1aaMmSJfrmm280fPhwTZs2TTfffLPMTG3atGn0edddd50uv/
xy/
fjHP9Zpp52mqqqqOvW3aNFCzz33nDZu3KjBgwdr1qxZuu2229S6deuaMZmZmfrVr36lRx55RD169NCkSZMk
SY8+
+qj279+vYcOGacqUKbriiivUp0+fqH0PAACJIxHaHTSHeTO7a0dKXl6eFxcXN7ittLRUAwcOjHFFqG3Dhg0
6+eSTVVxcrGHDhsW7nJDwcwMAqSHz1kxVeqUyLEMV8yqCPyEGzKzE3Rt8qw5mpiBJWrJkiV577TV9+OGHWr
lypaZPn66TTjpJQ4cOjXdpAIA0kz8sXxmWofxhCXD3WAhS/gJ0hObLL7/
UnDlztGPHDnXu3FljxozRfffdF7W7/
AAAaEzB+IK4tjpoLsIUJEmXXXaZLrvssniXAQBA0uE0HwAAQBgIUwAAAGEgTAEAgJhItpYHoSJMAQCAmCgs
KVSlV6qwJAHeMDeCCFMAACAmkq3lQai4mw8AAMREsrU8CBUzU0msT58+WrhwYVxee8KECZo+fXpcXhsAgER
CmIoQM2vyEU7wuOWWWzR48OBD1hcVFWnmzJlhVB07q1atkplp79698S4FAICI4jRfhJSVldV8vnTpUl111V
V11mVlZUX8NXNyciK+TwAA0DzMTEVI9+7dax6dOnU6ZN3q1as1bNgwtWnTRn379tXNN9+sgwcP1jx/
8eLFGjJkiLKysnTEEUdo9OjR2r17txYtWqQFCxZo06ZNNbNcixYtknToaT4z00MPPaSLLrpI7dq10zHHHKO
nnnqqTp1vv/22hg4dqjZt2uiUU07RsmXLZGZatWpVo1/
bV199penTpys7O1vdunXTnXfeeciYp556Sqeeeqrat2+vrl276qKLLtKuXbskSR999JHOPPNMSYEAWHumbv
ny5Ro5cqQ6d+6sI444QuPGjVNpaWlzv/
0AgDhK1ZYHoSJMxcCrr76qSy65RLNnz9amTZv06KOP6oUXXtBNN90kSfrss880ZcoUTZs2TaWlpVq9erUuv
fRSSdLkyZN13XXX6bjjjlNZWZnKyso0efLkRl/r1ltv1aRJk7RhwwZNnjxZV1xxhT7+
+GNJ0v79+zVhwgQdf/zxKikp0d13360bbrghaP3XX3+9Xn/
9db344otasWKF3nnnHa1evbrOmIMHD2rBggXasGGDli5dqr179+riiy+WJPXq1UsvvviiJGnTpk0qKyvT/
fffL0k6cOCArr32Wq1bt06rVq1Sx44dNXHixDpBEwCQ2FK15UHI3D0uj2HDhnljNm/
e3Oi25pq5dKZnLMjwmUtnRmyfwTz//PMe+NYGjBw50m+99dY6Y5YsWeLt2rXzqqoqLykpcUn+0UcfNbi/
+fPn+wknnHDI+t69e/s999xTsyzJ586dW7P87bffelZWlj/55JPu7v7ggw96586d/auvvqoZ8/
TTT7skX7lyZYOv/eWXX3qrVq38qaeeqrOuY8eOPm3atEa/
B6WlpS7Jd+zY4e7uK1eudEleXl7e6HPc3ffv3+8tWrTwNWvWNDmuIZH8uQEAhC4e/
9fGmqRibyTTpPzMVCKk5ZKSEt1xxx3Kzs6ueUydOlUHDhzQZ599ppNOOkljx47V4MGD9YMf/EC//
e1vVV5eflivNWTIkJrPMzMzlZOToz179kiStmzZosGDB9e5fuu0005rcn8ffPCBDh48qBEjRtSsy87O1okn
nlhn3Pr16zVp0iT17t1b7du3V15eniTpk08+Cbr/qVOnql+/
furQoYO6deumqqqqoM8DACSOgvEFqphXkZJtD0KR8mEqERqEVVVVaf78+Xr33XdrHhs3btT//u//
KicnRxkZGXrttdf02muvaciQIfrd736nAQMGaMOGDc1+rZYtW9ZZNjNVVVVJCsxCmlmz9hcI4007cOCAxo0
bp7Zt2+rJJ59UUVGRli9fLklBT9dNnDhR5eXlKiws1Ntvv6133nlHmZmZnOYDACSNlL+bLxEahA0dOlRbtm
xR//79Gx1jZhoxYoRGjBihefPm6YQTTtBzzz2nk046Sa1atVJlZWXYdQwcOFBPPPGE/
vnPf9bMTq1bt67J5/Tv318tW7bU2rVrdcwxx0gKhKf33ntP/fr1kxSY8dq7d6/
uvPNO9e3bV1LggvraWrVqJUl1vo7PP/9cpaWlKigoqLlAff369aqoqAj7awUAIFZSfmYqEcybN0+///
3vNW/
ePL333nvasmWLXnjhBd14442SpLVr1+r2229XUVGRPvnkE7388svasWOHBg0aJClw197HH3+s9evXa+/
evfrmm28Oq45LLrlEGRkZuuqqq7R582b95S9/qbkzr7EZq+zsbF155ZWaM2eOXn/
9dW3atElXXHFFnVB09NFHq3Xr1nrggQe0fft2vfLKK/rpT39aZz+9e/eWmemVV15ReXm59u/
fr86dO6tLly56+OGHtW3bNr3xxhuaMWOGMjNTPuMDAFIIYSoGxo0bp1deeUUrV67U8OHDNXz4cP3sZz/
T0UcfLUnq2LGj/vrXv2rChAkaMGCArrvuOv30pz/Vj370I0nSD37wA51//vk6+
+yzlZOTo2eeeeaw6sjOztaf/
vQnbdq0SaeccopuuOEG3XLLLZKkNm3aNPq8hQsX6swzz9QFF1ygM888U4MHD9aoUaNqtufk5Ojxxx/
XSy+9pEGDBmnBggX6xS9+UWcfPXv21IIFC3TzzTerW7dumj17tlq0aKHnnntOGzdu1ODBgzVr1izddtttat
269WF9fQCAyEn3dgfNYaFcExMNeXl5Xlxc3OC20tJSDRw4MMYVpac//
vGPuuCCC7Rnzx516dIl3uWEhZ8bAIiczFszVemVyrAMVczj8gszK3H3vIa2MTOVZh5//
HGtWbNGH330kZYuXaprr71WEydOTPogBQCIrES4gStZcHFKmtm9e7fmz5+vsrIyde/eXePHj9fPf/
7zeJcFAEgwiXADV7IgTKWZG2+8sebCdwAAED5O8wEAAISBMAUAABAGwhQAAGmElgeRR5gCACCNJMJ71qYaw
hQAAGmElgeRx918AACkEVoeRB4zU0nohRdeqPNeeosWLVJ2dnZY+1y1apXMTHv37g23PAAA0gphKoKmT58u
M5OZqWXLljrmmGN0/fXX68CBA1F93cmTJ2v79u0hj+/
Tp48WLlxYZ90ZZ5yhsrIyHXnkkZEuDwCAlBZSmDKzc81sq5ltM7O5DWzvbGZLzGyjma0zs8GRLzU5jB07Vm
VlZdq+fbtuv/12/eY3v9H1119/
yLiKigpF6n0Rs7Ky1LVr17D20apVK3Xv3r3OjBcAAAguaJgyswxJBZLOkzRI0sVmNqjesJskvevuQyRdJun
+SBeaLFq3bq3u3burV69emjp1qi655BK99NJLuuWWWzR48GAtWrRI/
fr1U+vWrXXgwAHt27dPV199tbp27ar27dtr9OjRqv8G0E888YR69+6ttm3basKECdq9e3ed7Q2d5nvllVd0
2mmnKSsrS0ceeaQmTpyor7/+WmPGjNHHH3+sG264oWYWTWr4NN/
ixYt14oknqnXr1urVq5fuuOOOOgGwT58+uv3225Wfn68OHTooNzdX99xzT506CgsLdeyxx6pNmzbKycnRuH
HjVFHBG2YCQKTR8iB+QpmZGi5pm7tvd/
eDkp6VNKnemEGSVkiSu2+R1MfMukW00iSVlZWlb7/9VpL04Ycf6ve//72ef/
55bdiwQa1bt9b48eO1a9cuLV26VO+8845GjRqls846S2VlZZKkt99+W9OnT9fVV1+td999VxMnTtS8efOaf
M3ly5dr0qRJOuecc1RSUqKVK1dq9OjRqqqq0uLFi5Wbm6t58+aprKys5nXqKykp0UUXXaTvf//7+vvf/
66f/exnuuuuu/TAAw/
UGXfffffpxBNP1Pr16zVnzhzdeOONeuuttyRJxcXFmjVrlubPn6+tW7fqL3/5i84999xwv6UAgAbQ8iCO3L
3Jh6QLJT1Sa/lSSQ/UG3OnpF9Ufz5cUoWkYQ3s62pJxZKKjz76aG/
M5s2bG93WbDNnumdkBD5G2bRp03z8+PE1y2+//bYfeeSR/sMf/tDnz5/vmZmZ/
tlnn9VsX7Fihbdr186/+uqrOvs56aST/Oc//
7m7u1988cU+duzYOtuvvPJKDxy6gMcee8zbtWtXs3zGGWf45MmTG62zd+/
efs8999RZt3LlSpfk5eXl7u4+depUP/PMM+uMmT9/vvfs2bPOfqZMmVJnTP/+/f22225zd/
cXX3zRO3To4F988UWjtURSRH9uACDJzFw60zMWZPjMpdH//y4dSSr2RrJSKDNTDV1EU/
9in59J6mxm70r6T0nvVAeq+sHtIXfPc/
e8nJycEF46AgoLpcrKwMcYWL58ubKzs9WmTRuNGDFCo0aN0q9//WtJUm5urrp1+/
eEXUlJib766ivl5OQoOzu75vHee+/pgw8+kCSVlpZqxIgRdV6j/nJ977zzjs4++
+ywvo7S0lJ95zvfqbPuu9/9rnbt2qUvvviiZt2QIUPqjOnRo4f27NkjSTrnnHPUu3dv9e3bV5dccokef/
xxffnll2HVBQBoWMH4AlXMq6DtQRyE0mdqp6RetZZzJX1ae4C7fyHpckmywEU4H1Y/4i8/
PxCk8mPTnGzUqFF66KGH1LJlS/
Xo0UMtW7as2dauXbs6Y6uqqtStWzetWbPmkP106NBBkiJ2kXpzuXujF6PXXl/76/vXtqqqKklS+/
bttX79eq1evVqvv/667rrrLt10000qKipSjx49olc8AAAxFMrMVJGkAWbW18xaSZoi6eXaA8ysU/
U2SfqxpNXVASv+CgqkiorAxxho27at+vfvr969ex8SNOobOnSodu/erRYtWqh///51Hv+6O2/
QoEFau3ZtnefVX67vlFNO0YoVKxrd3qpVK1VWVja5j0GDBunNN9+ss+7NN99Ubm6u2rdv3+Rza8vMzNRZZ5
2lu+66Sxs3btSBAwe0dOnSkJ8PAECiCzoz5e4VZjZb0quSMiQ96u6bzGxG9fYHJQ2U9ISZVUraLOnKKNacM
saOHavvfOc7mjRpku6++24df/zx+uyzz7R8+XKNHTtWI0eO1H/913/
pjDPO0F133aULL7xQq1at0pIlS5rc780336yJEyeqf//
+mjp1qtxdr732mvLz89W2bVv16dNHa9as0Y9+9CO1bt1aXbp0OWQf1113nU499VTdcsstmjp1qoqKinTvvf
fqzjvvDPnrW7p0qT744AONGjVKRxxxhFauXKkvv/xSAwcObPb3CgCARBVSnyl3X+bux7p7P3e/
o3rdg9VBSu7+lrsPcPfj3f377v7/oll0qjAzLVu2TGeddZauuuoqHXfccfrhD3+orVu31pwGO/300/
W73/1Ov/3tbzVkyBAtXrxYt9xyS5P7Pf/887VkyRL9+c9/1imnnKLRo0dr5cqVatEicLhvvfVW7dixQ/
369VNj164NHTpUzz//
vF588UUNHjxYc+fO1dy5czV79uyQv75OnTrppZde0tixY3X88cdr4cKFeuSRRzRy5MiQ9wEA6Yx2B8nB4nV
NTl5entfvp/QvpaWlzF6g2fi5AZBqMm/NVKVXKsMyVDGPHn3xZGYl7p7X0DbeTgYAgASVPyxfGZah/
GGxuYkKhyeUu/
kAAEAcFIwvoNVBEmBmCgAAIAyEKQAAgDAkbJj6V+NHIBTxupECAICEDFPt2rXTrl27dPDgQf6TRFDurs8//
1xt2rSJdykAEBJaHqSWhGyNUFVVpb1792rfvn2qqOBWUATXpk0b5ebmBu06DwCJgJYHyaep1ggJeTdfixYt
1LVr15q3VAEAIJXkD8tXYUkhLQ9SRELOTAEAACQSmnYCAABECWEKAAAgDIQpAACAMBCmAACIEFoepCfCFAA
AEVJYUqhKr1RhSWG8S0EMEaYAAIiQ/
GH5yrAMWh6kGVojAAAABEFrBAAAgCghTAEAAISBMAUAABAGwhQAAE2YNUvKzAx8BBpCmAIAoAmFhVJlZeAj
0BDCFAAATcjPlzIyAh+BhtAaAQAAIAhaIwAAAEQJYQoAACAMhCkAAIAwEKYAAGmJlgeIFMIUACAt0fIAkUK
YAgCkJVoeIFJojQAAABAErREAAACihDAFAAAQBsIUAABAGAhTAICUQssDxBphCgCQUmh5gFgjTAEAUgotDx
BrtEYAAAAIgtYIAAAAUUKYAgAACANhCgAAIAwhhSkzO9fMtprZNjOb28D2jmb2JzPbYGabzOzyyJcKAEhXt
DtAIgt6AbqZZUh6X9I5knZKKpJ0sbtvrjXmJkkd3X2OmeVI2iqpu7sfbGy/
XIAOAAhVZmag3UFGhlRREe9qkI7CvQB9uKRt7r69Ohw9K2lSvTEuqb2ZmaRsSf+QxI87ACAiaHeARBZKmOo
paUet5Z3V62p7QNJASZ9K+ruka9y9qv6OzOxqMys2s+Ly8vLDLBkAkG4KCgIzUgUF8a4EOFQoYcoaWFf/
3OA4Se9K6iHpZEkPmFmHQ57k/
pC757l7Xk5OTjNLBQAASDyhhKmdknrVWs5VYAaqtsslLfaAbZI+lHR8ZEoEAABIXKGEqSJJA8ysr5m1kjRF
0sv1xnwi6WxJMrNuko6TtD2ShQIAACSioGHK3SskzZb0qqRSSX9w901mNsPMZlQPu03SGWb2d0krJM1x973
RKhoAkBpoeYBUwHvzAQDihpYHSBa8Nx8AICHR8gCpgJkpAACAIJiZAgAAiBLCFAAAQBgIUwAAAGEgTAEAIo
6WB0gnhCkAQMQVFgZaHhQWxrsSIPoIUwCAiKPlAdIJrREAAACCoDUCAABAlBCmAAAAwkCYAgAACANhCgAAI
AyEKQBASOgdBTSMMAUACAm9o4CGEaYAACGhdxTQMPpMAQAABEGfKQAAgCghTAEAAISBMAUAABAGwhQApDla
HgDhIUwBQJqj5QEQHsIUAKQ5Wh4A4aE1AgAAQBC0RgAAAIgSwhQAAEAYCFMAAABhIEwBQAqi3QEQO4QpAEh
BtDsAYocwBQApiHY
HQOzQGgEAACAIWiMAAABECWEKAAAgDIQpAACAMBCmACCJ0PIASDyEKQBIIrQ8ABIPYQoAkggtD4DEQ2sEAA
CAIGiNAAAAECWEKQAAgDAQpgAAAMJAmAKABEDLAyB5hRSmzOxcM9tqZtvMbG4D228ws3erH+
+ZWaWZHRH5cgEgNdHyAEheQcOUmWVIKpB0nqRBki42s0G1x7j7Pe5+srufLOknkt5w939EoV4ASEm0PACSV
ygzU8MlbXP37e5+UNKzkiY1Mf5iSc9EojgASBcFBVJFReAjgOQSSpjqKWlHreWd1esOYWZtJZ0r6cVGtl9t
ZsVmVlxeXt7cWgEAABJOKGHKGljXWKfPiZL+2tgpPnd/
yN3z3D0vJycn1BoBAAASVihhaqekXrWWcyV92sjYKeIUHwAASCOhhKkiSQPMrK+ZtVIgML1cf5CZdZQ0WtI
fI1siACQn2h0A6SFomHL3CkmzJb0qqVTSH9x9k5nNMLMZtYZeIOk1dz8QnVIBILnQ7gBID5mhDHL3ZZKW1V
v3YL3lRZIWRaowAEh2+fmBIEW7AyC1mXtj15JHV15enhcXF8fltQEAAJrDzErcPa+hbbydDAAAQBgIUwAAA
GEgTAEAAISBMAUAzUTLAwC1EaYAoJloeQCgNsIUADRTfr6UkUHLAwABtEYAAAAIgtYIAAAAUUKYAgAACANh
CgAAIAyEKQCoRssDAIeDMAUA1Wh5AOBwEKYAoBotDwAcDlojAAAABEFrBAAAgCghTAEAAISBMAUAABAGwhS
AlEa7AwDRRpgCkNJodwAg2ghTAFIa7Q4ARButEQAAAIKgNQIAAECUEKYAAADCQJgCAAAIA2EKQFKi5QGARE
GYApCUaHkAIFEQpgAkJVoeAEgUtEYAAAAIgtYIAAAAUUKYAgAACANhCgAAIAyEKQAJhZYHAJINYQpAQqHlA
YBkQ5gCkFBoeQAg2dAaAQAAIAhaIwAAAEQJYQoAACAMhCkAAIAwEKYARB3tDgCkMsIUgKij3QGAVBZSmDKz
c81sq5ltM7O5jYwZY2bvmtkmM3sjsmUCSGa0OwCQyoK2RjCzDEnvSzpH0k5JRZIudvfNtcZ0kvQ3See6+yd
m1tXd9zS1X1ojAACAZBFua4Thkra5+3Z3PyjpWUmT6o2ZKmmxu38iScGCFAAAQKoIJUz1lLSj1vLO6nW1HS
ups5mtMrMSM7usoR2Z2dVmVmxmxeXl5YdXMQAAQAIJJUxZA+vqnxvMlDRM0nhJ4yT91MyOPeRJ7g+5e5675
+Xk5DS7WAAAgEQTSpjaKalXreVcSZ82MGa5ux9w972SVks6KTIlAkhUtDwAgNDCVJGkAWbW18xaSZoi6eV6
Y/
4oaaSZZZpZW0mnSSqNbKkAEg0tDwAghDDl7hWSZkt6VYGA9Ad332RmM8xsRvWYUknLJW2UtE7SI+7+XvTKB
pAIaHkAACG0RogWWiMAAIBkEW5rBAAAADSCMAUAABAGwhQAAEAYCFMADkHLAwAIHWEKwCFoeQAAoSNMATgE
LQ8AIHS0RgAAAAiC1ggAAABRQpgCAAAIA2EKAAAgDIQpIE3Q7gAAooMwBaQJ2h0AQHQQpoA0QbsDAIgOWiM
AAAAEQWsEAACAKCFMAQAAhIEwBQAAEAbCFJDkaHkAAPFFmAKSHC0PACC+CFNAkqPlAQDEF60RAAAAgqA1Ag
AAQJQQpgAAAMJAmAIAAAgDYQpIULQ8AIDkQJgCEhQtDwAgORCmgARFywMASA60RgAAAAiC1ggAAABRQpgCA
AAIA2EKAAAgDIQpAACAMBCmgBiidxQApB7CFBBD9I4CgNRDmAJiiN5RAJB66DMFAAAQBH2mAAAAooQwBQAA
EAbCFAAAQBgIU0AE0PIAANIXYQqIAFoeAED6IkwBEUDLAwBIXyGFKTM718y2mtk2M5vbwPYxZrbPzN6tfsy
LfKlA4iookCoqAh8BAOklM9gAM8uQVCDpHEk7JRWZ2cvuvrne0DXuPiEKNQIAACSsUGamhkva5u7b3f2gpG
clTYpuWQAAAMkhlDDVU9KOWss7q9fVN8LMNpjZn83shIZ2ZGZXm1mxmRWXl5cfRrkAAACJJZQwZQ2sq/
8eNOsl9Xb3kyT9WtJLDe3I3R9y9zx3z8vJyWlWoUCs0e4AABCKUMLUTkm9ai3nSvq09gB3/8Ld91d/
vkxSSzPrErEqgTig3QEAIBShhKkiSQPMrK+ZtZI0RdLLtQeYWXczs+rPh1fv9/
NIFwvEEu0OAAChCHo3n7tXmNlsSa9KypD0qLtvMrMZ1dsflHShpP9jZhWS/
ilpirvXPxUIJJWCAlodAACCs3hlnry8PC8uLo7LawMAADSHmZW4e15D2+iADgAAEAbCFAAAQBgIU0g7tDwA
AEQSYQpph5YHAIBIIkwh7dDyAAAQSdzNBwAAEAR38wEAAEQJYQoAACAMhCkAAIAwEKaQMmh5AACIB8IUUgY
tDwAA8UCYQsqg5QEAIB5ojQAAABAErREAAEBqSoALZglTAAAgeSXABbOEKQAAkLwS4IJZwhQSWgLM3gIAEl
lBgVRREfgYJ4QpJLQEmL0FAMRakv0lTZhCQkuA2VsAQKwl2V/ShCkktASYvQUAxFqS/
SVNmAIAALER6um7JPtLmjAFAABiI8lO34WKMAUAAGIjyU7fhYowhbhIshs1AACRkGSn70JFmEJcpOhMLwCk
pzT/C5kwhbhI0ZleAEhPaf4XMmEKcZGiM70AkJ7S/C9kwhQAADhUc07dpflfyIQpAABwqDQ/
ddcchCkAAHCoND911xyEKURUmt/QAQCJL0W7kMeTuXtcXjgvL8+Li4vj8tqInszMwKxwRkbg3yAAIMHwi/
qwmFmJu+c1tI2ZKUQUs8IAkOD4RR1xzEwBAAAEwcwUAACpjotW44YwBQBAKqCVQdwQpgAASAVcCxU3hCkEx
cwxAMQJXciTAhegIyjuogWAOOEXcMLgAnSEhZljAIgTfgEnBWamAAAAggh7ZsrMzjWzrWa2zczmNjHuVDOr
NLMLD7dYAABSHhejppSgYcrMMiQVSDpP0iBJF5vZoEbG/
VzSq5EuEgCAlEIbg5QSyszUcEnb3H27ux+U9KykSQ2M+09JL0raE8H6AABIPVwLlVJCCVM9Je2otbyzel0N
M+sp6QJJDza1IzO72syKzay4vLy8ubUiwphlBoAIC/UXK20MUkooYcoaWFf/
qvVfSprj7pVN7cjdH3L3PHfPy8nJCbFERAuzzAAQYfxiTUuhhKmdknrVWs6V9Gm9MXmSnjWzjyRdKOk3Zva
9SBSI6GGWGQAijF+saSloawQzy5T0vqSzJe2SVCRpqrtvamT8IklL3f2FpvZLawQAAJAsmmqNkBnsye5eYW
azFbhLL0PSo+6+ycxmVG9v8jopAACAVBY0TEmSuy+TtKzeugZDlLtPD78sAACA5MDbyQAAAISBMJWCaHkAA
EDsEKZSEHfmAgAQO4SpFMSduQAAxE7Q1gjRQmsEAACQLJpqjcDMFAAAQBgIUwAAAGEgTAEAAISBMJUkaHcA
AEBiIkwlCdodAACQmAhTSYJ2BwAAJCZaIwAAAARBawQAAIAoIUwBAACEgTAFAAAQBsJUnNHyAACA5EaYijN
aHgAAkNwIU3FGywMAAJIbrREAAACCoDUCAABAlBCmAAAAwkCYAgAACANhKkpoeQAAQHogTEUJLQ8AAEgPhK
kooeUBAADpgdYIAAAAQdAaAQAAIEoIUwAAAGEgTAEAAISBMNUMtDsAAAD1EaaagXYHAACgPsJUM9DuAAAA1
EdrBAAAgCBojQAAABAlhCkAAIAwEKYAAADCQJgSLQ8AAMDhI0yJlgcAAODwEaZEywMAAHD4aI0AAAAQBK0R
AAAAoiSkMGVm55rZVjPbZmZzG9g+ycw2mtm7ZlZsZt+NfKkAAACJJzPYADPLkFQg6RxJOyUVmdnL7r651rA
Vkl52dzezIZL+IOn4aBQMAACQSEKZmRouaZu7b3f3g5KelTSp9gB33+//
vviqnaT4XIgFAAAQY6GEqZ6SdtRa3lm9rg4zu8DMtkh6RdIVkSkvPPSPAgAA0RZKmLIG1h0y8+TuS9z9eEn
fk3Rbgzsyu7r6mqri8vLyZhV6OOgfBQAAoi2UMLVTUq9ay7mSPm1ssLuvltTPzLo0sO0hd89z97ycnJxmF9
tc9I8CAADRFkqYKpI0wMz6mlkrSVMkvVx7gJn1NzOr/nyopFaSPo90sc1VUCBVVAQ+AgAAREPQu/
ncvcLMZkt6VVKGpEfdfZOZzaje/
qCkH0i6zMy+lfRPSZM9Xt1AAQAAYogO6AAAAEHQAR0AACBKCFMAAABhIEwBAACEgTAFAAAQBsIUAABAGAhT
AAAAYSBMAQAAhIEwBQAAEAbCFAAAQBgIUwAAAGEgTAEAAISBMAUAABCGuL3RsZmVS/o4Bi/
VRdLeGLwOmo9jk9g4PomLY5PYOD6JK5xj09vdcxraELcwFStmVtzYuzwjvjg2iY3jk7g4NomN45O4onVsOM
0HAAAQBsIUAABAGNIhTD0U7wLQKI5NYuP4JC6OTWLj+CSuqByblL9mCgAAIJrSYWYKAAAgaghTAAAAYUiJM
GVm55rZVjPbZmZzG9huZvar6u0bzWxoPOpMVyEcn0uqj8tGM/
ubmZ0UjzrTUbBjU2vcqWZWaWYXxrK+dBfK8TGzMWb2rpltMrM3Yl1jugrh91pHM/
uTmW2oPjaXx6POdGRmj5rZHjN7r5Htkc8E7p7UD0kZkj6QdIykVpI2SBpUb8z5kv4sySSdLunteNedLo8Qj
88ZkjpXf34exydxjk2tcf9X0jJJF8a77nR5hPhvp5OkzZKOrl7uGu+60+ER4rG5SdLPqz/
PkfQPSa3iXXs6PCSNkjRU0nuNbI94JkiFmanhkra5+3Z3PyjpWUmT6o2ZJOkJD1grqZOZHRXrQtNU0OPj7n
9z9/9XvbhWUm6Ma0xXofzbkaT/
lPSipD2xLA4hHZ+pkha7+yeS5O4co9gI5di4pPZmZpKyFQhTFbEtMz25+2oFvt+NiXgmSIUw1VPSjlrLO6v
XNXcMoqO53/
srFfiLAdEX9NiYWU9JF0h6MIZ1ISCUfzvHSupsZqvMrMTMLotZdektlGPzgKSBkj6V9HdJ17h7VWzKQxARz
wSZYZWTGKyBdfX7PYQyBtER8vfezM5UIEx9N6oV4V9COTa/
lDTH3SsDf2AjhkI5PpmShkk6W1KWpLfMbK27vx/t4tJcKMdmnKR3JZ0lqZ+k181sjbt/
EeXaEFzEM0EqhKmdknrVWs5V4C+B5o5BdIT0vTezIZIekXSeu38eo9rSXSjHJk/
Ss9VBqouk882swt1fikmF6S3U32173f2ApANmtlrSSZIIU9EVyrG5XNLPPHCRzjYz+1DS8ZLWxaZENCHimS
AVTvMVSRpgZn3NrJWkKZJerjfmZUmXVV/Bf7qkfe5eFutC01TQ42NmR0taLOlS/
qKOqaDHxt37unsfd+8j6QVJMwlSMRPK77Y/
ShppZplm1lbSaZJKY1xnOgrl2HyiwIyhzKybpOMkbY9plWhMxDNB0s9MuXuFmc2W9KoCd1g86u6bzGxG9fY
HFbgL6XxJ2yR9pcBfDIiBEI/PPElHSvpN9QxIhfOO61EX4rFBnIRyfNy91MyWS9ooqUrSI+7e4O3giJwQ/
+3cJmmRmf1dgdNKc9x9b9yKTiNm9oykMZK6mNlOSfMltZSilwl4OxkAAIAwpMJpPgAAgLghTAEAAISBMAUA
ABAGwhQAAEAYCFMAAABhIEwBAACEgTAFAAAQhv8Plb9avrm9+PsAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_predictions(predictions=y_preds)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JLJWVANkhY3-",
"outputId": "ed29f680-d66f-4bbd-b1b3-b35655ca4fec"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.4618],\n",
" [0.4691],\n",
" [0.4764],\n",
" [0.4836],\n",
" [0.4909],\n",
" [0.4982],\n",
" [0.5054],\n",
" [0.5127],\n",
" [0.5200],\n",
" [0.5272]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test - y_preds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lxt8WUzdv1qS"
},
"source": [
"Woah! Those predictions look pretty bad...\n",
"\n",
"This makes sense though, when you remember our model is just using random
parameter values to make predictions.\n",
"\n",
"It hasn't even looked at the blue dots to try to predict the green dots.\n",
"\n",
"Time to change that."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZZpa-fXLec03"
},
"source": [
"## 3. Train model\n",
"\n",
"Right now our model is making predictions using random parameters to make
calculations, it's basically guessing (randomly).\n",
"\n",
"To fix that, we can update its internal parameters (I also refer to
*parameters* as patterns), the `weights` and `bias` values we set randomly using
`nn.Parameter()` and `torch.randn()` to be something that better represents the
data.\n",
"\n",
"We could hard code this (since we know the default values `weight=0.7` and
`bias=0.3`) but where's the fun in that?\n",
"\n",
"Much of the time you won't know what the ideal parameters are for a model.\n",
"\n",
"Instead, it's much more fun to write code to see if the model can try and
figure them out itself.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aD8pnhJUyZUT"
},
"source": [
"### Creating a loss function and optimizer in PyTorch\n",
"\n",
"For our model to update its parameters on its own, we'll need to add a few
more things to our recipe.\n",
"\n",
"And that's a **loss function** as well as an **optimizer**.\n",
"\n",
"The rolls of these are: \n",
"\n",
"| Function | What does it do? | Where does it live in PyTorch? | Common values
|\n",
"| ----- | ----- | ----- | ----- |\n",
"| **Loss function** | Measures how wrong your model's predictions (e.g.
`y_preds`) are compared to the truth labels (e.g. `y_test`). Lower the better. |
PyTorch has plenty of built-in loss functions in
[`torch.nn`](https://pytorch.org/docs/stable/nn.html#loss-functions). | Mean
absolute error (MAE) for regression problems
([`torch.nn.L1Loss()`](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.ht
ml)). Binary cross entropy for binary classification problems
([`torch.nn.BCELoss()`](https://pytorch.org/docs/stable/generated/
torch.nn.BCELoss.html)). |\n",
"| **Optimizer** | Tells your model how to update its internal parameters to
best lower the loss. | You can find various optimization function implementations
in [`torch.optim`](https://pytorch.org/docs/stable/optim.html). | Stochastic
gradient descent
([`torch.optim.SGD()`](https://pytorch.org/docs/stable/generated/torch.optim.SGD.ht
ml#torch.optim.SGD)). Adam optimizer
([`torch.optim.Adam()`](https://pytorch.org/docs/stable/generated/
torch.optim.Adam.html#torch.optim.Adam)). | \n",
"\n",
"Let's create a loss function and an optimizer we can use to help improve our
model.\n",
"\n",
"Depending on what kind of problem you're working on will depend on what loss
function and what optimizer you use.\n",
"\n",
"However, there are some common values, that are known to work well such as the
SGD (stochastic gradient descent) or Adam optimizer. And the MAE (mean absolute
error) loss function for regression problems (predicting a number) or binary cross
entropy loss function for classification problems (predicting one thing or
another). \n",
"\n",
"For our problem, since we're predicting a number, let's use MAE (which is
under `torch.nn.L1Loss()`) in PyTorch as our loss function. \n",
"\n",
"\n",
"*Mean absolute error (MAE, in PyTorch: `torch.nn.L1Loss`) measures the
absolute difference between two points (predictions and labels) and then takes the
mean across all examples.*\n",
"\n",
"And we'll use SGD, `torch.optim.SGD(params, lr)` where:\n",
"\n",
"* `params` is the target model parameters you'd like to optimize (e.g. the
`weights` and `bias` values we randomly set before).\n",
"* `lr` is the **learning rate** you'd like the optimizer to update the
parameters at, higher means the optimizer will try larger updates (these can
sometimes be too large and the optimizer will fail to work), lower means the
optimizer will try smaller updates (these can sometimes be too small and the
optimizer will take too long to find the ideal values). The learning rate is
considered a **hyperparameter** (because it's set by a machine learning engineer).
Common starting values for the learning rate are `0.01`, `0.001`, `0.0001`,
however, these can also be adjusted over time (this is called [learning rate
scheduling](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-
rate)). \n",
"\n",
"Woah, that's a lot, let's see it in code."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "P3T7hpNPec03"
},
"outputs": [],
"source": [
"# Create the loss function\n",
"loss_fn = nn.L1Loss() # MAE loss is same as L1Loss\n",
"\n",
"# Create the optimizer\n",
"optimizer = torch.optim.SGD(params=model_0.parameters(), # parameters of
target model to optimize\n",
" lr=0.01) # learning rate (how much the optimizer
should change parameters at each step, higher=more (less stable), lower=less (might
take a long time))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aFcKCsPcRfnA"
},
"source": [
"### Creating an optimization loop in PyTorch\n",
"\n",
"Woohoo! Now we've got a loss function and an optimizer, it's now time to
create a **training loop** (and **testing loop**).\n",
"\n",
"The training loop involves the model going through the training data and
learning the relationships between the `features` and `labels`.\n",
"\n",
"The testing loop involves going through the testing data and evaluating how
good the patterns are that the model learned on the training data (the model never
sees the testing data during training).\n",
"\n",
"Each of these is called a \"loop\" because we want our model to look (loop
through) at each sample in each dataset.\n",
"\n",
"To create these we're going to write a Python `for` loop in the theme of the
[unofficial PyTorch optimization loop
song](https://twitter.com/mrdbourke/status/1450977868406673410?s=20) (there's a
[video version too](https://youtu.be/Nutpusq_AFw)).\n",
"\n",
"\n",
"*The unofficial PyTorch optimization loops song, a fun way to remember the
steps in a PyTorch training (and testing) loop.*\n",
"\n",
"There will be a fair bit of code but nothing we can't handle.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "agXn72H-sgyd"
},
"source": [
"\n",
"\n",
"### PyTorch training loop\n",
"For the training loop, we'll build the following steps:\n",
"\n",
"| Number | Step name | What does it do? | Code example |\n",
"| ----- | ----- | ----- | ----- |\n",
"| 1 | Forward pass | The model goes through all of the training data once,
performing its `forward()` function calculations. | `model(x_train)` |\n",
"| 2 | Calculate the loss | The model's outputs (predictions) are compared to
the ground truth and evaluated to see how wrong they are. | `loss = loss_fn(y_pred,
y_train)` | \n",
"| 3 | Zero gradients | The optimizers gradients are set to zero (they are
accumulated by default) so they can be recalculated for the specific training step.
| `optimizer.zero_grad()` |\n",
"| 4 | Perform backpropagation on the loss | Computes the gradient of the loss
with respect for every model parameter to be updated (each parameter with
`requires_grad=True`). This is known as **backpropagation**, hence \"backwards\".
| `loss.backward()` |\n",
"| 5 | Update the optimizer (**gradient descent**) | Update the parameters with
`requires_grad=True` with respect to the loss gradients in order to improve them. |
`optimizer.step()` |\n",
"\n",
"\n",
"\n",
"> **Note:** The above is just one example of how the steps could be ordered or
described. With experience you'll find making PyTorch training loops can be quite
flexible.\n",
">\n",
"> And on the ordering of things, the above is a good default order but you may
see slightly different orders. Some rules of thumb: \n",
"> * Calculate the loss (`loss = ...`) *before* performing backpropagation on
it (`loss.backward()`).\n",
"> * Zero gradients (`optimizer.zero_grad()`) *before* computing the gradients
of the loss with respect to every model parameter (`loss.backward()`).\n",
"> * Step the optimizer (`optimizer.step()`) *after* performing backpropagation
on the loss (`loss.backward()`).\n",
"\n",
"For resources to help understand what's happening behind the scenes with
backpropagation and gradient descent, see the extra-curriculum section.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OXHDdlfjssDc"
},
"source": [
"\n",
"### PyTorch testing loop\n",
"\n",
"As for the testing loop (evaluating our model), the typical steps include:\n",
"\n",
"| Number | Step name | What does it do? | Code example |\n",
"| ----- | ----- | ----- | ----- |\n",
"| 1 | Forward pass | The model goes through all of the testing data once,
performing its `forward()` function calculations. | `model(x_test)` |\n",
"| 2 | Calculate the loss | The model's outputs (predictions) are compared to
the ground truth and evaluated to see how wrong they are. | `loss = loss_fn(y_pred,
y_test)` | \n",
"| 3 | Calulate evaluation metrics (optional) | Alongside the loss value you
may want to calculate other evaluation metrics such as accuracy on the test set. |
Custom functions |\n",
"\n",
"Notice the testing loop doesn't contain performing backpropagation
(`loss.backward()`) or stepping the optimizer (`optimizer.step()`), this is because
no parameters in the model are being changed during testing, they've already been
calculated. For testing, we're only interested in the output of the forward pass
through the model.\n",
"\n",
"\n",
"\n",
"Let's put all of the above together and train our model for 100 **epochs**
(forward passes through the data) and we'll evaluate it every 10 epochs.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "k1DfhyJ7ec03",
"outputId": "333f9780-c103-4e81-95da-9f721c80b617"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 | MAE Train Loss: 0.31288138031959534 | MAE Test Loss:
0.48106518387794495 \n",
"Epoch: 10 | MAE Train Loss: 0.1976713240146637 | MAE Test Loss:
0.3463551998138428 \n",
"Epoch: 20 | MAE Train Loss: 0.08908725529909134 | MAE Test Loss:
0.21729660034179688 \n",
"Epoch: 30 | MAE Train Loss: 0.053148526698350906 | MAE Test Loss:
0.14464017748832703 \n",
"Epoch: 40 | MAE Train Loss: 0.04543796554207802 | MAE Test Loss:
0.11360953003168106 \n",
"Epoch: 50 | MAE Train Loss: 0.04167863354086876 | MAE Test Loss:
0.09919948130846024 \n",
"Epoch: 60 | MAE Train Loss: 0.03818932920694351 | MAE Test Loss:
0.08886633068323135 \n",
"Epoch: 70 | MAE Train Loss: 0.03476089984178543 | MAE Test Loss:
0.0805937647819519 \n",
"Epoch: 80 | MAE Train Loss: 0.03132382780313492 | MAE Test Loss:
0.07232122868299484 \n",
"Epoch: 90 | MAE Train Loss: 0.02788739837706089 | MAE Test Loss:
0.06473556160926819 \n"
]
}
],
"source": [
"torch.manual_seed(42)\n",
"\n",
"# Set the number of epochs (how many times the model will pass over the
training data)\n",
"epochs = 100\n",
"\n",
"# Create empty loss lists to track values\n",
"train_loss_values = []\n",
"test_loss_values = []\n",
"epoch_count = []\n",
"\n",
"for epoch in range(epochs):\n",
" ### Training\n",
"\n",
" # Put model in training mode (this is the default state of a model)\n",
" model_0.train()\n",
"\n",
" # 1. Forward pass on train data using the forward() method inside \n",
" y_pred = model_0(X_train)\n",
" # print(y_pred)\n",
"\n",
" # 2. Calculate the loss (how different are our models predictions to the
ground truth)\n",
" loss = loss_fn(y_pred, y_train)\n",
"\n",
" # 3. Zero grad of the optimizer\n",
" optimizer.zero_grad()\n",
"\n",
" # 4. Loss backwards\n",
" loss.backward()\n",
"\n",
" # 5. Progress the optimizer\n",
" optimizer.step()\n",
"\n",
" ### Testing\n",
"\n",
" # Put the model in evaluation mode\n",
" model_0.eval()\n",
"\n",
" with torch.inference_mode():\n",
" # 1. Forward pass on test data\n",
" test_pred = model_0(X_test)\n",
"\n",
" # 2. Caculate loss on test data\n",
" test_loss = loss_fn(test_pred, y_test.type(torch.float)) # predictions
come in torch.float datatype, so comparisons need to be done with tensors of the
same type\n",
"\n",
" # Print out what's happening\n",
" if epoch % 10 == 0:\n",
" epoch_count.append(epoch)\n",
" train_loss_values.append(loss.detach().numpy())\n",
" test_loss_values.append(test_loss.detach().numpy())\n",
" print(f\"Epoch: {epoch} | MAE Train Loss: {loss} | MAE Test Loss:
{test_loss} \")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1krgBqXBdYHc"
},
"source": [
"Oh would you look at that! Looks like our loss is going down with every epoch,
let's plot it to find out."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 295
},
"id": "FPXfvPLkau72",
"outputId": "2f6b88b4-4c8e-48ad-eb99-27abd941993d"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA0XUlEQVR4nO3dd3hUZfr/
8fedSa8IoQekI0gJawRFBBFdQVTsva8/
F3cVy4Kiu7rYy9e1ra5d3LWBDRtWbKCgEpQigkoJEGkBBBICqffvj3MCk2QSEpjJmWTu13XNlSlnztxzlHz
ylPMcUVWMMcZEriivCzDGGOMtCwJjjIlwFgTGGBPhLAiMMSbCWRAYY0yEsyAwxpgIZ0Fg9ouIfCAiFwV7Wy
+JSI6IHBMGdUwSkRe9rsM0fdFeF2AanogU+D1MBIqAMvfxn1X1pbruS1VHhWLbcCUizwO5qvqP/
dxPJ2AlEKOqpUEozZh9ZkEQgVQ1ueK+iOQAl6nqjKrbiUi0/ZIy9WH/
zzRO1jVkdhORo0QkV0RuEJH1wGQROUBE3hORPBH53b2f4feeL0TkMvf+xSLylYjc7267UkRG7eO2nUVkpoj
ki8gMEXmspm6SOtZ4u4h87e7vYxFJ93v9AhFZJSKbReTvtRyfy4HzgOtFpEBE3nWfbycib7ifv1JExvm9Z6
CIZIvIdhHZICIPuC/NdH9udfd1eB3+
+5wkIotFZKv7nXr5vXaDiPzmfr+fRWTEXj4/0P7HiMh8d9vlIjLSfb5SV5l/l5WIdBIRFZE/
ichq4DMR+VBErqyy7wUicqp7/yAR+UREtri1num33fEi8pP7PX4TkfF7Oy5m/
1kQmKraAM2BA4HLcf4fmew+7gjsBB6t5f2DgJ+BdOA+4FkRkX3Y9mXgO6AFMAm4oJbPrEuN5wKXAK2AWGA8
gIj0Bh5399/O/
bwMAlDVp4CXgPtUNVlVTxSRKOBdYAHQHhgBXCMix7lvexh4WFVTga7Aq+7zQ92fzdx9zanl+yEiPYBXgGuA
lsD7wLsiEisiPYErgUNVNQU4DsjZy+dX3f9A4H/
ABKCZW19OoG1rMAzo5X72y8A5fvvujfPfZrqIJAGfuNu0crf7j4gc7G7+LE73ZArQB/
isHjWYfWRBYKoqB/6pqkWqulNVN6vqG6paqKr5wJ04/+hrskpVn1bVMuC/
QFugdX22FZGOwKHALaparKpfAe/
U9IF1rHGyqv6iqjtxfhlmus+fDrynqjNVtQi42T0GdXUo0FJVb3NrXQE8DZztvl4CdBORdFUtUNVv6rFvf2
cB01X1E1UtAe4HEoDBOOM7cUBvEYlR1RxVXV7Pz/8T8Jy7/3JV/
U1Vl9ajvkmqusM9vtOATBE50H3tPOBN9/
ieAOSo6mRVLVXV74E3cP47VNTbW0RSVfV393UTYhYEpqo8Vd1V8UBEEkXkSbfrZDtOl0YzEfHV8P71FXdUt
dC9m1zPbdsBW/
yeA1hTU8F1rHG93/1Cv5ra+e9bVXcAm2v6rAAOBNq53TVbRWQrcBN7wu9PQA9gqYjMFZET6rFvf+2AVX51l
rt1t1fVZTgthUnARhGZIiLt6vn5HYDlNbxWF/
7HMB+Yzp4wPBunJQXO8RpU5Xidh9MSBTgNOB5YJSJf1qXLzOw/
CwJTVdXlaP8G9AQGud0LFV0aNXX3BMM6oLmIJPo916GW7fenxnX++3Y/
s0Ut21c9PmuAlarazO+WoqrHA6jqr6p6Dk43yL3A6273SH2X/V2L80u0ok5x6/7N/
ZyXVXWIu426n1Xb51e1BqfrKJAdOLPLKrQJsE3V7/
MKcI77izwB+Nzvc76scrySVfUKt965qjrGrfctaujKMsFlQWD2JgWnz32riDQH/
hnqD1TVVUA2MMntAz8cODFENb4OnCAiQ0QkFriN2v9dbAC6+D3+DtjuDtYmiIhPRPqIyKEAInK+iLR0/4Lf
6r6nDMjD6YLy31dtXgVGi8gIEYnBCb8iYLaI9BSRo0UkDtiFcyzK9vL5VT0LXOLuP0pE2ovIQe5r84GzRSR
GRLLY041Tm/dxQuk2YKr7+QDvAT3EGaCPcW+Hikgv97/1eSKS5nZ/ba+hVhNkFgRmbx7C+YtuE/
AN8GEDfe55wOE43TR3AFNxfvEF8hD7WKOqLgb+ijN4uQ74Hcit5S3P4vRhbxWRt9zxjRNxxhxWujU8A6S52
48EFotz7sbDwNmqusvt9roT+Nrd12F7qfNn4Hzg3+5nnAicqKrFOOMD97jPr8f5a/qm2j4/wP6/
wxlMfxDYBnzJnhbIzTithd+BW91jVSt3POBN4Bj/7d1uoz/
idBetdeu91/0O4Aza57hdfGPd72xCTOzCNKYxEJGpwFJVDXmLxJhIYy0CE5bc7oKubjfFSGAMTp+xMSbI7M
xiE67a4HQttMDpqrlCVX/
wtiRjmibrGjLGmAgX0q4hERnpnkK+TEQmBnj9KBHZJs5p7fNF5JZQ1mOMMaa6kHUNuSfzPAYci9O0nysi76
jqT1U2naWqdT7JJj09XTt16hS8Qo0xJgLMmzdvk6q2DPRaKMcIBgLL3FPuEZEpOAN+VYOgXjp16kR2dnYQy
jPGmMghIqtqei2UXUPtqbwsQK77XFWHi7My4Qd+C09VIiKXi7OCYnZeXl4oajXGmIgVyiAIdHp/
1ZHp74EDVbU/zokybwXakao+papZqprVsmXAlo0xxph9FMogyKXy+jAZOGcS7qaq21W1wL3/
PhAjfuvEG2OMCb1QjhHMBbqLSGechbHOxlkTfjcRaQNsUFV110OPon4rPxpjmpCSkhJyc3PZtavaKhimjuL
j48nIyCAmJqbO7wlZEKhqqThXKfoI8OGsdb5YRMa6rz+Bs3jVFSJSirNQ1tlqJzYYE7Fyc3NJSUmhU6dOSI
3XMzI1UVU2b95Mbm4unTt3rvP7Qnpmsdvd836V557wu/8otV/
tyhgTQXbt2mUhsB9EhBYtWlDfSTW21pAxJqxYCOyffTl+kRMEv+fABxOhrMTrSowxJqxEThBs+Am+fRzmPu
N1JcaYMLR582YyMzPJzMykTZs2tG/
ffvfj4uLiWt+bnZ3NuHHj6vV5nTp1YtOmTftTctBEzuqjPUdBl+Hwxd3Q90xIqu1qhMaYSNOiRQvmz58PwK
RJk0hOTmb8+PG7Xy8tLSU6OvCvzKysLLKyshqizJCInBaBCIy8G4oK4Iu7vK7GGNMIXHzxxVx33XUMHz6cG
264ge+++47BgwczYMAABg8ezM8//
wzAF198wQknOEumTZo0iUsvvZSjjjqKLl268Mgjj+z1cx544AH69OlDnz59eOihhwDYsWMHo0ePpn///
vTp04epU6cCMHHiRHr37k2/fv0qBdX+iJwWAUCrXpB1KWQ/6/xsHXBFC2NMGLj13cX8tHZ7UPfZu10q/
zyxfv/uf/nlF2bMmIHP52P79u3MnDmT6OhoZsyYwU033cQbb7xR7T1Lly7l888/Jz8/
n549e3LFFVfUOK9/3rx5TJ48mW+//
RZVZdCgQQwbNowVK1bQrl07pk+fDsC2bdvYsmUL06ZNY+nSpYgIW7durfcxCCRyWgQVht8Ecanw4Y1gpywY
Y/
bijDPOwOfzAc4v4zPOOIM+ffpw7bXXsnjx4oDvGT16NHFxcaSnp9OqVSs2bNhQ4/6/+uorTjnlFJKSkkhOT
ubUU09l1qxZ9O3blxkzZnDDDTcwa9Ys0tLSSE1NJT4+nssuu4w333yTxMTEoHzHyGoRACQ2d8Lgg+vh5/
fhoNFeV2SMCaC+f7mHSlJS0u77N998M8OHD2fatGnk5ORw1FFHBXxPXFzc7vs+n4/S0tIa91/
TObQ9evRg3rx5vP/++9x444388Y9/5JZbbuG7777j008/ZcqUKTz66KN89tln+/bF/
EReiwCcbqH0nvDR36G0yOtqjDGNxLZt22jf3llE+fnnnw/
KPocOHcpbb71FYWEhO3bsYNq0aRx55JGsXbuWxMREzj//fMaPH8/3339PQUEB27Zt4/
jjj+ehhx7aPbi9vyKvRQDgi4GRd8GLp8E3j8OQa7yuyBjTCFx//fVcdNFFPPDAAxx99NFB2ecf/
vAHLr74YgYOHAjAZZddxoABA/
joo4+YMGECUVFRxMTE8Pjjj5Ofn8+YMWPYtWsXqsqDDz4YlBoa3TWLs7KyNGgXpnn5LMj5Gq6aBymtg7NPY
8w+W7JkCb169fK6jEYv0HEUkXmqGnCOa2R2DVX4451Qugs+u83rSowxxjORHQTp3WDQn+GHl2DtfK+rMcYY
T0R2EAAMux4SW8CHE206qTEmIlkQxKfBiJth9RxY/KbX1RhjTIOzIAAYcAG07guf/
BNKdnpdjTHGNCgLAoAoH4y6B7atgdn/
9roaY4xpUBYEFToNgd5j4KsHYdtvXldjjGlg+7MMNTgLz82ePTvga88//
zxXXnllsEsOmsg8oawmx94OP38IMybBaU97XY0xpgHtbRnqvfniiy9ITk5m8ODBIaowdKxF4O+AA2HwlbDo
VVjzndfVGGM8Nm/
ePIYNG8YhhxzCcccdx7p16wB45JFHdi8FffbZZ5OTk8MTTzzBgw8+SGZmJrNmzapxn6tWrWLEiBH069ePES
NGsHr1agBee+01+vTpQ//+/Rk6dCgAixcvZuDAgWRmZtKvXz9+/
fXXkHxPaxFUNeQ657yCD26Ayz6FKMtKYzzxwURYvyi4+2zT1xkPrANV5aqrruLtt9+mZcuWTJ06lb///
e8899xz3HPPPaxcuZK4uDi2bt1Ks2bNGDt2bJ1aEVdeeSUXXnghF110Ec899xzjxo3jrbfe4rbbbuOjjz6i
ffv2u5eXfuKJJ7j66qs577zzKC4upqysbH+PQED2W66quGQ4ZhKs/
R4WTvW6GmOMR4qKivjxxx859thjyczM5I477iA3NxeAfv36cd555/Hiiy/
WeNWymsyZM4dzzz0XgAsuuICvvvoKgCOOOIKLL76Yp59+evcv/MMPP5y77rqLe+
+9l1WrVpGQkBDEb7iHtQgC6XcWzH3aGSvodaITDsaYhlXHv9xDRVU5+OCDmTNnTrXXpk+fzsyZM3nnnXe4/
fbba7wuQV2ICOD89f/tt98yffp0MjMzmT9/
Pueeey6DBg1i+vTpHHfccTzzzDNBW+zOn7UIAomKgpH3QMF6+OoBr6sxxnggLi6OvLy83UFQUlLC4sWLKS8
vZ82aNQwfPpz77ruPrVu3UlBQQEpKCvn5+Xvd7+DBg5kyZQoAL730EkOGDAFg+fLlDBo0iNtuu4309HTWrF
nDihUr6NKlC+PGjeOkk05i4cKFIfmuFgQ16TDQucj97Efh9xyvqzHGNLCoqChef/
11brjhBvr3709mZiazZ8+mrKyM888/
n759+zJgwACuvfZamjVrxoknnsi0adP2Olj8yCOPMHnyZPr168cLL7zAww8/
DMCECRPo27cvffr0YejQofTv35+pU6fSp08fMjMzWbp0KRdeeGFIvmtkL0O9N9t+g0ezoPuxcOb/
GuYzjYlgtgx1cNgy1MGU1h6GXAs/
vQ05X3ldjTHGhIQFwd4MvgrSOjhT2cpDM3XLGGO8ZEGwNzEJcOytsGER/
PCC19UY0+Q1tu7qcLMvx8+CoC4OPhU6Hg6f3g67tnldjTFNVnx8PJs3b7Yw2EeqyubNm4mPj6/X+
+w8groQcaaTPnUUfHkfHHen1xUZ0yRlZGSQm5tLXl6e16U0WvHx8WRkZNTrPRYEddUuEwacB98+CYdc4lzm
0hgTVDExMXTu3NnrMiKOdQ3Vx9G3QHQ8fPx3rysxxpigsSCoj5TWMHQ8/
PIhLPvU62qMMSYoQhoEIjJSRH4WkWUiMrGW7Q4VkTIROT2U9QTFYVfAAZ3ho5ugrMTraowxZr+FLAhExAc8
BowCegPniEjvGra7F/goVLUEVXScM1ictxSyn/
O6GmOM2W+hbBEMBJap6gpVLQamAGMCbHcV8AawMYS1BFfP46HzMPj8Lijc4nU1xhizX0IZBO2BNX6Pc93nd
hOR9sApwBMhrCP4RGDk3VC0Hb642+tqjDFmv4QyCCTAc1XPEnkIuEFVa127QUQuF5FsEckOm/
nFrQ+GrEth7rOwcYnX1RhjzD4LZRDkAh38HmcAa6tskwVMEZEc4HTgPyJyctUdqepTqpqlqlktW7YMUbn74
KibnIvWfHgj2JmQxphGKpRBMBfoLiKdRSQWOBt4x38DVe2sqp1UtRPwOvAXVX0rhDUFV1ILOOpGWPG5M6XU
GGMaoZAFgaqWAlfizAZaAryqqotFZKyIjA3V5za4Qy+D9B7OdNLSIq+rMcaYegvpEhOq+j7wfpXnAg4Mq+r
FoawlZHwxcNzd8NJpzvITR4zzuiJjjKkXO7M4GLofA93/CDP/
DwoazyxYY4wBC4LgOe4uKCmEz273uhJjjKkXC4JgSe8OA/
8M378A6xZ4XY0xxtSZBUEwDbseEpvbdFJjTKNiQRBMCc3g6H/Aqq/hp7e8rsYYY+rEgiDY/
nARtO4DH98CJTu9rsYYY/
bKgiDYonzOOkTbVsOcR72uxhhj9sqCIBQ6D4VeJ8KsB2B71VU1jDEmvFgQhMqxt0N5Kcy41etKjDGmVhYEo
dK8Mxz+V1g4BXKzva7GGGNqZEEQSkf+DZJbwwc3QHm519UYY0xAFgShFJcCI/
4Jv2XDote8rsYYYwKyIAi1/udAuwEw459QVOB1NcYYU40FQahFRcHIeyF/
HXz9kNfVGGNMNRYEDaHjIOhzOsz+N2xd7XU1xhhTiQVBQzn2VkDgk1u8rsQYYyqxIGgoaRkw5BpYPA1yvva
6GmOM2c2CoCENHgepGfDhRCgv87oaY4wBLAgaVmwi/
PE2WL8Q5k32uhpjjAEsCBrewadCpyPh09tgxyavqzHGGAuCBicCx98PxTtgxiSvqzHGGAsCT7Q6CA77C/
zwAqyZ63U1xpgIZ0HglWHXQ0o7mH6dDRwbYzxlQeCVuBQ47g5n4Dj7Oa+rMcZEMAsCLx18qnMRm89ut4FjY
4xnLAi8VGng+J9eV2OMiVAWBF5r2dO5gM0PL8Ka77yuxhgTgSwIwsFQGzg2xnjHgiAcxCXDcXfC+kU2cGyM
aXAWBOHi4FOg8zBn4Lggz+tqjDERxIIgXOweOC60M46NMQ0qYoJAVVmwZqvXZdSuZQ9n4Hj+i7D6W6+rMcZ
EiIgJgteycxnz2NfMXhbm8/
WHToDU9vD+36Cs1OtqjDERIGKC4MT+7eicnsSE1xdSUBTGv2Bt4NgY08AiJggSYn3cf0Z/
1m3byZ3Tl3hdTu16nwxdjoLP7rCBY2NMyIU0CERkpIj8LCLLRGRigNfHiMhCEZkvItkiMiSU9Rxy4AH8vyO
78Mp3q5n5Sxj/
ghWBUf8HJYV2xrExJuRCFgQi4gMeA0YBvYFzRKR3lc0+BfqraiZwKfBMqOqpcO2xPejWKpkb3ljI9l0lof6
4fbd74PglWP2N19UYY5qwULYIBgLLVHWFqhYDU4Ax/huoaoGqqvswCVBCLD7Gx7/O6M/G/CJuf/enUH/c/
qkYOJ4+3gaOjTEhE8ogaA+s8Xuc6z5XiYicIiJLgek4rYJqRORyt+soOy9v/
7t0+ndoxthhXXhtXi6fLd2w3/sLmbhkOO4u2LAIsp/
1uhpjTBMVyiCQAM9V+4tfVaep6kHAycDtgXakqk+papaqZrVs2TIoxY0b0Z2erVOY+MYithWGcRdR7zHQZT
h8dicUbPS6GmNMExTKIMgFOvg9zgDW1rSxqs4EuopIeghr2i0u2se/
zuzPlh3F3Pru4ob4yH0jAse7A8ef2MCxMSb4QhkEc4HuItJZRGKBs4F3/DcQkW4iIu79PwCxwOYQ1lRJn/
Zp/HV4N9784Tc+Xry+oT62/
tK7w+ArYcHLsGqO19UYY5qYkAWBqpYCVwIfAUuAV1V1sYiMFZGx7manAT+KyHycGUZn+Q0eN4i/
Du9G77ap3DRtEVt2FDfkR9fP0AmQmgHv28CxMSa4pIF/
7+63rKwszc7ODuo+l6zbzkmPfsXIPm359zkDgrrvoPrpbXj1Qhh5Lxw2du/
bG2OMS0TmqWpWoNci5szi2vRqm8rVI7rz7oK1vL9ondfl1KzXSdD1aPjcBo6NMcFjQeAaO6wrfdun8Y+3fm
RTQZHX5QS2+4zjnfDJLV5XY4xpIiwIXNG+KP51Zn8KdpVy81s/
ErZdZundYPBVsOAVWDXb62qMMU2ABYGfHq1TuPbYHnzw43reXRjGXURDxzsDx3bGsTEmCCwIqrh8aBcGdGz
GLW//yMb8XV6XE1hsEoy8GzYuhrlPe12NMaaRsyCowhcl3H9Gf3YWl3HTm2HcRdTrROg6Aj6/C/
LDeJkMY0zYq1MQiEiSiES593uIyEkiEhPa0rzTtWUyE47ryYwlG5j2w29elxOYCIy6zwaOjTH7ra4tgplAv
Ii0x1k6+hLg+VAVFQ4uOaIzWQcewKR3FrN+W5h2EaV3gyPGwcIpkPO119UYYxqpugaBqGohcCrwb1U9Beca
A01WRRdRcVk5E99cGL5dREf+DdI6wPsTbODYGLNP6hwEInI4cB7OctEA0aEpKXx0Sk9i4siD+OLnPF7LzvW
6nMBs4NgYs5/
qGgTXADcC09z1groAn4esqjBy4eGdOKxLc25/7yd+27rT63ICO+gE6HaMO3AcxovnGWPCUp2CQFW/
VNWTVPVed9B4k6qOC3FtYSEqSvi/0/
tTpsrEN8K0i6hi4Lh0lw0cG2Pqra6zhl4WkVQRSQJ+An4WkQmhLS18dGieyE3H92LWr5t4+bvVXpcTWIuuM
HgcLJxqA8fGmHqpa9dQb1XdjnMVsfeBjsAFoSoqHJ03qCNDuqVz5/
QlrNlS6HU5ge0eOB4PZWF81TVjTFipaxDEuOcNnAy8raolNMCF5sOJiHDv6f2IEuH61xdSXh6GXz82EUbeA
xt/gu9s4NgYUzd1DYIngRwgCZgpIgcC20NVVLhq3yyBm0/
oxZwVm3nx21VelxPYQaNt4NgYUy91HSx+RFXbq+rx6lgFDA9xbWHpzKwODOvRkrvfX0rOph1el1NdxcBxWR
F8fLPX1RhjGoG6DhanicgDIpLt3v6F0zqIOCLCPaf1JdonTHh9QXh2EbXoCkdcDYtehZyvvK7GGBPm6to19
ByQD5zp3rYDk0NVVLhrm5bApBMPZm7O70yeneN1OYENuQ7SOrpLVdvAsTGmZnUNgq6q+k9VXeHebgW6hLKw
cHfqH9pzTK9W3PfhUpbnFXhdTnWxiTDqHshbAt895XU1xpgwVtcg2CkiQyoeiMgRQJieZtswRIS7TulLfIy
P8a8toCwcu4h6Hg/
djoXP74btYXyhHWOMp+oaBGOBx0QkR0RygEeBP4esqkaiVWo8t405mB9Wb+WZWSu8Lqc6ERh1rzNw/
IkNHBtjAqvrrKEFqtof6Af0U9UBwNEhrayROKl/O0Ye3IZ/
ffILv27I97qc6lp0hSOugUWvwcpZXldjjAlD9bpCmapud88wBrguBPU0OiLCHaf0ITkumvGvLaC0rNzrkqo
bci0062hnHBtjAtqfS1VK0Kpo5NKT47h9TB8W5G7jyZlh2EVUccZx3lL49kmvqzHGhJn9CYIwHB31zuh+bR
ndry0PzfiFpevD8KTrnsdD9z/
CFzZwbIyprNYgEJF8Edke4JYPtGugGhuN28f0IS0hhr+9uoCScOsi2j1wXAIf/
8PraowxYaTWIFDVFFVNDXBLUdUmf4Wy+mqeFMsdJ/dl8drtPPb5Mq/
Lqa55FxhyDfz4Oqyc6XU1xpgwsT9dQyaAkX3acHJmOx79bBk//
rbN63Kq2z1wPMEGjo0xgAVBSEw66WCaJ8Uy/rUFFJeGWRdRTAKMvNcZOJ7zmNfVGGPCgAVBCDRLjOXuU/
uydH0+j3z6q9flVNdzFPQcDZ/eCt+/4HU1xhiPWRCEyIherTn9kAwe/3I5C9Zs9bqcykTgtKehy1HwzpX
WMjAmwlkQhNDNJ/SmZXIc419bwK6SMq/LqSw2Cc6ZAr3HwEc3OReyUZsRbEwksiAIobSEGO49vR+/
bizgoRlh2EUUHQenT4YB58OX98KHE6E8zMY0jDEhZ0EQYsN6tOScgR14auZyvl/
9u9flVBflg5MehcP+Ct8+AW//FcpKva7KGNOAQhoEIjJSRH4WkWUiMjHA6+eJyEL3NltE+oeyHq/
cdHwv2qYlMP7VMOwiAmfM4Lg7YfjfYcHL8NpFUFrkdVXGmAYSsiAQER/
wGDAK6A2cIyK9q2y2Ehimqv2A24EmeQWVlPgY7ju9Hys27eD+j372upzARGDY9c7U0qXvwctnQXEYXpPZGB
N0oWwRDASWuVc0KwamAGP8N1DV2apa0V/yDZARwno8dUS3dC447ECe/Xolc3O2eF1OzQ4bCyc/Diu/
hP+dDDvDsDvLGBNUoQyC9sAav8e57nM1+RPwQaAXRORyEckWkey8vLwgltiwJo46iIwDEhj/
2gJ2FIVxP3zmuXDGf2HdfHj+BMjf4HVFxpgQCmUQBFqmOuD8RBEZjhMENwR6XVWfUtUsVc1q2bJlEEtsWEl
x0fzrjEzWbClk0juLvS6ndr1PgnNfhS0rYPJI2Lra64qMMSESyiDIBTr4Pc4A1lbdSET6Ac8AY1R1cwjrCQ
sDOzfnL0d147V5uUxfGObLQXcdDhe+DYWb4bmRkPeL1xUZY0IglEEwF+guIp1FJBY4G3jHfwMR6Qi8CVygq
hHzW+bqY7qT2aEZN765kLVbd3pdTu06DISLp0NZMUweBesWeF2RMSbIQhYEqloKXAl8BCwBXlXVxSIyVkTG
upvdArQA/
iMi80UkO1T1hJMYXxQPn51JWbly7dT5lJWH+Rm9bfrCpR85C9Y9fwKsmuN1RcaYIBJtZMsKZGVlaXZ208iL
1+flMv61BUw4rid/
Hd7N63L2bluuM5NoWy6c9SJ0P8briowxdSQi81Q1K9Brdmaxh077Q3tG92vLg5/8En4L0wWSlgGXfADp3eC
Vs2HxNK8rMsYEgQWBh0SEu07uS6uUOK6e8kN4TymtkNwSLnoPMrLg9Uvh+/95XZExZj9ZEHgsLTGGB8/
KZPWWQm59N8ynlFZIaAbnvwldj4Z3roLZj3pdkTFmP1gQhIFBXVrwl6O68Wp2Lu8vCvMppRViE+HsV6D3yf
Dx3+GzO20Za2MaKQuCMHH1Md3p36EZE99oBFNKK0THwunPwYALYOZ98MENtoy1MY2QBUGYiPFF8fBZjWhKa
YUoH5z0bzj8SvjuSXj7L7aMtTGNjAVBGOmUnsSkkw7m25VbeHLmcq/
LqTsR+OMdMPwfsOAVW8bamEbGgiDMnH5IBqP7teWBjxvJlNIKIjBsAoy6z13G+kwoKvC6KmNMHVgQhBn/
KaXXTJ3fOKaU+hv0Zzj5CVg5C1442ZaxNqYRsCAIQ2mJMTxwViY5m3c0niml/jLPgTP/
56xLNHm0LWNtTJizIAhTh3VpwV+O6tq4ppT663WCs4z17zm2jLUxYc6CIIxdc0wP+ndoxo1vLmo8U0r9dR0
OF75ly1gbE+YsCMJYxZTS0rJyrnu1EU0p9ddhIFz8PpSVOC2DtfO9rsgYU4UFQZirmFL6zYpGNqXUX5s+cO
mHEJME/
z0RVs32uiJjjB8Lgkbg9EMyGN23EU4p9deiqxMGKW3ghVPh1xleV2SMcVkQNAIiwl2nNOIppRXS2jvLWLfs
YctYGxNGLAgaCf8ppbe9+5PX5ey7pHS46F3IONRZxnref72uyJiIZ0HQiFRMKZ2avYYPGuOU0grxaXD+G9B
1BLw7Dt67FjYu8boqYyKWBUEjc80xPeifkcbExjqltEJsIpz9MmRdCj+8CP85DJ49Dua/
AiWN+HsZ0whZEDQyzoXvB1DSmKeUVoiOhRMehOuWOIvWFW6Ct8bCv3o6S1pbK8GYBmFB0Aj5Tyl9auYKr8v
Zf0npMPgquDLbuQxmt2Mg+7k9rYQFU6yVYEwIWRA0Ume4U0r/
9fHPLMzd6nU5wSECnY90LnZz3RI49nbYkQfT/mytBGNCSLSRXV4wKytLs7OzvS4jLGwrLGHUwzOJi/
Hx3lVDSIqL9rqk4FOFnFkw73n46R0oL4EOh0HWJdB7DMQkeF2hMY2CiMxT1axAr1mLoBFrMlNKayMCnYc6r
YS/
LbVWgjEhYEHQyB3WpQVXDGsCU0rrIikdjhgHV81zzkXodgzMfdbGEozZT9Y11ASUlJVz+uOzydlcyIfXHEn
btAjqLtmxCea/7HQdbVkO8c2g/zlwyMXQ6iCPizMmfFjXUBNXaUrp1AWNe0ppfVVrJYyAuc/
AfwY5S19bK8GYvbIgaCIqppTOWbG5aUwpra9AYwkFG92xhIPgg4mwcanXVRoTlqxrqAlRVf768vd8vHgDb/
5lMP0ymnldkrcqZhxlT4Yl7zozjjoe7nQb2YwjE2Fq6xqyIGhithWWMPLhmcQ35Sml+2LHJpj/
kjuWsMLGEkzEsTGCCJKWGMOD7pTS299rolNK90VSOhxxNVz1vTOW0PXo6mMJBRu9rtIYT9ifi01QxZTS/
3yxnGE9WjKqb1uvSwofFWMJnYdCQR4scGccTfuz83pSK2jTt/
KtRTeI8nlatjGhZF1DTVRJWTmnPT6bVZE4pbS+ysshdy78Ng82/
AjrFzoDy+UlzuvR8dCqd+VwaH0wxKV4W7cx9WBjBBFq5aYdjH5kFv0zmvHiZYPwRYnXJTUepcWw6RdYv2hP
OKxfBDt/37PNAZ2rhEMfSMtwWh3GhJnagiCkXUMiMhJ4GPABz6jqPVVePwiYDPwB+Luq3h/KeiJN5/
QkJp14MNe/
sZCnZ61g7LCuXpfUeETHQps+zq2CKmxf64bDIufn+h9hyTt7tolvVj0cWh7k7M+YMBWyIBARH/
AYcCyQC8wVkXdU1X8EcwswDjg5VHVEujOyMvjil43c/
9HPHNE1nb4ZaV6X1HiJONddTmsPPUfueb4oHzb8VDkcsidDqXsiW1QMtOxZORza9IXE5t58D2OqCGWLYCCw
TFVXAIjIFGAMsDsIVHUjsFFERoewjogmItx9Sj9+WD2Tq6f8wHvjhpAYa3MEgiouBToOcm4VysucaaoVXUr
rf4Tln8OCV/Zsk9q+cjC06et0N0XZZD7TsEL5G6E9sMbvcS4wqIZtayUilwOXA3Ts2HH/
K4swaYkxPHBmJuc+8w23vfsT95zWz+uSmr4oH6R3d259TtvzfEGe23L40Q2IRfDrJ6BlzusxSdC8MxzQye/
mPm7WAaLjGv67mCYvlEEQaMRsn0amVfUp4ClwBov3p6hIdXjXFowd1pXHbUqpt5JbQvLRznkMFUp2Qd4SJx
w2LIbfc2DzMlg2A0p3+b1ZnFaEf0j4h0ZiCxuoNvsklEGQC3Twe5wBrA3h55m9uPaYHny9bBMT31xEZsdmN
qU0XMTEQ7sBzs2fKhRscIKh6m3ZDChYX3n72OQqLYlO1powdRLKIJgLdBeRzsBvwNnAuSH8PLMXsdHOKqWj
H5nF315dwIt/GkSUTSkNXyKQ0sa5dTys+uvFhbB1dfWQqGtrwr9FYa2JiBayIFDVUhG5EvgIZ/
roc6q6WETGuq8/ISJtgGwgFSgXkWuA3qq6PVR1RTr/KaX3frSUa4/
pQXyMnTXbKMUmOuskBVorqbbWxPJPIb/
KRYyqtiaaHQip7SC1LaS0g+RWdnZ1E2YnlEUgVeW6Vxcw7YffSE+O45IjOnH+YQeSlhDjdWmmodTUmqi4lV
a5hoP43NZJWzcg2gW+byu6hi07s9hUo6rMWbGZJ75cwcxf8kiK9XHuoI78aUgX2qTFe12e8ZKqc13o7WudW
/5a2L6u+v3i/
OrvTTjAaUGkusGQ4teqqAiNhAOsG8oDFgSmVovXbuOpmSt4b+E6ogTGZLZn7LAudGtla+mYWhTlO6GQv9Yv
NNZVvl+wkWqTBaPj97QgamphJLcGn7VQg8mCwNTJmi2FPDNrBVOz17CrpJxjerVm7LAuZHWyM2DNPiorgfz
1bkD8ViU41u1pYZQVVXmjOOMS/q2KSvfbO8ERl+zJ12qMLAhMvWzZUcx/
Z+fw3zk5bC0sIevAAxg7rCtHH9TKZhmZ4FOFwi1+3U6/
VW9ZbF8Lu7ZWf29cau0ti5R2zowoO1vbgsDsm8LiUl6du4anZ63kt6076d4qmcuHdmFMZntio+0flmlgxYV
VAqJqcKxzzq3Q8srvi4rxG6fwa034309p2+QXBrQgMPulpKyc6QvX8cSXy1m6Pp82qfH8aUhnzhnUkWS7FK
YJJ2WlsGNj7S2L7Wurz4oCSGpZZeyivRsgbfeMWzTigW4LAhMUqsqXv+Tx5JcrmLNiM6nx0Vxw+IFcPLgzL
VPsrFXTSKg63UwBxyv87hdurv5eX5wTCCmt3Z9tILmN+9jvZ1J62J13YUFggm7+mq08+eVyPly8nhhfFKcf
ksHlR3ahU3qS16UZExwlu/a0IgrWQ/4G96d7K9jg/
Aw0diE+t4VRERDurVp4NNzsKAsCEzIr8gp4etZK3piXS0l5OaP6tGHssK70y2jmdWnGNIySXU4oVATD7p/
+4bHBOTcj0LqbiS0Ctyqq/
oxN3K8yLQhMyG3M38XzX+fwwjeryN9VyuCuLfjzsK4M7Z6ONNI+VWOCqqzUCYPaWhcVgVJeWv39cakweBwM
m7BPH29BYBpM/q4SXvluNc9+tZIN24vo3TaVPw/rwui+bYn22UwjY/
aqvBx2bgncquh8JPQ6cZ92a0FgGlxRaRlvz1/Lk18uZ3neDjIOSOD/
HdmFM7M6kBAbXoNoxkQCCwLjmfJy5dOlG3niy+XMW/
U7ByTGcNHgTlx0eCcOSGra87aNCScWBCYszM3ZwpNfLmfGko0kxPg469AOXHZkZzIO2L9BMGPM3tUWBHY2k
Gkwh3ZqzqGdmvPLhnye/
HIFL36zihe+WUXvtqm0To2jZUo8rVLiaJ3q9zM1jhZJsTa+YEwIWYvAeGbt1p38b84qlqzbzsb8IjZu38Xm
HcXVtosSaJEcR+vUOFqlxNcYGunJFhjG1MRaBCYstWuWwMRRla+uVVxazqaCot3BsCG/
iLztu9iwvYiN+btYv20XC3O3sXlHEVX/
hhGBFklxbjA4odEqNY5WqZVDo2VKHDEWGMbsZkFgwkpsdBTtmiXQrlntV7oqLStnU0ExG/
P3hMSG7UXk+T1evHY7mwqKKA8QGM0TY/
0CwgmN1IRoEmOjSY6LJjHWR1JctHOruB8bTWKcz0LENDkWBKZRivZF0SYtfq9XUystK2fLjuJKYeEfGhvzi
1i6fjubCoopq5oYNYiNjqoUDklxzv3ESs/5BUicL3DAxPlIio0mIcZny3sbT1kQmCYt2hfl/
OWfGg+k1bhdebmys6SMHUWl7Ch2fxaVUlhcRkFRKYXFpRQUlVFYVEpBcSmFRWXsKK68zcbtRbuf21FcRnFp
eY2f508EEmN8JMY5YZEQ4yMpzkdCrBMmCbFOYCTGOoGSGOsjMc5X+bH7Myk22tk+zkd8tAWMqRsLAmOAqCj
Z/
Zd6sJSUlVNYVOYGR+WAcQKjrFLA7HDv7yx2wmXbzhLWb9vJjqKy3SFVVMdwqeCEhH9gVLkfF707hJwgqRxA
1d4X5yMxxmeD8k2MBYExIRLjiyItMYq0xOCtLllWrhQWl7KzuIwdxU6QFBaXOTc3bHYWl7qvOc8Vlrg/
K54rLmVTQdHu+xXP10esL2p3KCS43V0JMVXCxe9+xTaJsT63xRPtBs2eVkzFa7Y2VcOzIDCmEfFFCSnxMaT
EB3fp4oqusYpwcFohe0LCCR4ngAqr3PcPk7yCIgq3FDrbF5Wys6SMkrK6T1EXYXegJPh3dfmFRaDWjX/
3WaX3xflIjHHu21X1amZBYIyp0jUW3IsMFZeWO6FRESxul5jTUnHuO11fe7rF/
Fs2O6u1YtztSsqqTSGuTYxP3JCJ3jPGErPnfkJMtDs24zxfcX+vIdQEBvstCIwxIRUbHUVsdBRpBLcVo6rs
Kimv3D1W0323i6yipeJ/f8uOYnJ/
9+9GK6O4rH5jMYFaMbW3XpzHFa2WhJiKGWXOGI0zbuMj1hfVIF1lFgTGmEZJREhwf7G2CPK+S8vKK7VY/
MNlp1+LpbDK/
aoB9HvhzmotmzrOUgacrsCKUEiMjea8QR257MguQf62FgTGGFNNtC+KVF8UqUEei1FVikrLA7deisrclkpp
pZli/
mM0obo2uAWBMcY0EBEhPsZHfIyP5mG0DLsNoxtjTISzIDDGmAhnQWCMMRHOgsAYYyKcBYExxkQ4CwJjjIlw
FgTGGBPhLAiMMSbCNbqL14tIHrBqH9+eDmwKYjmNnR2Pyux47GHHorKmcDwOVNWWgV5odEGwP0QkW1WzvK4
jXNjxqMyOxx52LCpr6sfDuoaMMSbCWRAYY0yEi7QgeMrrAsKMHY/
K7HjsYceisiZ9PCJqjMAYY0x1kdYiMMYYU4UFgTHGRLiICQIRGSkiP4vIMhGZ6HU9DUlEOojI5yKyREQWi8
jV7vPNReQTEfnV/XmA17U2JBHxicgPIvKe+zhij4eINBOR10Vkqfv/
yeGRejxE5Fr338mPIvKKiMQ39WMREUEgIj7gMWAU0Bs4R0R6e1tVgyoF/
qaqvYDDgL+6338i8Kmqdgc+dR9HkquBJX6PI/l4PAx8qKoHAf1xjkvEHQ8RaQ+MA7JUtQ/
gA86miR+LiAgCYCCwTFVXqGoxMAUY43FNDUZV16nq9+79fJx/5O1xjsF/3c3+C5zsSYEeEJEMYDTwjN/
TEXk8RCQVGAo8C6Cqxaq6lQg9HjiX8E0QkWggEVhLEz8WkRIE7YE1fo9z3ecijoh0AgYA3wKtVXUdOGEBtP
KwtIb2EHA9UO73XKQejy5AHjDZ7Sp7RkSSiMDjoaq/AfcDq4F1wDZV/
ZgmfiwiJQgkwHMRN29WRJKBN4BrVHW71/V4RUROADaq6jyvawkT0cAfgMdVdQCwgybW9VFXbt//
GKAz0A5IEpHzva0q9CIlCHKBDn6PM3CaexFDRGJwQuAlVX3TfXqDiLR1X28LbPSqvgZ2BHCSiOTgdBMeLSI
vErnHIxfIVdVv3cev4wRDJB6PY4CVqpqnqiXAm8BgmvixiJQgmAt0F5HOIhKLM/
jzjsc1NRgREZz+3yWq+oDfS+8AF7n3LwLebujavKCqN6pqhqp2wvl/4TNVPZ/
IPR7rgTUi0tN9agTwE5F5PFYDh4lIovvvZgTOmFqTPhYRc2axiByP0y/
sA55T1Tu9rajhiMgQYBawiD194jfhjBO8CnTE+Qdwhqpu8aRIj4jIUcB4VT1BRFoQocdDRDJxBs5jgRXAJT
h/KEbc8RCRW4GzcGbb/
QBcBiTThI9FxASBMcaYwCKla8gYY0wNLAiMMSbCWRAYY0yEsyAwxpgIZ0FgjDERzoLAGJeIlInIfL9b0M6u
FZFOIvJjsPZnTDBFe12AMWFkp6pmel2EMQ3NWgTG7IWI5IjIvSLynXvr5j5/
oIh8KiIL3Z8d3edbi8g0EVng3ga7u/KJyNPuWvcfi0iCu/
04EfnJ3c8Uj76miWAWBMbskVCla+gsv9e2q+pA4FGcM9Rx7/9PVfsBLwGPuM8/
Anypqv1x1uxZ7D7fHXhMVQ8GtgKnuc9PBAa4+xkbmq9mTM3szGJjXCJSoKrJAZ7PAY5W1RXu4n3rVbWFiGw
C2qpqifv8OlVNF5E8IENVi/
z20Qn4xL2wCSJyAxCjqneIyIdAAfAW8JaqFoT4qxpTibUIjKkbreF+TdsEUuR3v4w9Y3Sjca6gdwgwz70gi
jENxoLAmLo5y+/
nHPf+bJzVSwHOA75y738KXAG7r4ucWtNORSQK6KCqn+NcKKcZzgJnxjQY+8vDmD0SRGS+3+MPVbViCmmciH
yL88fTOe5z44DnRGQCzhW+LnGfvxp4SkT+hPOX/
xU4V7sKxAe8KCJpOBdQetC9TKQxDcbGCIzZC3eMIEtVN3ldizGhYF1DxhgT4axFYIwxEc5aBMYYE+EsCIwx
JsJZEBhjTISzIDDGmAhnQWCMMRHu/wPOAov1fMWjJAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plot the loss curves\n",
"plt.plot(epoch_count, train_loss_values, label=\"Train loss\")\n",
"plt.plot(epoch_count, test_loss_values, label=\"Test loss\")\n",
"plt.title(\"Training and test loss curves\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.xlabel(\"Epochs\")\n",
"plt.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lmqQE8Kpec04"
},
"source": [
"Nice! The **loss curves** show the loss going down over time. Remember, loss
is the measure of how *wrong* your model is, so the lower the better.\n",
"\n",
"But why did the loss go down?\n",
"\n",
"Well, thanks to our loss function and optimizer, the model's internal
parameters (`weights` and `bias`) were updated to better reflect the underlying
patterns in the data.\n",
"\n",
"Let's inspect our model's
[`.state_dict()`](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.
html) to see how close our model gets to the original values we set for weights and
bias.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ci0W7kn5ec04",
"outputId": "2c27ba8b-e388-484e-c59e-464fdb53d73e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model learned the following values for weights and bias:\n",
"OrderedDict([('weights', tensor([0.5784])), ('bias', tensor([0.3513]))])\n",
"\n",
"And the original values for weights and bias are:\n",
"weights: 0.7, bias: 0.3\n"
]
}
],
"source": [
"# Find our model's learned parameters\n",
"print(\"The model learned the following values for weights and bias:\")\n",
"print(model_0.state_dict())\n",
"print(\"\\nAnd the original values for weights and bias are:\")\n",
"print(f\"weights: {weight}, bias: {bias}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BZyBa9rMelBv"
},
"source": [
"Wow! How cool is that?\n",
"\n",
"Our model got very close to calculating the exact original values for `weight`
and `bias` (and it would probably get even closer if we trained it for longer).\n",
"\n",
"> **Exercise:** Try changing the `epochs` value above to 200, what happens to
the loss curves and the weights and bias parameter values of the model?\n",
"\n",
"It'd likely never guess them *perfectly* (especially when using more
complicated datasets) but that's okay, often you can do very cool things with a
close approximation.\n",
"\n",
"This is the whole idea of machine learning and deep learning, **there are some
ideal values that describe our data** and rather than figuring them out by hand,
**we can train a model to figure them out programmatically**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c-VBDFd2ec05"
},
"source": [
"## 4. Making predictions with a trained PyTorch model (inference)\n",
"\n",
"Once you've trained a model, you'll likely want to make predictions with it.\
n",
"\n",
"We've already seen a glimpse of this in the training and testing code above,
the steps to do it outside of the training/testing loop are similar.\n",
"\n",
"There are three things to remember when making predictions (also called
performing inference) with a PyTorch model:\n",
"\n",
"1. Set the model in evaluation mode (`model.eval()`).\n",
"2. Make the predictions using the inference mode context manager (`with
torch.inference_mode(): ...`).\n",
"3. All predictions should be made with objects on the same device (e.g. data
and model on GPU only or data and model on CPU only).\n",
"\n",
"The first two items make sure all helpful calculations and settings PyTorch
uses behind the scenes during training but aren't necessary for inference are
turned off (this results in faster computation). And the third ensures that you
won't run into cross-device errors."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xKKxSBVuec05",
"outputId": "7a637fab-186e-4269-85a7-6dc28ee690e0"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.8141],\n",
" [0.8256],\n",
" [0.8372],\n",
" [0.8488],\n",
" [0.8603],\n",
" [0.8719],\n",
" [0.8835],\n",
" [0.8950],\n",
" [0.9066],\n",
" [0.9182]])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1. Set the model in evaluation mode\n",
"model_0.eval()\n",
"\n",
"# 2. Setup the inference mode context manager\n",
"with torch.inference_mode():\n",
" # 3. Make sure the calculations are done with the model and data on the same
device\n",
" # in our case, we haven't setup device-agnostic code yet so our data and
model are\n",
" # on the CPU by default.\n",
" # model_0.to(device)\n",
" # X_test = X_test.to(device)\n",
" y_preds = model_0(X_test)\n",
"y_preds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Cn21JvzmjbBO"
},
"source": [
"Nice! We've made some predictions with our trained model, now how do they
look?"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 428
},
"id": "b_kBqpCfec05",
"outputId": "b2e3870b-dfdf-4dbc-877c-a940cb732859"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAlMAAAGbCAYAAADgEhWsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtvElEQVR4nO3de3RU9bn/
8c9DwiUQbkoAIQgIqCCiQkSx5aZ4UIFFbbUgVkGthh9wjq7jBaqnIN4vWGtrWlGreKtaFaxFilp+INiKJEG
hQsCDeAGMEOxvoWAVkzy/
PyZNk5BkJsx95v1aa1ay9/7OniezA3zY+7ufMXcXAAAADk+zeBcAAACQzAhTAAAAYSBMAQAAhIEwBQAAEAb
CFAAAQBgy4/
XCnTp18l69esXr5QEAAEJWXFy8191z6tsWtzDVq1cvFRUVxevlAQAAQmZmnzS0jct8AAAAYSBMAQAAhIEwB
QAAEAbCFAAAQBgIUwAAAGEIejefmT0mabykPe4+sJ7tJukBSedJ+lrSNHdfH25hX375pfbs2aPvvvsu3F0h
DbRp00a5ublq1oz/HwAAYiuU1giLJD0o6ckGtp8rqV/V4zRJv636eti+/
PJL7d69W927d1dWVpYCeQ2oX2VlpXbt2qW9e/eqc+fO8S4HAJBmgv433t1XS/
pHI0MmSnrSA9ZK6mBmR4VT1J49e9S9e3e1bt2aIIWgmjVrpi5dumjfvn3xLgUAkIYicU2ku6QdNZZ3Vq07b
N99952ysrLCKgrppXnz5iovL493GQCANBSJMFXfqSOvd6DZVWZWZGZFZWVlje+UM1JoAn5fAADxEokwtVNS
jxrLuZI+q2+guz/s7nnunpeTU+/
H2wAAACSVSISpVyRdagGnS9rn7qUR2C8AAEDCCxqmzOxZSW9LOs7MdprZFWY23cymVw1ZJmm7pG2SHpE0I2
rVpqFp06Zp/PjxTXrOqFGjNGvWrChV1LhZs2Zp1KhRcXltAADiIWhrBHe/
KMh2lzQzYhUlqWBzdqZOnapFixY1eb8PPPCAAm9x6BYvXqzmzZs3+bXi4eOPP1bv3r1VWFiovLy8eJcDAEC
ThdJnCiEoLf33lc2lS5fqyiuvrLWu7t2J3333XUiBp3379k2u5YgjjmjycwAAwOGhXXSEdO3atfrRoUOHWu
u++eYbdejQQc8+
+6zOPPNMZWVlaeHChfriiy900UUXKTc3V1lZWTrhhBP0+OOP19pv3ct8o0aN0owZM3TjjTeqU6dO6ty5s66
77jpVVlbWGlPzMl+vXr102223KT8/X+3atVNubq7uvffeWq/
zwQcfaOTIkWrVqpWOO+44LVu2TNnZ2Y2eTauoqNB1112njh07qmPHjrrmmmtUUVFRa8zy5cs1fPhwdezYUU
cccYTGjh2rkpKS6u29e/eWJJ166qkys+pLhIWFhfqP//gPderUSe3atdP3v/
99vf3228EPBAAgrcx8daYyb8nUzFfjd5GMMBVDP/vZzzRjxgxt3rxZP/jBD/
TNN99o8ODBWrp0qTZt2qSrr75a+fn5WrFiRaP7eeaZZ5SZmam//e1vevDBB/XLX/5Szz//fKPPuf/+
+3XiiSdq/fr1mj17tm644YbqcFJZWanzzz9fmZmZWrt2rRYtWqT58+fr22+/
bXSf9913nx555BEtXLhQb7/9tioqKvTMM8/
UGnPgwAFdc801WrdunVatWqX27dtrwoQJOnjwoCRp3bp1kgKhq7S0VIsXL5YkffXVV7rkkku0Zs0arVu3Ti
effLLOO+887d27t9GaAADpZWHxQlV4hRYWL4xfEe4el8eQIUO8IZs3b25wW1PNmOGekRH4GisvvPCCB97ag
I8++sgl+YIFC4I+d9KkSX7FFVdUL0+dOtXHjRtXvTxy5Eg//
fTTaz1nzJgxtZ4zcuRInzlzZvVyz549ffLkybWe07dvX7/11lvd3X358uWekZHhO3furN7+17/+1SX5448/
3mCtRx11lN92223VyxUVFd6vXz8fOXJkg8/Zv3+/N2vWzNesWePu/
35vCgsLG3yOu3tlZaV37drVn3rqqQbHRPL3BgCQHGYsneEZ8zN8xtLo/
kMvqcgbyDQpf2Zq4UKpoiLwNd7qTrCuqKjQ7bffrkGDBunII49Udna2Fi9erE8//bTR/
QwaNKjWcrdu3bRnz57Dfs6WLVvUrVs3de/+78b1p556aqMfGrxv3z6VlpZq2LBh1euaNWum006r/
bGMH374oaZMmaI+ffqoXbt26tKliyorK4P+jHv27FF+fr6OPfZYtW/
fXm3bttWePXuCPg8AkF4KxhWofG65CsYVxK2GlJ+Anp8fCFL5+fGuRGrTpk2t5QULFui++
+7TAw88oBNPPFHZ2dm68cYbgwajuhPXzazWnKmmPsfdo9ZBfMKECerevbsWLlyo7t27KzMzUwMGDKi+zNeQ
qVOnavfu3br//
vvVq1cvtWzZUmeddVbQ5wEAEGspH6YKCgKPRPTWW29pwoQJuuSSSyQFQs0HH3xQPYE9Vvr3769du3bps88+
U7du3SRJRUVFjQa09u3b66ijjtLatWt15plnSgrUv27dOh11VOBzrr/44guVlJSooKBAo0ePliStX7+
+1mfotWjRQpIOmbj+1ltv6Ve/
+pXGjRsnSdq9e3etuyMBAEgUKX+ZL5Ede+yxWrFihd566y1t2bJFs2bN0kcffRTzOs4+
+2wdd9xxmjp1qjZs2KC1a9fqv//
7v5WZmdnoGaurr75a99xzj1588UVt3bpV11xzTa3A07FjR3Xq1EmPPPKItm3bpjfffFPTp09XZua/
M3znzp2VlZWl1157Tbt379a+ffskBd6bp59+Wps3b1ZhYaEmT55cHbwAAEgkhKk4+p//
+R8NHTpU5557rkaMGKE2bdro4osvjnkdzZo105IlS/
Ttt99q6NChmjp1qm666SaZmVq1atXg86699lpddtll+ulPf6rTTjtNlZWVtepv1qyZnn/
+eW3cuFEDBw7UzJkzdeutt6ply5bVYzIzM/WrX/
1Kjz76qLp166aJEydKkh577DHt379fQ4YM0eTJk3X55ZerV69eUXsPAACJIxHaHTSFeRO7a0dKXl6eFxUV1
butpKRE/
fv3j3FFqGnDhg06+eSTVVRUpCFDhsS7nJDwewMAqSHzlkxVeIUyLEPlc8uDPyEGzKzY3ev9qA7OTEGStGTJ
Er3++uv66KOPtHLlSk2bNk0nnXSSBg8eHO/SAABpJn9IvjIsQ/lDEuDusRCk/
AR0hOarr77S7NmztWPHDnXs2FGjRo3S/fffH7W7/
AAAaEjBuIK4tjpoKsIUJEmXXnqpLr300niXAQBA0uEyHwAAQBgIUwAAAGEgTAEAgJhItpYHoSJMAQCAmFhY
vFAVXqGFxQnwgbkRRJgCAAAxkWwtD0LF3XwAACAmkq3lQag4M5XEevXqpQULFsTltcePH69p06bF5bUBAEg
khKkIMbNGH+EEj5tvvlkDBw48ZH1hYaFmzJgRRtWxs2rVKpmZ9u7dG+9SAACIKC7zRUhpaWn190uXLtWVV1
5Za11WVlbEXzMnJyfi+wQAAE3DmakI6dq1a/
WjQ4cOh6xbvXq1hgwZolatWql379666aabdPDgwernL168WIMGDVJWVpaOOOIIjRw5Urt379aiRYs0f/
58bdq0qfos16JFiyQdepnPzPTwww/rwgsvVJs2bXTMMcfo6aefrlXnO+
+8o8GDB6tVq1Y65ZRTtGzZMpmZVq1a1eDP9vXXX2vatGnKzs5Wly5ddMcddxwy5umnn9app56qtm3bqnPnz
rrwwgu1a9cuSdLHH3+s0aNHSwoEwJpn6pYvX67hw4erY8eOOuKIIzR27FiVlJQ09e0HAMRRqrY8CBVhKgZe
e+01XXzxxZo1a5Y2bdqkxx57TC++
+KJuvPFGSdLnn3+uyZMna+rUqSopKdHq1at1ySWXSJImTZqka6+9Vscdd5xKS0tVWlqqSZMmNfhat9xyiyZ
OnKgNGzZo0qRJuvzyy/XJJ59Ikvbv36/x48fr+OOPV3Fxse655x5df/31Qeu/
7rrr9MYbb+ill17SihUr9O6772r16tW1xhw8eFDz58/
Xhg0btHTpUu3du1cXXXSRJKlHjx566aWXJEmbNm1SaWmpHnjgAUnSgQMHdM0112jdunVatWqV2rdvrwkTJt
QKmgCAxJaqLQ9C5u5xeQwZMsQbsnnz5ga3NdWMpTM8Y36Gz1g6I2L7DOaFF17wwFsbMHz4cL/
llltqjVmyZIm3adPGKysrvbi42CX5xx9/XO/
+5s2b5yeccMIh63v27On33ntv9bIknzNnTvXyd99951lZWf7UU0+5u/tDDz3kHTt29K+//
rp6zDPPPOOSfOXKlfW+9ldffeUtWrTwp59+uta69u3b+9SpUxt8D0pKSlyS79ixw93dV65c6ZK8rKyswee4
u+/fv9+bNWvma9asaXRcfSL5ewMACF08/
q2NNUlF3kCmSfkzU4mQlouLi3X77bcrOzu7+jFlyhQdOHBAn3/+uU466SSNGTNGAwcO1I9+9CP99re/
VVlZ2WG91qBBg6q/z8zMVE5Ojvbs2SNJ2rJliwYOHFhr/tZpp53W6P4+/
PBDHTx4UMOGDatel52drRNPPLHWuPXr12vixInq2bOn2rZtq7y8PEnSp59+GnT/
U6ZMUZ8+fdSuXTt16dJFlZWVQZ8HAEgcBeMKVD63PCXbHoQi5cNUIjQIq6ys1Lx58/Tee+9VPzZu3Kj//
d//VU5OjjIyMvT666/r9ddf16BBg/S73/1O/
fr104YNG5r8Ws2bN6+1bGaqrKyUFDgLaWZN2l8gjDfuwIEDGjt2rFq3bq2nnnpKhYWFWr58uSQFvVw3YcIE
lZWVaeHChXrnnXf07rvvKjMzk8t8AICkkfJ38yVCg7DBgwdry5Yt6tu3b4NjzEzDhg3TsGHDNHfuXJ1wwgl
6/vnnddJJJ6lFixaqqKgIu47+/fvrySef1D//+c/qs1Pr1q1r9Dl9+/ZV8+bNtXbtWh1zzDGSAuHp/
fffV58+fSQFznjt3btXd9xxh3r37i0pMKG+phYtWkhSrZ/jiy++UElJiQoKCqonqK9fv17l5eVh/
6wAAMRKyp+ZSgRz587V73//e82dO1fvv/++tmzZohdffFE33HCDJGnt2rW67bbbVFhYqE8//
VSvvPKKduzYoQEDBkgK3LX3ySefaP369dq7d6++/fbbw6rj4osvVkZGhq688kpt3rxZf/
nLX6rvzGvojFV2drauuOIKzZ49W2+88YY2bdqkyy+/vFYoOvroo9WyZUs9+OCD2r59u1599VX9/Oc/
r7Wfnj17ysz06quvqqysTPv371fHjh3VqVMnPfLII9q2bZvefPNNTZ8+XZmZKZ/
xAQAphDAVA2PHjtWrr76qlStXaujQoRo6dKjuuusuHX300ZKk9u3b669//avGjx+vfv366dprr9XPf/5z/
eQnP5Ek/
ehHP9J5552ns846Szk5OXr22WcPq47s7Gz96U9/0qZNm3TKKafo+uuv18033yxJatWqVYPPW7BggUaPHq3z
zz9fo0eP1sCBAzVixIjq7Tk5OXriiSf08ssva8CAAZo/f75+8Ytf1NpH9+7dNX/
+fN10003q0qWLZs2apWbNmun555/Xxo0bNXDgQM2cOVO33nqrWrZseVg/HwAgctK93UFTWChzYqIhLy/
Pi4qK6t1WUlKi/v37x7ii9PTHP/5R559/
vvbs2aNOnTrFu5yw8HsDAJGTeUumKrxCGZah8rlMvzCzYnfPq28bZ6bSzBNPPKE1a9bo448/1tKlS3XNNdd
owoQJSR+kAACRlQg3cCULJqekmd27d2vevHkqLS1V165dNW7cON19993xLgsAkGAS4QauZEGYSjM33HBD9c
R3AAAQPi7zAQAAhIEwBQAAktfMmVJmZuBrnBCmAABIIynX8mDhQqmiIvA1TghTAACkkUT4zNqIys+XMjICX
+OEMAUAQBpJuZYHBQVSeXnga5xwNx8AAGmElgeRx5mpJPTiiy/
W+iy9RYsWKTs7O6x9rlq1SmamvXv3hlseAABphTAVQdOmTZOZyczUvHlzHXPMMbruuut04MCBqL7upEmTtH
379pDH9+rVSwsWLKi17owzzlBpaamOPPLISJcHAEDTJMAdek0RUpgys3PMbKuZbTOzOfVs72hmS8xso5mtM
7OBkS81OYwZM0alpaXavn27brvtNv3mN7/
Rddddd8i48vJyRepzEbOystS5c+ew9tGiRQt17dq11hkvAADiIgHu0GuKoGHKzDIkFUg6V9IASReZ2YA6w2
6U9J67D5J0qaQHIl1osmjZsqW6du2qHj16aMqUKbr44ov18ssv6+abb9bAgQO1aNEi9enTRy1bttSBAwe0b
98+XXXVVercubPatm2rkSNHqu4HQD/55JPq2bOnWrdurfHjx2v37t21ttd3me/VV1/
VaaedpqysLB155JGaMGGCvvnmG40aNUqffPKJrr/++uqzaFL9l/kWL16sE088US1btlSPHj10+
+231wqAvXr10m233ab8/
Hy1a9dOubm5uvfee2vVsXDhQh177LFq1aqVcnJyNHbsWJWX84GZABBpKdXyIAHu0GuKUM5MDZW0zd23u/
tBSc9JmlhnzABJKyTJ3bdI6mVmXSJaaZLKysrSd999J0n66KOP9Pvf/
14vvPCCNmzYoJYtW2rcuHHatWuXli5dqnfffVcjRozQmWeeqdLSUknSO++8o2nTpumqq67Se+
+9pwkTJmju3LmNvuby5cs1ceJEnX322SouLtbKlSs1cuRIVVZWavHixcrNzdXcuXNVWlpa/
Tp1FRcX68ILL9QPf/hD/f3vf9ddd92lO++8Uw8++GCtcffff79OPPFErV+/
XrNnz9YNN9ygt99+W5JUVFSkmTNnat68edq6dav+8pe/
6Jxzzgn3LQUA1COlWh4kwB16TeLujT4kXSDp0RrLl0h6sM6YOyT9our7oZLKJQ2pZ19XSSqSVHT00Ud7QzZ
v3tzgtiabMcM9IyPwNcqmTp3q48aNq15+5513/Mgjj/Qf//jHPm/ePM/MzPTPP/+8evuKFSu8TZs2/
vXXX9faz0knneR33323u7tfdNFFPmbMmFrbr7jiCg8cuoDHH3/
c27RpU718xhln+KRJkxqss2fPnn7vvffWWrdy5UqX5GVlZe7uPmXKFB89enStMfPmzfPu3bvX2s/
kyZNrjenbt6/feuut7u7+0ksvebt27fzLL79ssJZIiujvDQAkmRlLZ3jG/
AyfsTT6/96lI0lF3kBWCuXMVH2TaOpO9rlLUkcze0/
Sf0p6typQ1Q1uD7t7nrvn5eTkhPDSERDj667Lly9Xdna2WrVqpWHDhmnEiBH69a9/LUnKzc1Vly7/
PmFXXFysr7/+Wjk5OcrOzq5+vP/++/rwww8lSSUlJRo2bFit16i7XNe7776rs846K6yfo6SkRN/
73vdqrfv+97+vXbt26csvv6xeN2jQoFpjunXrpj179kiSzj77bPXs2VO9e/
fWxRdfrCeeeEJfffVVWHUBAOpXMK5A5XPLE7vtQZJNLA9VKH2mdkrqUWM5V9JnNQe4+5eSLpMkC0zC+ajqE
X/
5+YEgFaPrriNGjNDDDz+s5s2bq1u3bmrevHn1tjZt2tQaW1lZqS5dumjNmjWH7Kddu3aSFLFJ6k3l7g1ORq
+5vubP969tlZWVkqS2bdtq/
fr1Wr16td544w3deeeduvHGG1VYWKhu3bpFr3gAQGKqeYIjWS7hhSCUM1OFkvqZWW8zayFpsqRXag4wsw5V
2yTpp5JWVwWs+IvxddfWrVurb9+
+6tmz5yFBo67Bgwdr9+7datasmfr27Vvr8a+78wYMGKC1a9fWel7d5bpOOeUUrVixosHtLVq0UEVFRaP7GD
BggN56661a69566y3l5uaqbdu2jT63pszMTJ155pm68847tXHjRh04cEBLly4N+fkAgBSSZBPLQxX0zJS7l
5vZLEmvScqQ9Ji7bzKz6VXbH5LUX9KTZlYhabOkK6JYc8oYM2aMvve972nixIm65557dPzxx+vzzz/
X8uXLNWbMGA0fPlz/9V//pTPOOEN33nmnLrjgAq1atUpLlixpdL833XSTJkyYoL59+2rKlClyd73+
+uvKz89X69at1atXL61Zs0Y/
+clP1LJlS3Xq1OmQfVx77bU69dRTdfPNN2vKlCkqLCzUfffdpzvuuCPkn2/
p0qX68MMPNWLECB1xxBFauXKlvvrqK/Xv37/J7xUAIAUUFKTUGal/CanPlLsvc/
dj3b2Pu99ete6hqiAld3/b3fu5+/Hu/kN3/3/RLDpVmJmWLVumM888U1deeaWOO+44/
fjHP9bWrVurL4Odfvrp+t3vfqff/va3GjRokBYvXqybb7650f2ed955WrJkif785z/
rlFNO0ciRI7Vy5Uo1axY43Lfccot27NihPn36qKG5a4MHD9YLL7ygl156SQMHDtScOXM0Z84czZo1K+Sfr0
OHDnr55Zc1ZswYHX/88VqwYIEeffRRDR8+POR9AEA6S6l2BynM4jUnJy8vz+v2U/
qXkpISzl6gyfi9AZBqMm/JVIVXKMMyVD6XHn3xZGbF7p5X3zY+TgYAgASVPyRfGZah/CEJPscoRe/
SCxVnppAy+L0BgDjJzAzcpZeREbjpKwVxZgoAAERPit6lF6pQ+kwBAAA0LEXv0gtVwp6Z+lfjRyAU8bpcDQ
BAQoapNm3aaNeuXTp48CD/
SCIod9cXX3yhVq1axbsUAAgJLQ9SS0JOQK+srNTevXu1b98+lafoRDZEVqtWrZSbmxu06zwAJIKkaHkwc+a
/P44tjS/h/UtjE9ATcs5Us2bN1Llz5+qPVAEAIJXkD8nXwuKFid3yIEU/
Ry8aEvIyHwAAqaxgXIHK55arYFwCh5Q0v0OvKRLyMh8AAEAioc8UAABAlBCmAAAAwkCYAgAgQpKi5UGaf45
eNDBnCgCACEmKlgdp8Dl60cCcKQAAYiB/
SL4yLCOxWx5wl17EcWYKAAAgCM5MAQAARAlhCgAAIAyEKQAAUgF36cUNYQoAgEYkTUap+Vl6iCnCFAAAjUi
ajMJdenFDmAIAoBFJk1EKCgJ9owoS+MOTUxStEQAAAIKgNQIAAECUEKYAAEhUSTP7Pb0RpgAASFRJM/
s9vRGmAABpKSlO+iTN7Pf0xgR0AEBayswMnPTJyAjcBAc0hgnoAADUwUkfRAphCgCQluLalikprjEiVIQpA
ABijYnlKYUwBQBArHGNMaUwAR0AACAIJqADANIG05EQa4QpAEBKYToSYo0wBQBIKXGdjsRpsbTEnCkAACKF
TqApizlTAADEAnfppSXOTAEAAATBmSkAAIAoCSlMmdk5ZrbVzLaZ2Zx6trc3sz+Z2QYz22Rml0W+VABAumJ
eNxJZ0Mt8ZpYh6QNJZ0vaKalQ0kXuvrnGmBsltXf32WaWI2mrpK7ufrCh/
XKZDwAQKuZ1I97Cvcw3VNI2d99eFY6ekzSxzhiX1NbMTFK2pH9I4tcdABARzOtGIgslTHWXtKPG8s6qdTU9
KKm/pM8k/
V3S1e5eWXdHZnaVmRWZWVFZWdlhlgwASDcFBYEzUgUF8a4EOFQoYcrqWVf32uBYSe9J6ibpZEkPmlm7Q57k
/
rC757l7Xk5OThNLBQAASDyhhKmdknrUWM5V4AxUTZdJWuwB2yR9JOn4yJQIAACQuEIJU4WS+plZbzNrIWmy
pFfqjPlU0lmSZGZdJB0naXskCwUAAEhEQcOUu5dLmiXpNUklkv7g7pvMbLqZTa8adqukM8zs75JWSJrt7nu
jVTQAIDXQ8gCpgA7oAIC4oeUBkgUd0AEACYmWB0gFnJkCAAAIgjNTAAAAUUKYAgAACANhCgAAIAyEKQBAxN
HyAOmEMAUAiLiFCwMtDxYujHclQPQRpgA
AEUfLA6QTWiMAAAAEQWsEAACAKCFMAQAAhIEwBQAAEAbCFAAAQBgIUwCAkNA7CqgfYQoAEBJ6RwH1I0wBAE
JC7yigfvSZAgAACII+UwAAAFFCmAIAAAgDYQoAACAMhCkASHO0PADCQ5gCgDRHywMgPIQpAEhztDwAwkNrB
AAAgCBojQAAABAlhCkAAIAwEKYAAADCQJgCgBREuwMgdghTAJCCaHcAxA5hCgBSEO0OgNihNQIAAEAQtEYA
AACIEsIUAABAGAhTAAAAYSBMAUASoeUBkHgIUwCQRGh5ACQewhQAJBFaHgCJh9YIAAAAQdAaAQAAIEoIUwA
AAGEgTAEAAISBMAUACYCWB0DyCilMmdk5ZrbVzLaZ2Zx6tl9vZu9VPd43swozOyLy5QJAaqLlAZC8goYpM8
uQVCDpXEkDJF1kZgNqjnH3e939ZHc/
WdLPJL3p7v+IQr0AkJJoeQAkr1DOTA2VtM3dt7v7QUnPSZrYyPiLJD0bieIAIF0UFEjl5YGvAJJLKGGqu6Q
dNZZ3Vq07hJm1lnSOpJca2H6VmRWZWVFZWVlTawUAAEg4oYQpq2ddQ50+J0j6a0OX+Nz9YXfPc/
e8nJycUGsEAABIWKGEqZ2SetRYzpX0WQNjJ4tLfAAAII2EEqYKJfUzs95m1kKBwPRK3UFm1l7SSEl/
jGyJAJCcaHcApIegYcrdyyXNkvSapBJJf3D3TWY23cym1xh6vqTX3f1AdEoFgORCuwMgPWSGMsjdl0laVmf
dQ3WWF0laFKnCACDZ5ecHghTtDoDUZu4NzSWPrry8PC8qKorLawMAADSFmRW7e1592/
g4GQAAgDAQpgAAAMJAmAIAAAgDYQoAmoiWBwBqIkwBQBPR8gBATYQpAGii/
HwpI4OWBwACaI0AAAAQBK0RAAAAooQwBQAAEAbCFAAAQBgIUwBQhZYHAA4HYQoAqtDyAMDhIEwBQBVaHgA4
HLRGAAAACILWCAAAAFFCmAIAAAgDYQoAACAMhCkAKY12BwCijTAFIKXR7gBAtBGmAKQ02h0AiDZaIwAAAAR
BawQAAIAoIUwBAACEgTAFAAAQBsIUgKREywMAiYIwBSAp0fIAQKIgTAFISrQ8AJAoaI0AAAAQBK0RAAAAoo
QwBQAAEAbCFAAAQBgIUwASCi0PACQbwhSAhELLAwDJhjAFIKHQ8gBAsqE1AgAAQBC0RgAAAIgSwhQAAEAYC
FMAAABhIEwBiDraHQBIZYQpAFFHuwMAqSykMGVm55jZVjPbZmZzGhgzyszeM7NNZvZmZMsEkMxodwAglQVt
jWBmGZI+kHS2pJ2SCiVd5O6ba4zpIOlvks5x90/
NrLO772lsv7RGAAAAySLc1ghDJW1z9+3uflDSc5Im1hkzRdJid/
9UkoIFKQAAgFQRSpjqLmlHjeWdVetqOlZSRzNbZWbFZnZpfTsys6vMrMjMisrKyg6vYgAAgAQSSpiyetbVv
TaYKWmIpHGSxkr6uZkde8iT3B929zx3z8vJyWlysQAAAIkmlDC1U1KPGsu5kj6rZ8xydz/
g7nslrZZ0UmRKBJCoaHkAAKGFqUJJ/
cyst5m1kDRZ0it1xvxR0nAzyzSz1pJOk1QS2VIBJBpaHgBACGHK3cslzZL0mgIB6Q/
uvsnMppvZ9KoxJZKWS9ooaZ2kR939/
eiVDSAR0PIAAEJojRAttEYAAADJItzWCAAAAGgAYQoAACAMhCkAAIAwEKYAHIKWBwAQOsIUgEPQ8gAAQkeY
AnAIWh4AQOhojQAAABAErREAAACihDAFAAAQBsIUAABAGAhTQJqg3QEARAdhCkgTtDsAgOggTAFpgnYHABA
dtEYAAAAIgtYIAAAAUUKYAgAACANhCgAAIAyEKSDJ0fIAAOKLMAUkOVoeAEB8EaaAJEfLAwCIL1ojAAAABE
FrBAAAgCghTAEAAISBMAUAABAGwhSQoGh5AADJgTAFJChaHgBAciBMAQmKlgcAkBxojQAAABAErREAAACih
DAFAAAQBsIUAABAGAhTAAAAYSBMATFE7ygASD2EKSCG6B0FAKmHMAXEEL2jACD10GcKAAAgCPpMAQAARAlh
CgAAIAyEKQAAgDAQpoAIoOUBAKQvwhQQAbQ8AID0RZgCIoCWBwCQvkIKU2Z2jpltNbNtZjannu2jzGyfmb1
X9Zgb+VKBxFVQIJWXB74CANJLZrABZpYhqUDS2ZJ2Sio0s1fcfXOdoWvcfXwUagQAAEhYoZyZGippm7tvd/
eDkp6TNDG6ZQEAACSHUMJUd0k7aizvrFpX1zAz22BmfzazE+rbkZldZWZFZlZUVlZ2GOUCAAAkllDClNWzr
u5n0KyX1NPdT5L0a0kv17cjd3/Y3fPcPS8nJ6dJhQKxRrsDAEAoQglTOyX1qLGcK+mzmgPc/Ut331/1/
TJJzc2sU8SqBOKAdgcAgFCEEqYKJfUzs95m1kLSZEmv1BxgZl3NzKq+H1q13y8iXSwQS7Q7AACEIujdfO5e
bmazJL0mKUPSY+6+ycymV21/
SNIFkv6PmZVL+qekye5e91IgkFQKCmh1AAAIzuKVefLy8ryoqCgurw0AANAUZlbs7nn1baMDOgAAQBgIUwA
AAGEgTCHt0PIAABBJhCmkHVoeAAAiiTCFtEPLAwBAJHE3HwAAQBDczQcAABAlhCkAAIAwEKYAAADCQJhCyq
DlAQAgHghTSBm0PAAAxANhCimDlgcAgHigNQIAAEAQtEYAAACIEsIUAABAGAhTAAAAYSBMIaHR7gAAkOgIU
0hotDsAACQ6whQSGu0OAACJjtYIAAAAQdAaAQAAIEoIUwAAAGEgTAEAAISBMIW4oOUBACBVEKYQF7Q8AACk
CsIU4oKWBwCAVEFrBAAAgCBojQAAABAlhCkAAIAwEKYAAADCQJhCRNHyAACQbghTiChaHgAA0g1hChFFywM
AQLqhNQIAAEAQtEYAAACIEsIUAABAGAhTAAAAYSBMISjaHQAA0DDCFIKi3QEAAA0jTCEo2h0AANAwWiMAAA
AEEXZrBDM7x8y2mtk2M5vTyLhTzazCzC443GIBAACSSdAwZWYZkgoknStpgKSLzGxAA+PulvRapIsEAABIV
KGcmRoqaZu7b3f3g5KekzSxnnH/
KeklSXsiWB8AAEBCCyVMdZe0o8byzqp11cysu6TzJT3U2I7M7CozKzKzorKysqbWigij5QEAAOELJUxZPev
qzlr/paTZ7l7R2I7c/
WF3z3P3vJycnBBLRLTQ8gAAgPCFEqZ2SupRYzlX0md1xuRJes7MPpZ0gaTfmNkPIlEgooeWBwAAhC9oawQz
y5T0gaSzJO2SVChpirtvamD8IklL3f3FxvZLawQAAJAsGmuNkBnsye5ebmazFLhLL0PSY+6+ycymV21vdJ4
UAABAKgsapiTJ3ZdJWlZnXb0hyt2nhV8WAABAcuDjZAAAAMJAmEpBtDwAACB2CFMpiJYHAADEDmEqBdHyAA
CA2AnaGiFaaI0AAACSRWOtETgzBQAAEAbCFAAAQBgIUwAAAGEgTCUJ2h0AAJCYCFNJgnYHAAAkJsJUkqDdA
QAAiYnWCAAAAEHQGgEAACBKCFMAAABhIEwBAACEgTAVZ7Q8AAAguRGm4oyWBwAAJDfCVJzR8gAAgORGawQA
AIAgaI0AAAAQJYQpAACAMBCmAAAAwkCYihJaHgAAkB4IU1FCywMAANIDYSpKaHkAAEB6oDUCAABAELRGAAA
AiBLCFAAAQBgIUwAAAGEgTDUB7Q4AAEBdhKkmoN0BAACoizDVBLQ7AAAAddEaAQAAIAhaIwAAAEQJYQoAAC
AMhCkAAIAwEKZEywMAAHD4CFOi5QEAADh8hCnR8gAAABw+WiMAAAAEQWsEAACAKAkpTJnZOWa21cy2mdmce
rZPNLONZvaemRWZ2fcjXyoAAEDiyQw2wMwyJBVIOlvSTkmFZvaKu2+uMWyFpFfc3c1skKQ/
SDo+GgUDAAAkklDOTA2VtM3dt7v7QUnPSZpYc4C77/d/
T75qIyk+E7EAAABiLJQw1V3SjhrLO6vW1WJm55vZFkmvSro8MuWFh/5RAAAg2kIJU1bPukPOPLn7Enc/
XtIPJN1a747MrqqaU1VUVlbWpEIPB/
2jAABAtIUSpnZK6lFjOVfSZw0NdvfVkvqYWad6tj3s7nnunpeTk9PkYpuK/
lEAACDaQglThZL6mVlvM2shabKkV2oOMLO+ZmZV3w+W1ELSF5EutqkKCqTy8sBXAACAaAh6N5+7l5vZLEmv
ScqQ9Ji7bzKz6VXbH5L0I0mXmtl3kv4paZLHqxsoAABADNEBHQAAIAg6oAMAAEQJYQoAACAMhCkAAIAwEKY
AAADCQJgCAAAIA2EKAAAgDIQpAACAMBCmAAAAwkCYAgAACANhCgAAIAyEKQAAgDAQpgAAAMIQtw86NrMySZ
/
E4KU6Sdobg9dB03FsEhvHJ3FxbBIbxydxhXNserp7Tn0b4hamYsXMihr6lGfEF8cmsXF8EhfHJrFxfBJXtI
4Nl/
kAAADCQJgCAAAIQzqEqYfjXQAaxLFJbByfxMWxSWwcn8QVlWOT8nOmAAAAoikdzkwBAABEDWEKAAAgDCkRp
szsHDPbambbzGxOPdvNzH5VtX2jmQ2OR53pKoTjc3HVcdloZn8zs5PiUWc6CnZsaow71cwqzOyCWNaX7kI5
PmY2yszeM7NNZvZmrGtMVyH8vdbezP5kZhuqjs1l8agzHZnZY2a2x8zeb2B75DOBuyf1Q1KGpA8lHSOphaQ
NkgbUGXOepD9LMkmnS3on3nWnyyPE43OGpI5V35/
L8UmcY1Nj3P+VtEzSBfGuO10eIf7Z6SBps6Sjq5Y7x7vudHiEeGxulHR31fc5kv4hqUW8a0+Hh6QRkgZLer
+B7RHPBKlwZmqopG3uvt3dD0p6TtLEOmMmSnrSA9ZK6mBmR8W60DQV9Pi4+9/c/
f9VLa6VlBvjGtNVKH92JOk/Jb0kaU8si0NIx2eKpMXu/qkkuTvHKDZCOTYuqa2ZmaRsBcJUeWzLTE/
uvlqB97shEc8EqRCmukvaUWN5Z9W6po5BdDT1vb9Cgf8xIPqCHhsz6y7pfEkPxbAuBITyZ+dYSR3NbJWZFZ
vZpTGrLr2FcmwelNRf0meS/i7panevjE15CCLimSAzrHISg9Wzrm6/h1DGIDpCfu/NbLQCYer7Ua0I/
xLKsfmlpNnuXhH4DzZiKJTjkylpiKSzJGVJetvM1rr7B9EuLs2FcmzGSnpP0pmS+kh6w8zWuPuXUa4NwUU8
E6RCmNopqUeN5VwF/
ifQ1DGIjpDeezMbJOlRSee6+xcxqi3dhXJs8iQ9VxWkOkk6z8zK3f3lmFSY3kL9u22vux+QdMDMVks6SRJh
KrpCOTaXSbrLA5N0tpnZR5KOl7QuNiWiERHPBKlwma9QUj8z621mLSRNlvRKnTGvSLq0agb/
6ZL2uXtprAtNU0GPj5kdLWmxpEv4H3VMBT027t7b3Xu5ey9JL0qaQZCKmVD+bvujpOFmlmlmrSWdJqkkxnW
mo1COzacKnDGUmXWRdJyk7TGtEg2JeCZI+jNT7l5uZrMkvabAHRaPufsmM5tetf0hBe5COk/SNklfK/A/
BsRAiMdnrqQjJf2m6gxIufOJ61EX4rFBnIRyfNy9xMyWS9ooqVLSo+5e7+3giJwQ/
+zcKmmRmf1dgctKs919b9yKTiNm9qykUZI6mdlOSfMkNZeilwn4OBkAAIAwpMJlPgAAgLghTAEAAISBMAUA
ABAGwhQAAEAYCFMAAABhIEwBAACEgTAFAAAQhv8P+PKB+Xi+DiYAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_predictions(predictions=y_preds)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fEHGrjLgji6E"
},
"source": [
"Woohoo! Those red dots are looking far closer than they were before!\n",
"\n",
"Let's get onto saving and reloading a model in PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8NRng9aEec05"
},
"source": [
"## 5. Saving and loading a PyTorch model\n",
"\n",
"If you've trained a PyTorch model, chances are you'll want to save it and
export it somewhere.\n",
"\n",
"As in, you might train it on Google Colab or your local machine with a GPU but
you'd like to now export it to some sort of application where others can use it. \
n",
"\n",
"Or maybe you'd like to save your progress on a model and come back and load it
back later.\n",
"\n",
"For saving and loading models in PyTorch, there are three main methods you
should be aware of (all of below have been taken from the [PyTorch saving and
loading models
guide](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-
loading-model-for-inference)):\n",
"\n",
"| PyTorch method | What does it do? | \n",
"| ----- | ----- |\n",
"| [`torch.save`](https://pytorch.org/docs/stable/torch.html?
highlight=save#torch.save) | Saves a serialized object to disk using Python's
[`pickle`](https://docs.python.org/3/library/pickle.html) utility. Models, tensors
and various other Python objects like dictionaries can be saved using `torch.save`.
| \n",
"| [`torch.load`](https://pytorch.org/docs/stable/torch.html?highlight=torch
%20load#torch.load) | Uses `pickle`'s unpickling features to deserialize and load
pickled Python object files (like models, tensors or dictionaries) into memory. You
can also set which device to load the object to (CPU, GPU etc). |\n",
"| [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/
generated/torch.nn.Module.html?
highlight=load_state_dict#torch.nn.Module.load_state_dict)| Loads a model's
parameter dictionary (`model.state_dict()`) using a saved `state_dict()` object.
| \n",
"\n",
"> **Note:** As stated in [Python's `pickle`
documentation](https://docs.python.org/3/library/pickle.html), the `pickle` module
**is not secure**. That means you should only ever unpickle (load) data you trust.
That goes for loading PyTorch models as well. Only ever use saved PyTorch models
from sources you trust.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdAGcH2aec05"
},
"source": [
"### Saving a PyTorch model's `state_dict()`\n",
"\n",
"The [recommended
way](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-
loading-model-for-inference) for saving and loading a model for inference (making
predictions) is by saving and loading a model's `state_dict()`.\n",
"\n",
"Let's see how we can do that in a few steps:\n",
"\n",
"1. We'll create a directory for saving models to called `models` using
Python's `pathlib` module.\n",
"2. We'll create a file path to save the model to.\n",
"3. We'll call `torch.save(obj, f)` where `obj` is the target model's
`state_dict()` and `f` is the filename of where to save the model.\n",
"\n",
"> **Note:** It's common convention for PyTorch saved models or objects to end
with `.pt` or `.pth`, like `saved_model_01.pth`.\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qsQhY2S2jv90",
"outputId": "a897070c-a843-4a7c-a06e-e6406206412c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving model to: models/01_pytorch_workflow_model_0.pth\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"# 1. Create models directory \n",
"MODEL_PATH = Path(\"models\")\n",
"MODEL_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# 2. Create model save path \n",
"MODEL_NAME = \"01_pytorch_workflow_model_0.pth\"\n",
"MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n",
"\n",
"# 3. Save the model state dict \n",
"print(f\"Saving model to: {MODEL_SAVE_PATH}\")\n",
"torch.save(obj=model_0.state_dict(), # only saving the state_dict() only saves
the models learned parameters\n",
" f=MODEL_SAVE_PATH) "
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mpQc45zwec06",
"outputId": "50e1b51b-1b98-41f1-ca36-ce9cb5682064"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-rw-rw-r-- 1 daniel daniel 1063 Nov 10 16:07
models/01_pytorch_workflow_model_0.pth\n"
]
}
],
"source": [
"# Check the saved file path\n",
"!ls -l models/01_pytorch_workflow_model_0.pth"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jFQpRoH5ec06"
},
"source": [
"### Loading a saved PyTorch model's `state_dict()`\n",
"\n",
"Since we've now got a saved model `state_dict()` at
`models/01_pytorch_workflow_model_0.pth` we can now load it in using
`torch.nn.Module.load_state_dict(torch.load(f))` where `f` is the filepath of our
saved model `state_dict()`.\n",
"\n",
"Why call `torch.load()` inside `torch.nn.Module.load_state_dict()`? \n",
"\n",
"Because we only saved the model's `state_dict()` which is a dictionary of
learned parameters and not the *entire* model, we first have to load the
`state_dict()` with `torch.load()` and then pass that `state_dict()` to a new
instance of our model (which is a subclass of `nn.Module`).\n",
"\n",
"Why not save the entire model?\n",
"\n",
"[Saving the entire
model](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-
entire-model) rather than just the `state_dict()` is more intuitive, however, to
quote the PyTorch documentation (italics mine):\n",
"\n",
"> The disadvantage of this approach *(saving the whole model)* is that the
serialized data is bound to the specific classes and the exact directory structure
used when the model is saved...\n",
">\n",
"> Because of this, your code can break in various ways when used in other
projects or after refactors.\n",
"\n",
"So instead, we're using the flexible method of saving and loading just the
`state_dict()`, which again is basically a dictionary of model parameters.\n",
"\n",
"Let's test it out by creating another instance of `LinearRegressionModel()`,
which is a subclass of `torch.nn.Module` and will hence have the in-built method
`load_state_dict()`."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1xnh3cFDec06",
"outputId": "7ef66bf8-122e-476a-ee86-b1c388d6167c"
},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Instantiate a new instance of our model (this will be instantiated with
random weights)\n",
"loaded_model_0 = LinearRegressionModel()\n",
"\n",
"# Load the state_dict of our saved model (this will update the new instance of
our model with trained weights)\n",
"loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vK8PRtY7Qgpz"
},
"source": [
"Excellent! It looks like things matched up.\n",
"\n",
"Now to test our loaded model, let's perform inference with it (make
predictions) on the test data.\n",
"\n",
"Remember the rules for performing inference with PyTorch models?\n",
"\n",
"If not, here's a refresher:\n",
"\n",
"<details>\n",
" <summary>PyTorch inference rules</summary>\n",
" <ol>\n",
" <li> Set the model in evaluation mode (<code>model.eval()</code>).
</li>\n",
" <li> Make the predictions using the inference mode context manager
(<code>with torch.inference_mode(): ...</code>). </li>\n",
" <li> All predictions should be made with objects on the same device
(e.g. data and model on GPU only or data and model on CPU only).</li>\n",
" </ol> \n",
"</details>\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "Ps-AuJqkec06"
},
"outputs": [],
"source": [
"# 1. Put the loaded model into evaluation mode\n",
"loaded_model_0.eval()\n",
"\n",
"# 2. Use the inference mode context manager to make predictions\n",
"with torch.inference_mode():\n",
" loaded_model_preds = loaded_model_0(X_test) # perform a forward pass on
the test data with the loaded model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e81XpN8WSSqn"
},
"source": [
"Now we've made some predictions with the loaded model, let's see if they're
the same as the previous predictions."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "il9gqj6Nec06",
"outputId": "56210de9-9888-4e90-d2e7-6cd0de47f823"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True]])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Compare previous model predictions with loaded model predictions (these
should be the same)\n",
"y_preds == loaded_model_preds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Y4ZcxxfNcVu"
},
"source": [
"Nice! \n",
"\n",
"It looks like the loaded model predictions are the same as the previous model
predictions (predictions made prior to saving). This indicates our model is saving
and loading as expected.\n",
"\n",
"> **Note:** There are more methods to save and load PyTorch models but I'll
leave these for extra-curriculum and further reading. See the [PyTorch guide for
saving and loading
models](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-
and-loading-models) for more. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FeAITvLXec06"
},
"source": [
"## 6. Putting it all together \n",
"\n",
"We've covered a fair bit of ground so far. \n",
"\n",
"But once you've had some practice, you'll be performing the above steps like
dancing down the street.\n",
"\n",
"Speaking of practice, let's put everything we've done so far together. \n",
"\n",
"Except this time we'll make our code device agnostic (so if there's a GPU
available, it'll use it and if not, it will default to the CPU). \n",
"\n",
"There'll be far less commentary in this section than above since what we're
going to go through has already been covered.\n",
"\n",
"We'll start by importing the standard libraries we need.\n",
"\n",
"> **Note:** If you're using Google Colab, to setup a GPU, go to Runtime ->
Change runtime type -> Hardware acceleration -> GPU. If you do this, it will reset
the Colab runtime and you will lose saved variables."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "8hZ3CWhAIpUF",
"outputId": "60b4e98b-8d83-4573-cbe2-131df190b223"
},
"outputs": [
{
"data": {
"text/plain": [
"'1.12.1+cu113'"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Import PyTorch and matplotlib\n",
"import torch\n",
"from torch import nn # nn contains all of PyTorch's building blocks for neural
networks\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Check PyTorch version\n",
"torch.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bT-krbNMIw0d"
},
"source": [
"Now let's start making our code device agnostic by setting `device=\"cuda\"`
if it's available, otherwise it'll default to `device=\"cpu\"`.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sx2Zpb5sec06",
"outputId": "88323445-9070-4b3d-a62a-3d924d8d6898"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"# Setup device agnostic code\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G1t0Ek0GJq6T"
},
"source": [
"If you've got access to a GPU, the above should've printed out:\n",
"\n",
"```\n",
"Using device: cuda\n",
"```\n",
"Otherwise, you'll be using a CPU for the following computations. This is fine
for our small dataset but it will take longer for larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DmilLp3Vec07"
},
"source": [
"### 6.1 Data\n",
"\n",
"Let's create some data just like before.\n",
"\n",
"First, we'll hard-code some `weight` and `bias` values.\n",
"\n",
"Then we'll make a range of numbers between 0 and 1, these will be our `X`
values.\n",
"\n",
"Finally, we'll use the `X` values, as well as the `weight` and `bias` values
to create `y` using the linear regression formula (`y = weight * X + bias`)."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fJqgDWUfec07",
"outputId": "62d07f54-bb59-4327-a153-79be9ada83d7"
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[0.0000],\n",
" [0.0200],\n",
" [0.0400],\n",
" [0.0600],\n",
" [0.0800],\n",
" [0.1000],\n",
" [0.1200],\n",
" [0.1400],\n",
" [0.1600],\n",
" [0.1800]]),\n",
" tensor([[0.3000],\n",
" [0.3140],\n",
" [0.3280],\n",
" [0.3420],\n",
" [0.3560],\n",
" [0.3700],\n",
" [0.3840],\n",
" [0.3980],\n",
" [0.4120],\n",
" [0.4260]]))"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create weight and bias\n",
"weight = 0.7\n",
"bias = 0.3\n",
"\n",
"# Create range values\n",
"start = 0\n",
"end = 1\n",
"step = 0.02\n",
"\n",
"# Create X and y (features and labels)\n",
"X = torch.arange(start, end, step).unsqueeze(dim=1) # without unsqueeze,
errors will happen later on (shapes within linear layers)\n",
"y = weight * X + bias \n",
"X[:10], y[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oaar6rDGLGaQ"
},
"source": [
"Wonderful!\n",
"\n",
"Now we've got some data, let's split it into training and test sets.\n",
"\n",
"We'll use an 80/20 split with 80% training data and 20% testing data."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lQoo65evec07",
"outputId": "80c3f9b7-4d1d-4aef-fc19-7abceaf93eb2"
},
"outputs": [
{
"data": {
"text/plain": [
"(40, 40, 10, 10)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Split data\n",
"train_split = int(0.8 * len(X))\n",
"X_train, y_train = X[:train_split], y[:train_split]\n",
"X_test, y_test = X[train_split:], y[train_split:]\n",
"\n",
"len(X_train), len(y_train), len(X_test), len(y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "INW8-McyLeFE"
},
"source": [
"Excellent, let's visualize them to make sure they look okay."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 428
},
"id": "gxhc0zCdec07",
"outputId": "cc3cb921-0d25-4cec-d681-da102547bdb9"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAlMAAAGbCAYAAADgEhWsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAm+ElEQVR4nO3dfXRUhb3u8eeXBARBIJQAEhRQUUEECxHLWV
VA60EFFuXaLkCroFbDBc6SdXzB6hGL2q622lq95tRgD6W+VG0VWwoc0OMVQSuSgEINQU9EhCAlCb0LBVshy
e/
+MTlpEpLMhD3v8/2sNSvZLzPzIxv0yZ49z5i7CwAAACcmK9EDAAAApDLCFAAAQACEKQAAgAAIUwAAAAEQpg
AAAALISdQT9+nTxwcPHpyopwcAAIjYli1batw9r7VtCQtTgwcPVmlpaaKeHgAAIGJm9klb23iZDwAAIADCF
AAAQACEKQAAgAAIUwAAAAEQpgAAAAII+24+M1smaYqkKncf0cp2k/
SopKskfSFpjrtvDTrYZ599pqqqKh07dizoQyEDdOvWTQMHDlRWFr8fAADiK5JqhOWSHpf0VBvbr5Q0tOF2k
aRfNHw9YZ999pkOHDig/Px8de3aVaG8BrSuvr5e+/btU01Njfr27ZvocQAAGSbsr/
HuvkHSX9vZZZqkpzxkk6ReZnZqkKGqqqqUn5+vk08+mSCFsLKystSvXz8dOnQo0aMAADJQNF4TyZe0t8lyZ
cO6E3bs2DF17do10FDILJ06dVJtbW2ixwAAZKBohKnWTh15qzua3WJmpWZWWl1d3f6DckYKHcDfFwBAokQj
TFVKOq3J8kBJn7a2o7svdfcCdy/
Iy2v1420AAABSSjTC1EpJ11vI1yQdcvf9UXhcAACApBc2TJnZc5LelnSOmVWa2U1mNtfM5jbsskbSLkkVkp
6UNC9m02agOXPmaMqUKR26z4QJE7RgwYIYTdS+BQsWaMKECQl5bgAAEiFsNYK7zwqz3SXNj9pEKSrcNTuzZ
8/W8uXLO/y4jz76qEI/4sitWLFCnTp16vBzJcLu3bs1ZMgQlZSUqKCgINHjAADQYZH0TCEC+/f/
45XNVatW6eabb262ruW7E48dOxZR4OnZs2eHZ+ndu3eH7wMAAE4MddFR0r9//8Zbr169mq37+9//
rl69eum5557TpZdeqq5du6q4uFgHDx7UrFmzNHDgQHXt2lXnnXeefvWrXzV73JYv802YMEHz5s3T3XffrT5
9+qhv3766/fbbVV9f32yfpi/zDR48WA8++KAKCwvVo0cPDRw4UA899FCz5/
nwww81fvx4denSReecc47WrFmj7t27t3s2ra6uTrfffrtyc3OVm5urhQsXqq6urtk+a9eu1cUXX6zc3Fz17
t1bkyZNUnl5eeP2IUOGSJIuvPBCmVnjS4QlJSX653/+Z/
Xp00c9evTQ17/+db399tvhDwQAIKPMXz1fOffnaP7qxL1IRpiKo+9973uaN2+eduzYoW9+85v6+9//
rtGjR2vVqlUqKyvTrbfeqsLCQr322mvtPs6zzz6rnJwc/elPf9Ljjz+un//853rhhRfavc8jjzyi888/
X1u3btWiRYt05513NoaT+vp6TZ8+XTk5Odq0aZOWL1+uJUuW6Msvv2z3MX/605/qySefVHFxsd5+
+23V1dXp2WefbbbPkSNHtHDhQm3evFnr169Xz549NXXqVB09elSStHnzZkmh0LV//36tWLFCkvT555/
ruuuu08aNG7V582ZdcMEFuuqqq1RTU9PuTACAzFK8pVh1XqfiLcWJG8LdE3IbM2aMt2XHjh1tbuuoefPcs7
NDX+Pld7/7nYd+tCEff/yxS/
KHH3447H1nzJjhN910U+Py7NmzffLkyY3L48eP96997WvN7vONb3yj2X3Gjx/
v8+fPb1weNGiQz5w5s9l9zjrrLH/ggQfc3X3t2rWenZ3tlZWVjdvfeustl+S/
+tWv2pz11FNP9QcffLBxua6uzocOHerjx49v8z6HDx/2rKws37hxo7v/
42dTUlLS5n3c3evr671///7+9NNPt7lPNP/eAABSw7xV8zx7SbbPWxXb/
9FLKvU2Mk3an5kqLpbq6kJfE63lBdZ1dXX6wQ9+oJEjR+orX/
mKunfvrhUrVmjPnj3tPs7IkSObLQ8YMEBVVVUnfJ+dO3dqwIABys//R3H9hRde2O6HBh86dEj79+/
XuHHjGtdlZWXpoouafyzjRx99pGuuuUZnnnmmevTooX79+qm+vj7sn7GqqkqFhYU6+
+yz1bNnT51yyimqqqoKez8AQGYpmlyk2sW1KppclLAZ0v4C9MLCUJAqLEz0JFK3bt2aLT/88MP66U9/
qkcffVTnn3++unfvrrvvvjtsMGp54bqZNbtmqqP3cfeYNYhPnTpV+fn5Ki4uVn5+vnJycjR8+PDGl/
naMnv2bB04cECPPPKIBg8erJNOOkmXXXZZ2PsBABBvaR+miopCt2T05ptvaurUqbruuuskhULNhx9+2HgBe
7wMGzZM+/bt06effqoBAwZIkkpLS9sNaD179tSpp56qTZs26dJLL5UUmn/
z5s069dTQ51wfPHhQ5eXlKioq0sSJEyVJW7dubfYZep07d5ak4y5cf/PNN/
XYY49p8uTJkqQDBw40e3ckAADJIu1f5ktmZ599tl577TW9+eab2rlzpxYsWKCPP/
447nNcfvnlOuecczR79mxt27ZNmzZt0r/+678qJyen3TNWt956q37yk5/
oxRdf1AcffKCFCxc2Czy5ubnq06ePnnzySVVUVOiNN97Q3LlzlZPzjwzft29fde3aVevWrdOBAwd06NAhSa
GfzTPPPKMdO3aopKREM2fObAxeAAAkE8JUAv3bv/2bxo4dqyuvvFKXXHKJunXrpmuvvTbuc2RlZenll1/
Wl19+qbFjx2r27Nm65557ZGbq0qVLm/
e77bbbdMMNN+i73/2uLrroItXX1zebPysrSy+88IK2b9+uESNGaP78+XrggQd00kknNe6Tk5Ojxx57TL/
85S81YMAATZs2TZK0bNkyHT58WGPGjNHMmTN14403avDgwTH7GQAAkkcy1B10hHkH27WjpaCgwEtLS1vdVl
5ermHDhsV5IjS1bds2XXDBBSotLdWYMWMSPU5E+HsDAOkh5/4c1Xmdsi1btYtrw98hDsxsi7u3+lEdnJmCJ
Onll1/WK6+8oo8//livv/665syZo1GjRmn06NGJHg0AkGEKxxQq27JVOCYJ3j0WgbS/
AB2R+fzzz7Vo0SLt3btXubm5mjBhgh555JGYvcsPAIC2FE0uSmjVQUcRpiBJuv7663X99dcnegwAAFIOL/
MBAAAEQJgCAAAIgDAFAADiItUqDyJFmAIAAHFRvKVYdV6n4i1J8IG5UUSYAgAAcZFqlQeR4t18AAAgLlKt8
iBSnJlKYYMHD9bDDz+ckOeeMmWK5syZk5DnBgAgmRCmosTM2r0FCR7f//73NWLEiOPWl5SUaN68eQGmjp/
169fLzFRTU5PoUQAAiCpe5ouS/fv3N36/atUq3Xzzzc3Wde3aNerPmZeXF/
XHBAAAHcOZqSjp379/461Xr17HrduwYYPGjBmjLl26aMiQIbrnnnt09OjRxvuvWLFCI0eOVNeuXdW7d2+NH
z9eBw4c0PLly7VkyRKVlZU1nuVavny5pONf5jMzLV26VN/
+9rfVrVs3nXHGGXrmmWeazfnOO+9o9OjR6tKli7761a9qzZo1MjOtX7++zT/
bF198oTlz5qh79+7q16+ffvjDHx63zzPPPKMLL7xQp5xyivr27atvf/
vb2rdvnyRp9+7dmjhxoqRQAGx6pm7t2rW6+OKLlZubq969e2vSpEkqLy/
v6I8fAJBA6Vp5ECnCVBysW7dO1157rRYsWKCysjItW7ZML774ou6++25J0l/
+8hfNnDlTs2fPVnl5uTZs2KDrrrtOkjRjxgzddtttOuecc7R//37t379fM2bMaPO57r//
fk2bNk3btm3TjBkzdOONN+qTTz6RJB0+fFhTpkzRueeeqy1btugnP/
mJ7rjjjrDz33777Xr11Vf10ksv6bXXXtO7776rDRs2NNvn6NGjWrJkibZt26ZVq1appqZGs2bNkiSddtppe
umllyRJZWVl2r9/
vx599FFJ0pEjR7Rw4UJt3rxZ69evV8+ePTV16tRmQRMAkNzStfIgYu6ekNuYMWO8LTt27GhzW0fNWzXPs5d
k+7xV86L2mOH87ne/89CPNuTiiy/2+++/
v9k+L7/8snfr1s3r6+t9y5YtLsl3797d6uPdd999ft555x23ftCgQf7QQw81Lkvyu+66q3H52LFj3rVrV3/
66afd3f2JJ57w3Nxc/+KLLxr3efbZZ12Sv/
76660+9+eff+6dO3f2Z555ptm6nj17+uzZs9v8GZSXl7sk37t3r7u7v/766y7Jq6ur27yPu/
vhw4c9KyvLN27c2O5+rYnm3xsAQOQS8f/aeJNU6m1kmrQ/M5UMaXnLli36wQ9+oO7duzferrnmGh05ckR/
+ctfNGrUKH3jG9/QiBEjdPXVV+sXv/
iFqqurT+i5Ro4c2fh9Tk6O8vLyVFVVJUnauXOnRowY0ez6rYsuuqjdx/
voo4909OhRjRs3rnFd9+7ddf755zfbb+vWrZo2bZoGDRqkU045RQUFBZKkPXv2hH38a665RmeeeaZ69Oihf
v36qb6+Puz9AADJo2hykWoX16Zl7UEk0j5MJUNBWH19ve677z699957jbft27frv//
7v5WXl6fs7Gy98soreuWVVzRy5Ej9x3/8h4YOHapt27Z1+Lk6derUbNnMVF9fLyl0FtLMOvR4oTDeviNHjm
jSpEk6+eST9fTTT6ukpERr166VpLAv102dOlXV1dUqLi7WO++8o3fffVc5OTm8zAcASBlp/
26+ZCgIGz16tHbu3KmzzjqrzX3MTOPGjdO4ceO0ePFinXfeeXrhhRc0atQode7cWXV1dYHnGDZsmJ566in9
7W9/azw7tXnz5nbvc9ZZZ6lTp07atGmTzjjjDEmh8PT+++/rzDPPlBQ641VTU6Mf/
vCHGjJkiKTQBfVNde7cWZKa/TkOHjyo8vJyFRUVNV6gvnXrVtXW1gb+swIAEC9pf2YqGSxevFi/
+c1vtHjxYr3//
vvauXOnXnzxRd15552SpE2bNunBBx9USUmJ9uzZo5UrV2rv3r0aPny4pNC79j755BNt3bpVNTU1+vLLL09o
jmuvvVbZ2dm6+eabtWPHDv3Xf/1X4zvz2jpj1b17d910001atGiRXn31VZWVlenGG29sFopOP/
10nXTSSXr88ce1a9curV69Wvfee2+zxxk0aJDMTKtXr1Z1dbUOHz6s3Nxc9enTR08+
+aQqKir0xhtvaO7cucrJSfuMDwBII4SpOJg0aZJWr16t119/XWPHjtXYsWP1ox/
9SKeffrokqWfPnnrrrbc0ZcoUDR06VLfddpvuvfdefec735EkXX311brqqqt02WWXKS8vT88999wJzdG9e3
f98Y9/VFlZmb761a/qjjvu0Pe//31JUpcuXdq838MPP6yJEydq+vTpmjhxokaMGKFLLrmkcXteXp5+/
etf6/e//72GDx+uJUuW6Gc/+1mzx8jPz9eSJUt0zz33qF+/
flqwYIGysrL0wgsvaPv27RoxYoTmz5+vBx54QCeddNIJ/
fkAANGT6XUHHWGRXBMTCwUFBV5aWtrqtvLycg0bNizOE2WmP/
zhD5o+fbqqqqrUp0+fRI8TCH9vACB6cu7PUZ3XKduyVbuYyy/MbIu7F7S2jTNTGebXv/
61Nm7cqN27d2vVqlVauHChpk6dmvJBCgAQXcnwBq5UwcUpGebAgQO67777tH//fvXv31+TJ0/
Wj3/840SPBQBIMsnwBq5UQZjKMHfeeWfjhe8AACA4XuYDAAAIIGnD1P8UTQKRSNQbKQAASMow1a1bN+3bt0
9Hjx7lf5IIy9118ODBdusdAAAhVB5EX1JWI9TX16umpkaHDh2iDRsR6dKliwYOHHjcx+kAAJqj8uDEtFeNk
JQXoGdlZalv377q27dvokcBACCtFI4pVPGWYioPoigpz0wBAAAkE0o7AQAAYoQwBQAAEEBEYcrMrjCzD8ys
wszuamV7rpm9bGbbzWyzmY2I/qgAAADJJ2yYMrNsSUWSrpQ0XNIsMxveYre7Jb3n7iMlXS/
p0WgPCgAA2kblQeJEcmZqrKQKd9/
l7kclPS9pWot9hkt6TZLcfaekwWbWL6qTAgCANhVvKVad16l4S3GiR8k4kYSpfEl7myxXNqxrapuk/
yVJZjZW0iBJA1s+kJndYmalZlZaXV19YhMDAIDjFI4pVLZlU3mQAJH0TFkr61r2KfxI0qNm9p6kP0t6V9Jx
TWDuvlTSUilUjdChSQEAQJuKJhepaHJRosfISJGEqUpJpzVZHijp06Y7uPtnkm6QJDMzSR833AAAANJaJC/
zlUgaamZDzKyzpJmSVjbdwcx6NWyTpO9K2tAQsAAAANJa2DNT7l5rZgskrZOULWmZu5eZ2dyG7U9IGibpKT
Ork7RD0k0xnBkAACBpRPTZfO6+RtKaFuueaPL925KGRnc0AAAy2/zV8xs/R4/
roZIXDegAACQp6g5SA2EKAIAkRd1BajD3xDQUFBQUeGlpaUKeGwAAoCPMbIu7F7S2jTNTAAAAARCmAAAAAi
BMAQAABECYAgAgzuavnq+c+3M0f/
X8RI+CKCBMAQAQZ1QepBfCFAAAcUblQXqhGgEAACAMqhEAAABihDAFAAAQAGEKAAAgAMIUAABRQuVBZiJMA
QAQJVQeZCbCFAAAUULlQWaiGgEAACAMqhEAAABihDAFAAAQAGEKAAAgAMIUAADtmD9fyskJfQVaQ5gCAKAd
xcVSXV3oK9AawhQAAO0oLJSys0NfgdZQjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQAGEKAJCRqDxAtBCmAAA
ZicoDRAthCgCQkag8QLRQjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQAGEKAJBWqDxAvBGmAABphcoDxBthCg
CQVqg8QLxRjQAAABAG1QgAAAAxQpgCAAAIgDAFAAAQQERhysyuMLMPzKzCzO5qZXtPM/
ujmW0zszIzuyH6owIAMhV1B0hmYS9AN7NsSR9KulxSpaQSSbPcfUeTfe6W1NPdF5lZnqQPJPV396NtPS4Xo
AMAIpWTE6o7yM6WamsTPQ0yUdAL0MdKqnD3XQ3h6HlJ01rs45JOMTOT1F3SXyXx1x0AEBXUHSCZRRKm8iXt
bbJc2bCuqcclDZP0qaQ/
S7rV3etbPpCZ3WJmpWZWWl1dfYIjAwAyTVFR6IxUUVGiJwGOF0mYslbWtXxtcJKk9yQNkHSBpMfNrMdxd3J
f6u4F7l6Ql5fXwVEBAACSTyRhqlLSaU2WByp0BqqpGySt8JAKSR9LOjc6IwIAACSvSMJUiaShZjbEzDpLmi
lpZYt99ki6TJLMrJ+kcyTtiuagAAAAyShsmHL3WkkLJK2TVC7pt+5eZmZzzWxuw24PSPonM/
uzpNckLXL3mlgNDQBID1QeIB3w2XwAgISh8gCpgs/
mAwAkJSoPkA44MwUAABAGZ6YAAABihDAFAAAQAGEKAAAgAMIUACDqqDxAJiFMAQCirrg4VHlQXJzoSYDYI0
wBAKKOygNkEqoRAAAAwqAaAQAAIEYIUwAAAAEQpgAAAAIgTAEAAARAmAIARITuKKB1hCkAQETojgJaR5gCA
ESE7iigdfRMAQAAhEHPFAAAQIwQpgAAAAIgTAEAAARAmAKADEflARAMYQoAMhyVB0AwhCkAyHBUHgDBUI0A
AAAQBtUIAAAAMUKYAgAACIAwBQAAEABhCgDSEHUHQPwQpgAgDVF3AMQPYQoA0hB1B0D8UI0AAAAQBtUIAAA
AMUKYAgAACIAwBQAAEABhCgBSCJUHQPIhTAFACqHyAEg+hCkASCFUHgDJh2oEAACAMKhGAAAAiBHCFAAAQA
CEKQAAgAAIUwCQBKg8AFJXRGHKzK4wsw/
MrMLM7mpl+x1m9l7D7X0zqzOz3tEfFwDSE5UHQOoKG6bMLFtSkaQrJQ2XNMvMhjfdx90fcvcL3P0CSd+T9I
a7/zUG8wJAWqLyAEhdkZyZGiupwt13uftRSc9LmtbO/
rMkPReN4QAgUxQVSbW1oa8AUkskYSpf0t4my5UN645jZidLukLSS21sv8XMSs2stLq6uqOzAgAAJJ1IwpS1
sq6tps+pkt5q6yU+d1/
q7gXuXpCXlxfpjAAAAEkrkjBVKem0JssDJX3axr4zxUt8AAAgg0QSpkokDTWzIWbWWaHAtLLlTmbWU9J4SX
+I7ogAkJqoOwAyQ9gw5e61khZIWiepXNJv3b3MzOaa2dwmu06X9Iq7H4nNqACQWqg7ADJDTiQ7ufsaSWtar
HuixfJyScujNRgApLrCwlCQou4ASG/
m3ta15LFVUFDgpaWlCXluAACAjjCzLe5e0No2Pk4GAAAgAMIUAABAAIQpAACAAAhTANBBVB4AaIowBQAdRO
UBgKYIUwDQQYWFUnY2lQcAQqhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUADag8AHAiCFMA0IDKAwAng
jAFAA2oPABwIqhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUgrVF3ACDWCFMA0hp1BwBijTAFIK1RdwAg
1qhGAAAACINqBAAAgBghTAEAAARAmAIAAAiAMAUgJVF5ACBZEKYApCQqDwAkC8IUgJRE5QGAZEE1AgAAQBh
UIwAAAMQIYQoAACAAwhQAAEAAhCkASYXKAwCphjAFIKlQeQAg1RCmACQVKg8ApBqqEQAAAMKgGgEAACBGCF
MAAAABEKYAAAACIEwBiDnqDgCkM8IUgJij7gBAOosoTJnZFWb2gZlVmNldbewzwczeM7MyM3sjumMCSGXUH
QBIZ2GrEcwsW9KHki6XVCmpRNIsd9/RZJ9ekv4k6Qp332Nmfd29qr3HpRoBAACkiqDVCGMlVbj7Lnc/
Kul5SdNa7HONpBXuvkeSwgUpAACAdBFJmMqXtLfJcmXDuqbOlpRrZuvNbIuZXd/
aA5nZLWZWamal1dXVJzYxAABAEokkTFkr61q+NpgjaYykyZImSbrXzM4+7k7uS929wN0L8vLyOjwsAABAso
kkTFVKOq3J8kBJn7ayz1p3P+LuNZI2SBoVnREBJCsqDwAgsjBVImmomQ0xs86SZkpa2WKfP0i62MxyzOxkS
RdJKo/uqACSDZUHABBBmHL3WkkLJK1TKCD91t3LzGyumc1t2Kdc0lpJ2yVtlvRLd38/
dmMDSAZUHgBABNUIsUI1AgAASBVBqxEAAADQBsIUAABAAIQpAACAAAhTAI5D5QEARI4wBeA4VB4AQOQIUwC
OQ+UBAESOagQAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQIag7gAAYoMwBWQI6g4AIDYIU0CGoO4AAGKDag
QAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQIqj8gAAEoswBaQ4Kg8AILEIU0CKo/
IAABKLagQAAIAwqEYAAACIEcIUAABAAIQpAACAAAhTQJKi8gAAUgNhCkhSVB4AQGogTAFJisoDAEgNVCMAA
ACEQTUCAABAjBCmAAAAAiBMAQAABECYAgAACIAwBcQR3VEAkH4IU0Ac0R0FAOmHMAXEEd1RAJB+6JkCAAAI
g54pAACAGCFMAQAABECYAgAACIAwBUQBlQcAkLkIU0AUUHkAAJmLMAVEAZUHAJC5IgpTZnaFmX1gZhVmdlc
r2yeY2SEze6/
htjj6owLJq6hIqq0NfQUAZJaccDuYWbakIkmXS6qUVGJmK919R4tdN7r7lBjMCAAAkLQiOTM1VlKFu+9y96
OSnpc0LbZjAQAApIZIwlS+pL1Nlisb1rU0zsy
2mdl/mtl5rT2Qmd1iZqVmVlpdXX0C4wIAACSXSMKUtbKu5WfQbJU0yN1HSfo/
kn7f2gO5+1J3L3D3gry8vA4NCsQbdQcAgEhEEqYqJZ3WZHmgpE+b7uDun7n74Ybv10jqZGZ9ojYlkADUHQA
AIhFJmCqRNNTMhphZZ0kzJa1suoOZ9Tcza/
h+bMPjHoz2sEA8UXcAAIhE2HfzuXutmS2QtE5StqRl7l5mZnMbtj8h6VuS/
reZ1Ur6m6SZ7t7ypUAgpRQVUXUAAAjPEpV5CgoKvLS0NCHPDQAA0BFmtsXdC1rbRgM6AABAAIQpAACAAAhT
yDhUHgAAookwhYxD5QEAIJoIU8g4VB4AAKKJd/
MBAACEwbv5AAAAYoQwBQAAEABhCgAAIADCFNIGlQcAgEQgTCFtUHkAAEgEwhTSBpUHAIBEoBoBAAAgDKoRA
AAAYoQwBQAAEABhCgAAIADCFJIadQcAgGRHmEJSo+4AAJDsCFNIatQdAACSHdUIAAAAYVCNAAAAECOEKQAA
gAAIUwAAAAEQppAQVB4AANIFYQoJQeUBACBdEKaQEFQeAADSBdUIAAAAYVCNAAAAECOEKQAAgAAIUwAAAAE
QphBVVB4AADINYQpRReUBACDTEKYQVVQeAAAyDdUIAAAAYVCNAAAAECOEKQAAgAAIUwAAAAEQphAWdQcAAL
SNMIWwqDsAAKBthCmERd0BAABtoxoBAAAgjMDVCGZ2hZl9YGYVZnZXO/
tdaGZ1ZvatEx0WAAAglYQNU2aWLalI0pWShkuaZWbD29jvx5LWRXtIAACAZBXJmamxkircfZe7H5X0vKRpr
ez3L5JeklQVxfkAAACSWiRhKl/
S3ibLlQ3rGplZvqTpkp5o74HM7BYzKzWz0urq6o7Oiiij8gAAgOAiCVPWyrqWV63/
XNIid69r74Hcfam7F7h7QV5eXoQjIlaoPAAAILhIwlSlpNOaLA+U9GmLfQokPW9muyV9S9K/
m9k3ozEgYofKAwAAggtbjWBmOZI+lHSZpH2SSiRd4+5lbey/
XNIqd3+xvcelGgEAAKSK9qoRcsLd2d1rzWyBQu/Sy5a0zN3LzGxuw/
Z2r5MCAABIZ2HDlCS5+xpJa1qsazVEufuc4GMBAACkBj5OBgAAIADCVBqi8gAAgPghTKUhKg8AAIgfwlQao
vIAAID4CVuNECtUIwAAgFTRXjUCZ6YAAAACIEwBAAAEQJgCAAAIgDCVIqg7AAAgORGmUgR1BwAAJCfCVIqg
7gAAgORENQIAAEAYVCMAAADECGEKAAAgAMIUAABAAISpBKPyAACA1EaYSjAqDwAASG2EqQSj8gAAgNRGNQI
AAEAYVCMAAADECGEKAAAgAMIUAABAAISpGKHyAACAzECYihEqDwAAyAyEqRih8gAAgMxANQIAAEAYVCMAAA
DECGEKAAAgAMIUAABAAISpDqDuAAAAtESY6gDqDgAAQEuEqQ6g7gAAALRENQIAAEAYVCMAAADECGEKAAAgA
MIUAABAAIQpUXkAAABOHGFKVB4AAIATR5gSlQcAAODEUY0AAAAQBtUIAAAAMRJRmDKzK8zsAzOrMLO7Wtk+
zcy2m9l7ZlZqZl+P/qgAAADJJyfcDmaWLalI0uWSKiWVmNlKd9/
RZLfXJK10dzezkZJ+K+ncWAwMAACQTCI5MzVWUoW773L3o5KelzSt6Q7uftj/
cfFVN0mJuRALAAAgziIJU/
mS9jZZrmxY14yZTTeznZJWS7oxOuMFQ38UAACItUjClLWy7rgzT+7+srufK+mbkh5o9YHMbmm4pqq0urq6Q
4OeCPqjAABArEUSpiolndZkeaCkT9va2d03SDrTzPq0sm2puxe4e0FeXl6Hh+0o+qMAAECsRRKmSiQNNbMh
ZtZZ0kxJK5vuYGZnmZk1fD9aUmdJB6M9bEcVFUm1taGvAAAAsRD23XzuXmtmCyStk5QtaZm7l5nZ3IbtT0i
6WtL1ZnZM0t8kzfBEtYECAADEEQ3oAAAAYdCADgAAECOEKQAAgAAIUwAAAAEQpgAAAAIgTAEAAARAmAIAAA
iAMAUAABAAYQoAACAAwhQAAEAAhCkAAIAACFMAAAABEKYAAAACSNgHHZtZtaRP4vBUfSTVxOF50HEcm+TG8
UleHJvkxvFJXkGOzSB3z2ttQ8LCVLyYWWlbn/KMxOLYJDeOT/Li2CQ3jk/
yitWx4WU+AACAAAhTAAAAAWRCmFqa6AHQJo5NcuP4JC+OTXLj+CSvmBybtL9mCgAAIJYy4cwUAABAzBCmAA
AAAkiLMGVmV5jZB2ZWYWZ3tbLdzOyxhu3bzWx0IubMVBEcn2sbjst2M/
uTmY1KxJyZKNyxabLfhWZWZ2bfiud8mS6S42NmE8zsPTMrM7M34j1jporgv2s9zeyPZrat4djckIg5M5GZL
TOzKjN7v43t0c8E7p7SN0nZkj6SdIakzpK2SRreYp+rJP2nJJP0NUnvJHruTLlFeHz+SVJuw/
dXcnyS59g02e//
Sloj6VuJnjtTbhH+2+klaYek0xuW+yZ67ky4RXhs7pb044bv8yT9VVLnRM+eCTdJl0gaLen9NrZHPROkw5m
psZIq3H2Xux+V9LykaS32mSbpKQ/ZJKmXmZ0a70EzVNjj4+5/cvf/17C4SdLAOM+YqSL5tyNJ/
yLpJUlV8RwOER2fayStcPc9kuTuHKP4iOTYuKRTzMwkdVcoTNXGd8zM5O4bFPp5tyXqmSAdwlS+pL1Nlisb
1nV0H8RGR3/2Nyn0GwNiL+yxMbN8SdMlPRHHuRASyb+dsyXlmtl6M9tiZtfHbbrMFsmxeVzSMEmfSvqzpFv
dvT4+4yGMqGeCnEDjJAdrZV3LvodI9kFsRPyzN7OJCoWpr8d0IvyPSI7NzyUtcve60C/
YiKNIjk+OpDGSLpPUVdLbZrbJ3T+M9XAZLpJjM0nSe5IulXSmpFfNbKO7fxbj2RBe1DNBOoSpSkmnNVkeqN
BvAh3dB7ER0c/
ezEZK+qWkK939YJxmy3SRHJsCSc83BKk+kq4ys1p3/31cJsxskf63rcbdj0g6YmYbJI2SRJiKrUiOzQ2Sfu
Shi3QqzOxjSedK2hyfEdGOqGeCdHiZr0TSUDMbYmadJc2UtLLFPislXd9wBf/
XJB1y9/3xHjRDhT0+Zna6pBWSruM36rgKe2zcfYi7D3b3wZJelDSPIBU3kfy37Q+SLjazHDM7WdJFksrjPG
cmiuTY7FHojKHMrJ+kcyTtiuuUaEvUM0HKn5ly91ozWyBpnULvsFjm7mVmNrdh+xMKvQvpKkkVkr5Q6DcGx
EGEx2expK9I+veGMyC1zieux1yExwYJEsnxcfdyM1srabukekm/
dPdW3w6O6Inw384Dkpab2Z8VellpkbvXJGzoDGJmz0maIKmPmVVKuk9SJyl2mYCPkwEAAAggHV7mAwAASBj
CFAAAQACEKQAAgAAIUwAAAAEQpgAAAAIgTAEAAARAmAIAAAjg/wOpTIj28IK1hAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Note: If you've reset your runtime, this function won't work, \n",
"# you'll have to rerun the cell above where it's instantiated.\n",
"plot_predictions(X_train, y_train, X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X0ycBrxIec07"
},
"source": [
"### 6.2 Building a PyTorch linear model\n",
"\n",
"We've got some data, now it's time to make a model.\n",
"\n",
"We'll create the same style of model as before except this time, instead of
defining the weight and bias parameters of our model manually using
`nn.Parameter()`, we'll use [`nn.Linear(in_features,
out_features)`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) to
do it for us.\n",
"\n",
"Where `in_features` is the number of dimensions your input data has and
`out_features` is the number of dimensions you'd like it to be output to.\n",
"\n",
"In our case, both of these are `1` since our data has `1` input feature (`X`)
per label (`y`).\n",
"\n",
"\n",
"*Creating a linear regression model using `nn.Parameter` versus using
`nn.Linear`. There are plenty more examples of where the `torch.nn` module has pre-
built computations, including many popular and useful neural network layers.*\n"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6iOwqtFqec08",
"outputId": "f7aabd1d-55a7-4f1e-c9b9-9db73d178aef"
},
"outputs": [
{
"data": {
"text/plain": [
"(LinearRegressionModelV2(\n",
" (linear_layer): Linear(in_features=1, out_features=1, bias=True)\n",
" ),\n",
" OrderedDict([('linear_layer.weight', tensor([[0.7645]])),\n",
" ('linear_layer.bias', tensor([0.8300]))]))"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Subclass nn.Module to make our model\n",
"class LinearRegressionModelV2(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # Use nn.Linear() for creating the model parameters\n",
" self.linear_layer = nn.Linear(in_features=1, \n",
" out_features=1)\n",
" \n",
" # Define the forward computation (input data x flows through nn.Linear())\
n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return self.linear_layer(x)\n",
"\n",
"# Set the manual seed when creating the model (this isn't always needed but is
used for demonstrative purposes, try commenting it out and seeing what happens)\n",
"torch.manual_seed(42)\n",
"model_1 = LinearRegressionModelV2()\n",
"model_1, model_1.state_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4vLN2pPXNXUs"
},
"source": [
"Notice the outputs of `model_1.state_dict()`, the `nn.Linear()` layer created
a random `weight` and `bias` parameter for us.\n",
"\n",
"Now let's put our model on the GPU (if it's available).\n",
"\n",
"We can change the device our PyTorch objects are on using `.to(device)`.\n",
"\n",
"First let's check the model's current device."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HhCvYNpAec08",
"outputId": "4d0d2c5f-4a9c-44a0-bda5-fd54d16cfa51"
},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check model device\n",
"next(model_1.parameters()).device"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZqalUGW5N93K"
},
"source": [
"Wonderful, looks like the model's on the CPU by default.\n",
"\n",
"Let's change it to be on the GPU (if it's available)."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JfTYec5Rec08",
"outputId": "b0d331ba-56b9-4f18-f93d-de7965de41dd"
},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Set model to GPU if it's available, otherwise it'll default to CPU\n",
"model_1.to(device) # the device variable was set above to be \"cuda\" if
available or \"cpu\" if not\n",
"next(model_1.parameters()).device"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qHs0bL5_Oc1k"
},
"source": [
"Nice! Because of our device agnostic code, the above cell will work regardless
of whether a GPU is available or not.\n",
"\n",
"If you do have access to a CUDA-enabled GPU, you should see an output of
something like:\n",
"\n",
"```\n",
"device(type='cuda', index=0)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jwTeP_vkec08"
},
"source": [
"### 6.3 Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vPFOV3wUec09"
},
"source": [
"Time to build a training and testing loop.\n",
"\n",
"First we'll need a loss function and an optimizer.\n",
"\n",
"Let's use the same functions we used earlier, `nn.L1Loss()` and
`torch.optim.SGD()`.\n",
"\n",
"We'll have to pass the new model's parameters (`model.parameters()`) to the
optimizer for it to adjust them during training. \n",
"\n",
"The learning rate of `0.01` worked well before too so let's use that again.\
n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "ZRgqFKrNec09"
},
"outputs": [],
"source": [
"# Create loss function\n",
"loss_fn = nn.L1Loss()\n",
"\n",
"# Create optimizer\n",
"optimizer = torch.optim.SGD(params=model_1.parameters(), # optimize newly
created model's parameters\n",
" lr=0.01)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NxuBdoWRP2nU"
},
"source": [
"Beautiful, loss function and optimizer ready, now let's train and evaluate our
model using a training and testing loop.\n",
"\n",
"The only different thing we'll be doing in this step compared to the previous
training loop is putting the data on the target `device`.\n",
"\n",
"We've already put our model on the target `device` using
`model_1.to(device)`.\n",
"\n",
"And we can do the same with the data.\n",
"\n",
"That way if the model is on the GPU, the data is on the GPU (and vice versa).\
n",
"\n",
"Let's step things up a notch this time and set `epochs=1000`.\n",
"\n",
"If you need a reminder of the PyTorch training loop steps, see below.\n",
"\n",
"<details>\n",
" <summary>PyTorch training loop steps</summary>\n",
" <ol>\n",
" <li><b>Forward pass</b> - The model goes through all of the training
data once, performing its\n",
" <code>forward()</code> function\n",
" calculations (<code>model(x_train)</code>).\n",
" </li>\n",
" <li><b>Calculate the loss</b> - The model's outputs (predictions) are
compared to the ground truth and evaluated\n",
" to see how\n",
" wrong they are (<code>loss = loss_fn(y_pred,
y_train</code>).</li>\n",
" <li><b>Zero gradients</b> - The optimizers gradients are set to zero
(they are accumulated by default) so they\n",
" can be\n",
" recalculated for the specific training step
(<code>optimizer.zero_grad()</code>).</li>\n",
" <li><b>Perform backpropagation on the loss</b> - Computes the gradient
of the loss with respect for every model\n",
" parameter to\n",
" be updated (each parameter\n",
" with <code>requires_grad=True</code>). This is known as
<b>backpropagation</b>, hence \"backwards\"\n",
" (<code>loss.backward()</code>).</li>\n",
" <li><b>Step the optimizer (gradient descent)</b> - Update the
parameters with <code>requires_grad=True</code>\n",
" with respect to the loss\n",
" gradients in order to improve them
(<code>optimizer.step()</code>).</li>\n",
" </ol>\n",
"</details>"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JDOHzX8lec09",
"outputId": "23ee6dda-7145-463c-e684-d65ba6874757"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 | Train loss: 0.5551779866218567 | Test loss: 0.5739762187004089\
n",
"Epoch: 100 | Train loss: 0.006215683650225401 | Test loss:
0.014086711220443249\n",
"Epoch: 200 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 300 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 400 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 500 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 600 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 700 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 800 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n",
"Epoch: 900 | Train loss: 0.0012645035749301314 | Test loss:
0.013801801018416882\n"
]
}
],
"source": [
"torch.manual_seed(42)\n",
"\n",
"# Set the number of epochs \n",
"epochs = 1000 \n",
"\n",
"# Put data on the available device\n",
"# Without this, error will happen (not all model/data on device)\n",
"X_train = X_train.to(device)\n",
"X_test = X_test.to(device)\n",
"y_train = y_train.to(device)\n",
"y_test = y_test.to(device)\n",
"\n",
"for epoch in range(epochs):\n",
" ### Training\n",
" model_1.train() # train mode is on by default after construction\n",
"\n",
" # 1. Forward pass\n",
" y_pred = model_1(X_train)\n",
"\n",
" # 2. Calculate loss\n",
" loss = loss_fn(y_pred, y_train)\n",
"\n",
" # 3. Zero grad optimizer\n",
" optimizer.zero_grad()\n",
"\n",
" # 4. Loss backward\n",
" loss.backward()\n",
"\n",
" # 5. Step the optimizer\n",
" optimizer.step()\n",
"\n",
" ### Testing\n",
" model_1.eval() # put the model in evaluation mode for testing (inference)\
n",
" # 1. Forward pass\n",
" with torch.inference_mode():\n",
" test_pred = model_1(X_test)\n",
" \n",
" # 2. Calculate the loss\n",
" test_loss = loss_fn(test_pred, y_test)\n",
"\n",
" if epoch % 100 == 0:\n",
" print(f\"Epoch: {epoch} | Train loss: {loss} | Test loss:
{test_loss}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nt-b2Y131flk"
},
"source": [
"> **Note:** Due to the random nature of machine learning, you will likely get
slightly different results (different loss and prediction values) depending on
whether your model was trained on CPU or GPU. This is true even if you use the same
random seed on either device. If the difference is large, you may want to look for
errors, however, if it is small (ideally it is), you can ignore it.\n",
"\n",
"Nice! That loss looks pretty low.\n",
"\n",
"Let's check the parameters our model has learned and compare them to the
original parameters we hard-coded."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TP_tFn5rec09",
"outputId": "53b6c53a-1bab-4f13-e09a-c9473200af39"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model learned the following values for weights and bias:\n",
"OrderedDict([('linear_layer.weight', tensor([[0.6968]], device='cuda:0')),\
n",
" ('linear_layer.bias', tensor([0.3025], device='cuda:0'))])\n",
"\n",
"And the original values for weights and bias are:\n",
"weights: 0.7, bias: 0.3\n"
]
}
],
"source": [
"# Find our model's learned parameters\n",
"from pprint import pprint # pprint = pretty print, see:
https://docs.python.org/3/library/pprint.html \n",
"print(\"The model learned the following values for weights and bias:\")\n",
"pprint(model_1.state_dict())\n",
"print(\"\\nAnd the original values for weights and bias are:\")\n",
"print(f\"weights: {weight}, bias: {bias}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rDZo0vEU1_-1"
},
"source": [
"Ho ho! Now that's pretty darn close to a perfect model.\n",
"\n",
"Remember though, in practice, it's rare that you'll know the perfect
parameters ahead of time.\n",
"\n",
"And if you knew the parameters your model had to learn ahead of time, what
would be the fun of machine learning?\n",
"\n",
"Plus, in many real-world machine learning problems, the number of parameters
can well exceed tens of millions.\n",
"\n",
"I don't know about you but I'd rather write code for a computer to figure
those out rather than doing it by hand."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mBR1qvqhec09"
},
"source": [
"### 6.4 Making predictions\n",
"\n",
"Now we've got a trained model, let's turn on it's evaluation mode and make
some predictions."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ksqG5N5Iec09",
"outputId": "a0d4a51f-e1d9-4038-fd8a-0bbf4386f36a"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.8600],\n",
" [0.8739],\n",
" [0.8878],\n",
" [0.9018],\n",
" [0.9157],\n",
" [0.9296],\n",
" [0.9436],\n",
" [0.9575],\n",
" [0.9714],\n",
" [0.9854]], device='cuda:0')"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Turn model into evaluation mode\n",
"model_1.eval()\n",
"\n",
"# Make predictions on the test data\n",
"with torch.inference_mode():\n",
" y_preds = model_1(X_test)\n",
"y_preds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NtOoVnbi2ysL"
},
"source": [
"If you're making predictions with data on the GPU, you might notice the output
of the above has `device='cuda:0'` towards the end. That means the data is on CUDA
device 0 (the first GPU your system has access to due to zero-indexing), if you end
up using multiple GPUs in the future, this number may be higher. \n",
"\n",
"Now let's plot our model's predictions.\n",
"\n",
"> **Note:** Many data science libraries such as pandas, matplotlib and NumPy
aren't capable of using data that is stored on GPU. So you might run into some
issues when trying to use a function from one of these libraries with tensor data
not stored on the CPU. To fix this, you can call
[`.cpu()`](https://pytorch.org/docs/stable/generated/torch.Tensor.cpu.html) on your
target tensor to return a copy of your target tensor on the CPU."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 428
},
"id": "Z4dmfr2bec09",
"outputId": "dd68d5a7-1733-4385-c1cb-7d7b44085813"
},
"outputs": [
{
"data": {
"image/png":
"iVBORw0KGgoAAAANSUhEUgAAAlMAAAGbCAYAAADgEhWsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIH
ZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/
YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtPElEQVR4nO3de3RU9bn/
8c+ThEsg3JQAQhAQUMGIChFFy0XFg3JZ1FbLrQpqNfyQc3QdL1A9Be9XrLWVVtQq3qpWBWuRopYfCFqRBBQ
EAh7EC2CEYH8LBauQ5Pn9MTFNQpKZsGcyk8z7tdasyd77u/d+MjuBT/
be84y5uwAAAHB4UuJdAAAAQENGmAIAAAiAMAUAABAAYQoAACAAwhQAAEAAafHacfv27b179+7x2j0AAEDE1
qxZs8fdM6tbFrcw1b17d+Xn58dr9wAAABEzs89qWsZlPgAAgAAIUwAAAAEQpgAAAAIgTAEAAARAmAIAAAgg
7Lv5zOxxSaMl7Xb37GqWm6QHJY2U9K2kKe6+NmhhX3/9tXbv3q2DBw8G3RSSQMuWLZWVlaWUFP4+AADUr0h
aI8yX9JCkp2pYfr6k3mWP0yT9oez5sH399dfatWuXunTpovT0dIXyGlC90tJS7dy5U3v27FGHDh3iXQ4AIM
mE/
TPe3VdI+mctQ8ZKespDVklqa2ZHBSlq9+7d6tKli1q0aEGQQlgpKSnq2LGj9u7dG+9SAABJKBrXRLpI2l5h
ekfZvMN28OBBpaenByoKyaVJkyYqLi6OdxkAgCQUjTBV3akjr3ag2ZVmlm9m+UVFRbVvlDNSqAN+XgAA8RK
NMLVDUtcK01mSvqhuoLs/
4u457p6TmVntx9sAAAA0KNEIU69KusRCTpe0190Lo7BdAACAhBc2TJnZc5LelXScme0ws8vNbKqZTS0bslj
SNklbJT0qaVrMqk1CU6ZM0ejRo+u0zrBhwzR9+vQYVVS76dOna9iwYXHZNwAA8RC2NYK7Twiz3CVdFbWKGq
hw9+xMnjxZ8+fPr/
N2H3zwQYVe4sgtWLBATZo0qfO+4uHTTz9Vjx49lJeXp5ycnHiXAwBAnUXSZwoRKCz895XNRYsW6Yorrqg0r
+q7Ew8ePBhR4GnTpk2dazniiCPqvA4AADg8tIuOkk6dOpU/
2rZtW2ned999p7Zt2+q5557T2WefrfT0dM2bN09fffWVJkyYoKysLKWnp+uEE07QE088UWm7VS/
zDRs2TNOmTdONN96o9u3bq0OHDrruuutUWlpaaUzFy3zdu3fX7bffrtzcXLVu3VpZWVm67777Ku3no48+0t
ChQ9W8eXMdd9xxWrx4sTIyMmo9m1ZSUqLrrrtO7dq1U7t27XTNNdeopKSk0pglS5Zo8ODBateunY444giNG
DFCBQUF5ct79OghSTr11FNlZuWXCPPy8vQf//Efat++vVq3bq0f/
ehHevfdd8MfCABAUnlr9IkqTjG9NfrEuNVAmKpHv/zlLzVt2jRt2rRJP/7xj/Xdd9+pf//
+WrRokTZu3Kirr75aubm5Wrp0aa3befbZZ5WWlqZ//OMfeuihh/Sb3/
xGL7zwQq3rPPDAAzrxxBO1du1azZgxQzfccEN5OCktLdUFF1ygtLQ0rVq1SvPnz9ctt9yi77//
vtZt3n///Xr00Uc1b948vfvuuyopKdGzzz5bacz+/
ft1zTXXaPXq1Vq+fLnatGmjMWPG6MCBA5Kk1atXSwqFrsLCQi1YsECS9M033+jiiy/
WypUrtXr1ap188skaOXKk9uzZU2tNAIDkcubiDUrz0HPcuHtcHgMGDPCabNq0qcZldTVtmntqaui5vrz44o
seemlDPvnkE5fkc+bMCbvuuHHj/
PLLLy+fnjx5so8aNap8eujQoX766adXWmf48OGV1hk6dKhfddVV5dPdunXz8ePHV1qnV69eftttt7m7+5Il
Szw1NdV37NhRvvydd95xSf7EE0/UWOtRRx3lt99+e/l0SUmJ9+7d24cOHVrjOvv27fOUlBRfuXKlu//
7tcnLy6txHXf30tJS79Spkz/99NM1jonmzw0AoGFYPirbD5p8+ajsmO5HUr7XkGka/
ZmpefOkkpLQc7xVvcG6pKREd9xxh/r166cjjzxSGRkZWrBggT7//
PNat9OvX79K0507d9bu3bsPe53Nmzerc+fO6tLl343rTz311Fo/
NHjv3r0qLCzUoEGDyuelpKTotNMqfyzjxx9/
rIkTJ6pnz55q3bq1OnbsqNLS0rDf4+7du5Wbm6tjjz1Wbdq0UatWrbR79+6w6wEAksvQRR8qrdQ1dNGHcau
h0d+AnpsbClK5ufGuRGrZsmWl6Tlz5uj+++/
Xgw8+qBNPPFEZGRm68cYbwwajqjeum1mle6bquo67x6yD+JgxY9SlSxfNmzdPXbp0UVpamvr27Vt+ma8mky
dP1q5du/TAAw+oe/fuatasmc4555yw6wEAUN8afZiaOzf0SERvv/
22xowZo4svvlhSKNR89NFH5Tew15c+ffpo586d+uKLL9S5c2dJUn5+fq0BrU2bNjrqqKO0atUqnX322ZJC9
a9evVpHHRX6nOuvvvpKBQUFmjt3rs466yxJ0tq1ayt9hl7Tpk0l6ZAb199++2399re/
1ahRoyRJu3btqvTuSAAAEkWjv8yXyI499lgtXbpUb7/9tjZv3qzp06frk08+qfc6zj33XB133HGaPHmy1q1
bp1WrVum///u/lZaWVusZq6uvvlr33nuvXnrpJW3ZskXXXHNNpcDTrl07tW/fXo8+
+qi2bt2qt956S1OnTlVa2r8zfIcOHZSenq7XX39du3bt0t69eyWFXptnnnlGmzZtUl5ensaPH18evAAASCS
EqTj6n//5Hw0cOFDnn3++hgwZopYtW2rSpEn1XkdKSooWLlyo77//
XgMHDtTkyZN10003yczUvHnzGte79tprdemll+oXv/
iFTjvtNJWWllaqPyUlRS+88ILWr1+v7OxsXXXVVbrtttvUrFmz8jFpaWn67W9/
q8cee0ydO3fW2LFjJUmPP/649u3bpwEDBmj8+PG67LLL1L1795i9BgCAxJEI7Q7qwryO3bWjJScnx/
Pz86tdVlBQoD59+tRzRaho3bp1Ovnkk5Wfn68BAwbEu5yI8HMDAI1DcYopzaVik9JK45NTqjKzNe5e7Ud1c
GYKkqSFCxfqjTfe0CeffKJly5ZpypQpOumkk9S/f/
94lwYASDLvjMxWsYWeG4JGfwM6IvPNN99oxowZ2r59u9q1a6dhw4bpgQceiNm7/
AAAqMkPbQ6GxrmOSBGmIEm65JJLdMkll8S7DAAAGhwu8wEAAARAmAIAAAiAMAUAAOpFQ2t5ECnCFAAAqBdn
Lt6gNA89NyaEKQAAUC8aWsuDSPFuPgAAUC8aWsuDSHFmqgHr3r275syZE5d9jx49WlOmTInLvgEASCSEqSg
xs1ofQYLHzTffrOzsQ0+J5uXladq0aQGqrj/
Lly+XmWnPnj3xLgUAgKjiMl+UFBYWln+9aNEiXXHFFZXmpaenR32fmZmZUd8mAACoG85MRUmnTp3KH23btj
1k3ooVKzRgwAA1b95cPXr00E033aQDBw6Ur79gwQL169dP6enpOuKIIzR06FDt2rVL8+fP1y233KKNGzeWn
+WaP3++pEMv85mZHnnkEV100UVq2bKljjnmGD3zzDOV6nzvvffUv39/NW/eXKeccooWL14sM9Py5ctr/
N6+/
fZbTZkyRRkZGerYsaPuvPPOQ8Y888wzOvXUU9WqVSt16NBBF110kXbu3ClJ+vTTT3XWWWdJCgXAimfqlixZ
osGDB6tdu3Y64ogjNGLECBUUFNT15QcAxFFjbXkQKcJUPXj99dc1adIkTZ8+XRs3btTjjz+ul156STfeeKM
k6csvv9T48eM1efJkFRQUaMWKFbr44oslSePGjdO1116r4447ToWFhSosLNS4ceNq3Nett96qsWPHat26dR
o3bpwuu+wyffbZZ5Kkffv2afTo0Tr++OO1Zs0a3Xvvvbr++uvD1n/dddfpzTff1Msvv6ylS5fq/
fff14oVKyqNOXDggG655RatW7dOixYt0p49ezRhwgRJUteuXfXyyy9LkjZu3KjCwkI9+OCDkqT9+/
frmmuu0erVq7V8+XK1adNGY8aMqRQ0AQCJrbG2PIiYu8flMWDAAK/
Jpk2balxWV9MWTfPUW1J92qJpUdtmOC++
+KKHXtqQwYMH+6233lppzMKFC71ly5ZeWlrqa9ascUn+6aefVru92bNn+wknnHDI/
G7duvl9991XPi3JZ86cWT598OBBT09P96efftrd3R9++GFv166df/
vtt+Vjnn32WZfky5Ytq3bf33zzjTdt2tSfeeaZSvPatGnjkydPrvE1KCgocEm+fft2d3dftmyZS/
KioqIa13F337dvn6ekpPjKlStrHVedaP7cAAAit3xUth80+fJR2fEuJWYk5XsNmabRn5mat2aeSrxE89bMi
1sNa9as0R133KGMjIzyx8SJE7V//359+eWXOumkkzR8+HBlZ2frpz/9qf7whz+oqKjosPbVr1+/8q/
T0tKUmZmp3bt3S5I2b96s7OzsSvdvnXbaabVu7+OPP9aBAwc0aNCg8nkZGRk68cTKp3LXrl2rsWPHqlu3bm
rVqpVycnIkSZ9//nnY7U+cOFE9e/
ZU69at1bFjR5WWloZdDwCQOIYu+lBppV7e+iDZNPowlTsgV6mWqtwBuXGrobS0VLNnz9YHH3xQ/li/
fr3+93//V5mZmUpNTdUbb7yhN954Q/369dMf//
hH9e7dW+vWravzvpo0aVJp2sxUWloqKXQW0szqtL1QGK/d/v37NWLECLVo0UJPP/
208vLytGTJEkkKe7luzJgxKioq0rx58/Tee+/p/fffV1paGpf5AAANRqN/N9/cUXM1d9TcuNbQv39/
bd68Wb169apxjJlp0KBBGjRokGbNmqUTTjhBL7zwgk466SQ1bdpUJSUlgevo06ePnnrqKf3rX/
8qPzu1evXqWtfp1auXmjRpolWrVumYY46RFApPGzZsUM+ePSWFznjt2bNHd955p3r06CEpdEN9RU2bNpWkS
t/HV199pYKCAs2dO7f8BvW1a9equLg48PcKAEB9afRnphLBrFmz9Kc//
UmzZs3Shg0btHnzZr300ku64YYbJEmrVq3S7bffrry8PH3+
+ed69dVXtX37dvXt21dS6F17n332mdauXas9e/bo+++/P6w6Jk2apNTUVF1xxRXatGmT/v73v5e/
M6+mM1YZGRm6/
PLLNWPGDL355pvauHGjLrvsskqh6Oijj1azZs300EMPadu2bXrttdf0q1/9qtJ2unXrJjPTa6+9pqKiIu3b
t0/t2rVT+/
bt9eijj2rr1q166623NHXqVKWlNfqMDwBoRAhT9WDEiBF67bXXtGzZMg0cOFADBw7U3XffraOPPlqS1KZNG
73zzjsaPXq0evfurWuvvVa/+tWv9POf/1yS9NOf/lQjR47UOeeco8zMTD333HOHVUdGRob++te/
auPGjTrllFN0/fXX6+abb5YkNW/evMb15syZo7POOksXXHCBzjrrLGVnZ2vIkCHlyzMzM/Xkk0/qlVdeUd+
+fXXLLbfo17/+daVtdOnSRbfccotuuukmdezYUdOnT1dKSopeeOEFrV+/XtnZ2brqqqt02223qVmzZof1/
QEAoifZ2x3UhUVyT0ws5OTkeH5+frXLCgoK1KdPn3quKDn95S9/0QUXXKDdu3erffv28S4nEH5uACB6ilNM
aS4Vm5RWGp+skEjMbI2751S3jDNTSebJJ5/
UypUr9emnn2rRokW65pprNGbMmAYfpAAA0fXOyGwVW+gZtePmlCSza9cuzZ49W4WFherUqZNGjRqle+65J9
5lAQASzA9tDobGuY6GgDCVZG644YbyG98BAEBwXOYDAAAIgDAFAAAQAGEKAIAkQsuD6CNMAQCQRM5cvEFpH
npGdBCmAABIIrQ8iD7ezQcAQBKh5UH0cWaqAXrppZcqfZbe/
PnzlZGREWiby5cvl5lpz549QcsDACCpEKaiaMqUKTIzmZmaNGmiY445Rtddd532798f0/2OGzdO27Zti3h8
9+7dNWfOnErzzjjjDBUWFurII4+MdnkAADRqEYUpMzvPzLaY2VYzm1nN8nZmttDM1pvZajNL2guxw4cPV2F
hobZt26bbb79dv//
973XdddcdMq64uFjR+lzE9PR0dejQIdA2mjZtqk6dOlU64wUAAMILG6bMLFXSXEnnS+oraYKZ9a0y7EZJH7
h7P0mXSHow2oU2FM2aNVOnTp3UtWtXTZw4UZMmTdIrr7yim2++WdnZ2Zo/f7569uypZs2aaf/+/
dq7d6+uvPJKdejQQa1atdLQoUNV9QOgn3rqKXXr1k0tWrTQ6NGjtWvXrkrLq7vM99prr+m0005Tenq6jjzy
SI0ZM0bfffedhg0bps8+
+0zXX399+Vk0qfrLfAsWLNCJJ56oZs2aqWvXrrrjjjsqBcDu3bvr9ttvV25urlq3bq2srCzdd999leqYN2+
ejj32WDVv3lyZmZkaMWKEiouLo/
JaAwD+jZYH8RPJmamBkra6+zZ3PyDpeUljq4zpK2mpJLn7ZkndzaxjVCttoNLT03Xw4EFJ0ieffKI//
elPevHFF7Vu3To1a9ZMo0aN0s6dO7Vo0SK9//77GjJkiM4+
+2wVFhZKkt577z1NmTJFV155pT744AONGTNGs2bNqnWfS5Ys0dixY3XuuedqzZo1WrZsmYYOHarS0lItWLB
AWVlZmjVrlgoLC8v3U9WaNWt00UUX6Sc/+Yk+/
PBD3X333brrrrv00EMPVRr3wAMP6MQTT9TatWs1Y8YM3XDDDXr33XclSfn5+brqqqs0e/ZsbdmyRX//
+9913nnnBX1JAQDVoOVBHLl7rQ9JF0p6rML0xZIeqjLmTkm/
Lvt6oKRiSQOq2daVkvIl5R999NFek02bNtW4rM6mTXNPTQ09x9jkyZN91KhR5dPvvfeeH3nkkf6zn/
3MZ8+e7Wlpaf7ll1+WL1+6dKm3bNnSv/3220rbOemkk/yee+5xd/cJEyb48OHDKy2//
PLLPXToQp544glv2bJl+fQZZ5zh48aNq7HObt26+X333Vdp3rJly1ySFxUVubv7xIkT/
ayzzqo0Zvbs2d6lS5dK2xk/fnylMb169fLbbrvN3d1ffvllb926tX/
99dc11hJNUf25AYAGZvmobD9o8uWjsuNdSqMkKd9ryEqRnJmq7iaaqjf73C2pnZl9IOk/
Jb1fFqiqBrdH3D3H3XMyMzMj2HUUzJsnlZSEnuvBkiVLlJGRoebNm2vQoEEaMmSIfve730mSsrKy1LHjv0/
YrVmzRt9++60yMzOVkZFR/tiwYYM+/
vhjSVJBQYEGDRpUaR9Vp6t6//33dc455wT6PgoKCnTmmWdWmvejH/1IO3fu1Ndff10+r1+/
fpXGdO7cWbt375YknXvuuerWrZt69OihSZMm6cknn9Q333wTqC4AQPWGLvpQaaVe3voA9SeSPlM7JHWtMJ0
l6YuKA9z9a0mXSpKFbsL5pOwRf7m5oSCVm1svuxsyZIgeeeQRNWnSRJ07d1aTJk3Kl7Vs2bLS2NLSUnXs2F
ErV648ZDutW7eWpKjdpF5X7l7jzegV51f8/n5YVlpaKklq1aqV1q5dqxUrVujNN9/
UXXfdpRtvvFF5eXnq3Llz7IoHAKAeRXJmKk9SbzPrYWZNJY2X9GrFAWbWtmyZJP1C0oqygBV/
c+dKxcWh53rQokUL9erVS926dTskaFTVv39/7dq1SykpKerVq1elxw/
vzuvbt69WrVpVab2q01WdcsopWrp0aY3LmzZtqpKSklq30bdvX7399tuV5r399tvKyspSq1atal23orS0NJ
199tm66667tH79eu3fv1+LFi2KeH0AABJd2DNT7l5sZtMlvS4pVdLj7r7RzKaWLX9YUh9JT5lZiaRNki6PY
c2NxvDhw3XmmWdq7Nixuvfee3X88cfryy+/1JIlSzR8+HANHjxY//Vf/
6UzzjhDd911ly688EItX75cCxcurHW7N910k8aMGaNevXpp4sSJcne98cYbys3NVYsWLdS9e3etXLlSP//
5z9WsWTO1b9/+kG1ce+21OvXUU3XzzTdr4sSJysvL0/33368777wz4u9v0aJF+vjjjzVkyBAdccQRWrZsmb
755hv16dOnzq8VAACJKqI+U+6+2N2Pdfee7n5H2byHy4KU3P1dd+/t7se7+0/c/f/
FsujGwsy0ePFinX322briiit03HHH6Wc/+5m2bNlSfhns9NNP1x//+Ef94Q9/UL9+/
bRgwQLdfPPNtW535MiRWrhwof72t7/plFNO0dChQ7Vs2TKlpIQO96233qrt27erZ8+equnetf79++vFF1/
Uyy+/rOzsbM2cOVMzZ87U9OnTI/7+2rZtq1deeUXDhw/
X8ccfrzlz5uixxx7T4MGDI94GACQz2h00DBave3JycnK8aj+lHxQUFHD2AnXGzw2AxqY4xZTmUrFJaaXx+f
8aIWa2xt1zqlvGx8kAAJCg3hmZrWILPSNxRfJuPgAAEAc/
tDkYGuc6UDvOTAEAAARAmAIAAAggYcPUD40fgUjE640UAAAkZJhq2bKldu7cqQMHDvCfJMJyd3311Vdq3rx
5vEsBgIjQ8qBxScjWCKWlpdqzZ4/27t2r4uJDPuIPOETz5s2VlZUVtus8ACQCWh40PLW1RkjId/
OlpKSoQ4cO5R+pAgBAY/
LOyGyduXiD3hmZzTv1GoGEDFMAADRmtDxoXBLynikAAICGgjAFAAAQAGEKAAAgAMIUAABRQsuD5ESYAgAgS
s5cvEFpHnpG8iBMAQAQJe+MzFaxhZ6RPGiNAABAlNDyIDlxZgoAACAAwhQAAEAAhCkAAIAACFMAANTiqquk
tLTQM1AdwhQAALWYN08qKQk9A9UhTAEAUIvcXCk1NfQMVMfcPS47zsnJ8fz8/
LjsGwAAoC7MbI2751S3jDNTAAAAARCmAAAAAiBMAQAABECYAgAkJVoeIFoIUwCApETLA0QLYQoAkJRoeYBo
oTUCAABAGLRGAAAAiBHCFAAAQACEKQAAgAAIUwCARoWWB6hvhCkAQKNCywPUN8IUAKBRoeUB6hutEQAAAMK
gNQIAAECMEKYAAAACIEwBAAAEEFGYMrPzzGyLmW01s5nVLG9jZn81s3VmttHMLo1+qQCAZEW7AySysDegm1
mqpI8knStph6Q8SRPcfVOFMTdKauPuM8wsU9IWSZ3c/
UBN2+UGdABApNLSQu0OUlOl4uJ4V4NkFPQG9IGStrr7trJw9LyksVXGuKRWZmaSMiT9UxI/
7gCAqKDdARJZJGGqi6TtFaZ3lM2r6CFJfSR9IelDSVe7e2nVDZnZlWaWb2b5RUVFh1kyACDZzJ0bOiM1d26
8KwEOFUmYsmrmVb02OELSB5I6SzpZ0kNm1vqQldwfcfccd8/
JzMysY6kAAACJJ5IwtUNS1wrTWQqdgaroUkkLPGSrpE8kHR+dEgEAABJXJGEqT1JvM+thZk0ljZf0apUxn0
s6R5LMrKOk4yRti2ahAAAAiShsmHL3YknTJb0uqUDSn919o5lNNbOpZcNuk3SGmX0oaamkGe6+J1ZFAwAaB
1oeoDHgs/
kAAHFDywM0FHw2HwAgIdHyAI0BZ6YAAADC4MwUAABAjBCmAAAAAiBMAQAABECYAgBEHS0PkEwIUwCAqJs3L
9TyYN68eFcCxB5hCgAQdbQ8QDKhNQIAAEAYtEYAAACIEcIUAABAAIQpAACAAAhTAAAAARCmAAARoXcUUD3C
FAAgIvSOAqpHmAIARITeUUD16DMFAAAQBn2mAAAAYoQwBQAAEABhCgAAIADCFAAkOVoeAMEQpgAgydHyAAi
GMAUASY6WB0AwtEYAAAAIg9Y
IAAAAMUKYAgAACIAwBQAAEABhCgAaIdodAPWHMAUAjRDtDoD6Q5gCgEaIdgdA/
aE1AgAAQBi0RgAAAIgRwhQAAEAAhCkAAIAACFMA0IDQ8gBIPIQpAGhAaHkAJB7CFAA0ILQ8ABIPrREAAADC
oDUCAABAjBCmAAAAAiBMAQAABECYAoAEQMsDoOGKKEyZ2XlmtsXMtprZzGqWX29mH5Q9NphZiZkdEf1yAaB
xouUB0HCFDVNmlipprqTzJfWVNMHM+lYc4+73ufvJ7n6ypF9Kesvd/
xmDegGgUaLlAdBwRXJmaqCkre6+zd0PSHpe0thaxk+Q9Fw0igOAZDF3rlRcHHoG0LBEEqa6SNpeYXpH2bxD
mFkLSedJermG5VeaWb6Z5RcVFdW1VgAAgIQTSZiyaubV1OlzjKR3arrE5+6PuHuOu+dkZmZGWiMAAEDCiiR
M7ZDUtcJ0lqQvahg7XlziAwAASSSSMJUnqbeZ9TCzpgoFplerDjKzNpKGSvpLdEsEgIaJdgdAcggbpty9WN
J0Sa9LKpD0Z3ffaGZTzWxqhaEXSHrD3ffHplQAaFhodwAkh7RIBrn7YkmLq8x7uMr0fEnzo1UYADR0ubmhI
EW7A6BxM/ea7iWPrZycHM/Pz4/LvgEAAOrCzNa4e051y/
g4GQAAgAAIUwAAAAEQpgAAAAIgTAFAHdHyAEBFhCkAqCNaHgCoiDAFAHWUmyulptLyAEAIrREAAADCoDUCA
ABAjBCmAAAAAiBMAQAABECYAoAytDwAcDgIUwBQhpYHAA4HYQoAytDyAMDhoDUCAABAGLRGAAAAiBHCFAAA
QACEKQAAgAAIUwAaNdodAIg1whSARo12BwBijTAFoFGj3QGAWKM1AgAAQBi0RgAAAIgRwhQAAEAAhCkAAIA
ACFMAGiRaHgBIFIQpAA0SLQ8AJArCFIAGiZYHABIFrREAAADCoDUCAABAjBCmAAAAAiBMAQAABECYApBQaH
kAoKEhTAFIKLQ8ANDQEKYAJBRaHgBoaGiNAAAAEAatEQAAAGKEMAUAABAAYQoAACAAwhSAmKPdAYDGjDAFI
OZodwCgMYsoTJnZeWa2xcy2mtnMGsYMM7MPzGyjmb0V3TIBNGS0OwDQmIVtjWBmqZI+knSupB2S8iRNcPdN
Fca0lfQPSee5+
+dm1sHdd9e2XVojAACAhiJoa4SBkra6+zZ3PyDpeUljq4yZKGmBu38uSeGCFAAAQGMRSZjqIml7hekdZfMq
OlZSOzNbbmZrzOyS6jZkZleaWb6Z5RcVFR1exQAAAAkkkjBl1cyrem0wTdIASaMkjZD0KzM79pCV3B9x9xx
3z8nMzKxzsQAAAIkmkjC1Q1LXCtNZkr6oZswSd9/
v7nskrZB0UnRKBJCoaHkAAJGFqTxJvc2sh5k1lTRe0qtVxvxF0mAzSzOzFpJOk1QQ3VIBJBpaHgBABGHK3Y
slTZf0ukIB6c/
uvtHMpprZ1LIxBZKWSFovabWkx9x9Q+zKBpAIaHkAABG0RogVWiMAAICGImhrBAAAANSAMAUAABAAYQoAAC
AAwhSAQ9DyAAAiR5gCcAhaHgBA5AhTAA5BywMAiBytEQAAAMKgNQIAAECMEKYAAAACIEwBAAAEQJgCkgTtD
gAgNghTQJKg3QEAxAZhCkgStDsAgNigNQIAAEAYtEYAAACIEcIUAABAAIQpAACAAAhTQANHywMAiC/
CFNDA0fIAAOKLMAU0cLQ8AID4ojUCAABAGLRGAAAAiBHCFAAAQACEKQAAgAAIU0CCouUBADQMhCkgQdHyAA
AaBsIUkKBoeQAADQOtEQAAAMKgNQIAAECMEKYAAAACIEwBAAAEQJgCAAAIgDAF1CN6RwFA40OYAuoRvaMAo
PEhTAH1iN5RAND40GcKAAAgDPpMAQAAxAhhCgAAIADCFAAAQACEKSAKaHkAAMmLMAVEAS0PACB5EaaAKKDl
AQAkr4jClJmdZ2ZbzGyrmc2sZvkwM9trZh+UPWZFv1Qgcc2dKxUXh54BAMklLdwAM0uVNFfSuZJ2SMozs1f
dfVOVoSvdfXQMagQAAEhYkZyZGihpq7tvc/cDkp6XNDa2ZQEAADQMkYSpLpK2V5jeUTavqkFmts7M/
mZmJ1S3ITO70szyzSy/
qKjoMMoFAABILJGEKatmXtXPoFkrqZu7nyTpd5JeqW5D7v6Iu+e4e05mZmadCgXqG+0OAACRiCRM7ZDUtcJ
0lqQvKg5w96/dfV/
Z14slNTGz9lGrEogD2h0AACIRSZjKk9TbzHqYWVNJ4yW9WnGAmXUyMyv7emDZdr+KdrFAfaLdAQAgEmHfze
fuxWY2XdLrklIlPe7uG81satnyhyVdKOn/mFmxpH9JGu/
uVS8FAg3K3Lm0OgAAhGfxyjw5OTmen58fl30DAADUhZmtcfec6pbRAR0AACAAwhQAAEAAhCkkHVoeAACiiT
CFpEPLAwBANBGmkHRoeQAAiCbezQcAABAG7+YDAACIEcIUAABAAIQpAACAAAhTaDRoeQAAiAfCFBoNWh4AA
OKBMIVGg5YHAIB4oDUCAABAGLRGAAAAiBHCFAAAQACEKQAAgAAIU0hotDsAACQ6whQSGu0OAACJjjCFhEa7
AwBAoqM1AgAAQBi0RgAAAIgRwhQAAEAAhCkAAIAACFOIC1oeAAAaC8IU4oKWBwCAxoIwhbig5QEAoLGgNQI
AAEAYtEYAAACIEcIUAABAAIQpAACAAAhTiCpaHgAAkg1hClFFywMAQLIhTCGqaHkAAEg2tEYAAAAIg9YIAA
AAMUKYAgAACIAwBQAAEABhCmHR7gAAgJoRphAW7Q4AAKgZYQph0e4AAICa0RoBAAAgjMCtEczsPDPbYmZbz
WxmLeNONbMSM7vwcIsFAABoSMKGKTNLlTRX0vmS+kqaYGZ9axh3j6TXo10kAABAoorkzNRASVvdfZu7H5D0
vKSx1Yz7T0kvS9odxfoAAAASWiRhqouk7RWmd5TNK2dmXSRdIOnh2jZkZleaWb6Z5RcVFdW1VkQZLQ8AAAg
ukjBl1cyretf6byTNcPeS2jbk7o+4e46752RmZkZYImKFlgcAAAQXSZjaIalrheksSV9UGZMj6Xkz+1TShZ
J+b2Y/jkaBiB1aHgAAEFzY1ghmlibpI0nnSNopKU/
SRHffWMP4+ZIWuftLtW2X1ggAAKChqK01Qlq4ld292MymK/
QuvVRJj7v7RjObWra81vukAAAAGrOwYUqS3H2xpMVV5lUbotx9SvCyAAAAGgY+TgYAACAAwlQjRMsDAADqD
2GqEaLlAQAA9Ycw1QjR8gAAgPoTtjVCrNAaAQAANBS1tUbgzBQAAEAAhCkAAIAACFMAAAABEKYaCNodAACQ
mAhTDQTtDgAASEyEqQaCdgcAACQmWiMAAACEQWsEAACAGCFMAQAABECYAgAACIAwFWe0PAAAoGEjTMUZLQ8
AAGjYCFNxRssDAAAaNlojAAAAhEFrBAAAgBghTAEAAARAmAIAAAiAMBUjtDwAACA5EKZihJYHAAAkB8JUjN
DyAACA5EBrBAAAgDBojQAAABAjhCkAAIAACFMAAAABEKbqgHYHAACgKsJUHdDuAAAAVEWYqgPaHQAAgKpoj
QAAABAGrREAAABihDAFAAAQAGEKAAAgAMKUaHkAAAAOH2FKtDwAAACHjzAlWh4AAIDDR2sEAACAMGiNAAAA
ECMRhSkzO8/MtpjZVjObWc3ysWa23sw+MLN8M/
tR9EsFAABIPGnhBphZqqS5ks6VtENSnpm96u6bKgxbKulVd3cz6yfpz5KOj0XBAAAAiSSSM1MDJW11923uf
kDS85LGVhzg7vv83zdftZQUnxuxAAAA6lkkYaqLpO0VpneUzavEzC4ws82SXpN0WXTKC4b+UQAAINYiCVNW
zbxDzjy5+0J3P17SjyXdVu2GzK4su6cqv6ioqE6FHg76RwEAgFiLJEztkNS1wnSWpC9qGuzuKyT1NLP21Sx
7xN1z3D0nMzOzzsXWFf2jAABArEUSpvIk9TazHmbWVNJ4Sa9WHGBmvczMyr7uL6mppK+iXWxdzZ0rFReHng
EAAGIh7Lv53L3YzKZLel1SqqTH3X2jmU0tW/
6wpJ9KusTMDkr6l6RxHq9uoAAAAPWIDugAAABh0AEdAAAgRghTAAAAARCmAAAAAiBMAQAABECYAgAACIAwB
QAAEABhCgAAIADCFAAAQACEKQAAgAAIUwAAAAEQpgAAAAIgTAEAAAQQtw86NrMiSZ/Vw67aS9pTD/
tB3XFsEhvHJ3FxbBIbxydxBTk23dw9s7oFcQtT9cXM8mv6lGfEF8cmsXF8EhfHJrFxfBJXrI4Nl/
kAAAACIEwBAAAEkAxh6pF4F4AacWwSG8cncXFsEhvHJ3HF5Ng0+numAAAAYikZzkwBAADEDGEKAAAggEYRp
szsPDPbYmZbzWxmNcvNzH5btny9mfWPR53JKoLjM6nsuKw3s3+Y2UnxqDMZhTs2FcadamYlZnZhfdaX7CI5
PmY2zMw+MLONZvZWfdeYrCL4d62Nmf3VzNaVHZtL41FnMjKzx81st5ltqGF59DOBuzfoh6RUSR9LOkZSU0n
rJPWtMmakpL9JMkmnS3ov3nUnyyPC43OGpHZlX5/
P8UmcY1Nh3P+VtFjShfGuO1keEf7utJW0SdLRZdMd4l13MjwiPDY3Srqn7OtMSf+U1DTetSfDQ9IQSf0lba
hhedQzQWM4MzVQ0lZ33+buByQ9L2lslTFjJT3lIasktTWzo+q70CQV9vi4+z/c/
f+VTa6SlFXPNSarSH53JOk/Jb0saXd9FoeIjs9ESQvc/
XNJcneOUf2I5Ni4pFZmZpIyFApTxfVbZnJy9xUKvd41iXomaAxhqouk7RWmd5TNq+sYxEZdX/
vLFfqLAbEX9tiYWRdJF0h6uB7rQkgkvzvHSmpnZsvNbI2ZXVJv1SW3SI7NQ5L6SPpC0oeSrnb30vopD2FEP
ROkBSonMVg186r2e4hkDGIj4tfezM5SKEz9KKYV4QeRHJvfSJrh7iWhP7BRjyI5PmmSBkg6R1K6pHfNbJW7
fxTr4pJcJMdmhKQPJJ0tqaekN81spbt/
HePaEF7UM0FjCFM7JHWtMJ2l0F8CdR2D2IjotTezfpIek3S+u39VT7Ulu0iOTY6k58uCVHtJI82s2N1fqZc
Kk1uk/
7btcff9kvab2QpJJ0kiTMVWJMfmUkl3e+gmna1m9omk4yWtrp8SUYuoZ4LGcJkvT1JvM+thZk0ljZf0apUx
r0q6pOwO/tMl7XX3wvouNEmFPT5mdrSkBZIu5i/qehX22Lh7D3fv7u7dJb0kaRpBqt5E8m/
bXyQNNrM0M2sh6TRJBfVcZzKK5Nh8rtAZQ5lZR0nHSdpWr1WiJlHPBA3+zJS7F5vZdEmvK/
QOi8fdfaOZTS1b/
rBC70IaKWmrpG8V+osB9SDC4zNL0pGSfl92BqTY+cT1mIvw2CBOIjk+7l5gZkskrZdUKukxd6/27eCIngh/
d26TNN/
MPlTostIMd98Tt6KTiJk9J2mYpPZmtkPSbElNpNhlAj5OBgAAIIDGcJkPAAAgbghTAAAAARCmAAAAAiBMAQ
AABECYAgAACIAwBQAAEABhCgAAIID/D3PFxscZJbEwAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# plot_predictions(predictions=y_preds) # -> won't work... data not on CPU\n",
"\n",
"# Put data on the CPU and plot it\n",
"plot_predictions(predictions=y_preds.cpu())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DxZa-5-Tec0-"
},
"source": [
"Woah! Look at those red dots, they line up almost perfectly with the green
dots. I guess the extra epochs helped.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K8jCHl1gec0-"
},
"source": [
"### 6.5 Saving and loading a model\n",
"\n",
"We're happy with our models predictions, so let's save it to file so it can be
used later.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DcQo4JqL7eSU",
"outputId": "e43ada0c-c074-4b50-9207-fa01581b1d5f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving model to: models/01_pytorch_workflow_model_1.pth\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"# 1. Create models directory \n",
"MODEL_PATH = Path(\"models\")\n",
"MODEL_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# 2. Create model save path \n",
"MODEL_NAME = \"01_pytorch_workflow_model_1.pth\"\n",
"MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME\n",
"\n",
"# 3. Save the model state dict \n",
"print(f\"Saving model to: {MODEL_SAVE_PATH}\")\n",
"torch.save(obj=model_1.state_dict(), # only saving the state_dict() only saves
the models learned parameters\n",
" f=MODEL_SAVE_PATH) "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lk0rvpwV7slc"
},
"source": [
"And just to make sure everything worked well, let's load it back in.\n",
"\n",
"We'll:\n",
"* Create a new instance of the `LinearRegressionModelV2()` class\n",
"* Load in the model state dict using `torch.nn.Module.load_state_dict()`\n",
"* Send the new instance of the model to the target device (to ensure our code
is device-agnostic)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jMnVHzf1ec0-",
"outputId": "76f10046-cd42-4b39-a372-aa95227828e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded model:\n",
"LinearRegressionModelV2(\n",
" (linear_layer): Linear(in_features=1, out_features=1, bias=True)\n",
")\n",
"Model on device:\n",
"cuda:0\n"
]
}
],
"source": [
"# Instantiate a fresh instance of LinearRegressionModelV2\n",
"loaded_model_1 = LinearRegressionModelV2()\n",
"\n",
"# Load model state dict \n",
"loaded_model_1.load_state_dict(torch.load(MODEL_SAVE_PATH))\n",
"\n",
"# Put model to target device (if your data is on GPU, model will have to be on
GPU to make predictions)\n",
"loaded_model_1.to(device)\n",
"\n",
"print(f\"Loaded model:\\n{loaded_model_1}\")\n",
"print(f\"Model on device:\\n{next(loaded_model_1.parameters()).device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hv6EMEx99LV2"
},
"source": [
"Now we can evaluate the loaded model to see if its predictions line up with
the predictions made prior to saving."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fYODT7ONec0_",
"outputId": "c8184cd1-595a-43e4-8155-89dcecc4d0b0"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True],\n",
" [True]], device='cuda:0')"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Evaluate loaded model\n",
"loaded_model_1.eval()\n",
"with torch.inference_mode():\n",
" loaded_model_1_preds = loaded_model_1(X_test)\n",
"y_preds == loaded_model_1_preds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7M_kcRC89YrZ"
},
"source": [
"Everything adds up! Nice!\n",
"\n",
"Well, we've come a long way. You've now built and trained your first two
neural network models in PyTorch!\n",
"\n",
"Time to practice your skills."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o6rf3hTWec0_"
},
"source": [
"## Exercises\n",
"\n",
"All exercises have been inspired from code throughout the notebook.\n",
"\n",
"There is one exercise per major section.\n",
"\n",
"You should be able to complete them by referencing their specific section.\n",
"\n",
"> **Note:** For all exercises, your code should be device agnostic (meaning it
could run on CPU or GPU if it's available).\n",
"\n",
"1. Create a straight line dataset using the linear regression formula (`weight
* X + bias`).\n",
" * Set `weight=0.3` and `bias=0.9` there should be at least 100 datapoints
total. \n",
" * Split the data into 80% training, 20% testing.\n",
" * Plot the training and testing data so it becomes visual.\n",
"2. Build a PyTorch model by subclassing `nn.Module`. \n",
" * Inside should be a randomly initialized `nn.Parameter()` with
`requires_grad=True`, one for `weights` and one for `bias`. \n",
" * Implement the `forward()` method to compute the linear regression function
you used to create the dataset in 1. \n",
" * Once you've constructed the model, make an instance of it and check its
`state_dict()`.\n",
" * **Note:** If you'd like to use `nn.Linear()` instead of `nn.Parameter()`
you can.\n",
"3. Create a loss function and optimizer using `nn.L1Loss()` and
`torch.optim.SGD(params, lr)` respectively. \n",
" * Set the learning rate of the optimizer to be 0.01 and the parameters to
optimize should be the model parameters from the model you created in 2.\n",
" * Write a training loop to perform the appropriate training steps for 300
epochs.\n",
" * The training loop should test the model on the test dataset every 20
epochs.\n",
"4. Make predictions with the trained model on the test data.\n",
" * Visualize these predictions against the original training and testing data
(**note:** you may need to make sure the predictions are *not* on the GPU if you
want to use non-CUDA-enabled libraries such as matplotlib to plot).\n",
"5. Save your trained model's `state_dict()` to file.\n",
" * Create a new instance of your model class you made in 2. and load in the
`state_dict()` you just saved to it.\n",
" * Perform predictions on your test data with the loaded model and confirm
they match the original model predictions from 4.\n",
"\n",
"> **Resource:** See the [exercises notebooks
templates](https://github.com/mrdbourke/pytorch-deep-learning/tree/main/extras/
exercises) and
[solutions](https://github.com/mrdbourke/pytorch-deep-learning/tree/main/extras/
solutions) on the course GitHub.\n",
"\n",
"## Extra-curriculum\n",
"* Listen to [The Unofficial PyTorch Optimization Loop
Song](https://youtu.be/Nutpusq_AFw) (to help remember the steps in a PyTorch
training/testing loop).\n",
"* Read [What is `torch.nn`,
really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html) by Jeremy Howard
for a deeper understanding of how one of the most important modules in PyTorch
works. \n",
"* Spend 10-minutes scrolling through and checking out the [PyTorch
documentation cheatsheet](https://pytorch.org/tutorials/beginner/ptcheat.html) for
all of the different PyTorch modules you might come across.\n",
"* Spend 10-minutes reading the [loading and saving documentation on the
PyTorch website](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
to become more familiar with the different saving and loading options in PyTorch. \
n",
"* Spend 1-2 hours reading/watching the following for an overview of the
internals of gradient descent and backpropagation, the two main algorithms that
have been working in the background to help our model learn. \n",
" * [Wikipedia page for gradient
descent](https://en.wikipedia.org/wiki/Gradient_descent)\n",
" * [Gradient Descent Algorithm — a deep
dive](https://towardsdatascience.com/gradient-descent-algorithm-a-deep-dive-
cf04e8115f21) by Robert Kwiatkowski\n",
" * [Gradient descent, how neural networks learn
video](https://youtu.be/IHZwWFHWa-w) by 3Blue1Brown\n",
" * [What is backpropagation really doing?](https://youtu.be/Ilg3gGewQ5U) video
by 3Blue1Brown\n",
" * [Backpropagation Wikipedia
Page](https://en.wikipedia.org/wiki/Backpropagation)\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "01_pytorch_workflow.ipynb",
"provenance": [],
"toc_visible": true
},
"interpreter": {
"hash": "3fbe1355223f7b2ffc113ba3ade6a2b520cadace5d5ec3e828c83ce02eb221bf"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}