import tensorflow as tf
import numpy as np
X = tf.convert_to_tensor(np.concatenate([np.ones((6, 2)), np.zeros((6, 4)), np.ones((6, 2))], axis=1), dtype=tf.float64)
X
<tf.Tensor: shape=(6, 8), dtype=float64, numpy= array([[1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.], [1., 1., 0., 0., 0., 0., 1., 1.]])>
Y = tf.convert_to_tensor([[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.],
[ 0., 1., 0., 0., 0., -1., 0.]])
Y
<tf.Tensor: shape=(6, 7), dtype=float32, numpy= array([[ 0., 1., 0., 0., 0., -1., 0.], [ 0., 1., 0., 0., 0., -1., 0.], [ 0., 1., 0., 0., 0., -1., 0.], [ 0., 1., 0., 0., 0., -1., 0.], [ 0., 1., 0., 0., 0., -1., 0.], [ 0., 1., 0., 0., 0., -1., 0.]], dtype=float32)>
conv2d = tf.keras.layers.Conv2D(1, (1, 2), use_bias=False)
X = tf.reshape(X, (1, 6, 8, 1))
Y = tf.reshape(Y, (1, 6, 7, 1))
lr = 3e-2 # Learning rate
epochs = 20
test = conv2d(X)
for epoch in range(epochs):
with tf.GradientTape() as g:
g.watch(conv2d.weights[0])
Y_hat = conv2d(X)
loss = (abs(Y_hat - Y)) ** 2
update = tf.math.multiply(lr, g.gradient(loss, conv2d.weights[0]))
weights = conv2d.get_weights()
weights[0] = conv2d.weights[0] - update
conv2d.set_weights(weights)
print(f"epoch {epoch}, loss {tf.reduce_sum(loss)}")
epoch 0, loss 37.925758361816406 epoch 1, loss 19.48427963256836 epoch 2, loss 10.508689880371094 epoch 3, loss 5.922234535217285 epoch 4, loss 3.4611868858337402 epoch 5, loss 2.0803825855255127 epoch 6, loss 1.276240348815918 epoch 7, loss 0.7941816449165344 epoch 8, loss 0.4990140199661255 epoch 9, loss 0.3155749440193176 epoch 10, loss 0.20041388273239136 epoch 11, loss 0.127628356218338 epoch 12, loss 0.08142148703336716 epoch 13, loss 0.052002955228090286 epoch 14, loss 0.03323804587125778 epoch 15, loss 0.021254435181617737 epoch 16, loss 0.013595545664429665 epoch 17, loss 0.008698157966136932 epoch 18, loss 0.0055656046606600285 epoch 19, loss 0.003561462042853236
conv2d.get_weights()[0]
array([[[[ 0.9936623]], [[-1.0059878]]]], dtype=float32)