import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class InfoVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha: float = -0.5,
beta: float = 5.0,
reg_weight: int = 100,
kernel_type: str = 'imq',
latent_var: float = 2.,
**kwargs) -> None:
super(InfoVAE, self).__init__()
self.latent_dim = latent_dim
self.reg_weight = reg_weight
self.kernel_type = kernel_type
self.z_var = latent_var
assert alpha <= 0, 'alpha must be negative or zero.'
self.alpha = alpha
self.beta = beta
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
for i in range(len(hidden_dims) - 1):
hidden_dims[i + 1],
stride = 2,
nn.BatchNorm2d(hidden_dims[i + 1]),
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
def encode(self, input: Tensor) -> List[Tensor]:
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
Reparameterization trick to sample from N(mu, var) from
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, z, mu, log_var]
def loss_function(self,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z = args[2]
mu = args[3]
log_var = args[4]
batch_size = input.size(0)
bias_corr = batch_size * (batch_size - 1)
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
loss = self.beta * recons_loss + \
(1. - self.alpha) * kld_weight * kld_loss + \
(self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}
def compute_kernel(self,
x1: Tensor,
x2: Tensor) -> Tensor:
# Convert the tensors into row and column vectors
D = x1.size(1)
N = x1.size(0)
x1 = x1.unsqueeze(-2) # Make it into a column tensor
x2 = x2.unsqueeze(-3) # Make it into a row tensor
Usually the below lines are not required, especially in our case,
but this is useful when x1 and x2 have different sizes
along the 0th dimension.
x1 = x1.expand(N, N, D)
x2 = x2.expand(N, N, D)
if self.kernel_type == 'rbf':
result = self.compute_rbf(x1, x2)
elif self.kernel_type == 'imq':
result = self.compute_inv_mult_quad(x1, x2)
raise ValueError('Undefined kernel type.')
return result
def compute_rbf(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
Computes the RBF Kernel between x1 and x2.
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
z_dim = x2.size(-1)
sigma = 2. * z_dim * self.z_var
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
return result
def compute_inv_mult_quad(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
given by
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
z_dim = x2.size(-1)
C = 2 * z_dim * self.z_var
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))
# Exclude diagonal elements
result = kernel.sum() - kernel.diag().sum()
return result
def compute_mmd(self, z: Tensor) -> Tensor:
# Sample from prior (Gaussian) distribution
prior_z = torch.randn_like(z)
prior_z__kernel = self.compute_kernel(prior_z, prior_z)
z__kernel = self.compute_kernel(z, z)
priorz_z__kernel = self.compute_kernel(prior_z, z)
mmd = prior_z__kernel.mean() + \
z__kernel.mean() - \
2 * priorz_z__kernel.mean()
return mmd
def sample(self,
current_device: int, **kwargs) -> Tensor:
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of sa