### MNIST Implementation for Hypothetical Data of Shop-floor###
# MNist Notes:
# Labels go from 0-9.
# Total number of samples: training: 60000, test: 10000.
# samples per class are imbalanced
'''
Samples per class(0-9):
5923
6742
5958
6131
5842
5421
5918
6265
5851
5949
'''
import numpy as np
import os
import sys
python_version = int(sys.version[0]);
if(python_version == 3):
# For Python 3.0 and later
from urllib.request import urlretrieve
else:
# Python 2's urllib
import urllibimport
import gzip
import struct
from struct import unpack
import random
# For visualization of data
from PIL import Image
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from skimage import transform as tf
from skimage.transform import rotate
import sklearn.preprocessing
class Mnist:
def __init__(self):
self.NUM_CLASSES = 10;
# dict to store number of samples per class in training data
self.SAMPLES_PER_CLASS={}
path = 'http://yann.lecun.com/exdb/mnist/';
self.training_labels,self.training_data = self.loadbatch(path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz');
self.test_labels,self.test_data = self.loadbatch(path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz');
for i in range(0,10):
self.SAMPLES_PER_CLASS[str(i)] = len(np.where(self.training_labels == i)[0]);
#print("Samples for label %s: %i" %(i,self.SAMPLES_PER_CLASS[str(i)]))
def load_training_batch(self):
return self.training_data,self.training_labels
def load_test_batch(self):
return self.test_data,self.test_labels
#download data
def downloadfiles(self,url, force_download=True):
fname = url.split("/")[-1]
if force_download or not os.path.exists(fname):
if(python_version == 3):
urlretrieve(url, fname)
else:
urllib.urlretrieve(url, fname)
return fname
def loadbatch(self,label_url, image_url):
#read label files
with gzip.open(self.downloadfiles(label_url)) as labelfile:
#Interpret strings as packed binary data. unpack the string to unsigned int of 2x4
magic, num = struct.unpack(">II", labelfile.read(8))
label = np.fromstring(labelfile.read(), dtype=np.int8)
#read image files
with gzip.open(self.downloadfiles(image_url), 'rb') as imagefile:
# unpack string to 4x4
magic, num, rows, cols = struct.unpack(">IIII", imagefile.read(16))
image = np.fromstring(imagefile.read(), dtype=np.uint8)#.reshape(len(label), rows, cols)
#image = np.reshape(image,(len(image)*28*28,1)); # store samples in one long array
image = image.astype('float32');
return (label, image)
# function that radnomly removes a specified fraction of data samples corresponding to class_id
def remove_samples(self,data,labels,label_to_remove,fraction):
# update number of samples
num_samples = self.SAMPLES_PER_CLASS[str(label_to_remove)]
samples_to_remove = int(fraction*num_samples)
self.SAMPLES_PER_CLASS[str(label_to_remove)] = num_samples-samples_to_remove
total_samples = 0;
for n in self.SAMPLES_PER_CLASS:
total_samples += self.SAMPLES_PER_CLASS[n]
new_data = np.zeros((total_samples*28*28));
new_labels = np.zeros((total_samples,1));
# find indices of the elements to remove
indices = np.where(labels == label_to_remove)[0]
indices = indices.tolist()
indices = random.sample(indices,samples_to_remove)
indices = np.sort(indices)
indices_index = 0; #index for the indices list
new_index = 0;
for i in range(0,len(labels)):
if( indices_index < len(indices) and indices[indices_index] == i and samples_to_remove > 0):
samples_to_remove = samples_to_remove-1
indices_index += 1;
else:
new_labels[new_index] = labels[i];
new_data[new_index*28*28:(new_index+1)*28*28] = data[i*28*28:(i+1)*28*28]
new_index+=1;
return new_data,new_labels
# removes samples from classes to have the same amount of samples for all classes
def balance_data(self,data,labels):
least_samples = self.SAMPLES_PER_CLASS[str(0)];
for n in self.SAMPLES_PER_CLASS:
if(self.SAMPLES_PER_CLASS[n] < least_samples):
least_samples = self.SAMPLES_PER_CLASS[n];
for n in self.SAMPLES_PER_CLASS:
samples_to_remove = self.SAMPLES_PER_CLASS[n] - least_samples;
if(samples_to_remove > 0):
fraction = samples_to_remove/self.SAMPLES_PER_CLASS[n]; # should perhaps use exact number of samples instead
data,labels = self.remove_samples(data,labels,int(n),fraction) # not super efficient passing data and labels back and forth, but works
return data,labels
# JUST FOR VISUALIZATION
def display_data(self,data,images_per_row):
img = Image.new('L',(28*images_per_row,28*images_per_row))
for i in range(0,images_per_row):
for j in range(0,images_per_row):
im = data[((i*images_per_row)+j)*784:((i*images_per_row)+j+1)*784];
im = np.reshape(im,(28,28))
im = Image.fromarray(im)
img.paste(im,(j*28,i*28))
img.show()
# Augments data so that all classes have the same amount of samples corresponding to the maximum
# possible modes: oversample - duplicate existing sample, random - use any augmentaiton method for new sample
def augment_data(self,data,labels,mode='oversample'):
print("Augmentation mode:", mode)
max_samples = 0;
for n in self.SAMPLES_PER_CLASS:
if(self.SAMPLES_PER_CLASS[n] > max_samples):
max_samples = self.SAMPLES_PER_CLASS[n];
data = np.reshape(data,(-1,28,28));
data_length = len(data);
for n in self.SAMPLES_PER_CLASS:
samples_to_create = max_samples - self.SAMPLES_PER_CLASS[n];
data_length += samples_to_create;
labels = np.reshape(labels,((-1)))
output_labels = np.zeros((data_length));
output_data = np.zeros((data_length,28,28));
output_labels[0:len(labels)] = labels[:];
output_data[0:len(data)] = data[:];
start_index = len(labels);
for n in self.SAMPLES_PER_CLASS:
samples_to_create = max_samples - self.SAMPLES_PER_CLASS[n];
if(samples_to_create > 0):
print("Samples to create for label %s: %i" %(n,samples_to_create));
# augment the data
new_samples_data,new_samples_labels = self.create_samples(data,labels,int(n),samples_to_create,mode);
#add to data and add corresponding labels to labels
output_labels[start_index:start_index+samples_to_create] = new_samples_labels[:];
output_data[start_index:start_index+samples_to_create] = new_samples_data[:];
start_index += samples_to_create;
p = np.random.permutation(len(output_labels));
output_labels = output_labels[p];
output_data = output_data[p];
output_data = np.reshape(output_data,(len(output_data)*28*28));
return output_data,output_labels;
def create_samples(self,data,labels,label,number_of_samples,mode='oversample'):
old_data = np.reshape(data,(-1,28,28))
new_samples_data = np.zeros((number_of_samples,28,28))
new_samples_labels = np.zeros((number_of_samples));
indices = np.where(labels == label)[0];
for i in range(0,number_of_samples):
if(mode == 'oversample'):
random_index = np.random.choice(indices);
new_samples_labels[i] = labels[random_index];
new_samples_data[i] = old_data[random_index];
elif(mode == 'random'):
random_index = np.random.choice(indices); #pick random sample to use as basis for new sample
new_samples_data[i],new_samples_labels[i] = self.create_transformed_sample(old_data[random_index],labels[random_index])
else:
raise Exception("Invalid data augmentation mode picked")
new_samples_data = new_samples_data.astype(np.float32);
return new_samples_data,new_samples_labels;
def create_transformed_sample(self,sample,label):
random_rotation = 20*np.random.rand()-10 # interval [-10,10]
random_translation_x = 4*np.random.rand()-2 # interval [-2,2] is probably max for mnist
random_translation_y = 4*np.random.rand()-2 # interval [-2,2]
tform = tf.SimilarityTransform(scale=1,rotation=0,translation=(random_translation_x,random_translation_y))
sample

xilance
- 粉丝: 7
最新资源
- 【OFDM-MIMO系统单射频链束训练】对具有1个射频链的OFDM-MIMO系统进行束扫描研究附Matlab代码.rar
- 【SCI】利用信念传播在超密集无线网络中进行分布式信道分配附Matlab代码.rar
- 【PSO-LSTM】基于PSO优化LSTM网络的电力负荷预测附Python代码.rar
- 【SVPWM的模型】基于三相VSC的空间矢量PWM方法研究附Simulink仿真.rar
- 【UAV】改进的多旋翼无人机动态模拟的模块化仿真环境附Matlab、Simulink.rar
- 【UAV】【倾斜旋翼六旋翼飞行器】激活多体系统动力学的重力补偿和最优控制研究附Matlab代码.rar
- 【VMD-SSA-LSSVM】基于变分模态分解与麻雀优化Lssvm的负荷预测【多变量】附Matlab代码.rar
- 【UAV四旋翼的PD控制】使用AscTec Pelican四旋翼无人机的PD控制器研究附Matlab代码.rar
- 【UDQ正弦PWM】单相统一功率因数变流器控制、单相VSI或交直变流器以统一功率因数模式运行、控制器采用不平衡d-q控制在同步参考框架中实现研究附Simulink仿真.rar
- 【车间调度】基于卷积神经网络的柔性作业车间调度问题的两阶段算法附Matlab代码.rar
- 【车牌识别】使用傅里叶分析从车牌中提取字符附Matlab代码.rar
- 【车间调度FJSP】基于全球邻域和爬山优化算法的模糊柔性车间调度问题研究附Matlab代码.rar
- 【电池组模型】用于模拟电池的电压、电流、功率和SOC特性,包含6V、12V、24V和48V的模型,通过考虑电池中观察到的各种电压降来实现附Simulink仿真.rar
- 【车牌识别】使用形态学算子进行车牌检测附Matlab代码.rar
- 【大规模 MIMO 检测】基于ADMM的大型MU-MIMO无穷大范数检测研究附Matlab代码.rar
- 【创新、复现】基于蜣螂优化算法的无线传感器网络覆盖优化研究附Matlab代码.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈


