Quickstart
Basic Featuresยค
See the script quick_start.py for this example in code form.
First our imports. ginjax is built in JAX, and the majority of the model code resides in geometric.
import jax.numpy as jnp
import jax.random as random
import ginjax.geometric as geom
First we construct our image. Suppose you have some data that forms a 3 by 3 vector image. Thus the dimension D=2, the sidelength N=3, and the tensor order k=1 for a block of data N x N x D. Currently only D=2 or D=3 images are valid. The parity is how the image responds when it is reflected. Normal images have parity 0, an image of pseudovectors like angular velocity will have parity 1.
key = random.PRNGKey(0)
key, subkey = random.split(key)
N = 3 # sidelength
D = 2 # dimension
k = 1 # tensor order
parity = 0
data = random.normal(subkey, shape=((N,)*D + (D,)*k))
image = geom.GeometricImage(data, parity=0, D=2)
We can visualize this image with the plotting tools in utils. You will need to call matplotlib.pypolot.show() to display.
image.plot()
Now we can do various operations on this geometric image
image2 = geom.GeometricImage.fill(N, parity, D, fill=jnp.array([1,0])) # fill constructor, each pixel is fill
# pixel-wise addition
image + image2
# pixel-wise subtraction
image - image2
# pixel-wise tensor product
image * image2
# scalar multiplication
image * 3
We can also apply a group action on the image. First we generate all the operators for dimension D, then we apply one
operators = geom.make_all_operators(D)
print("Number of operators:", len(operators))
image.times_group_element(operators[1])
Now let us generate all 3 by 3 filters of tensor order k=0,1 and parity=0,1 that are invariant to the operators
invariant_filters = geom.get_invariant_filters(
Ms=[3],
ks=[0,1],
parities=[0,1],
D=D,
operators=operators,
scale='one', #all the values of the filter are 1, can also 'normalize' so the norm of the tensor pixel is 1
return_list=True,
)
print('Number of invariant filters N=3, k=0,1 parity=0,1:', len(invariant_filters))
Using these filters, we can perform convolutions on our image. Since the filters are invariant, the convolution will be equivariant.
gg = operators[1] # one operator, a flip over the y-axis
ff_k0 = invariant_filters[1] # one filter, a non-trivial scalar filter
print(
"Equivariant:",
jnp.allclose(
image.times_group_element(gg).convolve_with(ff_k0).data,
image.convolve_with(ff_k0).times_group_element(gg).data,
rtol=1e-2,
atol=1e-2,
),
)
When convolving with filters that have tensor order > 0, the resulting image have tensor order img.k + filter.k
ff_k1 = invariant_filters[5]
print('image k:', image.k)
print('filter k:', ff_k1.k)
convolved_image = image.convolve_with(ff_k1)
print('convolved image k:', convolved_image.k)
After convolving, the image has tensor order 1+1=2 pixels. We can transpose the indices of the tensor:
convolved_image.transpose((1,0))
Since the tensor order is >= 2, we can perform a contraction on those indices which will reduce it to tensor order 0.
print('contracted image k:', convolved_image.contract(0,1).k)