TensorFlow示例解析:notMNIST数据集预处理实战指南

TensorFlow示例解析:notMNIST数据集预处理实战指南

概述

notMNIST数据集是一个经典的机器学习数据集,它模仿了著名的MNIST手写数字数据集,但使用的是字母A-J(共10类)而非数字。这个数据集的特点是比MNIST更接近真实世界的数据,包含更多的噪声和不规则性,因此对模型的挑战性更大。

数据集介绍

notMNIST数据集包含:

  • 训练集:约50万张28x28像素的灰度图像
  • 测试集:约1.9万张图像
  • 10个类别(A-J)

环境准备

在开始处理数据前,我们需要导入必要的Python库:

from __future__ import print_function
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
from IPython.display import display, Image
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle

# 配置matplotlib在Jupyter Notebook中内联显示
%matplotlib inline

数据下载与解压

下载数据集

我们首先实现一个下载函数,包含进度显示功能:

def download_progress_hook(count, blockSize, totalSize):
    """下载进度回调函数"""
    global last_percent_reported
    percent = int(count * blockSize * 100 / totalSize)
    
    if last_percent_reported != percent:
        if percent % 5 == 0:
            sys.stdout.write("%s%%" % percent)
            sys.stdout.flush()
        else:
            sys.stdout.write(".")
            sys.stdout.flush()
        last_percent_reported = percent

def maybe_download(filename, expected_bytes, force=False):
    """下载文件(如果不存在)并验证大小"""
    dest_filename = os.path.join(data_root, filename)
    if force or not os.path.exists(dest_filename):
        print('Attempting to download:', filename) 
        filename, _ = urlretrieve(url + filename, dest_filename, 
                                 reporthook=download_progress_hook)
        print('\nDownload Complete!')
    statinfo = os.stat(dest_filename)
    if statinfo.st_size == expected_bytes:
        print('Found and verified', dest_filename)
    else:
        raise Exception('Failed to verify ' + dest_filename)
    return dest_filename

# 下载训练集和测试集
train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)
test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)

解压数据集

下载完成后,我们需要解压.tar.gz文件:

def maybe_extract(filename, force=False):
    """解压文件(如果尚未解压)"""
    root = os.path.splitext(os.path.splitext(filename)[0])[0]  # 移除.tar.gz
    if os.path.isdir(root) and not force:
        print('%s already present - Skipping extraction.' % root)
    else:
        print('Extracting data for %s. Please wait.' % root)
        tar = tarfile.open(filename)
        sys.stdout.flush()
        tar.extractall(data_root)
        tar.close()
    
    # 验证解压后的文件夹结构
    data_folders = [
        os.path.join(root, d) for d in sorted(os.listdir(root))
        if os.path.isdir(os.path.join(root, d))]
    if len(data_folders) != num_classes:
        raise Exception('Expected %d folders, found %d.' % 
                      (num_classes, len(data_folders)))
    print(data_folders)
    return data_folders

train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)

数据预处理

图像加载与归一化

我们将图像数据转换为3D数组(图像索引×宽度×高度),并进行归一化处理:

image_size = 28  # 图像尺寸
pixel_depth = 255.0  # 像素深度

def load_letter(folder, min_num_images):
    """加载单个字母类别的数据"""
    image_files = os.listdir(folder)
    dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
                         dtype=np.float32)
    print(folder)
    num_images = 0
    
    for image in image_files:
        image_file = os.path.join(folder, image)
        try:
            # 读取图像并归一化到[-0.5, 0.5]范围
            image_data = (imageio.imread(image_file).astype(float) - 
                         pixel_depth / 2) / pixel_depth
            if image_data.shape != (image_size, image_size):
                raise Exception('Unexpected image shape: %s' % str(image_data.shape))
            dataset[num_images, :, :] = image_data
            num_images += 1
        except (IOError, ValueError) as e:
            print('Could not read:', image_file, ':', e, "- skipping.")
    
    dataset = dataset[0:num_images, :, :]
    if num_images < min_num_images:
        raise Exception('Too few images: %d < %d' % (num_images, min_num_images))
        
    print('Full dataset tensor:', dataset.shape)
    print('Mean:', np.mean(dataset))
    print('Standard deviation:', np.std(dataset))
    return dataset

数据序列化

为便于后续使用,我们将处理后的数据序列化为pickle文件:

def maybe_pickle(data_folders, min_num_images_per_class, force=False):
    """将数据集序列化为pickle文件"""
    dataset_names = []
    for folder in data_folders:
        set_filename = folder + '.pickle'
        dataset_names.append(set_filename)
        if os.path.exists(set_filename) and not force:
            print('%s already present - Skipping.' % set_filename)
        else:
            print('Pickling %s.' % set_filename)
            dataset = load_letter(folder, min_num_images_per_class)
            try:
                with open(set_filename, 'wb') as f:
                    pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
            except Exception as e:
                print('Unable to save data to', set_filename, ':', e)
    return dataset_names

