1

I am new to pytorch and trying to implement a VAE for MNIST data. When I try to train my model, it appears that the model forces mu and logvar to zero (or something very close to zero) independent of the input. In a way it appears that it is failing to take into account the MSE part of the loss function, but I don't understand why.

Here's the complete code I am using:

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F


batch_size = 32

loc_data = 'MNIST'
transformations = transforms.ToTensor()
mnist_train =  datasets.MNIST(loc_data, train=True, download=True, transform = transformations)
mnist_test = datasets.MNIST(loc_data, train=False, download=True, transform = transformations)

train_loader = DataLoader(mnist_train,
                          batch_size=batch_size,
                          drop_last = True,
                          shuffle=True)
test_loader = DataLoader(mnist_test,
                         batch_size=batch_size,
                         drop_last = True,
                         shuffle=True)

class Encoder(nn.Module):
  def __init__(self, latent_dim=10):
    super(Encoder, self).__init__()
    self.latent_dim = latent_dim
    self._encoder = nn.Sequential(nn.Linear(in_features = 28*28, out_features = 512),
                                  nn.ReLU(),
                                  nn.Linear(in_features = 512, out_features = 2*latent_dim)
                                 )

  
  def forward(self, x):
    x = torch.reshape(self._encoder.forward(x), (-1, 2, self.latent_dim))
    mu, logvar = x[:,0,:], x[:,1,:]
    return mu, logvar


class Decoder(nn.Module):
  def __init__(self, latent_dim=10):
    super(Decoder, self).__init__()
    self.latent_dim = latent_dim
    self._decoder = nn.Sequential(nn.Linear(in_features = latent_dim, out_features = 512),
                                  nn.ReLU(),
                                  nn.Linear(in_features = 512, out_features = 28*28),
                                  nn.Sigmoid())

    
  def forward(self,x):
    return self._decoder.forward(x)

def sample(mu, logvar):
  z = torch.randn_like(mu)
  return mu + torch.mul(torch.exp(0.5*logvar), z)

def vae_loss(x, x_hat, mu, logvar):
  mse = (x - x_hat).pow(2).sum()/(x.shape[0]*1.0)
  KL_loss = 0.5*torch.sum(-1 + torch.pow(mu,2) - logvar + torch.exp(logvar))
  return torch.add(mse, KL_loss)

def train(encoder, decoder, train_loader, optimizer, num_epochs = 10):    
  encoder.train()
  decoder.train()
  for ii in range(num_epochs):
    print("Epoch {}".format(ii))
    for jj, (x, y) in enumerate(train_loader):
        x = torch.reshape(x, (-1,28*28))
        x.to(device)
        _mu, _logvar = encoder.forward(x)
        _z = sample(_mu, _logvar)
        x_hat = decoder.forward(_z) #.reshape((-1,28,28))
        optimizer.zero_grad()
        loss = vae_loss(x, x_hat, _mu, _logvar)
        loss.backward()
        optimizer.step()
        if jj % 100 == 0:
            print(loss)            
  return loss



latent_dim = 20
encoder = Encoder(latent_dim)
decoder = Decoder(latent_dim)
params = list(encoder.parameters())+list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train(encoder, decoder, train_loader, optimizer, num_epochs = 1)

when I try to probe the mu or logvar for some test data, it seems that the result is almost identically zero.

yorkiva
  • 11
  • 4

0 Answers0