官网地址:http://yann.lecun.com/exdb/mnist/
import numpy as np import cv2 import os # 将二进制格式的MNIST数据集转成.png图片格式并保存,图片标签包含在图片名中。 def save_mnist_to_png(mnist_image_file, mnist_label_file, save_dir): if 'train' in os.path.basename(mnist_image_file): num_file = 60000 prefix = 'train' else: num_file = 10000 prefix = 'test' with open(mnist_image_file, 'rb') as f1: image_file = f1.read() with open(mnist_label_file, 'rb') as f2: label_file = f2.read() image_file = image_file[16:] label_file = label_file[8:] for i in range(num_file): label = int(hex(label_file[i]), 16) # 一张图片包含28x28 = 784 个像素点,需要784 bytes的存储空间 image_list = [int(hex(item), 16) for item in image_file[i*784: i*784 + 784]] image_np = np.array(image_list, dtype=np.uint8).reshape(28, 28, 1) save_name = os.path.join(save_dir + "/" + str(label) + "/", '{}_{}_{}.png'.format(prefix, i, label)) cv2.imwrite(save_name, image_np) print('{} ==> {}_{}_{}.png'.format(i, prefix, i, label)) if __name__ == '__main__': train_image_file = './data/MNIST/raw/train-images-idx3-ubyte' train_label_file = './data/MNIST/raw/train-labels-idx1-ubyte' test_image_file = './data/MNIST/raw/t10k-images-idx3-ubyte' test_label_file = './data/MNIST/raw/t10k-labels-idx1-ubyte' save_train_dir = './data/MNIST/train_images' save_test_dir = './data/MNIST/test_images' # 创建目录 i = 0 for i in range(10): if not os.path.exists(save_train_dir + "/" + str(i)): os.makedirs(save_train_dir + "/" + str(i)) if not os.path.exists(save_test_dir + "/" + str(i)): os.makedirs(save_test_dir + "/" + str(i)) # 转换train save_mnist_to_png(train_image_file, train_label_file, save_train_dir) # 转换test save_mnist_to_png(test_image_file, test_label_file, save_test_dir)