跳到主要内容

PyTorch 中使用 nn.Module 定义模型

nn.Module 是什么?

nn.Module 是 PyTorch 中定义所有神经网络模块的基类,包括层、整个模型等。使用 nn.Module 基类可以使模型的创建、管理和使用变得更简单。大多数的神经网络层(如 nn.Linear, nn.Conv2d 等)都继承自 nn.Module。当你创建自己的复杂网络结构时,你也通常会从这个基类继承。

以下是 nn.Module 的一些主要特性和功能:

  1. 参数管理: 任何继承自 nn.Module 的子类(通常代表神经网络层或整个模型)可以自动跟踪其定义为参数(nn.Parameter)的属性。这使得整个模型的参数可以轻松地通过调用 .parameters() 方法来获取。

  2. 层与模块的嵌套: 可以包含其他 nn.Module 子类的实例作为其属性,这允许我们构建嵌套的结构和更复杂的网络。

  3. GPU/CPU转移: 使用 .to(device) 方法,可以轻松地将模型和其所有参数移动到 CPU 或 GPU。

  4. 前向传播定义: 通过定义 forward() 方法来指定模型的前向传播逻辑。

  5. 保存和加载: nn.Module 提供了简便的方法来保存和加载模型。

一个简单的示例:

以下是一个简单的多层感知器(MLP)模型,它继承自 nn.Module

import torch.nn as nn

class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()

# 定义网络的层
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()

def forward(self, x):
# 定义前向传播逻辑
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

# 实例化模型
model = MLP(10, 20, 1)

在上面的例子中,我们定义了一个简单的 MLP 类,它继承自 nn.Module。在 __init__ 方法中,我们定义了模型的层。在 forward 方法中,我们定义了模型的前向传播逻辑。

总的来说,nn.Module 是 PyTorch 中创建神经网络的核心类,它为创建、训练和使用模型提供了很多方便的功能和方法。

forward 是什么?

forward 方法是在 PyTorch 的 nn.Module 类中定义的一个方法,它用于指定神经网络模块如何处理输入数据并生成输出。换句话说,当你向网络模块传递一个输入并试图获得输出时,背后调用的实际方法就是 forward 方法。

以下是一个简单的例子:

例子:假设我们有一个简单的全连接层(也叫线性层)。该层接受一个输入向量并返回一个输出向量。

import torch.nn as nn

class LinearLayer(nn.Module):
def __init__(self, input_size, output_size):
super(LinearLayer, self).__init__()
self.fc = nn.Linear(input_size, output_size)

def forward(self, x):
return self.fc(x)

以下是一个描述这个模型如何工作的简单图示:

当我们使用以下代码:

model = LinearLayer(5, 3)
output = model(torch.randn(5))

实际上发生的是:

  1. 我们创建了一个名为 modelLinearLayer 实例。
  2. 当我们调用 model(torch.randn(5)),我们实际上是在调用 LinearLayerforward 方法,传入一个随机初始化的大小为 5 的张量作为输入。
  3. forward 方法内部调用了 self.fc,这是一个 nn.Linear 层,它执行线性变换(权重乘以输入加上偏置)。
  4. 结果(output)是一个大小为 3 的张量。

这就是 forward 方法的工作方式:定义了如何从输入获得输出。在更复杂的模型中,forward 方法可以涉及更多的操作、层和逻辑,但核心思想始终不变:指定从输入到输出的转换。

常用的 nn.Module 子类

