Original Paper:
(Chen et al. 2018) https://proceedings.neurips.cc/paper_files/paper/2018/file/69386f6bb1dfed68692a24c8686939b9-Paper.pdf
c = 0
t0 = 0
t1 = 10
num_points = 100
z_initial = torch.Tensor([z(0, c)]).to(device)
plt.plot(np.linspace(t0, t1, num_points), [z(t, c).cpu() for t in np.linspace(t0, t1, num_points)], label='Original')
plt.plot(np.linspace(t0, t1, num_points), [model.get_z1(0, t, z_initial).cpu().numpy() for t in tqdm(np.linspace(t0, t1, num_points))], label='Predicted')
plt.title("Visualization of Model Performance")
plt.legend()
100%|████████████████████████| 100/100 [00:00<00:00, 179.93it/s]
<matplotlib.legend.Legend at 0x17ffe4d90>
$L(\mathbf{z}(t_1)) = L[\mathbf{z}(t_0) + \int_{t_0}^{t_1} f(\mathbf{z}(t), t, \theta)dt]$
Let $\mathbf{a}(t) = \frac{\partial L}{\partial \mathbf{z}(t)}$
$\frac{d \mathbf{a}}{dt}(t) = \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{1}{\epsilon} [ \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} - \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t))}]$
$ = \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{1}{\epsilon} [ \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} - \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t+\epsilon))} \frac{\partial \mathbf{z}(t+\epsilon)}{\partial(\mathbf{z}(t))}]$
$ = \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} [ \frac{1}{\epsilon}(1 - \frac{\partial \mathbf{z}(t+\epsilon)}{\partial(\mathbf{z}(t))})]$
$ = \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} [ \frac{1}{\epsilon}(1 - \frac{\partial (\mathbf{z}(t)+\epsilon f(\mathbf{z}(t), t, \theta)))}{\partial(\mathbf{z}(t))})]$
$ = \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} [ \frac{1}{\epsilon}(1 - \frac{\partial \mathbf{z}(t)}{\partial \mathbf{z}(t)} - \frac{\epsilon \partial (f(\mathbf{z}(t), t, \theta))}{\partial(\mathbf{z}(t))})]$
$ = - \frac{\partial L}{\partial \mathbf{z}(t_1)} lim_{\epsilon \to 0} \frac{\partial \mathbf{z}(t_1)}{\partial(\mathbf{z}(t + \epsilon))} [\frac{ \partial (f(\mathbf{z}(t), t, \theta))}{\partial(\mathbf{z}(t))}]$
$ = - \frac{\partial L}{\partial \mathbf{z}(t_1)} \frac{\partial \mathbf{z}(t_1)}{\partial \mathbf{z}(t)} \frac{ \partial (f(\mathbf{z}(t), t, \theta))}{\partial(\mathbf{z}(t))}$
$ = - \mathbf{a}(t)^T \frac{ \partial (f(\mathbf{z}(t), t, \theta))}{\partial(\mathbf{z}(t))}$
$\frac{\partial L}{\partial \theta} = \int_{t_0}^{t_1} \frac{\partial L}{\partial \mathbf{z}(t)} \frac{\partial \mathbf{z}(t)}{\partial \theta}dt$
$ = \int_{t_1}^{t_0} -\mathbf{a}(t)^T \frac{ \partial f(\mathbf{z}(t), t, \theta)}{\partial(\mathbf{z}(t))} \frac{\partial \mathbf{z}(t)}{\partial \theta}dt$
$\frac{\partial L}{\partial \theta} = \int_{t_1}^{t_0} -\mathbf{a}(t)^T \frac{ \partial f(\mathbf{z}(t), t, \theta)}{\partial \theta}dt$
This function will try to approximate the solution to
$\frac{dz}{dt}(z_0, t) = z e^{-cos(t)}$
which is
$z(z_0, t) = e^{-cos(t)}+c$
From the Original Paper: (Chen et al. 2018)
import torch
from torch import nn
import torch.optim
from tqdm import tqdm
from scipy.integrate import solve_ivp, simpson
import warnings
import matplotlib.pyplot as plt
import numpy as np
device = torch.device('cpu')
def z(t, c):
if type(t) != torch.Tensor:
t = torch.Tensor([t]).to(device)
if not len(t.shape):
t = t.unsqueeze(axis=0)
return torch.exp(-torch.cos(t)+1) + c
class ODENet(nn.Module):
def __init__(self):
super(ODENet, self).__init__()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
self.f = nn.Sequential(
nn.Linear(2, 50),
nn.GELU(),
nn.Linear(50, 75),
nn.GELU(),
nn.Linear(75, 1)
)
self.loss_fn = nn.MSELoss()
self.num_parameters = sum([np.prod(x.shape) for x in self.f.parameters()])
def forward(self, t, z):
if type(z) != torch.Tensor:
z = torch.Tensor(z)
z = z.to(device)
input_ = torch.cat([z, torch.Tensor([t]).to(device)])
dz = self.f(input_)
return dz
def forward_detached(self, t, z):
dz = self.forward(t, z).cpu().detach().numpy()
return dz
def get_z1(self, t0, t1, z0):
z0 = z0.cpu().numpy()
return torch.Tensor(solve_ivp(self.forward_detached, [t0, t1], z0, method='LSODA').y[:,-1]).to(device)
def loss(self, z1_pred, z1_true):
return self.loss_fn(z1_pred, z1_true)
def get_grad_cat(self):
grad = []
for layer in self.f:
if hasattr(layer, 'weight'):
grad.append(layer.weight.grad.flatten())
grad.append(layer.bias.grad.flatten())
return torch.cat(grad)
def aug(self, t, x):
self.zero_grad()
x = torch.Tensor(x).to(device)
z, a, _ = torch.split(x, [1, 1,x.shape[0]-2])
z.requires_grad = True
ft = self.forward(t, z)
ft.backward()
partial_f_partial_z = z.grad
partial_f_partial_theta = self.get_grad_cat()
da_dt = -a*partial_f_partial_z
partial_L_partial_theta = -a * partial_f_partial_theta
return torch.cat([ft, da_dt, partial_L_partial_theta]).cpu().detach().numpy()
def update_grad(self, t0, t1, z0, true_z1):
if type(z0) != torch.Tensor:
z0 = torch.Tensor(z0)
pred_z1 = self.get_z1(t0, t1, z0)
pred_z1.requires_grad = True
loss = self.loss(pred_z1, true_z1)
loss.backward()
partial_L_partial_z1 = pred_z1.grad
pred_z1.requires_grad = False
partial_L_partial_z1.requires_grad = False
x1 = torch.cat([pred_z1.cpu(), partial_L_partial_z1.cpu(), torch.zeros(self.num_parameters)])
x0 = torch.Tensor(solve_ivp(self.aug, [t1, t0], x1, method='LSODA').y[:,-1]).to(device)
_, _, grad = torch.split(x0, [1, 1,x0.shape[0]-2])
start_index=0
for layer in self.f:
if hasattr(layer, 'weight'):
n = layer.weight.flatten().shape[0]
layer.weight.grad = grad[start_index:start_index+n].reshape(layer.weight.grad.shape)
start_index += n
n = layer.bias.flatten().shape[0]
layer.bias.grad = grad[start_index:start_index+n].reshape(layer.bias.grad.shape)
start_index += n
return loss.item(), pred_z1
model = ODENet().to(device)
optimizer = torch.optim.Adam(model.parameters())
num_epochs = 100
training_loss = []
for epoch in tqdm(range(num_epochs)):
curr_loss = n = 0
for t0 in torch.rand(100).to(device):
t0 *= 10
t0 -= 1
t1 = t0 + 1
c = torch.zeros(1).to(device)
true_z0 = z(t0, c)
true_z1 = z(t1, c)
optimizer.zero_grad()
curr_loss_, pred_z1 = model.update_grad(t0, t1, true_z0, true_z1)
curr_loss += curr_loss_
n += 1
optimizer.step()
training_loss.append(curr_loss / n)
# print(f"Epoch {epoch} Loss: {curr_loss / n}")
plt.plot(np.arange(1, len(training_loss)+1), training_loss)
plt.title('ODENet Training Loss')
plt.ylabel('MSE Loss')
plt.xlabel('Epoch')
100%|█████████████████████████| 100/100 [11:44<00:00, 7.05s/it]
Text(0.5, 0, 'Epoch')
c = 0
t0 = 0
t1 = 10
num_points = 100
z_initial = torch.Tensor([z(0, c)]).to(device)
plt.plot(np.linspace(t0, t1, num_points), [z(t, c).cpu() for t in np.linspace(t0, t1, num_points)], label='Original')
plt.plot(np.linspace(t0, t1, num_points), [model.get_z1(0, t, z_initial).cpu().numpy() for t in tqdm(np.linspace(t0, t1, num_points))], label='Predicted')
plt.title("Visualization of Model Performance")
plt.legend()
100%|████████████████████████| 100/100 [00:00<00:00, 179.93it/s]
<matplotlib.legend.Legend at 0x17ffe4d90>