PyTorch学习笔记(一)自定义数据集
利用PyTorch进行深度学习时,首先要做的就是自定义一个数据集类,用于告诉PyTorch如何加载数据集。继承torch.utils.data.Dataset
抽象类后,重写两个方法,即可完成自定义数据集的创建:
__len__(self)
:用于返回数据集的长度__getitem__(self, index)
:用于返回第i
个数据,传入一个index
参数,表示索引
代码示例
在这里,数据集大小并不大,因此在构造函数中,直接将所有数据缓存至内存中,内部使用一个列表存储数据。__getitem__(self, index)
方法返回2个变量,用于表示第i
个label及data,同时一并将data放入Tensor。
1 |
import torch |
参考文献
扩展阅读
- PyTorch学习笔记(一)自定义数据集
- PyTorch学习笔记(二)数据加载器
- PyTorch学习笔记(二)构建神经网络