torch.nn(通常简写为nn)库中包含了许多预定义的nn.Module子类,这些子类代表了常见的神经网络层和其他相关的功能。以下是一些常用的nn.Module子类:

  1. 线性和全连接层:

    • nn.Linear: 全连接层。
  2. 卷积层:

    • nn.Conv1d: 一维卷积层。
    • nn.Conv2d: 二维卷积层(常用于图像处理)。
    • nn.Conv3d: 三维卷积层(常用于体积数据)。
  3. 池化层:

    • nn.MaxPool1d, nn.AvgPool1d: 一维最大/平均池化。
    • nn.MaxPool2d, nn.AvgPool2d: 二维最大/平均池化。
    • nn.MaxPool3d, nn.AvgPool3d: 三维最大/平均池化。
  4. 激活函数:

    • nn.ReLU: ReLU 激活函数。
    • nn.LeakyReLU: Leaky ReLU 激活函数。
    • nn.Sigmoid: Sigmoid 激活函数。
    • nn.Tanh: Tanh 激活函数。
    • nn.Softmax: Softmax 函数。
  5. 归一化层:

    • nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d: 批归一化。
    • nn.LayerNorm: 层归一化。
    • nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d: 实例归一化。
  6. 循环层和LSTM:

    • nn.RNN: 基本循环神经网络层。
    • nn.LSTM: 长短时记忆网络。
    • nn.GRU: 门控循环单元。
  7. 嵌入层:

    • nn.Embedding: 一个简单的查找表,用于存储固定大小的字典和嵌入大小,通常用于NLP任务中的词嵌入。
  8. 丢弃层:

    • nn.Dropout, nn.Dropout2d, nn.Dropout3d: 丢弃层,用于防止过拟合。
  9. 损失函数 (虽然它们是nn.Module的子类,但通常被认为是独立的组件):

    • nn.MSELoss: 均方误差损失。
    • nn.CrossEntropyLoss: 交叉熵损失。
    • nn.BCELoss: 二进制交叉熵损失。
    • nn.SmoothL1Loss: Huber Loss 或 Smooth L1 Loss。
  10. 容器:

    • nn.Sequential: 一个有序的容器,其中的模块按顺序执行。
    • nn.ModuleList: 将子模块保存在列表中的容器。
    • nn.ModuleDict: 将子模块保存在字典中的容器。

以上只是nn.Module的部分子类,PyTorch中还有许多其他特定用途和功能的模块。不过,上面列出的是最常用的一些模块,应该足以覆盖许多常见的神经网络设计和任务。

定义自己的网络类

对于具有分支、多输入/多输出或其他复杂连接的网络结构,需要直接继承 nn.Module 来定义自己的网络类,并在其中实现自定义的前向传播逻辑。

以下是一些示例,展示如何定义具有复杂结构的网络:

1. 具有分支的网络:

假设我们想定义一个网络,它有两个并行的分支,最后再将这两个分支的输出合并。

import torch.nn as nn

class BranchedNet(nn.Module):
def __init__(self):
super(BranchedNet, self).__init__()
self.branch1 = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 10)
)
self.branch2 = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 10)
)

def forward(self, x):
out1 = self.branch1(x)
out2 = self.branch2(x)
return out1 + out2 # 合并两个分支的输出

2. 具有多输入的网络:

假设我们有一个网络,它接受两个不同的输入,并将它们的信息合并为一个输出。

class MultiInputNet(nn.Module):
def __init__(self):
super(MultiInputNet, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(15, 20)
self.fc3 = nn.Linear(40, 10)

def forward(self, x1, x2):
out1 = self.fc1(x1)
out2 = self.fc2(x2)
combined = torch.cat((out1, out2), dim=1) # 合并两个输入的信息
return self.fc3(combined)

3. 具有多输出的网络:

假设我们有一个网络,它基于一个输入产生两个不同的输出。

class MultiOutputNet(nn.Module):
def __init__(self):
super(MultiOutputNet, self).__init__()
self.shared_layers = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
self.output1 = nn.Linear(20, 10)
self.output2 = nn.Linear(20, 5)

def forward(self, x):
shared_out = self.shared_layers(x)
out1 = self.output1(shared_out)
out2 = self.output2(shared_out)
return out1, out2 # 返回两个输出

这些示例展示了如何使用nn.Module来定义具有复杂结构的网络。通过继承nn.Module并实现自定义的前向传播逻辑,你可以创建任意复杂的网络结构。