Gradient example
This example will show how to use an equivariant UNet to learn the gradient of a scalar field discretized to an image. The input will be a scalar image, the output will be a vector image, and the intermediate multi images will have both scalar and vector images. This will necessitate scalar, vector, and 2-tensor filters to map between the intermediate multi images. We start by specifying what GPUs to use, and importing packages.
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
import matplotlib.pyplot as plt
from typing import Any, Optional, Self, Union
import jax
import jax.numpy as jnp
import jax.random as random
from jaxtyping import ArrayLike
import equinox as eqx
import optax
import ginjax.geometric as geom
import ginjax.ml as ml
import ginjax.models as models
env: CUDA_DEVICE_ORDER=PCI_BUS_ID env: CUDA_VISIBLE_DEVICES=0
Next we generate our data. Any input scalar field will suffice, but to make them look nice and easily calculate the derivatives, we will use 3rd degree polynomials. We will calculate the scalar field in the box [-1.5,1.5] for both x and y coordinates, and discretize the field onto a 32 by 32 pixel image. The polynomial coefficients will be sampled from a normal distribution with mean 0 and standard deviation 0.4. We calculate the gradient of the polynomial exactly, then do the same discretization to get the gradient image. After finding the gradient image, we remove two rows and columns of pixels from each side to get a 28 by 28 image because calculating the gradient from only scalar image requires knowledge of the surrounding pixels.
def gen_scalar_gradient_imgs(
N: int,
border_pixels: int,
coeffs: jax.Array,
x_lims: tuple[float, float],
y_lims: tuple[float, float],
) -> tuple[jax.Array, jax.Array]:
# Dimension D=2
x, y = jnp.meshgrid(jnp.linspace(*x_lims, N), jnp.linspace(*y_lims, N), indexing="ij")
ones = jnp.ones(x.shape)
x2 = x**2
xy = x * y
y2 = y**2
x3 = x**3
x2y = (x**2) * y
xy2 = x * (y**2)
y3 = y**3
# library is [1, x, y, x^2, xy, y^2, x^3, x^2y, xy^2, y^3]
library = jnp.stack([ones, x, y, x2, xy, y2, x3, x2y, xy2, y3])
scalar_img = jnp.sum((coeffs * library), axis=0)
assert scalar_img.shape == (N, N)
# dx is [0, 1, 0, 2x, y, 0, 3x2, 2xy, y2, 0]
# dy is [0, 0, 1, 0, x, 2y, 0, x2, x 2y, 3y2]
zeros = jnp.zeros(x.shape)
dx_library = jnp.stack([zeros, ones, zeros, 2 * x, y, zeros, 3 * x2, 2 * xy, y2, zeros])
dy_library = jnp.stack([zeros, zeros, ones, zeros, x, 2 * y, zeros, x2, 2 * xy, 3 * y2])
dx_sum = jnp.sum((coeffs * dx_library), axis=0)
dy_sum = jnp.sum((coeffs * dy_library), axis=0)
gradient_img = jnp.stack([dx_sum, dy_sum], axis=2)[border_pixels:-border_pixels,border_pixels:-border_pixels]
assert gradient_img.shape == (N-(border_pixels*2), N-(border_pixels*2), 2)
return scalar_img, gradient_img
def gen_images_batch(
n_images: int,
N: int,
border_pixels: int,
x_lims: tuple[float, float],
y_lims: tuple[float, float],
coeffs_mean_std: tuple[float, float],
key: Any,
):
vmap_gen_images = jax.vmap(gen_scalar_gradient_imgs, in_axes=(None, None, 0, None, None))
coeffs = random.normal(key, shape=(n_images, 10, 1, 1)) * coeffs_mean_std[1] + coeffs_mean_std[0]
return vmap_gen_images(N, border_pixels, coeffs, x_lims, y_lims)
def gen_data(
N: int,
border_pixels: int,
n_train: int,
n_val: int,
n_test: int,
x_lims: tuple[float, float],
y_lims: tuple[float, float],
coeffs_mean_std: tuple[float, float],
key: ArrayLike,
normalize: bool = True,
) -> tuple[
geom.MultiImage,
geom.MultiImage,
geom.MultiImage,
geom.MultiImage,
geom.MultiImage,
geom.MultiImage,
]:
key, subkey1, subkey2, subkey3 = random.split(key, num=4)
train_x, train_y = gen_images_batch(n_train, N, border_pixels, x_lims, y_lims, coeffs_mean_std, subkey1)
val_x, val_y = gen_images_batch(n_val, N, border_pixels, x_lims, y_lims, coeffs_mean_std, subkey2)
test_x, test_y = gen_images_batch(n_test, N, border_pixels, x_lims, y_lims, coeffs_mean_std, subkey3)
if normalize:
scalar_mean = jnp.mean(jnp.concatenate([train_x, val_x]))
scalar_std = jnp.std(jnp.concatenate([train_x, val_x]))
train_x = (train_x - scalar_mean) / scalar_std
val_x = (val_x - scalar_mean) / scalar_std
test_x = (test_x - scalar_mean) / scalar_std
grad_std = jnp.std(jnp.concatenate([train_y, val_y]))
train_y = train_y / grad_std
val_y = val_y / grad_std
test_y = test_y / grad_std
return (
geom.MultiImage({(0, 0): jnp.expand_dims(train_x, axis=1)}, 2, is_torus=False),
geom.MultiImage({(1, 0): jnp.expand_dims(train_y, axis=1)}, 2, is_torus=False),
geom.MultiImage({(0, 0): jnp.expand_dims(val_x, axis=1)}, 2, is_torus=False),
geom.MultiImage({(1, 0): jnp.expand_dims(val_y, axis=1)}, 2, is_torus=False),
geom.MultiImage({(0, 0): jnp.expand_dims(test_x, axis=1)}, 2, is_torus=False),
geom.MultiImage({(1, 0): jnp.expand_dims(test_y, axis=1)}, 2, is_torus=False),
)
# data params
D = 2 # image dimension
N = 32 # image side length
border_pixels = 2
n_train = 128
n_val = 32
n_test = 128
domain_lims = (-1.5, 1.5)
coeff_mean_std = (0, 0.4)
key = random.PRNGKey(0)
key, subkey = random.split(key)
train_x, train_y, val_x, val_y, test_x, test_y = gen_data(
N,
border_pixels,
n_train,
n_val,
n_test,
domain_lims,
domain_lims,
coeff_mean_std,
subkey,
)
print(train_x)
print(train_y)
<class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (False, False) (0, 0): (128, 1, 32, 32) <class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (False, False) (1, 0): (128, 1, 28, 28, 2)
Lets take a look at an image to see if they look correct.
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6 * 2, 6 * 1))
train_x.get_one().to_images()[0].plot(axes[0], 'Scalar Image')
train_y.get_one().to_images()[0].plot(axes[1], 'Gradient Image')
Next we define our map_and_loss function. To use our training scripts, this must take 4 arguments: the model, the input multi image, the output target multi image, and aux data that would be used by layers like BatchNorm. Likewise, it must return the loss and the aux data. We use the sum of mean squared error loss which takes the sum over the tensor and channel indices, and the mean over the spatial dimensions and batch.
def map_and_loss(
model: models.MultiImageModule,
multi_image_x: geom.MultiImage,
multi_image_y: geom.MultiImage,
aux_data: Optional[eqx.nn.State] = None,
) -> tuple[jax.Array, Optional[eqx.nn.State]]:
pred_y, aux_data = jax.vmap(model, in_axes=(0,None), out_axes=(0,None))(multi_image_x, aux_data)
return ml.smse_loss(pred_y, multi_image_y), aux_data
Next, we define our model. We will use an equivariant UNet defined in ginjax.models. It requires the invariant basis of filters for downsampling and upsampling. We will specify 2 downsampling blocks, with 2 convolutions per block with 32 channels each. This model is more than enough firepower for this simple problem. Again, the input is a scalar image, the output image is a vector image, and the intermediate multi images will have both scalar images and vectors iamges. We modify the UNet slightly by reducing the output image by 2 rows and columns from each side (controlled by the border_pixels parameter). We also create a similar non-equivariant model to compare with.
class ModifiedUNet(models.MultiImageModule):
_unet: models.UNet
b: int
def __init__(self: Self, border_pixels: int, unet: models.UNet):
self._unet = unet
self.b = border_pixels
def __call__(
self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
unet_out, aux_data = self._unet(x, aux_data)
assert isinstance(unet_out, geom.MultiImage)
out = unet_out.empty()
for (k, parity), image in unet_out.items():
out.append(k, parity, image[:,self.b:-self.b,self.b:-self.b])
return out, aux_data
operators = geom.make_all_operators(D)
conv_filters = geom.get_invariant_filters([3], [0,1,2], [0,1], D, operators)
upsample_filters = geom.get_invariant_filters([2], [0,1,2], [0,1], D, operators)
key, subkey1, subkey2 = random.split(key, num=3)
equiv_unet = ModifiedUNet(
border_pixels,
models.UNet(
D,
train_x.get_signature(),
train_y.get_signature(),
depth=32,
num_downsamples=2,
num_conv=2,
use_bias='auto',
activation_f=jax.nn.gelu,
equivariant=True,
conv_filters=conv_filters,
upsample_filters=upsample_filters,
key=subkey1,
),
)
vanilla_unet = ModifiedUNet(
border_pixels,
models.UNet(
D,
train_x.get_signature(),
train_y.get_signature(),
depth=32,
num_downsamples=2,
num_conv=2,
use_bias=True,
activation_f=jax.nn.gelu,
equivariant=False,
kernel_size=3,
key=subkey2,
),
)
Finally, we specify our training hyper-parameters and call our training loop. The epochs, batch size, learning rate, and optimizer could likely be further tuned, but the choices below are sufficient for this simple problem.
# training_params
epochs = 10
batch_size = 8
model_list = [
('Equivariant UNet', equiv_unet, 1e-3),
('Vanilla UNet', vanilla_unet, 2e-3),
]
trained_models = []
for model_name, model, lr in model_list:
print(model_name, lr)
key, subkey = random.split(key)
trained_model, _, train_loss, val_loss, _ = ml.train(
train_x,
train_y,
map_and_loss,
model,
subkey,
ml.EpochStop(epochs, verbose=1),
batch_size,
optax.adamw(lr),
val_x,
val_y,
)
key, subkey = random.split(key)
print('train_loss:', train_loss)
print('val_loss:', val_loss)
trained_models.append((model_name, trained_model))
Equivariant UNet 0.001 Epoch 1 Train: 0.9681907 Val: 0.0517643 Epoch time: 33.94479 Epoch 2 Train: 0.0274913 Val: 0.0120747 Epoch time: 4.55271 Epoch 3 Train: 0.0054589 Val: 0.0036202 Epoch time: 4.78631 Epoch 4 Train: 0.0029062 Val: 0.0020292 Epoch time: 5.14995 Epoch 5 Train: 0.0018013 Val: 0.0018112 Epoch time: 4.85557 Epoch 6 Train: 0.0013603 Val: 0.0008283 Epoch time: 4.71776 Epoch 7 Train: 0.0008071 Val: 0.0007172 Epoch time: 4.72322 Epoch 8 Train: 0.0006658 Val: 0.0006565 Epoch time: 4.66001 Epoch 9 Train: 0.0006014 Val: 0.0007755 Epoch time: 4.78692 Epoch 10 Train: 0.0005596 Val: 0.0005794 Epoch time: 5.01457 train_loss: 0.0005596265 val_loss: 0.0005793912 Vanilla UNet 0.002 Epoch 1 Train: 1.8165448 Val: 1.1921453 Epoch time: 6.44025 Epoch 2 Train: 0.9374802 Val: 0.5293492 Epoch time: 1.06413 Epoch 3 Train: 0.2831277 Val: 0.1687648 Epoch time: 1.07501 Epoch 4 Train: 0.0737486 Val: 0.0423572 Epoch time: 1.06805 Epoch 5 Train: 0.0343936 Val: 0.0252487 Epoch time: 1.06876 Epoch 6 Train: 0.0188330 Val: 0.0149000 Epoch time: 1.07510 Epoch 7 Train: 0.0131375 Val: 0.0121628 Epoch time: 1.07215 Epoch 8 Train: 0.0101806 Val: 0.0098821 Epoch time: 1.44962 Epoch 9 Train: 0.0083471 Val: 0.0078261 Epoch time: 1.07428 Epoch 10 Train: 0.0068873 Val: 0.0069367 Epoch time: 1.07469 train_loss: 0.006887291 val_loss: 0.006936726
Lets see how the model performs on one example from the training set.
fig, axes = plt.subplots(nrows=len(trained_models), ncols=3, figsize=(6 * 3, 6 * len(trained_models)))
for i, (model_name, trained_model) in enumerate(trained_models):
vmap_trained_model = jax.vmap(trained_model, in_axes=(0,None), out_axes=(0,None))
pred_y = vmap_trained_model(train_x.get_one(), None)[0].get_one().to_images()[0]
y_img = train_y.get_one().to_images()[0]
pred_y.plot(axes[i,0], f'{model_name} Prediction')
y_img.plot(axes[i,1], 'Target')
(pred_y - y_img).plot(axes[i,2], 'Difference')
Finally, we would like to know how the model performs on a test data set. We print the average test error, as well as one example.
fig, axes = plt.subplots(nrows=len(trained_models), ncols=3, figsize=(6 * 3, 6 * len(trained_models)))
for i, (model_name, trained_model) in enumerate(trained_models):
test_loss = ml.map_loss_in_batches(map_and_loss, trained_model, test_x, test_y, batch_size, None)
print(f'{model_name} test_loss:', test_loss)
vmap_trained_model = jax.vmap(trained_model, in_axes=(0,None), out_axes=(0,None))
pred_y = jax.vmap(trained_model)(test_x.get_one(), None)[0].get_one().to_images()[0]
y_img = test_y.get_one().to_images()[0]
print(map_and_loss(trained_model, test_x.get_one(), test_y.get_one(), None)[0])
pred_y.plot(axes[i,0], f'{model_name} Prediction')
y_img.plot(axes[i,1], 'Target')
(pred_y - y_img).plot(axes[i,2], 'Difference')
Equivariant UNet test_loss: 0.00058578723 0.0010609535 Vanilla UNet test_loss: 0.0071529485 0.008420781
We would like to get a better idea of the performance between the equivariant model and the non-equivariant model. The equivariant model appears better both by looking at a single example picture, and by comparing the average loss. However, we might ask whether the non-equivariant version performs well in general, it just has a few outliers. To investigate this further, we will look at a histogram of the losses on each test data point.
First we define map_plus_loss which returns both the losses and the mapped inputs. We use the mapped inputs to get the loss on each test point, which we then use to generate the histogram.
def map_plus_loss(
model: models.MultiImageModule,
multi_image_x: geom.MultiImage,
multi_image_y: geom.MultiImage,
aux_data: Optional[eqx.nn.State] = None,
) -> tuple[jax.Array, Optional[eqx.nn.State], geom.MultiImage]:
pred_y, aux_data = jax.vmap(model, in_axes=(0,None), out_axes=(0,None))(multi_image_x, aux_data)
return ml.smse_loss(pred_y, multi_image_y), aux_data, pred_y
losses_per_batch = []
for _, trained_model in trained_models:
_, pred_y = ml.map_plus_loss_in_batches(map_plus_loss, trained_model, test_x, test_y, batch_size, None)
losses_per_batch.append(ml.smse_loss(pred_y, test_y, reduce=None))
# the visualization is better if we make the x-scale log
logmin = jnp.log10(jnp.min(jnp.stack(losses_per_batch)))
logmax = jnp.log10(jnp.max(jnp.stack(losses_per_batch)))
fig = plt.hist(losses_per_batch, bins=jnp.logspace(logmin,logmax,50), label=[name for name,_ in trained_models])
plt.legend()
plt.xscale('log')
A further interesting experiment would be to see how the model performs on data which looks very different, for example different domains, different coefficient sampling schemes or even different equations. Try it out!