从零开始深度学习Pytorch笔记(12)—— nn.Module

浏览: 2925

前文传送门:

从零开始深度学习Pytorch笔记(1)——安装Pytorch

从零开始深度学习Pytorch笔记(2)——张量的创建(上)

从零开始深度学习Pytorch笔记(3)——张量的创建(下)

从零开始深度学习Pytorch笔记(4)——张量的拼接与切分

从零开始深度学习Pytorch笔记(5)——张量的索引与变换

从零开始深度学习Pytorch笔记(6)——张量的数学运算

从零开始深度学习Pytorch笔记(7)—— 使用Pytorch实现线性回归

从零开始深度学习Pytorch笔记(8)—— 计算图与自动求导(上)

从零开始深度学习Pytorch笔记(9)—— 计算图与自动求导(下)

从零开始深度学习Pytorch笔记(10)—— Dataset类

从零开始深度学习Pytorch笔记(11)—— DataLoader类

在该系列的上一篇,我们讲解了DataLoader类,本篇我们来聊聊nn.Module。

Module是pytorch提供的一个基类,一般我们搭建深度学习模型会继承这个类,因为会让我们的搭建网络变得简单。

torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。

nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法。

我们来看看Mondule的初始化方法:

def __init__(self):
    self._backend = thnn_backend#指定我们前向传播时用thnn这种实现方式
    self._parameters = OrderedDict()#用来存放注册的 Parameter 对象
    self._buffers = OrderedDict()#用来存放注册的 Buffer 对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)
    self._backward_hooks = OrderedDict()##钩子技术,用来提取中间变量
    self._forward_hooks = OrderedDict()#钩子技术,用来提取中间变量
    self._forward_pre_hooks = OrderedDict()#钩子技术,用来提取中间变量
    self._modules = OrderedDict()#用来保存注册的 Module 对象。
    self.training = True#标志位,用来表示是不是在 training 状态下,dropout在训练和测试中采取的模式不同,通过training决定前向传播策略。

我们简单搭建一个网络:

import torch.nn as nn
import torch.nn.functional as F

class MyNet(nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(MyNet, self).__init__()
        self.hidden = nn.Linear(n_feature, n_hidden)
        self.out = nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.out(x)
        return x

我们定义的MyNet继承了nn.Module的初始化,super(MyNet, self).__init__()

然后自己在MyNet的初始化中定义了隐藏层hidden,该层为线性输出,并且定义了输出层out,该层也是线性输出

只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)

定义的forward方法是前向传播计算方法,具体输入会进入隐藏层,然后经过非线性激活函数relu,最后经过输出层输出结果。

如果子类中没有实现forward就会报错raise NotImplementedError

forward方法在Module的call中被调用,

def forward(self, *input):
    raise NotImplementedError

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        ...

我们实例化一个网络,定义每层的神经元数量,并且打印出网络的结构

net = Net(n_feature=5, n_hidden=20, n_output=10)
print(net)

欢迎关注公众号学习之后的深度学习连载部分~

扫码下图关注我们不会让你失望!

image.png

推荐 0
本文由 ID王大伟 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册