PyTorch学习笔记(二)数据加载器

  在定义好数据集后,PyTorch提供了一个DataLoader类作为数据集的迭代器,该类位于torch.utils.data模块中,我们仅需要在其构造函数中传入指定的数据集,即可直接使用for语句使用。除此之外,构造函数还可以接收一些实用的参数:

  • batch_size:指定每一轮迭代载入的数据数量(默认值:1)
  • shuffle:是否随机打乱顺序(默认值:否)
  • num_workers:指定用于载入数据的子线程数(默认值:0)

代码示例

  在定义好data_loader后,直接使用for语句迭代,每次返回batch_size组数据,数据的类型即为dataset中定义的__getitem__(self, index)方法返回的类型。

1
2
3
4
5
6
7
8
9
10
11
from torch.utils.data import DataLoader

BATCH_SIZE = 64

dataset = BallDataset(os.path.join(".", "data", "ball.txt"))
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# 使用
for labels, inputs in data_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)

参考文献

扩展阅读