1、问题
使用Pytorch及CNN框架实现手写数字识别问题
2、代码
1、导包
导入如下一些包,后面会用到
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import time
from torchvision import transforms
import matplotlib.pyplot as plt
2、设置如下超参数
#超参数
EPOCH = 1
BATCH_SIZE = 1
LR = 0.001
# DOWNLOAD_MNIST = False
if_use_gpu = 0
分别是循环次数、每一批次训练的样本个数(因为要输出,所以设置为1)、学习率、是否使用cuda加速
3、导入训练集并传入到train_loader中
# 获取训练集dataset
training_data = torchvision.datasets.MNIST(
root=path, # dataset存储路径
train=True, # True表示是train训练集,False表示test测试集
transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间
download=False,
)
# 获取测试集dataset
test_data = torchvision.datasets.MNIST(
root=path, # dataset存储路径
train=False, # True表示是train训练集,False表示test测试集
transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间
download=False,
)
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE,shuffle=True)
使用pytorch提供的MNIST库,当d