神经网络模块(PyTorch)

神经网络模块

神经网络模块,深度学习框架的核心力量。在深度学习领域,神经网络模块(Neural Network Module)扮演着至关重要的角色。它为构建和组织复杂的神经网络结构提供了抽象化的接口。以PyTorch为例,nn.Module作为基础类,允许用户通过继承来自定义自己的神经网络模型。

层的封装,将单个或多个神经网络层(如线性层、卷积层、激活函数等)组合在一起,形成具有层次结构的模块。参数管理:自动管理模型内部的所有可学习参数,包括权重和偏置等。这些参数在训练过程中被优化算法更新。前向传播:通过重写forward()方法来实现模型从输入到输出的计算逻辑。子模块嵌套:允许一个模块内部包含其他的nn.Module实例,构建深层次,多分支的复杂网络结构。状态保存与恢复:整个模块的状态(包括所有子模块的参数)可以方便的保存到磁盘并在需要时加载回来。损失函数集成:PyTorch中的nn库还包含了各种常用的损失函数,它们同样是nn.Module的实例,可以轻松应用在训练过程中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 定义一个简单的两层全连接神经网络模块
class TwoLayoutNet(nn.Module):
def __init__(self,input_size, hidden_size, num_classes):
super(TwoLayoutNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()

self.fc2 = nn.Linear(hidden_size, num_classes)

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

# 创建一个随机的小数据集, 假设我们有100个样本,每个样本5个特征
X = torch.randn(100, 5)
# 假设这是二分类问题,标签为0或1
y = torch.randint(0, 2, (100,))

#将数据转换为TensorDataSet并创建DataLoader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 初始化模型
model= TwoLayoutNet(36, 24, 100)