0%

PyTorch学习笔记(一)自定义数据集

  利用PyTorch进行深度学习时,首先要做的就是自定义一个数据集类,用于告诉PyTorch如何加载数据集。继承torch.utils.data.Dataset抽象类后,重写两个方法,即可完成自定义数据集的创建:

  • __len__(self):用于返回数据集的长度
  • __getitem__(self, index):用于返回第i个数据,传入一个index参数,表示索引

代码示例

  在这里,数据集大小并不大,因此在构造函数中,直接将所有数据缓存至内存中,内部使用一个列表存储数据。__getitem__(self, index)方法返回2个变量,用于表示第i个label及data,同时一并将data放入Tensor。

import torch
from torch.utils.data import Dataset
from utils import converter


class BallDataset(Dataset):
    """自定义数据集"""

    def __init__(self, data_file):
        """
        :param data_file: 数据集文件路径
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            raise FileNotFoundError("数据集文件不存在:{}".format(data_file))
        data = []
        file = open(data_file, "r")
        for line in file:
            row = line.split("\t")
            data.append({"label": int(row[0]), "data": row[1:8]})
        file.close()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        return item["label"], torch.IntTensor(item["data"])

参考文献

扩展阅读