mnist 数据集转换为图像 python版

mnist 数据集转换为图像 python版


官网地址:http://yann.lecun.com/exdb/mnist/

 

import numpy as np
import cv2
import os


# 将二进制格式的MNIST数据集转成.png图片格式并保存,图片标签包含在图片名中。
def save_mnist_to_png(mnist_image_file, mnist_label_file, save_dir):
    if 'train' in os.path.basename(mnist_image_file):
        num_file = 60000
        prefix = 'train'
    else:
        num_file = 10000
        prefix = 'test'

    with open(mnist_image_file, 'rb') as f1:
        image_file = f1.read()

    with open(mnist_label_file, 'rb') as f2:
        label_file = f2.read()

    image_file = image_file[16:]
    label_file = label_file[8:]

    for i in range(num_file):

        label = int(hex(label_file[i]), 16)
        # 一张图片包含28x28 = 784 个像素点,需要784 bytes的存储空间
        image_list = [int(hex(item), 16) for item in image_file[i*784: i*784 + 784]]
        image_np = np.array(image_list, dtype=np.uint8).reshape(28, 28, 1)
        save_name = os.path.join(save_dir + "/" + str(label) + "/", '{}_{}_{}.png'.format(prefix, i, label))
        cv2.imwrite(save_name, image_np)
        print('{} ==> {}_{}_{}.png'.format(i, prefix, i, label))


if __name__ == '__main__':
    train_image_file = './data/MNIST/raw/train-images-idx3-ubyte'
    train_label_file = './data/MNIST/raw/train-labels-idx1-ubyte'
    test_image_file = './data/MNIST/raw/t10k-images-idx3-ubyte'
    test_label_file = './data/MNIST/raw/t10k-labels-idx1-ubyte'

    save_train_dir = './data/MNIST/train_images'
    save_test_dir = './data/MNIST/test_images'

    # 创建目录
    i = 0
    for i in range(10):
        if not os.path.exists(save_train_dir + "/" + str(i)):
            os.makedirs(save_train_dir + "/" + str(i))

        if not os.path.exists(save_test_dir + "/" + str(i)):
            os.makedirs(save_test_dir + "/" + str(i))

    # 转换train
    save_mnist_to_png(train_image_file, train_label_file, save_train_dir)

    # 转换test
    save_mnist_to_png(test_image_file, test_label_file, save_test_dir)

发表回复

您的电子邮箱地址不会被公开。