from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
a_data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
a_label = torch.tensor([77, 88, 99, 77, 88, 99, 77, 88, 99, 77, 88, 99])
train_data = TensorDataset(a_data, a_label) # 封装数据与标签
# 切片输出
print(train_data[0:2])
print('-' * 100)
# 循环取数据
for x_train, y_label in train_data:
print(x_train, y_label)
print('-' * 100)
# DataLoader进行数据封装。batch_size批尺寸。shuffle将序列的所有元素随机排序。
train_loader = DataLoader(dataset=train_data, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader): # 注意enumerate返回值有两个。一个是序号,一个是数据(包含训练数据和标签)。
z_data, z_label = data
print('batch:{0}\n z_data:{1}\n z_label:{2}'.format(i, z_data.numpy(), z_label.numpy()))
