前文传送门:
从零开始深度学习Pytorch笔记(1)——安装Pytorch
从零开始深度学习Pytorch笔记(2)——张量的创建(上)
从零开始深度学习Pytorch笔记(3)——张量的创建(下)
从零开始深度学习Pytorch笔记(4)——张量的拼接与切分
从零开始深度学习Pytorch笔记(5)——张量的索引与变换
从零开始深度学习Pytorch笔记(6)——张量的数学运算
从零开始深度学习Pytorch笔记(7)—— 使用Pytorch实现线性回归
从零开始深度学习Pytorch笔记(8)—— 计算图与自动求导(上)
从零开始深度学习Pytorch笔记(9)—— 计算图与自动求导(下)
在该系列的上一篇,我们讲解了计算图和自动求导的知识点,这个内容是Pytorch的基础也是重点,如果不记得了,回去看看吧~我们本篇聊聊Pytorch中的Dataset类。
在进行深度学习的时候,最重要的是什么?没错,就是数据!数据的形式多种多样,可以是文本,可以是表格数据,可以是声音,可以是图像,甚至视频。当我们手上有了数据,接下来的步骤就是将数据读取处理给模型使用,Pytorch提供了很多工具,能让我们读取数据和预处理数据变得easy!
Pytorch的Dataset类是一个抽象类,源码如下,其内部有三个魔法方法:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
当我们加载数据时,可以定义子类继承Dataset类,定义的子类需要重载两个方法,分别是:
__len__方法,用来提供数据库的大小。
__getitem__方法,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]
例如我们可以定义自己的数据类,继承和重写这个抽象类,例如:
import torch
import pandas as pd
from torch.utils.data import Dataset
我们已经导入必要的模块,然后可以按你的需要继承重写:
class myDataset(Dataset):
def __init__(self, csv_file, root_dir):
self.csv_data = csv_file
self.root_dir = root_dir
def __len__(self):
return len(self.csv_data)
def __getitem__(self, idx):
data = (self.csv_data[idx])
return data
def read(self, csv, index):
return pd.read_csv(csv[index])
在上面的代码中,我们在初始化中写了个文件的名称和路径,调用__len__方法可以获取传入数据的个数,而__getitem__可以根据索引获取传入数据的名称。使用read方法可以打开传入的数据。
csv_file = [r'F:/train.csv',r'F:/test.csv']
root_dir = r'F/'
ds1 = myDataset(csv_file,root_dir)
可以得到:
得到 2,可以得出一共有两个文件
可以读取csv的内容:
欢迎关注公众号学习之后的深度学习连载部分~
扫码下图关注我们不会让你失望!