Scalar example
This is a simple example where we use ginjax to learn scalar filters. We start by specifying what GPUs to use, and importing packages.
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=6
import time
import optax
from typing_extensions import Optional, Self
import jax
from jax import random
from jaxtyping import ArrayLike
import equinox as eqx
import ginjax.geometric as geom
import ginjax.ml as ml
import ginjax.models as models
from ginjax import layers
env: CUDA_DEVICE_ORDER=PCI_BUS_ID env: CUDA_VISIBLE_DEVICES=6
Now lets define our images X and what filters we are going to use. Our image will be 2D, 64 x 64 scalar images. Our filters will be 3x3 and they will be the invariant scalar filters only. There are 3 of these, and the first one is the identity.
key = random.PRNGKey(time.time_ns())
D = 2
N = 64 # image size
M = 3 # filter image size
num_images = 10
group_actions = geom.make_all_operators(D)
conv_filters = geom.get_invariant_filters(
Ms=[M], ks=[0], parities=[0], D=D, operators=group_actions
)
key, subkey = random.split(key)
multi_image_X = geom.MultiImage(
{(0, 0): random.normal(subkey, shape=(num_images, 1) + (N,) * D)}, D
)
Now let us define our target function, and then construct our target images Y. The target function will merely be convolving by the filter at index 1, then convolving by the filter at index 2.
def target_function(
multi_image: geom.MultiImage, conv_filter_a: jax.Array, conv_filter_b: jax.Array
) -> geom.MultiImage:
convolved_data = geom.convolve(
multi_image.D,
geom.convolve(
multi_image.D, multi_image[((), 0)], conv_filter_a[None, None], multi_image.is_torus
),
conv_filter_b[None, None],
multi_image.is_torus,
)
return geom.MultiImage({(0, 0): convolved_data}, multi_image.D, multi_image.is_torus)
multi_image_y = target_function(multi_image_X, conv_filters[((), 0)][1], conv_filters[((), 0)][2])
We now want to define our network and loss function. Machine learning on the GeometricImageNet is done on the MultiImage object, which is a way of collecting batches of multiple channels of images at possible different tensor orders in a single object.
For this toy example, we will make our task straightforward by making our network a linear combination of all the pairs of convolving by one filter from our set of three, then another filter from our set of three with replacement. In this fashion, our target function will be the 5th of 6 images. Our loss is simply the root mean square error loss (RMSE). The ml.train function expects a map_and_loss function that operates on MultiImages.
class SimpleModel(models.MultiImageModule):
D: int
net: list[layers.ConvContract]
def __init__(
self: Self,
D: int,
input_keys: geom.Signature,
output_keys: geom.Signature,
conv_filters: geom.MultiImage,
key: ArrayLike,
):
self.D = D
key, subkey1, subkey2 = random.split(key, num=3)
self.net = [
layers.ConvContract(input_keys, output_keys, conv_filters, False, key=subkey1),
layers.ConvContract(output_keys, output_keys, conv_filters, False, key=subkey2),
]
def __call__(
self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
for layer in self.net:
x = layer(x)
return x, aux_data
def map_and_loss(
model: models.MultiImageModule,
multi_image_x: geom.MultiImage,
multi_image_y: geom.MultiImage,
aux_data: Optional[eqx.nn.State] = None,
) -> tuple[jax.Array, Optional[eqx.nn.State]]:
pred_y, aux_data = jax.vmap(model, in_axes=(0,None), out_axes=(0,None))(multi_image_x, aux_data)
return ml.smse_loss(multi_image_y, pred_y), aux_data
Now we will train our model using the train function from ml.py. Train takes the input data as a MultiImage, the target data as a MultiImage, a map and loss function that takes arguments (model, x, y, aux_data), the model, a random key for doing the batches, the number of epochs to run, the batch size, and the desired optax optimizer.
key, subkey = random.split(key)
model = SimpleModel(
D, multi_image_X.get_signature(), multi_image_y.get_signature(), conv_filters, subkey
)
key, subkey = random.split(key)
trained_model, _, _, _, _ = ml.train(
multi_image_X,
multi_image_y,
map_and_loss,
model,
subkey,
ml.EpochStop(500, verbose=1),
num_images,
optimizer=optax.adam(optax.exponential_decay(0.1, transition_steps=1, decay_rate=0.99)),
)
assert isinstance(trained_model, SimpleModel)
print(trained_model.net[0].weights)
print(trained_model.net[1].weights)
Epoch 50 Train: 0.1227631 Epoch time: 0.01350
Epoch 100 Train: 0.0005918 Epoch time: 0.01245
Epoch 150 Train: 0.0000003 Epoch time: 0.01249
Epoch 200 Train: 0.0000000 Epoch time: 0.01358
Epoch 250 Train: 0.0000000 Epoch time: 0.01210
Epoch 300 Train: 0.0000000 Epoch time: 0.01402
Epoch 350 Train: 0.0000000 Epoch time: 0.01245
Epoch 400 Train: 0.0000000 Epoch time: 0.01337
Epoch 450 Train: 0.0000000 Epoch time: 0.01242
Epoch 500 Train: 0.0000000 Epoch time: 0.01413
{((), 0): {((), 0): Array([[[1.4731792e-06, 9.1884673e-01, 4.2977635e-07]]], dtype=float32)}}
{((), 0): {((), 0): Array([[[-1.589166e-05, -7.275783e-07, 1.088328e+00]]], dtype=float32)}}
We can see that two are the filters have weight very close to 1, and the rest are close to 0. Hooray!