fig, axs = plt.subplots(ncols=5, nrows=2, figsize=(20, 10))
for ax in axs:
for ax_ in ax:
ax_.imshow(VAE.generate())
ax_.axis(False)
plt.show()
import torch
import os
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from torch import distributions
from torch import nn
import torchvision
import torchvision.transforms
device = torch.device('mps')
image_size = (256, 256)
batch_size = 64
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size),
torchvision.transforms.ToTensor(),
])
images = datasets.ImageFolder('CelebA', transform=transform)
dataset = DataLoader(images, batch_size=batch_size, shuffle=True, drop_last=True)
imshow(np.dstack(images[201][0].numpy()))
<matplotlib.image.AxesImage at 0x7f581b35f950>
class DenseBlock(nn.Module):
def __init__(self, input_channels, output_channels):
super(DenseBlock, self).__init__()
self.conv2d = nn.ModuleList(nn.Conv2d(i, j, 3, padding='same') for i, j in [
(input_channels, output_channels),
(output_channels, output_channels),
(input_channels, output_channels)])
self.activation = nn.GELU()
self.batch_norm = nn.ModuleList([nn.BatchNorm2d(output_channels) for _ in range(3)])
def forward(self, x):
inter = self.activation(self.batch_norm[0](self.conv2d[0](x)))
inter = self.batch_norm[1](self.conv2d[1](inter))
output = self.activation(self.batch_norm[2](self.conv2d[2](x)) + inter)
return output
class EncoderBlock(nn.Module):
def __init__(self, input_channels):
super(EncoderBlock, self).__init__()
self.max_pool2 = nn.MaxPool2d(2)
self.max_pool4 = nn.MaxPool2d(4)
self.dense = DenseBlock((input_channels//2)*3, input_channels*2)
def forward(self, x, save_state):
x = self.max_pool2(x)
output = torch.cat([x, save_state], axis=1)
output = self.dense(output)
return output, self.max_pool2(x)
class Encoder(nn.Module):
def __init__(self, num_encoder_blocks=5, input_channels=3, start_channels=16):
super(Encoder, self).__init__()
self.initial_dense = DenseBlock(input_channels, start_channels)
self.initial_state = DenseBlock(input_channels, start_channels//2)
self.max_pool2 = nn.MaxPool2d(2)
self.encoder_blocks = nn.ModuleList([EncoderBlock(start_channels << i) for i in range(num_encoder_blocks)])
self.flatten = nn.Flatten(start_dim=1)
self.sequential = nn.Sequential(nn.Linear(8192, 1024), nn.GELU(), nn.Linear(1024, 1024))
def forward(self, x):
x, prev_x = self.initial_dense(x), self.max_pool2(self.initial_state(x))
for encoder_block in self.encoder_blocks:
x, prev_x = encoder_block(x, prev_x)
x = self.flatten(self.max_pool2(x))
return self.sequential(x)
class DecoderBlock(nn.Module):
def __init__(self, input_channels):
super(DecoderBlock, self).__init__()
self.conv2d_transpose = nn.ConvTranspose2d(input_channels, input_channels, 2, stride=2)
self.conv2d_transpose_state = nn.ConvTranspose2d(input_channels, input_channels//4, 2, stride=2)
self.dense = DenseBlock((input_channels//2)*3, input_channels//2)
def forward(self, x, save_state):
x = self.conv2d_transpose(x)
output = torch.cat([x, save_state], axis=1)
output = self.dense(output)
return output, self.conv2d_transpose_state(x)
class Decoder(nn.Module):
def __init__(self, flattened_input_size=512, num_decoder_blocks=6, num_output_channels=3):
super(Decoder, self).__init__()
starting_channels = flattened_input_size
self.embedding = nn.Sequential(nn.Linear(flattened_input_size, flattened_input_size), nn.GELU(), nn.Linear(flattened_input_size, starting_channels*16))
self.initial_dense = DenseBlock(starting_channels, starting_channels)
self.initial_state = nn.ConvTranspose2d(starting_channels, starting_channels//2, 2, stride=2)
self.decoder_blocks = nn.ModuleList([DecoderBlock(starting_channels >> i) for i in range(num_decoder_blocks)])
self.finishing_touches = nn.Sequential(nn.Conv2d(starting_channels >> num_decoder_blocks, num_output_channels, 1), nn.Sigmoid())
def forward(self, x):
x = self.embedding(x)
x = x.reshape([x.shape[0], -1, 4, 4])
x, prev_x = self.initial_dense(x), self.initial_state(x)
for decoder_block in self.decoder_blocks:
x, prev_x = decoder_block(x, prev_x)
return self.finishing_touches(x)
class VAE(torch.nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.embedding_shape = 512
def encode(self, x):
output = self.encoder(x)
normal_values = torch.normal(torch.zeros(self.embedding_shape), torch.ones(self.embedding_shape)).to(device)
embedding_means, embedding_log_var = output[...,:self.embedding_shape], output[...,self.embedding_shape:]
output = embedding_means + (normal_values * torch.exp(embedding_log_var * 0.5))
embedding_stds = torch.exp(embedding_log_var * 0.5)
kl = (torch.square(embedding_stds) + torch.square(embedding_means) - (embedding_log_var / 2) - 0.5).mean()
return output, kl
def decode(self, embedding):
output = self.decoder(embedding)
return output
def forward(self, x):
embedding, kl = self.encode(x)
output = self.decode(embedding)
return output, kl
def generate(self):
with torch.no_grad():
embedding = torch.distributions.Normal(torch.zeros((batch_size, VAE.embedding_shape)), torch.ones((batch_size, VAE.embedding_shape))).sample().to(device)
generated_face = self.decode(embedding).cpu()
generated_face = torch.clip(generated_face, min=0, max=1).squeeze().numpy()
return np.dstack(generated_face[0])
VAE = VAE().to(device)
C = 1
MSE = torch.nn.MSELoss()
RMSE = lambda x, y: torch.sqrt(MSE(x, y).to(device))
from torch import optim
optimizer = optim.Adam(VAE.parameters(), lr=1e-3)
def criterion(x, y, kl):
return (RMSE(x, y) + (C * kl)).sum()
num_epochs = 50
training_losses = []
for epoch in range(1,num_epochs+1):
curr_loss = kl_loss = rmse_loss = n = 0
for i, batch in enumerate(tqdm(dataset), start=1):
optimizer.zero_grad()
batch = batch[0].to(device)
output, kl = VAE(batch)
loss = criterion(output, batch, kl)
loss.backward()
torch.nn.utils.clip_grad_norm_(VAE.parameters(), 1)
optimizer.step()
curr_loss += loss.item()
rmse_loss += RMSE(output, batch).item()
kl_loss += kl.item()
n += 1
print(f"Epoch {epoch}\t\tTraining Loss {curr_loss/n:.5f}\t\tRMSE {rmse_loss/n:.5f}\t\tKL Div {kl_loss/n:.5f}")
fig, ax = plt.subplots(ncols=2)
for ax_ in ax:
ax_.imshow(VAE.generate())
ax_.axis(False)
plt.show()
training_losses.append(curr_loss / n)
if not epoch % 5:
os.makedirs('models', exist_ok=True)
torch.save(VAE.state_dict(), f"models/{notebook_name}_{epoch}_{C:.1e}.pt")
os.makedirs('models', exist_ok=True)
torch.save(VAE.state_dict(), f"models/{notebook_name}.pt")
checkpoint = torch.load(f"models/VAE_40_1.0e+00.pt", map_location=device)
VAE.load_state_dict(checkpoint)
<All keys matched successfully>
fig, axs = plt.subplots(ncols=5, nrows=2, figsize=(20, 10))
for ax in axs:
for ax_ in ax:
ax_.imshow(VAE.generate())
ax_.axis(False)
plt.show()