Conversion example
In this example, we will take an existing vanilla model, and convert it to an equivariant model. To better exhibit the model performance, we will pick a dataset with both scalar fields and vector fields. We will use a tiny subset of the 2D computational fluid dynamics data from pdebench in the data folder. The full dataset (10,000 trajectories) can be found here: https://darus.uni-stuttgart.de/dataset.xhtml?persistentId=doi:10.18419/darus-2986.
This data consists of pressure (scalar), density (scalar), and velocity (vector) fields discretized on 128 x 128 images. The mini data set has 4 trajectories which each have 21 time steps. Our model will take as input 4 time steps and try to predict the next data set. We start by loading the data and cutting the time steps into overlapping sections of 4 input steps and 1 output step to use for training, validation, and testing.
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=6
# set to -1 to prevent JAX from using the GPUs.
import matplotlib.pyplot as plt
from typing import Any, Optional, Self
import jax
import jax.numpy as jnp
import jax.random as random
from jaxtyping import ArrayLike
import equinox as eqx
import optax
import h5py
import ginjax.geometric as geom
import ginjax.ml as ml
import ginjax.models as models
import ginjax.data as gc_data
from ginjax import layers
env: CUDA_DEVICE_ORDER=PCI_BUS_ID env: CUDA_VISIBLE_DEVICES=6
def read_one_h5(filename: str, num_trajectories: int) -> tuple:
"""
Given a filename and a type of data (train, test, or validation), read the data and return as jax arrays.
args:
filename (str): the full file path
data_class (str): either 'train', 'test', or 'valid'
returns: u, vxy as jax arrays
"""
data_dict = h5py.File(filename)
# all of these are shape (num_trajectories, t, x, y) = (10K, 21, 128, 128)
density = jax.device_put(
jnp.array(data_dict["density"][:num_trajectories][()]), jax.devices("cpu")[0]
)
pressure = jax.device_put(
jnp.array(data_dict["pressure"][:num_trajectories][()]), jax.devices("cpu")[0]
)
vx = jax.device_put(jnp.array(data_dict["Vx"][:num_trajectories][()]), jax.devices("cpu")[0])
vy = jax.device_put(jnp.array(data_dict["Vy"][:num_trajectories][()]), jax.devices("cpu")[0])
vxy = jnp.stack([vx, vy], axis=-1)
data_dict.close()
return density, pressure, vxy
def get_data(
D: int,
filename: str,
n_train: int,
n_test: int,
past_steps: int,
normalize: bool = True,
) -> tuple[geom.MultiImage, ...]:
density, pressure, velocity = read_one_h5(filename, n_train + n_test)
if normalize:
density = (density - jnp.mean(density[:n_train])) / jnp.std(density[:n_train])
pressure = (pressure - jnp.mean(pressure[:n_train])) / jnp.std(pressure[:n_train])
velocity = velocity / jnp.std(velocity[:n_train])
# (batch,2,timesteps,spatial)
density_pressure = jnp.concatenate([density[:, None], pressure[:, None]], axis=1)
# (batch,2*timesteps,spatial)
density_pressure = density_pressure.reshape(
(len(density_pressure), -1) + density_pressure.shape[3:]
)
total_steps = 21
is_torus = True
constant_fields = geom.MultiImage({}, D, is_torus)
train_X, train_Y = gc_data.batch_time_series(
geom.MultiImage({(0, 0): density_pressure[:n_train], (1, 0): velocity[:n_train]}, D, is_torus),
constant_fields,
total_steps,
past_steps,
1,
)
test_X, test_Y = gc_data.batch_time_series(
geom.MultiImage({(0, 0): density_pressure[n_train:], (1, 0): velocity[n_train:]}, D, is_torus),
constant_fields,
total_steps,
past_steps,
1,
)
return train_X, train_Y, test_X, test_Y
The mini dataset is provided in the repo, update the relative path if necessary.
D = 2
N = 128
n_train = 3
n_test = 1
past_steps = 4
normalize = True
train_X, train_Y, test_X, test_Y = get_data(
D,
'../data/mini_2D_CFD_Rand_M0.1_Eta0.01_Zeta0.01_periodic_128_Train.hdf5',
n_train,
n_test,
past_steps,
normalize,
)
print(train_X)
print(train_Y)
print(test_X)
print(test_Y)
typical_cnn_input = train_X.to_scalar_multi_image()[((),0)]
print(f'Typical CNN shape (N,C,H,W): {typical_cnn_input.shape}')
<class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (True, True) ((), 0): (51, 8, 128, 128) ((False,), 0): (51, 4, 128, 128, 2) <class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (True, True) ((), 0): (51, 2, 128, 128) ((False,), 0): (51, 1, 128, 128, 2) <class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (True, True) ((), 0): (17, 8, 128, 128) ((False,), 0): (17, 4, 128, 128, 2) <class 'ginjax.geometric.multi_image.MultiImage'> D: 2, is_torus: (True, True) ((), 0): (17, 2, 128, 128) ((False,), 0): (17, 1, 128, 128, 2) Typical CNN shape (N,C,H,W): (51, 16, 128, 128)
We can see that the inputs have 8 scalar channels for 4 timesteps each of the density and pressure and 4 timesteps for the velocity. Likewise, the outputs have 2 scalar channels and 1 vector channel. We can convert the input data to being entirely channels of a scalar image, which would be the typical input to a vanilla CNN. This would have 16 input channels, 8 for the 8 scalar images and 8 for the 4 vector images. Now let us define a typical equinox model to do machine learning.
class VanillaCNN(eqx.Module):
layers: list
def __init__(
self: Self,
in_channels: int,
out_channels: int,
width: int,
kernel_size: int,
key: ArrayLike,
) -> None:
key1, key2, key3, key4, key5, key6, key7 = jax.random.split(key, 7)
self.layers = [
# encoder
eqx.nn.Conv2d(in_channels, width, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key1),
jax.nn.relu,
eqx.nn.Conv2d(width, width, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key2),
jax.nn.relu,
# pooling
eqx.nn.MaxPool2d(kernel_size=2, stride=2),
eqx.nn.Conv2d(width, 2*width, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key3),
jax.nn.relu,
eqx.nn.Conv2d(2*width, 2*width, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key4),
jax.nn.relu,
eqx.nn.ConvTranspose2d(2*width, width, kernel_size=2, stride=2, padding='SAME', padding_mode='CIRCULAR', key=key5),
jax.nn.relu,
# decoder
eqx.nn.Conv2d(width, width, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key6),
jax.nn.relu,
eqx.nn.Conv2d(width, out_channels, kernel_size, padding='SAME', padding_mode='CIRCULAR', key=key7),
]
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
key = random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vanilla_model = VanillaCNN(in_channels=16, out_channels=4, width=32, kernel_size=3, key=subkey)
print(vanilla_model)
VanillaCNN(
layers=[
Conv2d(
num_spatial_dims=2,
weight=f32[32,16,3,3],
bias=f32[32,1,1],
in_channels=16,
out_channels=32,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
Conv2d(
num_spatial_dims=2,
weight=f32[32,32,3,3],
bias=f32[32,1,1],
in_channels=32,
out_channels=32,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
MaxPool2d(
init=-inf,
operation=<function max>,
num_spatial_dims=2,
kernel_size=(2, 2),
stride=(2, 2),
padding=((0, 0), (0, 0)),
use_ceil=False
),
Conv2d(
num_spatial_dims=2,
weight=f32[64,32,3,3],
bias=f32[64,1,1],
in_channels=32,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
Conv2d(
num_spatial_dims=2,
weight=f32[64,64,3,3],
bias=f32[64,1,1],
in_channels=64,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
ConvTranspose2d(
num_spatial_dims=2,
weight=f32[32,64,2,2],
bias=f32[32,1,1],
in_channels=64,
out_channels=32,
kernel_size=(2, 2),
stride=(2, 2),
padding='SAME',
output_padding=(0, 0),
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
Conv2d(
num_spatial_dims=2,
weight=f32[32,32,3,3],
bias=f32[32,1,1],
in_channels=32,
out_channels=32,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
),
<wrapped function relu>,
Conv2d(
num_spatial_dims=2,
weight=f32[4,32,3,3],
bias=f32[4,1,1],
in_channels=32,
out_channels=4,
kernel_size=(3, 3),
stride=(1, 1),
padding='SAME',
dilation=(1, 1),
groups=1,
use_bias=True,
padding_mode='CIRCULAR'
)
]
)
To convert this model to an equivariant version, we can replace each layer with its equivariant counterpart. These layers process MultiImages, so rather than defining the number of input and output channels, we must define the number of input and output channels for each type of geometric image. We call this object the MultiImage signature, which is a tuple of pairs of key, number of channels. The key itself is a tuple whose first element denotes which axes are covariant (True) or contravariant (False), and the second element is the parity. We can easily generate the signature of a multi image A by doing A.get_signature(). See below for the signatures.
print('Input signature:', train_X.get_signature())
print('Output signature:', train_Y.get_signature())
Input signature: ((((), 0), 8), (((False,), 0), 4)) Output signature: ((((), 0), 2), (((False,), 0), 1))
The input signature has 8 channels of scalar () even 0 parity and 4 channels of contravariant vector (False,) even 0 parity. The output signature has 2 channels of scalar () even 0 parity and 1 channel of contravariant vector (False,) even 0 parity.
Each intermediate step of the model will also be a MultiImage, so we need to define what tensor orders and parities those will be. Additionally, we need to define the number of channels for each image type. Typically, we will let the image types by the union of input and output image types, and we will set a width which will be the same for each image type. For activation functions, we use the Vector Neuron nonlinearity from: https://arxiv.org/abs/2104.12229. There are other options for equivariant nonlinearities, but this performs the best of ones that we have tried, and it is flexible enough to adapt typical nonlinearities such as relu, gelu, etc.
Note that when defining the hidden layer widths, the width is the number of channels of each image type. So a width of 32 for an equivariant CNN with that operates on both scalar and vector images will have 32 channels for each. For a more apples to apples comparison between the two networks, reduce the number of channels in the equivariant CNN.
The layer conversions are as follows:
- Conv -> ConvContract
- activation -> VectorNeuronNonLinear of that activation
- MaxPool -> MaxNormPool
- ConvTranspose -> ConvContract with
lhs_dilation
class EquivariantCNN(eqx.Module):
D: int
layers: list
def __init__(
self: Self,
D: int,
input_keys: geom.Signature,
output_keys: geom.Signature,
width: int,
invariant_filters: geom.MultiImage,
upsample_filters: geom.MultiImage,
key: ArrayLike,
) -> None:
self.D = D
mid_keys = geom.signature_union(input_keys, output_keys, width)
mid_keys2 = geom.signature_union(input_keys, output_keys, width*2)
rand_keys = random.split(key, 14)
self.layers = [
# encoder
layers.ConvContract(input_keys, mid_keys, invariant_filters, key=rand_keys[0]),
layers.VectorNeuronNonlinear(mid_keys, D, jax.nn.relu, key=rand_keys[1]),
layers.ConvContract(mid_keys, mid_keys, invariant_filters, key=rand_keys[2]),
layers.VectorNeuronNonlinear(mid_keys, D, jax.nn.relu, key=rand_keys[3]),
# pooling
layers.MaxNormPool(patch_len=2),
layers.ConvContract(mid_keys, mid_keys2, invariant_filters, key=rand_keys[4]),
layers.VectorNeuronNonlinear(mid_keys2, D, jax.nn.relu, key=rand_keys[5]),
layers.ConvContract(mid_keys2, mid_keys2, invariant_filters, key=rand_keys[6]),
layers.VectorNeuronNonlinear(mid_keys2, D, jax.nn.relu, key=rand_keys[7]),
layers.ConvContract(
mid_keys2,
mid_keys,
upsample_filters,
padding=((1, 1),) * self.D,
lhs_dilation=(2,) * self.D,
key=rand_keys[8],
),
layers.VectorNeuronNonlinear(mid_keys, D, jax.nn.relu, key=rand_keys[9]),
# decoder
layers.ConvContract(mid_keys, mid_keys, invariant_filters, key=rand_keys[10]),
layers.VectorNeuronNonlinear(mid_keys, D, jax.nn.relu, key=rand_keys[11]),
layers.ConvContract(mid_keys, output_keys, invariant_filters, key=rand_keys[12]),
]
def __call__(
self: Self,
x: geom.MultiImage,
aux_data: Optional[eqx.nn.State] = None,
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
for layer in self.layers:
x = layer(x)
return x, aux_data
operators = geom.make_all_operators(D)
conv_filters = geom.get_invariant_filters([3], [0,1,2], [0,1], D, operators)
upsample_filters = geom.get_invariant_filters([2], [0,1,2], [0,1], D, operators)
key, subkey = random.split(key)
equiv_model = EquivariantCNN(D, train_X.get_signature(), train_Y.get_signature(), 32, conv_filters, upsample_filters, key=subkey)
print(equiv_model)
EquivariantCNN(
D=2,
layers=[
ConvContract(
weights={
((), 0):
{((), 0): f32[32,8,3], ((False,), 0): f32[32,8,2]},
((False,), 0):
{((), 0): f32[32,4,2], ((False,), 0): f32[32,4,5]}
},
bias={((), 0): f32[32], ((False,), 0): f32[32]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 8), (((False,), 0), 4)),
target_keys=((((), 0), 32), (((False,), 0), 32)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[32,32]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
ConvContract(
weights={
((), 0):
{((), 0): f32[32,32,3], ((False,), 0): f32[32,32,2]},
((False,), 0):
{((), 0): f32[32,32,2], ((False,), 0): f32[32,32,5]}
},
bias={((), 0): f32[32], ((False,), 0): f32[32]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 32), (((False,), 0), 32)),
target_keys=((((), 0), 32), (((False,), 0), 32)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[32,32]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
MaxNormPool(patch_len=2, use_norm=True),
ConvContract(
weights={
((), 0):
{((), 0): f32[64,32,3], ((False,), 0): f32[64,32,2]},
((False,), 0):
{((), 0): f32[64,32,2], ((False,), 0): f32[64,32,5]}
},
bias={((), 0): f32[64], ((False,), 0): f32[64]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 32), (((False,), 0), 32)),
target_keys=((((), 0), 64), (((False,), 0), 64)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[64,64]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
ConvContract(
weights={
((), 0):
{((), 0): f32[64,64,3], ((False,), 0): f32[64,64,2]},
((False,), 0):
{((), 0): f32[64,64,2], ((False,), 0): f32[64,64,5]}
},
bias={((), 0): f32[64], ((False,), 0): f32[64]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 64), (((False,), 0), 64)),
target_keys=((((), 0), 64), (((False,), 0), 64)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[64,64]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
ConvContract(
weights={
((), 0):
{((), 0): f32[32,64,1], ((False,), 0): f32[32,64,1]},
((False,), 0):
{((), 0): f32[32,64,1], ((False,), 0): f32[32,64,2]}
},
bias={((), 0): f32[32], ((False,), 0): f32[32]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66bc6d5550>,
input_keys=((((), 0), 64), (((False,), 0), 64)),
target_keys=((((), 0), 32), (((False,), 0), 32)),
use_bias='auto',
stride=1,
padding=((1, 1), (1, 1)),
lhs_dilation=(2, 2),
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[32,32]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
ConvContract(
weights={
((), 0):
{((), 0): f32[32,32,3], ((False,), 0): f32[32,32,2]},
((False,), 0):
{((), 0): f32[32,32,2], ((False,), 0): f32[32,32,5]}
},
bias={((), 0): f32[32], ((False,), 0): f32[32]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 32), (((False,), 0), 32)),
target_keys=((((), 0), 32), (((False,), 0), 32)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
),
VectorNeuronNonlinear(
weights={((False,), 0): f32[32,32]},
eps=1e-05,
D=2,
scalar_activation=<wrapped function relu>
),
ConvContract(
weights={
((), 0):
{((), 0): f32[2,32,3], ((False,), 0): f32[1,32,2]},
((False,), 0):
{((), 0): f32[2,32,2], ((False,), 0): f32[1,32,5]}
},
bias={((), 0): f32[2], ((False,), 0): f32[1]},
invariant_filters=<ginjax.geometric.multi_image.MultiImage object at 0x7f66dc75ef30>,
input_keys=((((), 0), 32), (((False,), 0), 32)),
target_keys=((((), 0), 2), (((False,), 0), 1)),
use_bias='auto',
stride=1,
padding=None,
lhs_dilation=None,
rhs_dilation=1,
D=2,
fast_mode=False,
missing_filter=False
)
]
)
We can now proceed with our typical training pipeline. We define our map_and_loss function with the timestep_smse_loss. To conform with our train function, the map_and_loss takes 4 arguments: the model, the input multi image, the output target multi image, and aux data that would be used by layers like BatchNorm. Likewise, it returns the loss and the aux data.
def map_and_loss(
model: models.MultiImageModule,
multi_image_x: geom.MultiImage,
multi_image_y: geom.MultiImage,
aux_data: Optional[eqx.nn.State] = None,
) -> tuple[jax.Array, Optional[eqx.nn.State]]:
pred_y, aux_data = jax.vmap(model, in_axes=(0,None), out_axes=(0,None))(multi_image_x, aux_data)
return ml.timestep_smse_loss(pred_y, multi_image_y, n_steps=1)[0], aux_data
Since our training loop works with MultiImage inputs, we would like our vanilla model to also work with MultiImages. We can use the ModelWrapper class to wrap any model that handles image-like data.
vanilla_multi_image_model = models.ModelWrapper(D, vanilla_model, train_Y.get_signature(), train_Y.is_torus)
Finally, we specify our training hyper-parameters and call our training loop. The epochs, batch size, learning rate, and optimizer could likely be further tuned, but the choices below are sufficient for this simple problem.
# training_params
epochs = 50
batch_size = 5
model_list = [
('Equivariant Model', equiv_model, 1e-3),
('Vanilla Model', vanilla_multi_image_model, 1e-3),
]
trained_models = []
for model_name, model, lr in model_list:
print(f'{model_name} ({models.count_params(model)}), {lr}')
key, subkey = random.split(key)
trained_model, _, train_loss, val_loss, _ = ml.train(
train_X,
train_Y,
map_and_loss,
model,
subkey,
ml.EpochStop(epochs, verbose=1),
batch_size,
optax.adamw(lr),
)
key, subkey = random.split(key)
print('train_loss:', train_loss)
print('val_loss:', val_loss)
trained_models.append((model_name, trained_model))
Equivariant Model (126905), 0.001 Epoch 5 Train: 0.4490222 Epoch time: 1.88518 Epoch 10 Train: 0.3164000 Epoch time: 2.04646 Epoch 15 Train: 0.2579377 Epoch time: 1.99601 Epoch 20 Train: 0.2073154 Epoch time: 1.74278 Epoch 25 Train: 0.2163653 Epoch time: 1.91679 Epoch 30 Train: 0.1780696 Epoch time: 1.81148 Epoch 35 Train: 0.1477995 Epoch time: 1.89265 Epoch 40 Train: 0.1635500 Epoch time: 1.81615 Epoch 45 Train: 0.1613974 Epoch time: 1.78835 Epoch 50 Train: 0.1464429 Epoch time: 1.73702 train_loss: 0.14644289 val_loss: None Vanilla Model (87940), 0.001 Epoch 5 Train: 0.8677309 Epoch time: 0.38511 Epoch 10 Train: 0.8414311 Epoch time: 0.37431 Epoch 15 Train: 0.7406319 Epoch time: 0.36105 Epoch 20 Train: 0.6784432 Epoch time: 0.38231 Epoch 25 Train: 0.6981007 Epoch time: 0.37258 Epoch 30 Train: 0.6370169 Epoch time: 0.35824 Epoch 35 Train: 0.6091697 Epoch time: 0.37975 Epoch 40 Train: 0.5731363 Epoch time: 0.34888 Epoch 45 Train: 0.5503518 Epoch time: 0.38084 Epoch 50 Train: 0.5220579 Epoch time: 0.38528 train_loss: 0.5220579 val_loss: None
We can see that the loss of the equivariant model is smaller than the vanilla model. Now lets see how the model performs on one example from the training set.
img_idx = 0 # 0 for density, 1 for pressure, 2 for velocity
fig, axes = plt.subplots(nrows=len(trained_models), ncols=3, figsize=(6 * 3, 6 * len(trained_models)))
images = [[],[]]
for i, (model_name, trained_model) in enumerate(trained_models):
vmap_trained_model = jax.vmap(trained_model, in_axes=(0,None), out_axes=(0,None))
pred_y = vmap_trained_model(train_X.get_one(), None)[0].to_images()[img_idx]
y_img = train_Y.get_one().to_images()[img_idx]
images[i] = [pred_y, y_img, pred_y - y_img, model_name]
image_data = jnp.concatenate([
images[0][0].data.flatten(),
images[0][1].data.flatten(),
images[1][0].data.flatten(),
images[1][1].data.flatten(),
])
diff_data = jnp.concatenate([images[0][2].data.flatten(), images[1][2].data.flatten()])
vmax = jnp.max(jnp.abs(image_data))
vmax_diff = jnp.max(jnp.abs(diff_data))
for i, (pred_y, y_img, y_diff, model_name) in enumerate(images):
pred_y.plot(axes[i,0], f'{model_name} Prediction', colorbar=True, vmin=-vmax, vmax=vmax)
y_img.plot(axes[i,1], 'Target', colorbar=True, vmin=-vmax, vmax=vmax)
y_diff.plot(axes[i,2], 'Difference', colorbar=True, vmin=-vmax_diff, vmax=vmax_diff)
For this small example we just train long enough to get reasonable looking results, but hopefully this gives you an idea of the problem. Finally, we would like to know how the model performs on a test data set. We print the average test error, as well as one example.
img_idx = 0
fig, axes = plt.subplots(nrows=len(trained_models), ncols=3, figsize=(6 * 3, 6 * len(trained_models)))
images = [[],[]]
for i, (model_name, trained_model) in enumerate(trained_models):
test_loss = ml.map_loss_in_batches(map_and_loss, trained_model, test_X, test_Y, batch_size, None)
print(f'{model_name} test_loss:', test_loss)
vmap_trained_model = jax.vmap(trained_model, in_axes=(0,None), out_axes=(0,None))
pred_y = vmap_trained_model(test_X.get_one(), None)[0].to_images()[img_idx]
y_img = test_Y.get_one().to_images()[img_idx]
images[i] = [pred_y, y_img, pred_y - y_img, model_name]
image_data = jnp.concatenate([
images[0][0].data.flatten(),
images[0][1].data.flatten(),
images[1][0].data.flatten(),
images[1][1].data.flatten(),
])
diff_data = jnp.concatenate([images[0][2].data.flatten(), images[1][2].data.flatten()])
vmax = jnp.max(jnp.abs(image_data))
vmax_diff = jnp.max(jnp.abs(diff_data))
for i, (pred_y, y_img, y_diff, model_name) in enumerate(images):
pred_y.plot(axes[i,0], f'{model_name} Prediction', colorbar=True, vmin=-vmax, vmax=vmax)
y_img.plot(axes[i,1], 'Target', colorbar=True, vmin=-vmax, vmax=vmax)
y_diff.plot(axes[i,2], 'Difference', colorbar=True, vmin=-vmax_diff, vmax=vmax_diff)
Equivariant Model test_loss: 0.7887535 Vanilla Model test_loss: 6.456304
This trajectory has very different initial conditions than some of the training trajectories, and both models do significantly worse. The equivariant model does better, but we caution against reading too much into this result since the training data sets are so small for this toy example. Hopefully this notebook gives you a good idea of how to convert your non-equivariant model to an equivariant version.