PyTorch学习笔记(二)数据加载器
在定义好数据集后,PyTorch提供了一个DataLoader
类作为数据集的迭代器,该类位于torch.utils.data
模块中,我们仅需要在其构造函数中传入指定的数据集,即可直接使用for
语句使用。除此之外,构造函数还可以接收一些实用的参数:
batch_size
:指定每一轮迭代载入的数据数量(默认值:1)shuffle
:是否随机打乱顺序(默认值:否)num_workers
:指定用于载入数据的子线程数(默认值:0)
代码示例
在定义好data_loader后,直接使用for
语句迭代,每次返回batch_size
组数据,数据的类型即为dataset中定义的__getitem__(self, index)
方法返回的类型。
1 |
from torch.utils.data import DataLoader |
参考文献
扩展阅读
- PyTorch学习笔记(一)自定义数据集
- PyTorch学习笔记(二)数据加载器
- PyTorch学习笔记(二)构建神经网络