Let $X$ be the input we are trying to model
Let $Y$ be a latent representation of $X$, which follows a standard distribution
We want to minimize $p_X(x)$ in terms of $p_Y(y)$
Let $f$ be an invertible function, where $g = f^{-1}$, and $x = g(y)$ for all $x$
plot_model()
import torch
from torch import nn
class CustomCouplingLayer(nn.Module):
def __init__(self, n_dim=10000):
super(CustomCouplingLayer, self).__init__()
self.activation = nn.GELU()
self.nonlinear_x = nn.Sequential(
nn.Linear(1, n_dim),
# nn.Dropout(p=0.1),
# nn.BatchNorm1d(n_dim),
self.activation,
nn.Linear(n_dim, 2)
)
self.nonlinear_y_prime = nn.Sequential(
nn.Linear(1, n_dim),
# nn.Dropout(p=0.1),
# nn.BatchNorm1d(n_dim),
self.activation,
nn.Linear(n_dim, 2)
)
def forward(self, x, y, print_=False):
nonlinear_x = self.nonlinear_x(x)
y_prime = (y*nonlinear_x[...,:1])+nonlinear_x[...,1:]
nonlinear_y_prime = self.nonlinear_y_prime(y_prime)
x_prime = (x*nonlinear_y_prime[...,:1])+nonlinear_y_prime[...,1:]
if print_:
print(nonlinear_x[...,:1])
print(nonlinear_y_prime[...,:1])
return x_prime, y_prime, -torch.log(torch.abs(nonlinear_x[...,:1])).sum(axis=-1)-torch.log(torch.abs(nonlinear_y_prime[...,:1])).sum(axis=-1)
def reverse(self, x_prime, y_prime, print_=False):
nonlinear_y_prime = self.nonlinear_y_prime(y_prime)
x = (x_prime - nonlinear_y_prime[...,1:]) / nonlinear_y_prime[...,:1]
nonlinear_x = self.nonlinear_x(x)
y = (y_prime - nonlinear_x[...,1:]) / nonlinear_x[...,:1]
return x, y
class NormFlow(nn.Module):
def __init__(self):
super(NormFlow, self).__init__()
self.num_couplings = 5
self.couplings = nn.ModuleList([CustomCouplingLayer() for _ in range(self.num_couplings)])
def forward(self, x, print_=False):
x = x.to(device)
logp_sum = 0
x, y = x.clone(), x.clone()
for coupling in self.couplings:
x, y, logp = coupling(x, y, print_)
logp_sum += logp
# Will return standard distribution
return torch.cat([x, y], dim=-1), logp_sum
def reverse(self, y, print_=False):
x, y = y[...,:y.shape[-1]//2], y[...,y.shape[-1]//2:]
for coupling in self.couplings[::-1]:
x, y = coupling.reverse(x, y)
# Will return observed distribution
return x, y
def generate(self, n):
y = torch.normal(mean=torch.zeros((n, 2)), std=torch.ones((n, 2))).to(device)
return self.reverse(y)[1].squeeze()
def negative_log_prob(self, y):
return (torch.square(y) / 2).sum(axis=-1) + ((y.shape[-1] / 2)*np.log(2*np.pi))
def negative_log_loss(self, x):
y, log_p = self.forward(x)
return (self.negative_log_prob(y) + log_p).mean()
device = torch.device("mps")
model = NormFlow().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-7)
num_training_examples = 1000000
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
num_epochs = 10
plot_every = 5
def plot_model():
fig, ax = plt.subplots(1, 2)
gen_size = 1000
generated = model.generate(gen_size).detach().cpu().numpy()
examples = torch.rand((gen_size, 1)).numpy()
# bins = np.linspace(min(generated.min(), 0), max(generated.max(), 1), 20)
bins = np.linspace(0, 1, 20)
ax[0].hist(examples, color='blue', label='observed', bins=bins)
ax[0].hist(generated, color='orange', label='model', bins=bins)
ax[0].legend()
ax[1].hist(generated, color='orange', label='model', bins=bins)
ax[1].hist(examples, color='blue', label='observed', bins=bins)
ax[1].legend()
plt.show()
training_loss = []
for epoch in range(1, num_epochs+1):
curr_loss = n = 0
examples = torch.rand((num_training_examples, 1))
dataloader = torch.utils.data.DataLoader(examples, batch_size=1000, shuffle=True)
for i, example in enumerate(tqdm(dataloader)):
loss = model.negative_log_loss(example)
optimizer.zero_grad()
loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
curr_loss += loss.item()
n += 1
# if not i % 200:
# print(curr_loss / n)
# plot_model()
training_loss.append(curr_loss / n)
if not epoch % plot_every:
plot_model()
print(f"Epoch {epoch}: Loss={training_loss[-1]}")
plt.plot(np.arange(num_epochs), training_loss)
0%| | 0/1000 [00:00<?, ?it/s]/Users/ginoprasad/miniconda3/envs/torch-gpu/lib/python3.9/site-packages/torch/autograd/__init__.py:200: UserWarning: The operator 'aten::sgn.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 100%|███████████████████████| 1000/1000 [02:30<00:00, 6.66it/s] 100%|███████████████████████| 1000/1000 [02:30<00:00, 6.66it/s] 100%|███████████████████████| 1000/1000 [02:30<00:00, 6.62it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.61it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.60it/s]
Epoch 5: Loss=-4.077750432610512
100%|███████████████████████| 1000/1000 [02:31<00:00, 6.62it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.61it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.61it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.60it/s] 100%|███████████████████████| 1000/1000 [02:31<00:00, 6.60it/s]
Epoch 10: Loss=-4.935740837678313
[<matplotlib.lines.Line2D at 0x156c27fd0>]