PyTorch学习笔记(一)自定义数据集

  利用PyTorch进行深度学习时,首先要做的就是自定义一个数据集类,用于告诉PyTorch如何加载数据集。继承torch.utils.data.Dataset抽象类后,重写两个方法,即可完成自定义数据集的创建:

  • __len__(self):用于返回数据集的长度
  • __getitem__(self, index):用于返回第i个数据,传入一个index参数,表示索引

代码示例

  在这里,数据集大小并不大,因此在构造函数中,直接将所有数据缓存至内存中,内部使用一个列表存储数据。__getitem__(self, index)方法返回2个变量,用于表示第i个label及data,同时一并将data放入Tensor。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch.utils.data import Dataset
from utils import converter


class BallDataset(Dataset):
"""自定义数据集"""

def __init__(self, data_file):
"""
:param data_file: 数据集文件路径
"""
if not os.path.exists(data_file) or not os.path.isfile(data_file):
raise FileNotFoundError("数据集文件不存在:{}".format(data_file))
data = []
file = open(data_file, "r")
for line in file:
row = line.split("\t")
data.append({"label": int(row[0]), "data": row[1:8]})
file.close()
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
item = self.data[index]
return item["label"], torch.IntTensor(item["data"])

参考文献

扩展阅读