神经网络模块
神经网络模块,深度学习框架的核心力量。在深度学习领域,神经网络模块(Neural Network Module)扮演着至关重要的角色。它为构建和组织复杂的神经网络结构提供了抽象化的接口。以PyTorch为例,nn.Module作为基础类,允许用户通过继承来自定义自己的神经网络模型。
层的封装,将单个或多个神经网络层(如线性层、卷积层、激活函数等)组合在一起,形成具有层次结构的模块。参数管理:自动管理模型内部的所有可学习参数,包括权重和偏置等。这些参数在训练过程中被优化算法更新。前向传播:通过重写forward()方法来实现模型从输入到输出的计算逻辑。子模块嵌套:允许一个模块内部包含其他的nn.Module实例,构建深层次,多分支的复杂网络结构。状态保存与恢复:整个模块的状态(包括所有子模块的参数)可以方便的保存到磁盘并在需要时加载回来。损失函数集成:PyTorch中的nn库还包含了各种常用的损失函数,它们同样是nn.Module的实例,可以轻松应用在训练过程中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset
class TwoLayoutNet(nn.Module): def __init__(self,input_size, hidden_size, num_classes): super(TwoLayoutNet, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out
X = torch.randn(100, 5)
y = torch.randint(0, 2, (100,))
dataset = TensorDataset(X, y) dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model= TwoLayoutNet(36, 24, 100)
|