手写数字识别神经网络实战
时间: 2025-02-02 19:51:42 浏览: 34
### 手写数字识别神经网络实战教程
#### 使用TensorFlow/Keras进行手写数字识别
对于基于TensorFlow和Keras的手写数字识别,流程主要分为几个部分:
加载所需的功能包之后,通过`tensorflow.keras.datasets.mnist.load_data()`函数来获取MNIST数据集[^1]。
```python
import tensorflow as tf
from tensorflow import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
```
接着是对这些图像数据做标准化处理以及标签的one-hot编码转换。这一步骤是为了让输入到模型中的数值处于较小范围之内,从而加速训练过程并提高准确性。
构建一个多层感知机(MLP),即全连接前馈神经网络结构用于分类任务。此架构通常由多个隐藏层组成,每层之间完全相连,并采用ReLU激活函数促进非线性表达能力;最后一层则使用softmax回归完成多类别概率预测。
```python
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=[28, 28]),
keras.layers.Dense(300, activation="relu"),
keras.layers.Dense(100, activation="relu"),
keras.layers.Dense(10, activation="softmax")
])
```
编译该模型指定优化器(optimizer)、损失函数(loss function)及评估指标(metrics)。
```python
model.compile(
optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
```
调用fit方法执行训练循环,在给定轮次内迭代整个训练集以调整权重参数直至收敛或达到最大epoch数限制。
测试阶段会计算验证集中样本被正确分类的比例作为最终性能度量标准之一。
---
#### 利用手动定义PyTorch实现相同目标的方法如下所示:
引入必要的库文件后同样先下载准备好的mnist图片资源集合[^2]。
```python
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
valset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
```
设计自定义类继承于nn.Module基类表示我们的简单卷积神经网路(CNN)[^3]。
```python
class Classifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(784, 256)
self.fc2 = torch.nn.Linear(256, 128)
self.fc3 = torch.nn.Linear(128, 64)
self.fc4 = torch.nn.Linear(64, 10)
def forward(self, x):
# Flatten the input tensor into a vector of length 784.
x = x.view(x.shape[0], -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = torch.log_softmax(self.fc4(x), dim=1)
return x
```
实例化上述定义的对象并设置好相应的超参配置项如学习率等信息以便后续调用optimizer.step()更新梯度下降方向上的权值矩阵。
最后按照惯例编写一段完整的训练逻辑代码片段用来交替地读取批次数据喂入网络内部向前传播求解loss再反向传递误差信号修正各节点间的连接强度直到满足终止条件为止。
阅读全文
相关推荐















