pytorch dataloader enumerate 用法

pytorch dataloader enumerate 用法

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a_data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
a_label = torch.tensor([77, 88, 99, 77, 88, 99, 77, 88, 99, 77, 88, 99])
train_data = TensorDataset(a_data, a_label) # 封装数据与标签

# 切片输出
print(train_data[0:2])
print('-' * 100)

# 循环取数据
for x_train, y_label in train_data:
    print(x_train, y_label)

print('-' * 100)

# DataLoader进行数据封装。batch_size批尺寸。shuffle将序列的所有元素随机排序。
train_loader = DataLoader(dataset=train_data, batch_size=4, shuffle=True)

for i, data in enumerate(train_loader):  # 注意enumerate返回值有两个。一个是序号,一个是数据(包含训练数据和标签)。
    z_data, z_label = data
    print('batch:{0}\n z_data:{1}\n z_label:{2}'.format(i, z_data.numpy(), z_label.numpy()))

发表回复

您的电子邮箱地址不会被公开。