transformer.eval()
print("English Prompt".ljust(20) + "Predicted French Translation".ljust(20))
for phrase in phrases:
tokens = torch.Tensor(list(map(src_vocab.index, phrase))).type(torch.long).unsqueeze(axis=1)
length = torch.Tensor([len(tokens)]).type(torch.long)
generated_sequence = transformer.generate(tokens, length, 10)
generated_sequence = generated_sequence[:generated_sequence.index('<eos>')]
print(f"{' '.join(phrase[:-1]).ljust(20)}{' '.join(generated_sequence).ljust(20)}")
English Prompt Predicted French Translation go . va ! i lost . j'ai perdu . i'm home . je suis chez moi .
from requests import get
notebook_name = 'transformer'
notebook_name
'transformer'
import torchtext
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
import io
from tqdm import tqdm
from unidecode import unidecode
from torch.distributions import Categorical
import numpy as np
import pickle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
src_vocab = open(f'data/src_vocab.txt').read().split('\n')
tgt_vocab = open(f'data/tgt_vocab.txt').read().split('\n')
src_array = torch.Tensor(np.load('data/english.npy')).type(torch.long)
tgt_array = torch.Tensor(np.load('data/french.npy')).type(torch.long)
PAD_IDX_SRC = src_vocab.index('<pad>')
PAD_IDX_TGT = src_vocab.index('<pad>')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_src_array, val_src_array, train_tgt_array, val_tgt_array = train_test_split(src_array, tgt_array)
train_data = [(src, tgt) for src, tgt in zip(train_src_array, train_tgt_array)]
val_data = [(src, tgt) for src, tgt in zip(val_src_array, val_tgt_array)]
BATCH_SIZE = 25
def generate_batch(data_batch):
src_batch, tgt_batch = [], []
valid_lengths = []
for (src_item, tgt_item) in data_batch:
src_batch.append(src_item)
tgt_batch.append(tgt_item)
valid_lengths.append((src_array[0] != PAD_IDX_SRC).sum().item())
src_batch, tgt_batch = map(lambda x: torch.stack(x, dim=1), (src_batch, tgt_batch))
return src_batch, tgt_batch, torch.Tensor(valid_lengths)
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=generate_batch)
I used this d2l article for reference
src_vocab_size, tgt_vocab_size = len(src_vocab), len(tgt_vocab)
embedding_dim = 500
num_GRU_layers = 1
num_GRU_units = 1000
attention_units = 500
context_units = 500
decoder_hidden_units = num_GRU_units
decoder_final_units = tgt_vocab_size
dropout_prob = 0.2
class Transformer(torch.nn.Module):
def __init__(self):
super(Transformer, self).__init__()
self.dropout = torch.nn.Dropout(dropout_prob)
self.softmax = torch.nn.Softmax(dim=2)
self.src_embedding = torch.nn.Embedding(src_vocab_size, embedding_dim, device=device)
self.src_embedding = torch.nn.Embedding(tgt_vocab_size, embedding_dim, device=device)
self.encoder = torch.nn.GRU(embedding_dim, num_GRU_units, num_GRU_layers, batch_first=True, device=device)
self.decoder_hidden = torch.nn.GRU(num_GRU_units+embedding_dim, decoder_hidden_units, num_GRU_layers, batch_first=True, device=device)
self.decoder_final = torch.nn.Linear(decoder_hidden_units, decoder_final_units, device=device)
self.query_layer = torch.nn.Linear(num_GRU_units, attention_units, bias=True, device=device)
self.key_layer = torch.nn.Linear(num_GRU_units, attention_units, bias=True, device=device)
self.attention_layer = torch.nn.Linear(attention_units, 1, bias=True, device=device)
self.tanh = torch.nn.Tanh()
def masked_softmax(self, attention, valid_lengths, value=-1e6):
max_len = attention.shape[-1]
mask = torch.arange(max_len, device=device)[None, :] >= valid_lengths[:,None]
mask = mask.unsqueeze(1)
attention[mask] = -1e6
return self.softmax(attention)
def predict_next(self, encodings, state, valid_lengths, prev):
attention = self.tanh(self.query_layer(encodings) + self.key_layer(state))
attention = self.attention_layer(attention)
attention = attention.permute([0, 2, 1])
attention = self.masked_softmax(attention, valid_lengths)
context = torch.bmm(self.dropout(attention), encodings)
embedding = self.src_embedding(prev.unsqueeze(1))
context = torch.cat([context, embedding], axis=-1)
output, state = self.decoder_hidden(context)
return (output, state)
def forward(self, src_batch, tgt_batch, valid_lengths):
src_batch, tgt_batch, valid_lengths = map(lambda x: x.to(device), (src_batch, tgt_batch, valid_lengths))
src_batch = src_batch.permute([1, 0])
src_embedding_batch = self.src_embedding(src_batch)
encodings, state = self.encoder(src_embedding_batch)
state = state.permute([1, 0, 2])
outputs = []
for prev in tgt_batch:
output, state = self.predict_next(encodings, state, valid_lengths, prev)
state = state.permute([1, 0, 2])
outputs.append(output)
outputs = torch.cat(outputs, axis=1)
predictions = self.decoder_final(outputs)
return self.softmax(predictions)
def generate(self, src_batch, valid_lengths, generate_length):
src_batch, valid_lengths = map(lambda x: x.to(device), (src_batch, valid_lengths))
src_batch = src_batch.permute([1, 0])
src_embedding_batch = self.src_embedding(src_batch)
encodings, state = self.encoder(src_embedding_batch)
state = state.permute([1, 0, 2])
outputs = []
prev = torch.Tensor([tgt_vocab.index('<bos>')]).type(torch.long).to(device)
for _ in range(generate_length):
output, state = self.predict_next(encodings, state, valid_lengths, prev)
state = state.permute([1, 0, 2])
prev = torch.argmax(torch.nn.Softmax(dim=-1)(self.decoder_final(output.unsqueeze(axis=0)))).unsqueeze(axis=0)
outputs.append(tgt_vocab[prev.item()])
return outputs
transformer = Transformer()
import torch.optim as optim
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0005)
def criterion(predictions, tgt_tokens):
one_hot_tgt_tokens = torch.nn.functional.one_hot(tgt_tokens[:,1:], len(tgt_vocab))
one_hot_tgt_tokens *= (tgt_tokens[:,1:] != tgt_vocab.index('<pad>')).unsqueeze(axis=-1)
loss = (-torch.log(predictions[:,:-1][one_hot_tgt_tokens.type(torch.bool)]))
return loss.sum() / len(loss)
train, print_ = True, False
if train:
num_epochs = 100
train_loss, val_loss = [], []
for epoch in tqdm(range(1, num_epochs+1)):
curr_loss = n = 0
for batch in train_iter:
predictions = transformer(*batch)
optimizer.zero_grad()
loss = criterion(predictions, batch[1].permute([1, 0]))
loss.backward()
torch.nn.utils.clip_grad_norm_(transformer.parameters(), 1)
optimizer.step()
curr_loss += loss.item()
n += 1
train_loss.append(curr_loss / n)
curr_loss = n = 0
with torch.no_grad():
for batch in valid_iter:
predictions = transformer(*batch)
loss = criterion(predictions, batch[1].permute([1, 0]))
curr_loss += loss.item()
n += 1
val_loss.append(curr_loss / n)
if print_:
print(f"Epoch {epoch}\t\tTraining Loss {train_loss[-1]}\t\tValidation Loss {val_loss[-1]}")
torch.save(transformer.state_dict(), f"models/{notebook_name}.pt")
else:
checkpoint = torch.load(f'models/{notebook_name}.pt')
transformer.load_state_dict(checkpoint)
100%|█████████████████████████| 100/100 [07:59<00:00, 4.80s/it]
import matplotlib.pyplot as plt
plt.plot(range(1, len(train_loss)+1), train_loss, label='Training Loss')
plt.plot(range(1, len(val_loss)+1), val_loss, label='Validation Loss')
plt.ylabel('Cross Entropy')
plt.xlabel('Epoch')
Text(0.5, 0, 'Epoch')
phrases = [
"go . <eos>".split(' '),
"i lost . <eos>".split(' '),
"i'm home . <eos>".split(' ')
]
transformer.eval()
print("English Prompt".ljust(20) + "Predicted French Translation".ljust(20))
for phrase in phrases:
tokens = torch.Tensor(list(map(src_vocab.index, phrase))).type(torch.long).unsqueeze(axis=1)
length = torch.Tensor([len(tokens)]).type(torch.long)
generated_sequence = transformer.generate(tokens, length, 10)
generated_sequence = generated_sequence[:generated_sequence.index('<eos>')]
print(f"{' '.join(phrase[:-1]).ljust(20)}{' '.join(generated_sequence).ljust(20)}")
English Prompt Predicted French Translation go . va ! i lost . j'ai perdu . i'm home . je suis chez moi .