基于mnist数据集的手写数字识别模型的训练可视化预测

使用 tensorflow 库创建训练模型

数据集使用公开的 mnist 

一、构建模型

from  tensorflow.keras.layers import Dense, Dropout
import tensorflow as tf
def mnistModel():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)), #对其进行形状转化 转化成一维28*28
        tf.keras.layers.Dense(128, activation='relu'), # 添加隐藏层,使用全连接层,设置128个节点,使用relu作为激活函数
        tf.keras.layers.Dense(10, activation='softmax')  # 添加输出层,使用全连接层,设置10个节点,使用 softmax 作为激活函数
    ])
    return model

  1.  将输入28x28 转化成一维
  2. 添加隐藏层,使用全连接层,设置128个节点,使用relu作为激活函数
  3. 添加输出层,使用全连接层,设置10个节点,使用 softmax 作为激活函数

二、构建训练模型脚本

import argparse
import tensorflow as tf
from model import mnistModel

# 设置使用GPU,防止内存不足
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

def run(optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"],
        batch_size=64,  # 每个批量使用64条数据
        epochs=5,  # 训练轮数为5
        validation_split=0.2,  # 划分出20%作为验证
        save_model="./model/mnist_model.h5"
        ):
    model = mnistModel()
    # 配置训练方法
    model.compile(optimizer=optimizer,  # 使用adam优化器
                  loss=loss,  # 损失函数使用稀疏叉熵损失函数
                  metrics=metrics  # 准确率使用稀疏分类准确率函数
                  )
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    # 属性归一化
    X_train, X_test = tf.cast(train_x / 255.0, tf.float32), tf.cast(test_x / 255.0, tf.float32)
    y_train, y_test = tf.cast(train_y, tf.int16), tf.cast(test_y, tf.int16)
    # 训练模型
    model.fit(X_train,
              y_train,
              batch_size=batch_size,  # 每个批量使用64条数据
              epochs=epochs,  # 训练轮数为5
              validation_split=validation_split  # 划分出20%作为验证
              )
    if save_model:
        model.save(save_model, overwrite=True, save_format=None)
    # 评估模型
    accuracy = model.evaluate(X_test, y_test)
    print(accuracy)

def parse_args():
    parser = argparse.ArgumentParser(description="Train a MNIST model with custom parameters.")
    parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer to use for training.')
    parser.add_argument('--loss', type=str, default='sparse_categorical_crossentropy', help='Loss function to use.')
    parser.add_argument('--metrics', type=str, nargs='+', default=['sparse_categorical_accuracy'], help='List of metrics to evaluate during training.')
    parser.add_argument('--batch_size', type=int, default=64, help='Number of samples per gradient update.')
    parser.add_argument('--epochs', type=int, default=5, help='Number of epochs to train the model.')
    parser.add_argument('--validation_split', type=float, default=0.2, help='Fraction of the training data to be used as validation data.')
    parser.add_argument('--save_model', type=str, default='./model/mnist_model.h5', help='Path to save the trained model.')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    run(optimizer=args.optimizer,
        loss=args.loss,
        metrics=args.metrics,
        batch_size=args.batch_size,
        epochs=args.epochs,
        validation_split=args.validation_split,
        save_model=args.save_model)
  1. 使用adam 当作优化器
  2. 使用sparse_categorical_crossentropy 作为评价指标
  3. 划分出两成数据当作验证集validation_split=0.2
  4. batch_size=65,每个批量使用64条数据
  5. epochs=5 设置训练轮数,可以设置成200甚至更高
  6. 设置保存路径,配置训练方法
    model.compile(optimizer=optimizer,  # 使用adam优化器
                  loss=loss,  # 损失函数使用稀疏叉熵损失函数
                  metrics=metrics  # 准确率使用稀疏分类准确率函数
                  )

       7. 读取数据集并进行归一化

    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    # 属性归一化
    X_train, X_test = tf.cast(train_x / 255.0, tf.float32), tf.cast(test_x / 255.0, tf.float32)
    y_train, y_test = tf.cast(train_y, tf.int16), tf.cast(test_y, tf.int16)

 8. 训练模型并保存

    # 训练模型
    model.fit(X_train,
              y_train,
              batch_size=batch_size,  # 每个批量使用64条数据
              epochs=epochs,  # 训练轮数为5
              validation_split=validation_split  # 划分出20%作为验证
              )
    if save_model:
        model.save(save_model, overwrite=True, save_format=None)

9. 评估模型

    # 评估模型
    accuracy = model.evaluate(X_test, y_test)

三、训练模型

 运行train.py脚本

 

四、使用

使用PyQt5构建可视化操作界面

​​​​​​​

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值