PyTorch 中使用 nn.Module 定义模型
nn.Module 是什么?
nn.Module 是 PyTorch 中定义所有神经网络模块的基类,包括层、整个模型等。使用 nn.Module 基类可以使模型的创建、管理和使用变得更简单。大多数的神经网络层(如 nn.Linear, nn.Conv2d 等)都继承自 nn.Module。当你创建自己的复杂网络结构时,你也通常会从这个基类继承。
以下是 nn.Module 的一些主要特性和功能:
-
参数管理: 任何继承自
nn.Module的子类(通常代表神经网络层或整个模型)可以自动跟踪其定义为参数(nn.Parameter)的属性。这使得整个模型的参数可以轻松地通过调用.parameters()方法来获取。 -
层与模块的嵌套: 可以包含其他
nn.Module子类的实例作为其属性,这允许我们构建嵌套的结构和更复杂的网络。 -
GPU/CPU转移: 使用
.to(device)方法,可以轻松地将模型和其所有参数移动到 CPU 或 GPU。 -
前向传播定义: 通过定义
forward()方法来指定模型的前向传播逻辑。 -
保存和加载:
nn.Module提供了简便的方法来保存和加载模型。
一个简单的示例:
以下是一个简单的多层感知器(MLP)模型,它继承自 nn.Module:
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
# 定义网络的层
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
# 定义前向传播逻辑
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 实例化模型
model = MLP(10, 20, 1)
在上面的例子中,我们定义了一个简单的 MLP 类,它继承自 nn.Module。在 __init__ 方法中,我们定义了模型的层。在 forward 方法中,我们定义了模型的前向传播逻辑。
总的来说,nn.Module 是 PyTorch 中创建神经网络的核心类,它为创建、训练和使用模型提供了很多方便的功能和方法。
forward 是什么?
forward 方法是在 PyTorch 的 nn.Module 类中定义的一个方法,它用于指定神经网络模块如何处理输入数据并生成输出。换句话说,当你向网络模块传递一个输入并试图获得输出时,背后调用的实际方法就是 forward 方法。
以下是一个简单的 例子:
例子:假设我们有一个简单的全连接层(也叫线性层)。该层接受一个输入向量并返回一个输出向量。
import torch.nn as nn
class LinearLayer(nn.Module):
def __init__(self, input_size, output_size):
super(LinearLayer, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, x):
return self.fc(x)
以下是一个描述这个模型如何工作的简单图示:
当我们使用以下代码:
model = LinearLayer(5, 3)
output = model(torch.randn(5))
实际上发生的是:
- 我们创建了一个名为
model的LinearLayer实例。 - 当我们调用
model(torch.randn(5)),我们实际上是在调用LinearLayer的forward方法,传入一个随机初始化的大小为 5 的张量作为输入。 forward方法内部调用了self.fc,这是一个nn.Linear层,它执行线性变换(权重乘以输入加上偏置)。- 结果(
output)是一个大小为 3 的张量。
这就是 forward 方法的工作方式:定义了如何从输入获得输出。在更复杂的模型中,forward 方法可以涉及更多的操作、层和逻辑,但核心思想始终不变:指定从输入到输出的转换。
常用的 nn.Module 子类
torch.nn(通常简写为nn)库中包含了许多预定义的nn.Module子类,这些子类代表了常见的神经网络层和其他相关的功能。以下是一些常用的nn.Module子类:
-
线性和全连接层:
nn.Linear: 全连接层。
-
卷积层:
nn.Conv1d: 一维卷积层。nn.Conv2d: 二维卷积层(常用于图像处理)。nn.Conv3d: 三维卷积层(常用于体积数据)。
-
池化层:
nn.MaxPool1d,nn.AvgPool1d: 一维最大/平均池化。nn.MaxPool2d,nn.AvgPool2d: 二维最大/平均池化。nn.MaxPool3d,nn.AvgPool3d: 三维最大/平均池化。
-
激活函数:
nn.ReLU: ReLU 激活函数。nn.LeakyReLU: Leaky ReLU 激活函数。nn.Sigmoid: Sigmoid 激活函数。nn.Tanh: Tanh 激活函数。nn.Softmax: Softmax 函数。
-
归一化层:
nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d: 批归一化。nn.LayerNorm: 层归一化。nn.InstanceNorm1d,nn.InstanceNorm2d,nn.InstanceNorm3d: 实例归一化。
-
循环层和LSTM:
nn.RNN: 基本循环神经网络层。nn.LSTM: 长短时记忆网络。nn.GRU: 门控循环单元。
-
嵌入层:
nn.Embedding: 一个简单的查找表,用于存储固定大小的字典和嵌入大小,通常用于NLP任务中的词嵌入。
-
丢弃层:
nn.Dropout,nn.Dropout2d,nn.Dropout3d: 丢弃层,用于防止过拟合。
-
损失函数 (虽然它们是
nn.Module的子类,但通常被认为是独立的组件):nn.MSELoss: 均方误差损失。nn.CrossEntropyLoss: 交叉熵损失。nn.BCELoss: 二进制交叉熵损失。nn.SmoothL1Loss: Huber Loss 或 Smooth L1 Loss。
-
容器:
nn.Sequential: 一个有序的容器,其中的模块按顺序执行。nn.ModuleList: 将子模块保存在列表中的容器。nn.ModuleDict: 将子模块保存在字典中的容器。
以上只是nn.Module的部分子类,PyTorch中还有许多其他特定用途和功能的模块。不过,上面列出的是最常用的一些模块,应该足以覆盖许多常见的神经网络设计和任务。
定义自己的网络类
对于具有分支、多输入/多输出或其他复杂连接的网络结构,需要直接继承 nn.Module 来定义自己的网络类,并在其中实现自定义的前向传播逻辑。
以下是一些示例,展示如何定义具有复杂结构的网络:
1. 具有分支的网络:
假设我们想定义一个网络,它有两个并行的分支,最后再将这两个分支的输出合并。
import torch.nn as nn
class BranchedNet(nn.Module):
def __init__(self):
super(BranchedNet, self).__init__()
self.branch1 = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
)
self.branch2 = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 10)
)
def forward(self, x):
out1 = self.branch1(x)
out2 = self.branch2(x)
return out1 + out2 # 合并两个分支的输出
2. 具有多输入的网络:
假设我们有一个网络,它接受两个不同的输入,并将它们的信息合并为一个输出。
class MultiInputNet(nn.Module):
def __init__(self):
super(MultiInputNet, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(15, 20)
self.fc3 = nn.Linear(40, 10)
def forward(self, x1, x2):
out1 = self.fc1(x1)
out2 = self.fc2(x2)
combined = torch.cat((out1, out2), dim=1) # 合并两个输入的信息
return self.fc3(combined)
3. 具有多输出的网络:
假设我们有一个网络,它基于一个输入产生两个不同的输出。
class MultiOutputNet(nn.Module):
def __init__(self):
super(MultiOutputNet, self).__init__()
self.shared_layers = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
self.output1 = nn.Linear(20, 10)
self.output2 = nn.Linear(20, 5)
def forward(self, x):
shared_out = self.shared_layers(x)
out1 = self.output1(shared_out)
out2 = self.output2(shared_out)
return out1, out2 # 返回两个输出
这些示例展示了如何使用nn.Module来定义具有复杂结构的网络。通过继承nn.Module并实现自定义的前向传播逻辑,你可以创建任意复杂的网络结构。