1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| import torch import torch.nn as nn
class VAE(nn.Module): def __init__(self, input_dim, latent_dim): super(VAE, self).__init__() self.encoder = nn.Sequential( nn.Linear(input_dim, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 2* latent_dim), ) self.mu = nn.Linear(256, latent_dim) self.log_var = nn.Linear(256, latent_dim) self.decoder = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, input_dim), nn.Sigmoid(), ) def reparameterize(self, mu, log_var): ''' 实现重新参数化,从给定的均值和对数方差中采样 ''' std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) return mu * eps * std def forward(self, x): ''' 前向传播过程包括编码、采样和解码步骤 ''' z_params = self.encoder(x) mu = self.mu(z_params) log_var = self.log_var(z_params) z = self.reparameterize(mu, log_var) reconstructed_x = self.decoder(z) return reconstructed_x, mu, log_var
model = VAE(24, 32)
|