从零开始深度学习Pytorch笔记(10)—— Dataset类

浏览: 3237

前文传送门:

从零开始深度学习Pytorch笔记(1)——安装Pytorch

从零开始深度学习Pytorch笔记(2)——张量的创建(上)

从零开始深度学习Pytorch笔记(3)——张量的创建(下)

从零开始深度学习Pytorch笔记(4)——张量的拼接与切分

从零开始深度学习Pytorch笔记(5)——张量的索引与变换

从零开始深度学习Pytorch笔记(6)——张量的数学运算

从零开始深度学习Pytorch笔记(7)—— 使用Pytorch实现线性回归

从零开始深度学习Pytorch笔记(8)—— 计算图与自动求导(上)

从零开始深度学习Pytorch笔记(9)—— 计算图与自动求导(下)

在该系列的上一篇,我们讲解了计算图和自动求导的知识点,这个内容是Pytorch的基础也是重点,如果不记得了,回去看看吧~我们本篇聊聊Pytorch中的Dataset类。

在进行深度学习的时候,最重要的是什么?没错,就是数据!数据的形式多种多样,可以是文本,可以是表格数据,可以是声音,可以是图像,甚至视频。当我们手上有了数据,接下来的步骤就是将数据读取处理给模型使用,Pytorch提供了很多工具,能让我们读取数据和预处理数据变得easy!

Pytorch的Dataset类是一个抽象类,源码如下,其内部有三个魔法方法:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """


    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

当我们加载数据时,可以定义子类继承Dataset类,定义的子类需要重载两个方法,分别是:

__len__方法,用来提供数据库的大小。
__getitem__方法,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]

例如我们可以定义自己的数据类,继承和重写这个抽象类,例如:

import torch
import pandas as pd
from torch.utils.data import Dataset

我们已经导入必要的模块,然后可以按你的需要继承重写:

class myDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.csv_data = csv_file
        self.root_dir = root_dir

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        data = (self.csv_data[idx])
        return data

    def read(self, csv, index):
        return pd.read_csv(csv[index])

在上面的代码中,我们在初始化中写了个文件的名称和路径,调用__len__方法可以获取传入数据的个数,而__getitem__可以根据索引获取传入数据的名称。使用read方法可以打开传入的数据。

csv_file = [r'F:/train.csv',r'F:/test.csv']
root_dir = r'F/'
ds1 = myDataset(csv_file,root_dir)
ds1[0]

可以得到:

len(ds1)

得到 2,可以得出一共有两个文件

ds1.read(ds1,1)

可以读取csv的内容:

image.png

欢迎关注公众号学习之后的深度学习连载部分~

扫码下图关注我们不会让你失望!

image.png

推荐 0
本文由 ID王大伟 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册