train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)

数据验证

可视化检查

在处理完数据后,我们应该进行可视化检查,确保数据加载正确:

def display_samples(pickle_file):
    """显示pickle文件中的样本图像"""
    with open(pickle_file, 'rb') as f:
        data = pickle.load(f)
    
    plt.figure(figsize=(10, 10))
    for i in range(9):
        plt.subplot(3, 3, i+1)
        plt.imshow(data[i], cmap='gray')
        plt.axis('off')
    plt.show()

# 显示训练集和测试集的样本
display_samples(train_datasets[0])  # 字母A的样本
display_samples(test_datasets[0])   # 字母A的测试样本

类别平衡检查

良好的数据集应该在不同类别间保持平衡:

def check_balance(pickle_files):
    """检查各类别样本数量是否平衡"""
    class_counts = []
    for pickle_file in pickle_files:
        with open(pickle_file, 'rb') as f:
            data = pickle.load(f)
            class_counts.append(len(data))
    
    plt.bar(range(len(class_counts)), class_counts)
    plt.xlabel('Class')
    plt.ylabel('Number of samples')
    plt.title('Class distribution')
    plt.show()
    return class_counts

train_counts = check_balance(train_datasets)
test_counts = check_balance(test_datasets)

数据集合并与分割

为便于模型训练,我们需要将各个类别的数据合并,并分割为训练集、验证集和测试集:

def make_arrays(nb_rows, img_size):
    """创建空的数据和标签数组"""
    if nb_rows:
        dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
        labels = np.ndarray(nb_rows, dtype=np.int32)
    else:
        dataset, labels = None, None
    return dataset, labels

def merge_datasets(pickle_files, train_size, valid_size=0):
    """合并数据集并分割为训练集和验证集"""
    num_classes = len(pickle_files)
    valid_dataset, valid_labels = make_arrays(valid_size, image_size)
    train_dataset, train_labels = make_arrays(train_size, image_size)
    
    vsize_per_class = valid_size // num_classes
    tsize_per_class = train_size // num_classes
    
    start_v, start_t = 0, 0
    end_v, end_t = vsize_per_class, tsize_per_class
    end_l = vsize_per_class + tsize_per_class
    
    for label, pickle_file in enumerate(pickle_files):
        try:
            with open(pickle_file, 'rb') as f:
                letter_set = pickle.load(f)
                np.random.shuffle(letter_set)
                
                if valid_dataset is not None:
                    valid_letter = letter_set[:vsize_per_class, :, :]
                    valid_dataset[start_v:end_v, :, :] = valid_letter
                    valid_labels[start_v:end_v] = label
                    start_v += vsize_per_class
                    end_v += vsize_per_class
                    
                train_letter = letter_set[vsize_per_class:end_l, :, :]
                train_dataset[start_t:end_t, :, :] = train_letter
                train_labels[start_t:end_t] = label
                start_t += tsize_per_class
                end_t += tsize_per_class
        except Exception as e:
            print('Unable to process', pickle_file, ':', e)
            raise
    
    return train_dataset, train_labels, valid_dataset, valid_labels

# 合并训练集和验证集
train_size = 200000
valid_size = 10000
test_size = 10000

train_dataset, train_labels, valid_dataset, valid_labels = merge_datasets(
    train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)

print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)

数据随机化

最后,我们对数据进行随机化处理,确保训练时样本顺序不会影响模型学习:

def randomize(dataset, labels):
    """随机化数据集顺序"""
    permutation = np.random.permutation(labels.shape[0])
    shuffled_dataset = dataset[permutation, :, :]
    shuffled_labels = labels[permutation]
    return shuffled_dataset, shuffled_labels

train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)

总结

通过以上步骤,我们完成了notMNIST数据集的完整预处理流程:

  1. 下载原始数据集
  2. 解压并验证文件结构
  3. 加载图像数据并进行归一化处理
  4. 序列化处理后的数据以便后续使用
  5. 可视化检查数据质量
  6. 验证类别平衡性
  7. 合并数据集并分割为训练集、验证集和测试集
  8. 随机化数据顺序

这些预处理步骤为后续的机器学习模型训练奠定了坚实的基础。处理后的数据具有以下特点:

  • 归一化到[-0.5, 0.5]范围
  • 各类别样本数量平衡
  • 已分割为训练、验证和测试集
  • 数据顺序随机化

在实际应用中,这种规范化的数据处理流程可以大大提高机器学习项目的效率和可重复性。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

瞿凌骊Natalie

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值