1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| import os from torch import optim, nn, utils, Tensor from torchvision.datasets import MNIST from torchvision.transforms import ToTensor import lightning as L
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
class LitAutoEncoder(L.LightningModule): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder
def training_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = nn.functional.mse_loss(x_hat, x) self.log("train_loss", loss) return loss
def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-3) return optimizer
autoencoder = LitAutoEncoder(encoder, decoder)
|