前文传送门:
从零开始深度学习Pytorch笔记(1)——安装Pytorch
从零开始深度学习Pytorch笔记(2)——张量的创建(上)
从零开始深度学习Pytorch笔记(3)——张量的创建(下)
从零开始深度学习Pytorch笔记(4)——张量的拼接与切分
从零开始深度学习Pytorch笔记(5)——张量的索引与变换
从零开始深度学习Pytorch笔记(6)——张量的数学运算
从零开始深度学习Pytorch笔记(7)—— 使用Pytorch实现线性回归
从零开始深度学习Pytorch笔记(8)—— 计算图与自动求导(上)
从零开始深度学习Pytorch笔记(9)—— 计算图与自动求导(下)
从零开始深度学习Pytorch笔记(10)—— Dataset类
从零开始深度学习Pytorch笔记(11)—— DataLoader类
在该系列的上一篇,我们讲解了DataLoader类,本篇我们来聊聊nn.Module。
Module是pytorch提供的一个基类,一般我们搭建深度学习模型会继承这个类,因为会让我们的搭建网络变得简单。
torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。
nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法。
我们来看看Mondule的初始化方法:
def __init__(self):
self._backend = thnn_backend#指定我们前向传播时用thnn这种实现方式
self._parameters = OrderedDict()#用来存放注册的 Parameter 对象
self._buffers = OrderedDict()#用来存放注册的 Buffer 对象。(pytorch 中 buffer 的概念就是 不需要反向传导更新的值)
self._backward_hooks = OrderedDict()##钩子技术,用来提取中间变量
self._forward_hooks = OrderedDict()#钩子技术,用来提取中间变量
self._forward_pre_hooks = OrderedDict()#钩子技术,用来提取中间变量
self._modules = OrderedDict()#用来保存注册的 Module 对象。
self.training = True#标志位,用来表示是不是在 training 状态下,dropout在训练和测试中采取的模式不同,通过training决定前向传播策略。
我们简单搭建一个网络:
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(MyNet, self).__init__()
self.hidden = nn.Linear(n_feature, n_hidden)
self.out = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.out(x)
return x
我们定义的MyNet继承了nn.Module的初始化,super(MyNet, self).__init__()
然后自己在MyNet的初始化中定义了隐藏层hidden,该层为线性输出,并且定义了输出层out,该层也是线性输出
只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)
定义的forward方法是前向传播计算方法,具体输入会进入隐藏层,然后经过非线性激活函数relu,最后经过输出层输出结果。
如果子类中没有实现forward就会报错raise NotImplementedError
forward方法在Module的call中被调用,
def forward(self, *input):
raise NotImplementedError
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
...
我们实例化一个网络,定义每层的神经元数量,并且打印出网络的结构
net = Net(n_feature=5, n_hidden=20, n_output=10)
print(net)
欢迎关注公众号学习之后的深度学习连载部分~
扫码下图关注我们不会让你失望!