TensorFlow示例解析:notMNIST数据集预处理实战指南
概述
notMNIST数据集是一个经典的机器学习数据集,它模仿了著名的MNIST手写数字数据集,但使用的是字母A-J(共10类)而非数字。这个数据集的特点是比MNIST更接近真实世界的数据,包含更多的噪声和不规则性,因此对模型的挑战性更大。
数据集介绍
notMNIST数据集包含:
- 训练集:约50万张28x28像素的灰度图像
- 测试集:约1.9万张图像
- 10个类别(A-J)
环境准备
在开始处理数据前,我们需要导入必要的Python库:
from __future__ import print_function
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
from IPython.display import display, Image
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
# 配置matplotlib在Jupyter Notebook中内联显示
%matplotlib inline
数据下载与解压
下载数据集
我们首先实现一个下载函数,包含进度显示功能:
def download_progress_hook(count, blockSize, totalSize):
"""下载进度回调函数"""
global last_percent_reported
percent = int(count * blockSize * 100 / totalSize)
if last_percent_reported != percent:
if percent % 5 == 0:
sys.stdout.write("%s%%" % percent)
sys.stdout.flush()
else:
sys.stdout.write(".")
sys.stdout.flush()
last_percent_reported = percent
def maybe_download(filename, expected_bytes, force=False):
"""下载文件(如果不存在)并验证大小"""
dest_filename = os.path.join(data_root, filename)
if force or not os.path.exists(dest_filename):
print('Attempting to download:', filename)
filename, _ = urlretrieve(url + filename, dest_filename,
reporthook=download_progress_hook)
print('\nDownload Complete!')
statinfo = os.stat(dest_filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', dest_filename)
else:
raise Exception('Failed to verify ' + dest_filename)
return dest_filename
# 下载训练集和测试集
train_filename = maybe_download('notMNIST_large.tar.gz', 247336696)
test_filename = maybe_download('notMNIST_small.tar.gz', 8458043)
解压数据集
下载完成后,我们需要解压.tar.gz文件:
def maybe_extract(filename, force=False):
"""解压文件(如果尚未解压)"""
root = os.path.splitext(os.path.splitext(filename)[0])[0] # 移除.tar.gz
if os.path.isdir(root) and not force:
print('%s already present - Skipping extraction.' % root)
else:
print('Extracting data for %s. Please wait.' % root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall(data_root)
tar.close()
# 验证解压后的文件夹结构
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception('Expected %d folders, found %d.' %
(num_classes, len(data_folders)))
print(data_folders)
return data_folders
train_folders = maybe_extract(train_filename)
test_folders = maybe_extract(test_filename)
数据预处理
图像加载与归一化
我们将图像数据转换为3D数组(图像索引×宽度×高度),并进行归一化处理:
image_size = 28 # 图像尺寸
pixel_depth = 255.0 # 像素深度
def load_letter(folder, min_num_images):
"""加载单个字母类别的数据"""
image_files = os.listdir(folder)
dataset = np.ndarray(shape=(len(image_files), image_size, image_size),
dtype=np.float32)
print(folder)
num_images = 0
for image in image_files:
image_file = os.path.join(folder, image)
try:
# 读取图像并归一化到[-0.5, 0.5]范围
image_data = (imageio.imread(image_file).astype(float) -
pixel_depth / 2) / pixel_depth
if image_data.shape != (image_size, image_size):
raise Exception('Unexpected image shape: %s' % str(image_data.shape))
dataset[num_images, :, :] = image_data
num_images += 1
except (IOError, ValueError) as e:
print('Could not read:', image_file, ':', e, "- skipping.")
dataset = dataset[0:num_images, :, :]
if num_images < min_num_images:
raise Exception('Too few images: %d < %d' % (num_images, min_num_images))
print('Full dataset tensor:', dataset.shape)
print('Mean:', np.mean(dataset))
print('Standard deviation:', np.std(dataset))
return dataset
数据序列化
为便于后续使用,我们将处理后的数据序列化为pickle文件:
def maybe_pickle(data_folders, min_num_images_per_class, force=False):
"""将数据集序列化为pickle文件"""
dataset_names = []
for folder in data_folders:
set_filename = folder + '.pickle'
dataset_names.append(set_filename)
if os.path.exists(set_filename) and not force:
print('%s already present - Skipping.' % set_filename)
else:
print('Pickling %s.' % set_filename)
dataset = load_letter(folder, min_num_images_per_class)
try:
with open(set_filename, 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
print('Unable to save data to', set_filename, ':', e)
return dataset_names
train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)
数据验证
可视化检查
在处理完数据后,我们应该进行可视化检查,确保数据加载正确:
def display_samples(pickle_file):
"""显示pickle文件中的样本图像"""
with open(pickle_file, 'rb') as f:
data = pickle.load(f)
plt.figure(figsize=(10, 10))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.imshow(data[i], cmap='gray')
plt.axis('off')
plt.show()
# 显示训练集和测试集的样本
display_samples(train_datasets[0]) # 字母A的样本
display_samples(test_datasets[0]) # 字母A的测试样本
类别平衡检查
良好的数据集应该在不同类别间保持平衡:
def check_balance(pickle_files):
"""检查各类别样本数量是否平衡"""
class_counts = []
for pickle_file in pickle_files:
with open(pickle_file, 'rb') as f:
data = pickle.load(f)
class_counts.append(len(data))
plt.bar(range(len(class_counts)), class_counts)
plt.xlabel('Class')
plt.ylabel('Number of samples')
plt.title('Class distribution')
plt.show()
return class_counts
train_counts = check_balance(train_datasets)
test_counts = check_balance(test_datasets)
数据集合并与分割
为便于模型训练,我们需要将各个类别的数据合并,并分割为训练集、验证集和测试集:
def make_arrays(nb_rows, img_size):
"""创建空的数据和标签数组"""
if nb_rows:
dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
labels = np.ndarray(nb_rows, dtype=np.int32)
else:
dataset, labels = None, None
return dataset, labels
def merge_datasets(pickle_files, train_size, valid_size=0):
"""合并数据集并分割为训练集和验证集"""
num_classes = len(pickle_files)
valid_dataset, valid_labels = make_arrays(valid_size, image_size)
train_dataset, train_labels = make_arrays(train_size, image_size)
vsize_per_class = valid_size // num_classes
tsize_per_class = train_size // num_classes
start_v, start_t = 0, 0
end_v, end_t = vsize_per_class, tsize_per_class
end_l = vsize_per_class + tsize_per_class
for label, pickle_file in enumerate(pickle_files):
try:
with open(pickle_file, 'rb') as f:
letter_set = pickle.load(f)
np.random.shuffle(letter_set)
if valid_dataset is not None:
valid_letter = letter_set[:vsize_per_class, :, :]
valid_dataset[start_v:end_v, :, :] = valid_letter
valid_labels[start_v:end_v] = label
start_v += vsize_per_class
end_v += vsize_per_class
train_letter = letter_set[vsize_per_class:end_l, :, :]
train_dataset[start_t:end_t, :, :] = train_letter
train_labels[start_t:end_t] = label
start_t += tsize_per_class
end_t += tsize_per_class
except Exception as e:
print('Unable to process', pickle_file, ':', e)
raise
return train_dataset, train_labels, valid_dataset, valid_labels
# 合并训练集和验证集
train_size = 200000
valid_size = 10000
test_size = 10000
train_dataset, train_labels, valid_dataset, valid_labels = merge_datasets(
train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)
数据随机化
最后,我们对数据进行随机化处理,确保训练时样本顺序不会影响模型学习:
def randomize(dataset, labels):
"""随机化数据集顺序"""
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation, :, :]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)
总结
通过以上步骤,我们完成了notMNIST数据集的完整预处理流程:
- 下载原始数据集
- 解压并验证文件结构
- 加载图像数据并进行归一化处理
- 序列化处理后的数据以便后续使用
- 可视化检查数据质量
- 验证类别平衡性
- 合并数据集并分割为训练集、验证集和测试集
- 随机化数据顺序
这些预处理步骤为后续的机器学习模型训练奠定了坚实的基础。处理后的数据具有以下特点:
- 归一化到[-0.5, 0.5]范围
- 各类别样本数量平衡
- 已分割为训练、验证和测试集
- 数据顺序随机化
在实际应用中,这种规范化的数据处理流程可以大大提高机器学习项目的效率和可重复性。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考