Skip to content

Functional geometric image

ginjax.geometric.functional_geometric_image ¤

parse_shape(shape: tuple[int, ...], D: int) -> tuple[tuple[int, ...], int] ¤

Given a geometric image shape and dimension D, return the sidelength tuple and tensor order k.

Parameters:

Name Type Description Default
shape tuple[int, ...]

the shape of the data of a single geometric image

required
D int

dimension of the image

required

Returns:

Type Description
tuple[tuple[int, ...], int]

tuple of spatial dimensions, tensor order

Source code in ginjax/geometric/functional_geometric_image.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def parse_shape(shape: tuple[int, ...], D: int) -> tuple[tuple[int, ...], int]:
    """
    Given a geometric image shape and dimension D, return the sidelength tuple and tensor order k.

    args:
        shape: the shape of the data of a single geometric image
        D: dimension of the image

    returns:
        tuple of spatial dimensions, tensor order
    """
    assert isinstance(shape, tuple), f"parse_shape: Shape must be a tuple, but it is {type(shape)}"
    assert len(shape) >= D, f"parse_shape: Shape {shape} is shorter than D={D}"
    return shape[:D], len(shape) - D

hash(D: int, spatial_dims: tuple[int, ...], indices: ArrayLike) -> tuple[jax.Array, ...] ¤

Converts an array of indices to their pixels on the torus by modding the indices with the spatial dimensions.

Parameters:

Name Type Description Default
D int

dimension of the image

required
spatial_dims tuple[int, ...]

the spatial dimensions of the data

required
indices ArrayLike

array of indices, shape (num_idx, D) to apply the remainder to

required

Returns:

Type Description
tuple[Array, ...]

the pixel indices as a d-tuple of jax arrays

Source code in ginjax/geometric/functional_geometric_image.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def hash(D: int, spatial_dims: tuple[int, ...], indices: ArrayLike) -> tuple[jax.Array, ...]:
    """
    Converts an array of indices to their pixels on the torus by modding the indices with the
    spatial dimensions.

    args:
        D: dimension of the image
        spatial_dims: the spatial dimensions of the data
        indices: array of indices, shape (num_idx, D) to apply the remainder to

    returns:
        the pixel indices as a d-tuple of jax arrays
    """
    spatial_dims_array = jnp.array(spatial_dims).reshape((1, D))
    return tuple(jnp.remainder(indices, spatial_dims_array).transpose().astype(int))

nonempty_pixels(D: int, data: jax.Array, n_lead: int = 0) -> jax.Array ¤

Get the nonempty pixels as a true/false array.

Parameters:

Name Type Description Default
D int

the dimension

required
data Array

array of shape (n_lead,spatial,tensor)

required
n_lead int

the number of leading batch axes

0

Returns:

Type Description
Array

a true/false array of flattened shape (n_lead,image_size)

Source code in ginjax/geometric/functional_geometric_image.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def nonempty_pixels(D: int, data: jax.Array, n_lead: int = 0) -> jax.Array:
    """
    Get the nonempty pixels as a true/false array.

    args:
        D: the dimension
        data: array of shape (n_lead,spatial,tensor)
        n_lead: the number of leading batch axes

    returns:
        a true/false array of flattened shape (n_lead,image_size)
    """
    spatial_dims, k = parse_shape(data.shape[n_lead:], D)
    spatial_size = math.prod(spatial_dims)
    return jnp.any(
        ~jnp.isclose(
            data.reshape(data.shape[:n_lead] + (spatial_size, D**k)), 0.0, rtol=TINY, atol=TINY
        ),
        axis=-1,
    )

pixel_idxs(spatial_dims: tuple[int, ...]) -> jax.Array ¤

Get the idxs of pixels for spatial_dims, ordered in the flattened image order.

Parameters:

Name Type Description Default
spatial_dims tuple[int, ...]

tuple of the spatial dimensions

required

Returns:

Type Description
Array

pixels idxs, shape (num_pixels,D)

Source code in ginjax/geometric/functional_geometric_image.py
76
77
78
79
80
81
82
83
84
85
86
87
88
def pixel_idxs(spatial_dims: tuple[int, ...]) -> jax.Array:
    """
    Get the idxs of pixels for spatial_dims, ordered in the flattened image order.

    args:
        spatial_dims: tuple of the spatial dimensions

    returns:
        pixels idxs, shape (num_pixels,D)
    """
    D = len(spatial_dims)
    meshgrid_dims = tuple(jnp.arange(M) for M in spatial_dims)
    return jnp.stack(jnp.meshgrid(*meshgrid_dims, indexing="ij"), axis=-1).reshape((-1, D))

get_torus_expanded(image: jax.Array, is_torus: tuple[bool, ...], filter_spatial_dims: tuple[int, ...], rhs_dilation: tuple[int, ...]) -> tuple[jax.Array, tuple[tuple[int, int], ...]] ¤

For a particular filter, expand the image so that we no longer have to do convolutions on the torus, we are just doing convolutions on the expanded image and will get the same result.

Parameters:

Name Type Description Default
image Array

image data, (batch,spatial,channels)

required
is_torus tuple[bool, ...]

d-length tuple of bools specifying which spatial dimensions are toroidal

required
filter_spatial_dims tuple[int, ...]

d-length tuple of the spatial dimensions of the filter

required
rhs_dilation tuple[int, ...]

dilation to apply to each filter dimension D

required

Returns:

Type Description
tuple[Array, tuple[tuple[int, int], ...]]

The new expanded torus, and the appropriate padding_literal to use in convolve

Source code in ginjax/geometric/functional_geometric_image.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_torus_expanded(
    image: jax.Array,
    is_torus: tuple[bool, ...],
    filter_spatial_dims: tuple[int, ...],
    rhs_dilation: tuple[int, ...],
) -> tuple[jax.Array, tuple[tuple[int, int], ...]]:
    """
    For a particular filter, expand the image so that we no longer have to do convolutions on the torus, we are
    just doing convolutions on the expanded image and will get the same result.

    args:
        image: image data, (batch,spatial,channels)
        is_torus: d-length tuple of bools specifying which spatial dimensions are toroidal
        filter_spatial_dims: d-length tuple of the spatial dimensions of the filter
        rhs_dilation: dilation to apply to each filter dimension D

    Returns:
        The new expanded torus, and the appropriate padding_literal to use in convolve
    """
    # assert all the filter side lengths are odd
    assert functools.reduce(lambda carry, M: carry and (M % 2 == 1), filter_spatial_dims, True)

    # for each torus dimension, calculate the torus padding
    padding_f = lambda M, dilation, torus: ((((M - 1) // 2) * dilation),) * 2 if torus else (0, 0)
    zipped_dims = zip(filter_spatial_dims, rhs_dilation, is_torus)
    torus_padding = tuple(padding_f(M, dilation, torus) for M, dilation, torus in zipped_dims)

    # calculate indices for torus padding, then use hash to select the appropriate pixels
    expanded_image = jnp.pad(image, ((0, 0),) + torus_padding + ((0, 0),), mode="wrap")

    # zero_pad where we don't torus pad
    zero_padding = get_same_padding(
        filter_spatial_dims,
        rhs_dilation,
        tuple(not torus for torus in is_torus),
    )

    return expanded_image, zero_padding

get_same_padding(filter_spatial_dims: tuple[int, ...], rhs_dilation: tuple[int, ...], pad_dims: Optional[tuple[bool, ...]] = None) -> tuple[tuple[int, int], ...] ¤

Calculate the padding for each dimension D necessary for 'SAME' padding, including rhs_dilation.

Parameters:

Name Type Description Default
filter_spatial_dims tuple[int, ...]

filter spatial dimensions, length D tuple

required
rhs_dilation tuple[int, ...]

rhs (filter) dilation, length D tuple

required
pad_dims Optional[tuple[bool, ...]]

d-tuple of dimensions to pad, default (None) is all dimensions

None

Returns:

Type Description
tuple[tuple[int, int], ...]

d-tuple of pairs of amount of pixels to pad

Source code in ginjax/geometric/functional_geometric_image.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def get_same_padding(
    filter_spatial_dims: tuple[int, ...],
    rhs_dilation: tuple[int, ...],
    pad_dims: Optional[tuple[bool, ...]] = None,
) -> tuple[tuple[int, int], ...]:
    """
    Calculate the padding for each dimension D necessary for 'SAME' padding, including rhs_dilation.

    args:
        filter_spatial_dims: filter spatial dimensions, length D tuple
        rhs_dilation: rhs (filter) dilation, length D tuple
        pad_dims: d-tuple of dimensions to pad, default (None) is all dimensions

    returns:
        d-tuple of pairs of amount of pixels to pad
    """
    pad_dims = (True,) * len(filter_spatial_dims) if pad_dims is None else pad_dims

    def padding_f(M: int, dilation: int, pad: int) -> tuple[int, int]:
        if pad:
            return (((M - 1) // 2) * dilation, ((M - 1) // 2) * dilation)
        else:
            return (0, 0)

    zipped_dims = zip(filter_spatial_dims, rhs_dilation, pad_dims)
    return tuple(padding_f(M, dilation, pad) for M, dilation, pad in zipped_dims)

pre_tensor_product_expand(D: int, image_a: jax.Array, image_b: jax.Array, a_offset: int = 0, b_offset: int = 0, dtype: Optional[jnp.dtype] = None) -> tuple[jax.Array, jax.Array] ¤

Rather than take a tensor product of two tensors, we can first take a tensor product of each with a tensor of ones with the shape of the other. Then we have two matching shapes, and we can then do whatever operations.

Parameters:

Name Type Description Default
D int

dimension of the image

required
image_a Array

one geometric image whose tensors we will later be doing tensor products on

required
image_b Array

other geometric image

required
a_offset int

number of axes of image_a prior to the spatial dims

0
b_offset int

number of axes of image_b prior to the spatial dims

0
dtype Optional[dtype]

if present, cast both outputs to dtype

None

Returns:

Type Description
tuple[Array, Array]

tuple of the expanded images

Source code in ginjax/geometric/functional_geometric_image.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def pre_tensor_product_expand(
    D: int,
    image_a: jax.Array,
    image_b: jax.Array,
    a_offset: int = 0,
    b_offset: int = 0,
    dtype: Optional[jnp.dtype] = None,
) -> tuple[jax.Array, jax.Array]:
    """
    Rather than take a tensor product of two tensors, we can first take a tensor product of each with a tensor of
    ones with the shape of the other. Then we have two matching shapes, and we can then do whatever operations.

    args:
        D: dimension of the image
        image_a: one geometric image whose tensors we will later be doing tensor products on
        image_b: other geometric image
        a_offset: number of axes of image_a prior to the spatial dims
        b_offset: number of axes of image_b prior to the spatial dims
        dtype: if present, cast both outputs to dtype

    returns:
        tuple of the expanded images
    """
    _, img_a_k = parse_shape(image_a.shape[a_offset:], D)
    _, img_b_k = parse_shape(image_b.shape[b_offset:], D)

    if img_b_k > 0:
        image_a_expanded = jnp.tensordot(
            image_a,
            jnp.ones((D,) * img_b_k),
            axes=0,
        )
    else:
        image_a_expanded = image_a

    if img_a_k > 0:
        break1 = img_a_k + b_offset + D  # after outer product, end of image_b N^D axes
        # we want to expand the ones in the middle (D^ki), so add them on the front, then move to middle

        # (b_offset,b_spatial,b_tensor) -> (a_tensor,b_offset,b_spatial,b_tensor)
        image_b_expanded = jnp.tensordot(jnp.ones((D,) * img_a_k), image_b, axes=0)

        # (a_tensor,b_offset,b_spatial,b_tensor) -> (b_offset,b_spatial,a_tensor,b_tensor)
        idxs = (
            tuple(range(img_a_k, break1))
            + tuple(range(img_a_k))
            + tuple(range(break1, break1 + img_b_k))
        )
        image_b_expanded = image_b_expanded.transpose(idxs)
    else:
        image_b_expanded = image_b

    if dtype is not None:
        image_a_expanded = image_a_expanded.astype(dtype)
        image_b_expanded = image_b_expanded.astype(dtype)

    return image_a_expanded, image_b_expanded

conv_contract_image_expand(D: int, image: jax.Array, filter_k: int) -> jax.Array ¤

For conv_contract, we will be immediately performing a contraction, so we don't need to fully expand each tensor, just the k image to the k+k' conv filter.

Parameters:

Name Type Description Default
D int

dimension of the space

required
image Array

image data, shape (in_c,spatial,tensor)

required
filter_k int

the filter tensor order

required

Returns:

Type Description
Array

the expanded image data

Source code in ginjax/geometric/functional_geometric_image.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def conv_contract_image_expand(D: int, image: jax.Array, filter_k: int) -> jax.Array:
    """
    For conv_contract, we will be immediately performing a contraction, so we don't need to fully expand
    each tensor, just the k image to the k+k' conv filter.

    args:
        D: dimension of the space
        image: image data, shape (in_c,spatial,tensor)
        filter_k: the filter tensor order

    returns:
        the expanded image data
    """
    _, img_k = parse_shape(image.shape[2:], D)
    k_prime = filter_k - img_k  # not to be confused with Coach Prime
    assert k_prime >= 0

    return jnp.tensordot(image, jnp.ones((D,) * k_prime), axes=0)

mul(D: int, image_a: jax.Array, image_b: jax.Array, a_offset: int = 0, b_offset: int = 0) -> jax.Array ¤

Multiplication operator between two images, implemented as a tensor product of the pixels.

Parameters:

Name Type Description Default
D int

dimension of the images

required
image_a Array

image data

required
image_b Array

image data

required
a_offset int

number of axes before the spatial axes (batch, channels, etc.)

0
b_offset int

number of axes before the spatial axes (batch, channels, etc.)

0

Returns:

Type Description
Array

the multiplied images

Source code in ginjax/geometric/functional_geometric_image.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def mul(
    D: int,
    image_a: jax.Array,
    image_b: jax.Array,
    a_offset: int = 0,
    b_offset: int = 0,
) -> jax.Array:
    """
    Multiplication operator between two images, implemented as a tensor product of the pixels.

    args:
        D: dimension of the images
        image_a: image data
        image_b: image data
        a_offset: number of axes before the spatial axes (batch, channels, etc.)
        b_offset: number of axes before the spatial axes (batch, channels, etc.)

    returns:
        the multiplied images
    """
    image_a_data, image_b_data = pre_tensor_product_expand(D, image_a, image_b, a_offset, b_offset)
    return image_a_data * image_b_data  # now that shapes match, do elementwise multiplication

convolve(D: int, image: jax.Array, filter_image: jax.Array, is_torus: Union[tuple[bool, ...], bool], stride: Union[int, tuple[int, ...]] = 1, padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1, tensor_expand: bool = True) -> jax.Array ¤

Here is how this function works:

  1. Expand the geom_image to its torus shape, i.e. add filter.m cells all around the perimeter of the image
  2. Do the tensor product (with 1s) to each image.k, filter.k so that they are both image.k + filter.k tensors. That is if image.k=2, filter.k=1, do (D,D) => (D,D) x (D,) and (D,) => (D,D) x (D,) with tensors of 1s
  3. Now we shape the inputs to work with jax.lax.conv_general_dilated
  4. Put image in NHWC (batch, height, width, channel). Thus we vectorize the tensor
  5. Put filter in HWIO (height, width, input, output). Input is 1, output is the vectorized tensor
  6. Plug all that stuff in to conv_general_dilated, and feature_group_count is the length of the vectorized tensor, and it is basically saying that each part of the vectorized tensor is treated separately in the filter.

It must be the case that channel = input * feature_group_count. See: https://jax.readthedocs.io/en/latest/notebooks/convolutions.html#id1 and https://www.tensorflow.org/xla/operation_semantics#conv_convolution

Parameters:

Name Type Description Default
D int

dimension of the images

required
image Array

image data, shape (batch,in_c,spatial,tensor)

required
filter_image Array

the convolution filter, shape (out_c,in_c,spatial,tensor)

required
is_torus Union[tuple[bool, ...], bool]

what dimensions of the image are toroidal

required
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[Union[str, int, tuple[tuple[int, int], ...]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D, defaults to 1

1
tensor_expand bool

expand the tensor of image and filter to do tensor convolution, defaults to True. If there is something more complicated going on (e.g. conv_contract), you can skip this step.

True

Returns:

Type Description
Array

convolved_image, shape (batch,out_c,spatial,tensor)

Source code in ginjax/geometric/functional_geometric_image.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
@eqx.filter_jit
def convolve(
    D: int,
    image: jax.Array,
    filter_image: jax.Array,
    is_torus: Union[tuple[bool, ...], bool],
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
    tensor_expand: bool = True,
) -> jax.Array:
    """
    Here is how this function works:

    1. Expand the geom_image to its torus shape, i.e. add filter.m cells all around the perimeter of the image
    2. Do the tensor product (with 1s) to each image.k, filter.k so that they are both image.k + filter.k tensors.
    That is if image.k=2, filter.k=1, do (D,D) => (D,D) x (D,) and (D,) => (D,D) x (D,) with tensors of 1s
    3. Now we shape the inputs to work with jax.lax.conv_general_dilated
    4. Put image in NHWC (batch, height, width, channel). Thus we vectorize the tensor
    5. Put filter in HWIO (height, width, input, output). Input is 1, output is the vectorized tensor
    6. Plug all that stuff in to conv_general_dilated, and feature_group_count is the length of the vectorized
    tensor, and it is basically saying that each part of the vectorized tensor is treated separately in the filter.

    It must be the case that channel = input * feature_group_count.
    See: https://jax.readthedocs.io/en/latest/notebooks/convolutions.html#id1 and
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

    args:
        D: dimension of the images
        image: image data, shape (batch,in_c,spatial,tensor)
        filter_image: the convolution filter, shape (out_c,in_c,spatial,tensor)
        is_torus: what dimensions of the image are toroidal
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1
        tensor_expand: expand the tensor of image and filter to do tensor convolution, defaults to True.
            If there is something more complicated going on (e.g. conv_contract), you can skip this step.

    returns:
        convolved_image, shape (batch,out_c,spatial,tensor)
    """
    assert 1 <= D <= 4  # for now
    assert image.shape[1] == filter_image.shape[1], (
        f"Second axis (in_channels) for image and filter_image "
        f"must equal, but got image {image.shape} and filter {filter_image.shape}"
    )

    filter_spatial_dims, _ = parse_shape(filter_image.shape[2:], D)
    out_c, in_c = filter_image.shape[:2]
    batch = len(image)

    if tensor_expand:
        img_expanded, filter_expanded = pre_tensor_product_expand(
            D, image, filter_image, a_offset=2, b_offset=2, dtype=jnp.float32
        )
    else:
        img_expanded, filter_expanded = image, filter_image

    _, output_k = parse_shape(filter_expanded.shape[2:], D)
    image_spatial_dims, input_k = parse_shape(img_expanded.shape[2:], D)
    channel_length = D**input_k

    # convert the image to NHWC (or NHWDC), treating all the pixel values as channels
    # (batch,in_c,spatial,in_tensor) -> (batch,spatial,in_tensor,in_c)
    img_formatted = jnp.moveaxis(img_expanded, 1, -1)
    # (batch,spatial,in_tensor,in_c) -> (batch,spatial,in_tensor*in_c)
    img_formatted = img_formatted.reshape((batch,) + image_spatial_dims + (channel_length * in_c,))

    # convert filter to HWIO (or HWDIO)
    # (out_c,in_c,spatial,out_tensor) -> (spatial,in_c,out_tensor,out_c)
    filter_formatted = jnp.moveaxis(jnp.moveaxis(filter_expanded, 0, -1), 0, D)
    # (spatial,in_c,out_tensor,out_c) -> (spatial,in_c,out_tensor*out_c)
    filter_formatted = filter_formatted.reshape(
        filter_spatial_dims + (in_c, channel_length * out_c)
    )

    # (batch,spatial,out_tensor*out_c)
    convolved_array = convolve_ravel(
        D, img_formatted, filter_formatted, is_torus, stride, padding, lhs_dilation, rhs_dilation
    )
    out_shape = convolved_array.shape[:-1] + (D,) * output_k + (out_c,)
    return jnp.moveaxis(convolved_array.reshape(out_shape), -1, 1)  # move out_c to 2nd axis

convolve_ravel(D: int, image: jax.Array, filter_image: jax.Array, is_torus: Union[tuple[bool, ...], bool], stride: Union[int, tuple[int, ...]] = 1, padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1) -> jax.Array ¤

Raveled verson of convolution. Assumes the channels are all lined up correctly for the tensor convolution. This assumes that the feature_group_count is image in_c // filter in_c.

See convolve for a full description of this function.

Parameters:

Name Type Description Default
D int

dimension of the images

required
image Array

image data, shape (batch,spatial,tensor*in_c)

required
filter_image Array

the convolution filter, shape (spatial,in_c,tensor*out_c)

required
is_torus Union[tuple[bool, ...], bool]

what dimensions of the image are toroidal

required
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[Union[str, int, tuple[tuple[int, int], ...]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D, defaults to 1

1

Returns:

Type Description
Array

convolved_image, shape (batch,spatial,tensor*out_c)

Source code in ginjax/geometric/functional_geometric_image.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
@eqx.filter_jit
def convolve_ravel(
    D: int,
    image: jax.Array,
    filter_image: jax.Array,
    is_torus: Union[tuple[bool, ...], bool],
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
) -> jax.Array:
    """
    Raveled verson of convolution. Assumes the channels are all lined up correctly for the tensor
    convolution. This assumes that the feature_group_count is image in_c // filter in_c.

    See [convolve](functional_geometric_image.md#ginjax.geometric.functional_geometric_image.convolve) for a full
    description of this function.

    args:
        D: dimension of the images
        image: image data, shape (batch,spatial,tensor*in_c)
        filter_image: the convolution filter, shape (spatial,in_c,tensor*out_c)
        is_torus: what dimensions of the image are toroidal
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1

    returns:
        convolved_image, shape (batch,spatial,tensor*out_c)
    """
    assert 1 <= D <= 4  # for now
    assert (isinstance(is_torus, tuple) and len(is_torus) == D) or isinstance(is_torus, bool), (
        "geom::convolve" f" is_torus must be bool or tuple of bools, but got {is_torus}"
    )

    if isinstance(is_torus, bool):
        is_torus = (is_torus,) * D

    filter_spatial_dims, _ = parse_shape(filter_image.shape, D)

    assert not (
        functools.reduce(lambda carry, N: carry or (N % 2 == 0), filter_spatial_dims, False)
        and (padding == "TORUS" or padding == "SAME" or padding is None)
    ), f"convolve: Filters with even sidelengths {filter_spatial_dims} require literal padding, not {padding}"

    if not isinstance(rhs_dilation, tuple):
        rhs_dilation = (rhs_dilation,) * D

    if not isinstance(stride, tuple):
        stride = (stride,) * D

    if padding is None:  # if unspecified, infer from is_torus
        padding = "TORUS" if len(list(filter(lambda x: x, is_torus))) else "SAME"

    if (lhs_dilation is not None) and isinstance(padding, str):
        print(
            "WARNING convolve: lhs_dilation (transposed convolution) should specify padding exactly, "
            "see https://arxiv.org/pdf/1603.07285.pdf for the appropriate cases."
        )

    if padding == "TORUS":
        image, padding_literal = get_torus_expanded(
            image, is_torus, filter_spatial_dims, rhs_dilation
        )
    elif padding == "VALID":
        padding_literal = ((0, 0),) * D
    elif padding == "SAME":
        padding_literal = get_same_padding(filter_spatial_dims, rhs_dilation)
    elif isinstance(padding, int):
        padding_literal = ((padding, padding),) * D
    else:
        padding_literal = padding

    spatial_l = ("XYZT")[:D]
    dimension_numbers = ("N" + spatial_l + "C", spatial_l + "IO", "N" + spatial_l + "C")

    assert (image.shape[-1] // filter_image.shape[-2]) == (image.shape[-1] / filter_image.shape[-2])
    channel_length = image.shape[-1] // filter_image.shape[-2]

    # (batch,spatial,out_tensor*out_c)
    convolved_array = jax.lax.conv_general_dilated(
        image,  # lhs
        filter_image,  # rhs
        stride,
        padding_literal,
        lhs_dilation=lhs_dilation,
        rhs_dilation=rhs_dilation,
        dimension_numbers=dimension_numbers,
        feature_group_count=channel_length,  # each tensor component is treated separately
    )
    return convolved_array

convolve_contract(D: int, image: jax.Array, filter_image: jax.Array, is_torus: Union[bool, tuple[bool, ...]], stride: Union[int, tuple[int, ...]] = 1, padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1) -> jax.Array ¤

Given an input k image and a k+k' filter, take the tensor convolution that contract k times with one index each from the image and filter. This implementation is slightly more efficient then doing the convolution and contraction separately by avoiding constructing the k+k+k' intermediate tensor. See convolve for a full description of the convolution.

Parameters:

Name Type Description Default
D int

dimension of the images

required
image Array

image data, shape (batch,in_c,spatial,tensor)

required
filter_image Array

the convolution filter, shape (out_c,in_c,spatial,tensor)

required
is_torus Union[bool, tuple[bool, ...]]

what dimensions of the image are toroidal

required
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[Union[str, int, tuple[tuple[int, int], ...]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D, defaults to 1

1

Returns:

Type Description
Array

convolved_image, shape (batch,out_c,spatial,tensor)

Source code in ginjax/geometric/functional_geometric_image.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
@eqx.filter_jit
def convolve_contract(
    D: int,
    image: jax.Array,
    filter_image: jax.Array,
    is_torus: Union[bool, tuple[bool, ...]],
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
) -> jax.Array:
    """
    Given an input k image and a k+k' filter, take the tensor convolution that contract k times with one index
    each from the image and filter. This implementation is slightly more efficient then doing the convolution
    and contraction separately by avoiding constructing the k+k+k' intermediate tensor. See
    [convolve](functional_geometric_image.md#ginjax.geometric.functional_geometric_image.convolve) for a full
    description of the convolution.

    args:
        D: dimension of the images
        image: image data, shape (batch,in_c,spatial,tensor)
        filter_image: the convolution filter, shape (out_c,in_c,spatial,tensor)
        is_torus: what dimensions of the image are toroidal
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1

    returns:
        convolved_image, shape (batch,out_c,spatial,tensor)
    """
    _, img_k = parse_shape(image.shape[2:], D)
    _, filter_k = parse_shape(filter_image.shape[2:], D)
    img_expanded = conv_contract_image_expand(D, image, filter_k).astype("float32")
    convolved_img = convolve(
        D,
        img_expanded,
        filter_image,
        is_torus,
        stride,
        padding,
        lhs_dilation,
        rhs_dilation,
        tensor_expand=False,
    )
    # then sum along first img_k tensor axes, this is the contraction
    return jnp.sum(convolved_img, axis=range(2 + D, 2 + D + img_k))

get_contraction_indices(initial_k: int, final_k: int, swappable_idxs: tuple[tuple[int, int], ...] = ()) -> list[tuple[tuple[int, int], ...]] ¤

Get all possible unique indices for multicontraction. Returns a list of indices. The indices are a tuple of tuples where each of the inner tuples are pairs of indices. For example, if initial_k=5, final_k = 4, one element of the list that is returned will be ((0,1), (2,3)), another will be ((1,4), (0,2)), etc.

Note that contracting (0,1) is the same as contracting (1,0). Also, contracting ((0,1),(2,3)) is the same as contracting ((2,3),(0,1)). In both of those cases, they won't be returned. There is also the optional argument swappable_idxs to specify indices that can be swapped without changing the contraction. Suppose we have A * c1 where c1 is a k=2, parity=0 invariant conv_filter. In that case, we can contract on either of its indices and it won't change the result because transposing the axes is a group operation.

Parameters:

Name Type Description Default
initial_k int

the starting number of indices that we have

required
final_k int

the final number of indices that we want to end up with

required
swappable_idxs tuple[tuple[int, int], ...]

Indices that can swapped w/o changing the contraction

()

Returns:

Type Description
list[tuple[tuple[int, int], ...]]

all the possible contraction indices

Source code in ginjax/geometric/functional_geometric_image.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def get_contraction_indices(
    initial_k: int,
    final_k: int,
    swappable_idxs: tuple[tuple[int, int], ...] = (),
) -> list[tuple[tuple[int, int], ...]]:
    """
    Get all possible unique indices for multicontraction. Returns a list of indices. The indices are a tuple of tuples
    where each of the inner tuples are pairs of indices. For example, if initial_k=5, final_k = 4, one element of the
    list that is returned will be ((0,1), (2,3)), another will be ((1,4), (0,2)), etc.

    Note that contracting (0,1) is the same as contracting (1,0). Also, contracting ((0,1),(2,3)) is the same as
    contracting ((2,3),(0,1)). In both of those cases, they won't be returned. There is also the optional
    argument swappable_idxs to specify indices that can be swapped without changing the contraction. Suppose
    we have A * c1 where c1 is a k=2, parity=0 invariant conv_filter. In that case, we can contract on either of
    its indices and it won't change the result because transposing the axes is a group operation.

    args:
        initial_k: the starting number of indices that we have
        final_k: the final number of indices that we want to end up with
        swappable_idxs: Indices that can swapped w/o changing the contraction

    returns:
        all the possible contraction indices
    """
    assert ((initial_k + final_k) % 2) == 0
    assert initial_k >= final_k
    assert final_k >= 0

    tuple_pairs = it.combinations(it.combinations(range(initial_k), 2), (initial_k - final_k) // 2)
    rows = np.array([np.array(pair).reshape((initial_k - final_k,)) for pair in tuple_pairs])
    unique_rows = np.array([True if len(np.unique(row)) == len(row) else False for row in rows])
    unique_pairs = rows[unique_rows]  # remove rows which have an index multiple times

    # replace every element of the second term of the swappable pair with the first term
    for a, b in swappable_idxs:
        unique_pairs[np.where(np.isin(unique_pairs, b))] = a

    # convert back to lists
    sorted_tuples = [
        sorted(sorted([x, y]) for x, y in zip(row[0::2], row[1::2])) for row in unique_pairs
    ]
    sorted_rows = np.array(
        [np.array(pair).reshape((initial_k - final_k,)) for pair in sorted_tuples]
    )
    unique_sorted_rows = np.unique(sorted_rows, axis=0)  # after sorting remove redundant rows

    # restore by elements of the swappable pairs to being in the sequences
    for pair in swappable_idxs:
        for row in unique_sorted_rows:
            locs = np.isin(row, pair)
            if len(np.where(locs)[0]) > 0:
                row[np.max(np.where(locs))] = pair[1]
                row[np.min(np.where(locs))] = pair[
                    0
                ]  # if there is only 1, it will get set to pair 0

    return [tuple((x, y) for x, y in zip(idxs[0::2], idxs[1::2])) for idxs in unique_sorted_rows]

multicontract(data: jax.Array, indices: tuple[tuple[int, int], ...], idx_shift: int = 0) -> jax.Array ¤

Perform the Kronecker Delta contraction on the data. Must have at least 2 dimensions, and because we implement with einsum, must have at most 52 dimensions. Indices a tuple of pairs of indices, also tuples.

Parameters:

Name Type Description Default
data Array

data to perform the contraction on

required
indices tuple[tuple[int, int], ...]

index pairs to perform the contractions on

required
idx_shift int

indices are the tensor indices, so if data has spatial indices or channel/batch indices in the beginning we shift over by idx_shift

0

Returns:

Type Description
Array

the contracted data

Source code in ginjax/geometric/functional_geometric_image.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
@functools.partial(jax.jit, static_argnums=[1, 2])
def multicontract(
    data: jax.Array, indices: tuple[tuple[int, int], ...], idx_shift: int = 0
) -> jax.Array:
    """
    Perform the Kronecker Delta contraction on the data. Must have at least 2 dimensions, and because we implement with
    einsum, must have at most 52 dimensions. Indices a tuple of pairs of indices, also tuples.

    args:
        data: data to perform the contraction on
        indices: index pairs to perform the contractions on
        idx_shift: indices are the tensor indices, so if data has spatial indices or channel/batch
            indices in the beginning we shift over by idx_shift

    returns:
        the contracted data
    """
    dimensions = len(data.shape)
    assert dimensions + len(indices) < 52
    assert dimensions >= 2
    # all indices must be unique, indices must be greater than 0 and less than dimensions

    einstr = list(LETTERS[:dimensions])
    for i, (idx1, idx2) in enumerate(indices):
        einstr[idx1 + idx_shift] = einstr[idx2 + idx_shift] = LETTERS[-(i + 1)]

    return jnp.einsum("".join(einstr), data)

raise_lower(data: jax.Array, metric_tensor: jax.Array, metric_tensor_inv: jax.Array, from_axes: tuple[bool, ...], to_axes: tuple[bool, ...], precision: Optional[jax.lax.Precision] = None) -> jax.Array ¤

Raise or lower the axes of a tensor or tensor image according to the metric tensor and axes.

Parameters:

Name Type Description Default
data Array

a tensor, or tensor image, shape (...,tensor)

required
metric_tensor Array

the metric tensor g_ij, shape (...,tensor)

required
metric_tensor_inv Array

the inverse metric tensor, g^ij. Must be same spatial shape as this

required
from_axes tuple[bool, ...]

covariant axes you are starting at, True for covariant, False contravariant

required
to_axes tuple[bool, ...]

covariant axes to convert to, True for covariant, False contravariant

required
precision Optional[Precision]

precision used for einsum

None

Returns:

Type Description
Array

the data with the modified axes

Source code in ginjax/geometric/functional_geometric_image.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
def raise_lower(
    data: jax.Array,
    metric_tensor: jax.Array,
    metric_tensor_inv: jax.Array,
    from_axes: tuple[bool, ...],
    to_axes: tuple[bool, ...],
    precision: Optional[jax.lax.Precision] = None,
) -> jax.Array:
    """
    Raise or lower the axes of a tensor or tensor image according to the metric tensor and axes.

    args:
        data: a tensor, or tensor image, shape (...,tensor)
        metric_tensor: the metric tensor g_ij, shape (...,tensor)
        metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
        from_axes: covariant axes you are starting at, True for covariant, False contravariant
        to_axes: covariant axes to convert to, True for covariant, False contravariant
        precision: precision used for einsum

    returns:
        the data with the modified axes
    """
    assert len(from_axes) == len(to_axes)
    k = len(from_axes)
    assert k < 13

    # convert to 0 if unchanged, or -1 if upper->lower and 1 for lower->upper
    int_axes = tuple(
        0 if from_axis == to_axis else (-2 * int(to_axis) + 1)
        for from_axis, to_axis in zip(from_axes, to_axes)
    )
    if int_axes == (0,) * k:  # no axes are changed
        return data

    changed_idxs = list(filter(lambda x: int_axes[x] != 0, range(k)))
    einstr = f"...{LETTERS[:k]},"
    einstr += ",".join(["..." + LETTERS[13 + i] + LETTERS[i] for i in changed_idxs])
    einstr += "->..."
    einstr += "".join(
        [LETTERS[i] if int_axis == 0 else LETTERS[13 + i] for i, int_axis in enumerate(int_axes)]
    )

    changed_axes = filter(lambda x: x != 0, int_axes)
    metric_tensors = tuple(
        metric_tensor_inv if axis == 1 else metric_tensor for axis in changed_axes
    )
    tensor_inputs = (data,) + metric_tensors

    return jnp.einsum(einstr, *tensor_inputs, precision=precision)

get_rotated_keys(D: int, spatial_dims: tuple[int, ...], gg: np.ndarray) -> np.ndarray ¤

Get the rotated keys of data when it will be rotated by gg. Note that we rotate the key vector indices by the inverse of gg per the definition (this is done by key_array @ gg, rather than gg @ key_array). When the spatial_dims are not square, this gets a little tricky. The gg needs to be a concrete (numpy) array, not a traced jax array.

Parameters:

Name Type Description Default
D int

dimension of image

required
spatial_dims tuple[int, ...]

the spatial dimensions of the data to be rotated

required
gg ndarray

group operation

required

Returns:

Type Description
ndarray

the rotated keys

Source code in ginjax/geometric/functional_geometric_image.py
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
def get_rotated_keys(D: int, spatial_dims: tuple[int, ...], gg: np.ndarray) -> np.ndarray:
    """
    Get the rotated keys of data when it will be rotated by gg. Note that we rotate the key vector indices
    by the inverse of gg per the definition (this is done by key_array @ gg, rather than gg @ key_array).
    When the spatial_dims are not square, this gets a little tricky.
    The gg needs to be a concrete (numpy) array, not a traced jax array.

    args:
        D: dimension of image
        spatial_dims: the spatial dimensions of the data to be rotated
        gg: group operation

    returns:
        the rotated keys
    """
    rotated_spatial_dims = tuple(np.abs(gg @ np.array(spatial_dims)))

    # When spatial_dims is nonsquare, we have to subtract one version, then add the rotated version.
    centering_coords = (np.array(spatial_dims).reshape((1, D)) - 1) / 2
    rot_centering_coords = (np.array(rotated_spatial_dims).reshape((1, D)) - 1) / 2

    # rotated keys will need to have the rotated_spatial_dims numbers
    key_array = np.array([key for key in it.product(*list(range(N) for N in rotated_spatial_dims))])
    shifted_key_array = key_array - rot_centering_coords
    return np.rint((shifted_key_array @ gg) + centering_coords).astype(int)

times_group_element(D: int, data: jax.Array, parity: int, gg: np.ndarray, covariant_axes: tuple[bool, ...], precision: Optional[jax.lax.Precision] = None) -> jax.Array ¤

Apply a group element of O(d) to the geometric image. First apply the action to the location of the pixels, then apply the action to the pixels themselves.

Parameters:

Name Type Description Default
D int

dimension of the data

required
data Array

data block of image data to rotate, shape (batch,spatial,tensor)

required
parity int

parity of the data, 0 for even parity, 1 for odd parity

required
gg ndarray

a DxD matrix that rotates the tensor. Note that you cannot vmap by this argument because it needs to deal with concrete values

required
covariant_axes tuple[bool, ...]

which axes of the tensor are covariant (True) or contravariant (False). Also specifies the number of tensor axes.

required
precision Optional[Precision]

einsum precision, normally uses lower precision, use jax.lax.Precision.HIGHEST for testing equality in unit tests

None

Returns:

Type Description
Array

the rotated image data

Source code in ginjax/geometric/functional_geometric_image.py
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
def times_group_element(
    D: int,
    data: jax.Array,
    parity: int,
    gg: np.ndarray,
    covariant_axes: tuple[bool, ...],
    precision: Optional[jax.lax.Precision] = None,
) -> jax.Array:
    """
    Apply a group element of O(d) to the geometric image. First apply the action to the
    location of the pixels, then apply the action to the pixels themselves.

    args:
        D: dimension of the data
        data: data block of image data to rotate, shape (batch,spatial,tensor)
        parity: parity of the data, 0 for even parity, 1 for odd parity
        gg: a DxD matrix that rotates the tensor. Note that you cannot vmap by this argument
            because it needs to deal with concrete values
        covariant_axes: which axes of the tensor are covariant (True) or contravariant (False).
            Also specifies the number of tensor axes.
        precision: einsum precision, normally uses lower precision, use jax.lax.Precision.HIGHEST
            for testing equality in unit tests

    returns:
        the rotated image data
    """
    n_lead = data.ndim - D - len(covariant_axes)
    spatial_dims, k = parse_shape(data.shape[n_lead:], D)
    sign, _ = jnp.linalg.slogdet(gg)
    parity_flip = sign**parity  # if parity=1, the flip operators don't flip the tensors

    rotated_spatial_dims = tuple(np.abs(gg @ np.array(spatial_dims)))
    rotated_keys = get_rotated_keys(D, spatial_dims, gg)

    # hash, then reshape keys
    vmap_hash = jax.vmap(lambda x: x[hash(D, spatial_dims, rotated_keys)])
    rotated_pixels = vmap_hash(data.reshape((-1,) + spatial_dims + (D,) * k)).reshape(
        (data.shape[:n_lead] + rotated_spatial_dims + (D,) * k)
    )

    if k == 0:
        newdata = 1.0 * rotated_pixels * parity_flip
    else:
        # applying the rotation to tensors is essentially multiplying each index, which we can think of as a
        # vector, by the group action. The image pixels have already been rotated.
        einstr = f"...{LETTERS[:k]},"
        einstr += ",".join(
            [
                LETTERS[i] + LETTERS[i + 13] if covariant else LETTERS[i + 13] + LETTERS[i]
                for i, covariant in enumerate(covariant_axes)
            ]
        )
        einstr += f"->...{LETTERS[13:13+k]}"
        tensor_inputs = (rotated_pixels,) + tuple(gg.T if cov else gg for cov in covariant_axes)
        newdata = jnp.einsum(einstr, *tensor_inputs, precision=precision) * (parity_flip)

    return newdata

times_D8_element(D: int, data: jax.Array, parity: int, gg: np.ndarray, covariant_axes: tuple[bool, ...], precision: Optional[jax.lax.Precision] = None) -> jax.Array ¤

For the very limited case of a 3x3 GeometricImage, we define the group action of D8 as treating the 8 pixels around the center pixel as if they were equally spaced radially around the center.

Parameters:

Name Type Description Default
D int

dimension of the data

required
data Array

data block of image data to rotate, shape (batch,spatial,tensor)

required
parity int

parity of the data, 0 for even parity, 1 for odd parity

required
gg ndarray

a DxD matrix that rotates the tensor. Note that you cannot vmap by this argument because it needs to deal with concrete values

required
covariant_axes tuple[bool, ...]

which axes of the tensor are covariant (True) or contravariant (False). Also specifies the number of tensor axes.

required
precision Optional[Precision]

einsum precision, normally uses lower precision, use jax.lax.Precision.HIGHEST for testing equality in unit tests

None

Returns:

Type Description
Array

the rotated image data

Source code in ginjax/geometric/functional_geometric_image.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
def times_D8_element(
    D: int,
    data: jax.Array,
    parity: int,
    gg: np.ndarray,
    covariant_axes: tuple[bool, ...],
    precision: Optional[jax.lax.Precision] = None,
) -> jax.Array:
    """
    For the very limited case of a 3x3 GeometricImage, we define the group action of D8 as treating
    the 8 pixels around the center pixel as if they were equally spaced radially around the center.

    args:
        D: dimension of the data
        data: data block of image data to rotate, shape (batch,spatial,tensor)
        parity: parity of the data, 0 for even parity, 1 for odd parity
        gg: a DxD matrix that rotates the tensor. Note that you cannot vmap by this argument
            because it needs to deal with concrete values
        covariant_axes: which axes of the tensor are covariant (True) or contravariant (False).
            Also specifies the number of tensor axes.
        precision: einsum precision, normally uses lower precision, use jax.lax.Precision.HIGHEST
            for testing equality in unit tests

    returns:
        the rotated image data
    """
    n_lead = data.ndim - D - len(covariant_axes)
    spatial_dims, k = parse_shape(data.shape[n_lead:], D)
    sign, _ = jnp.linalg.slogdet(gg)
    parity_flip = sign**parity  # if parity=1, the flip operators don't flip the tensors

    # this function is only defined for cubes with sidelength 3
    assert spatial_dims == (3,) * D

    centering_coords = (np.array(spatial_dims).reshape((1, D)) - 1) / 2

    # (3**D,D)
    key_array = np.array([key for key in it.product(*list(range(N) for N in spatial_dims))])
    shifted_key_array = key_array - centering_coords
    rotated_keys = shifted_key_array @ gg
    # leave the all zeros center key alone, but scale the others so the max value is 1
    rotated_keys = np.where(
        np.all(np.isclose(rotated_keys, 0, atol=TINY, rtol=TINY), axis=1, keepdims=True),
        np.zeros_like(rotated_keys),
        rotated_keys / np.max(np.abs(rotated_keys) + 1e-5, axis=1, keepdims=True),
    )
    rotated_keys = np.rint(rotated_keys + centering_coords)

    # hash, then reshape keys
    vmap_hash = jax.vmap(lambda x: x[hash(D, spatial_dims, rotated_keys)])
    rotated_pixels = vmap_hash(data.reshape((-1,) + spatial_dims + (D,) * k)).reshape(
        (data.shape[:n_lead] + spatial_dims + (D,) * k)
    )

    if k == 0:
        newdata = 1.0 * rotated_pixels * parity_flip
    else:
        # applying the rotation to tensors is essentially multiplying each index, which we can think of as a
        # vector, by the group action. The image pixels have already been rotated.
        einstr = f"...{LETTERS[:k]},"
        einstr += ",".join(
            [
                LETTERS[i] + LETTERS[i + 13] if covariant else LETTERS[i + 13] + LETTERS[i]
                for i, covariant in enumerate(covariant_axes)
            ]
        )
        einstr += f"->...{LETTERS[13:13+k]}"
        tensor_inputs = (rotated_pixels,) + tuple(gg.T if cov else gg for cov in covariant_axes)
        newdata = jnp.einsum(einstr, *tensor_inputs, precision=precision) * (parity_flip)

    return newdata

tensor_times_gg(tensor: jax.Array, parity: int, gg: np.ndarray, precision: Optional[jax.lax.Precision] = None) -> jax.Array ¤

Apply a group element of SO(2) or SO(3) to a single tensor.

Parameters:

Name Type Description Default
tensor Array

data of the tensor

required
parity int

parity of the data, 0 for even parity, 1 for odd parity

required
gg ndarray

a DxD matrix that rotates the tensor. Note that you cannot vmap by this argument because it needs to deal with concrete values

required
precision Optional[Precision]

eisnum precision, normally uses lower precision, use jax.lax.Precision.HIGH for testing equality in unit tests

None

Returns:

Type Description
Array

rotated tensor data

Source code in ginjax/geometric/functional_geometric_image.py
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def tensor_times_gg(
    tensor: jax.Array,
    parity: int,
    gg: np.ndarray,
    precision: Optional[jax.lax.Precision] = None,
) -> jax.Array:
    """
    Apply a group element of SO(2) or SO(3) to a single tensor.

    args:
        tensor: data of the tensor
        parity: parity of the data, 0 for even parity, 1 for odd parity
        gg: a DxD matrix that rotates the tensor. Note that you cannot vmap
            by this argument because it needs to deal with concrete values
        precision: eisnum precision, normally uses lower precision, use
            jax.lax.Precision.HIGH for testing equality in unit tests

    returns:
        rotated tensor data
    """
    k = len(tensor.shape)
    sign, _ = jnp.linalg.slogdet(gg)
    parity_flip = sign**parity  # if parity=1, the flip operators don't flip the tensors

    if k == 0:
        newdata = 1.0 * tensor * parity_flip
    else:
        # applying the rotation to tensors is essentially multiplying each index, which we can think of as a
        # vector, by the group action. The image pixels have already been rotated.
        einstr = LETTERS[:k] + ","
        einstr += ",".join([LETTERS[i + 13] + LETTERS[i] for i in range(k)])
        tensor_inputs = (tensor,) + k * (gg,)
        newdata = jnp.einsum(einstr, *tensor_inputs, precision=precision) * (parity_flip)

    return newdata

translate(D: int, data: jax.Array, tau: jax.Array, n_lead: int = 0) -> jax.Array ¤

Translate an image by translation tau, on the torus. Translations on the data matrix are ij ordering. For example, a translation of [1,-1] moves the down one row, then to the left one column.

Parameters:

Name Type Description Default
D int

dimension of the image

required
data Array

image data with n_lead batch axes, followed by spatial then tensor

required
tau Array

the translation

required
n_lead int

number of leading batch axes

0

Returns:

Type Description
Array

translated image data

Source code in ginjax/geometric/functional_geometric_image.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def translate(D: int, data: jax.Array, tau: jax.Array, n_lead: int = 0) -> jax.Array:
    """
    Translate an image by translation tau, on the torus. Translations on the data matrix are ij
    ordering. For example, a translation of [1,-1] moves the down one row, then to the left one
    column.

    args:
        D: dimension of the image
        data: image data with n_lead batch axes, followed by spatial then tensor
        tau: the translation
        n_lead: number of leading batch axes

    returns:
        translated image data
    """
    spatial_dims, k = parse_shape(data.shape[n_lead:], D)
    # does key array need to be a numpy array?
    key_array = jnp.array(
        np.array([key for key in it.product(*list(range(N) for N in spatial_dims))])
    )
    translated_keys = key_array - tau

    # hash, then reshape keys
    vmap_hash = jax.vmap(lambda x: x[hash(D, spatial_dims, translated_keys)])
    translated_pixels = vmap_hash(data.reshape((-1,) + spatial_dims + (D,) * k)).reshape(
        (data.shape[:n_lead] + spatial_dims + (D,) * k)
    )
    return translated_pixels

norm(idx_shift: int, data: jax.Array, keepdims: bool = False) -> jax.Array ¤

Perform the frobenius norm on each pixel tensor, returning a scalar image

Parameters:

Name Type Description Default
idx_shift int

the number of leading axes before the tensor, should be D for spatial plus the batch and spatial axes if they

required
data Array

image data, shape (spatial,tensor)

required
keepdims bool

passed to jnp.linalg.norm

False

Returns:

Type Description
Array

the data of a scalar image after performing the norm

Source code in ginjax/geometric/functional_geometric_image.py
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
def norm(idx_shift: int, data: jax.Array, keepdims: bool = False) -> jax.Array:
    """
    Perform the frobenius norm on each pixel tensor, returning a scalar image

    args:
        idx_shift: the number of leading axes before the tensor, should be D for spatial plus
            the batch and spatial axes if they
        data: image data, shape (spatial,tensor)
        keepdims: passed to jnp.linalg.norm

    returns:
        the data of a scalar image after performing the norm
    """
    assert (
        idx_shift <= data.ndim
    ), f"norm: idx shift must be at most ndim, but {idx_shift} > {data.ndim}"
    if data.ndim == idx_shift:  # in this case, reshape creates an axis, so we need to collapse it
        keepdims = False

    normed_data = jnp.linalg.norm(data.reshape(data.shape[:idx_shift] + (-1,)), axis=idx_shift)
    if keepdims:
        extra_axes = data.ndim - normed_data.ndim
        return normed_data.reshape(normed_data.shape + (1,) * extra_axes)
    else:
        return normed_data

max_pool(D: int, image_data: jax.Array, patch_len: int, use_norm: bool = True, comparator_image: Optional[jax.Array] = None) -> jax.Array ¤

Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined by the value of comparator_image if present, then the norm of image_data if use_norm is true, then finally the image_data otherwise.

Parameters:

Name Type Description Default
D int

the dimension of the space, must be between 1 and 4 inclusive

required
image_data Array

the image data, shape (spatial,tensor)

required
patch_len int

the side length of the patches, must evenly divide all spatial dims

required
use_norm bool

if true, use the norm (over the tensor) of the image as the comparator image

True
comparator_image Optional[Array]

scalar image whose argmax is used to determine what value to use.

None

Returns:

Type Description
Array

the image data that has been max pooled, shape (spatial,tensor)

Source code in ginjax/geometric/functional_geometric_image.py
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
@functools.partial(jax.jit, static_argnums=[0, 2, 3])
def max_pool(
    D: int,
    image_data: jax.Array,
    patch_len: int,
    use_norm: bool = True,
    comparator_image: Optional[jax.Array] = None,
) -> jax.Array:
    """
    Perform a max pooling operation where the length of the side of each patch is patch_len. Max is
    determined by the value of comparator_image if present, then the norm of image_data if use_norm
    is true, then finally the image_data otherwise.

    args:
        D: the dimension of the space, must be between 1 and 4 inclusive
        image_data: the image data, shape (spatial,tensor)
        patch_len: the side length of the patches, must evenly divide all spatial dims
        use_norm: if true, use the norm (over the tensor) of the image as the comparator image
        comparator_image: scalar image whose argmax is used to determine what value to use.

    returns:
        the image data that has been max pooled, shape (spatial,tensor)
    """
    assert 1 <= D <= 4
    spatial_dims, k = parse_shape(image_data.shape, D)
    assert (comparator_image is not None) or use_norm or (k == 0)

    spatial_l = ("XYZT")[:D]
    dimension_numbers = ("N" + spatial_l + "C", "OI" + spatial_l, "NC" + spatial_l)

    # TODO: use the batch dimension of dilated_patches correctly
    patches = jax.lax.conv_general_dilated_patches(
        image_data.reshape((1,) + spatial_dims + (-1,)).astype("float32"),  # NHWDC
        filter_shape=(patch_len,) * D,  # filter_shape
        window_strides=(patch_len,) * D,
        padding=((0, 0),) * D,  # padding
        dimension_numbers=dimension_numbers,
    )[
        0
    ]  # no batch. Out shape (batch,channels,spatial)

    new_spatial_dims = patches.shape[1:]
    patches = patches.reshape((D**k, patch_len**D, -1))  # (tensor,patch,num_patches)

    if comparator_image is not None:
        assert comparator_image.shape == spatial_dims
        comparator_patches = jax.lax.conv_general_dilated_patches(
            comparator_image.reshape((1,) + spatial_dims + (-1,)).astype("float32"),  # NHWDC
            filter_shape=(patch_len,) * D,  # filter_shape
            window_strides=(patch_len,) * D,
            padding=((0, 0),) * D,  # padding
            dimension_numbers=dimension_numbers,
        )[0]
        comparator_patches = comparator_patches.reshape((patch_len**D, -1))
    elif use_norm:
        comparator_patches = jnp.linalg.norm(patches, axis=0)  # (patch,num_patches)
    else:
        assert len(patches) == 1  # can only use image as your comparator if its a scalar image
        comparator_patches = patches[0]

    idxs = jnp.argmax(comparator_patches, axis=0)  # (num_patches,)
    vmap_max = jax.vmap(lambda patch, idx: patch[:, idx], in_axes=(2, 0))
    return vmap_max(patches, idxs).reshape(new_spatial_dims + (D,) * k)

average_pool(D: int, image_data: jax.Array, patch_len: int) -> jax.Array ¤

Perform a average pooling operation where the length of the side of each patch is patch_len. This is equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the filter, the stride length is patch_len, and the padding is 'VALID'.

Parameters:

Name Type Description Default
D int

dimension of data

required
image_data Array

image data, shape (spatial,tensor)

required
patch_len int

the side length of the patches, must evenly divide the sidelength

required

Returns:

Type Description
Array

the image data after being averaged pooled, shape (spatial, tensor)

Source code in ginjax/geometric/functional_geometric_image.py
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
@functools.partial(jax.jit, static_argnums=[0, 2])
def average_pool(D: int, image_data: jax.Array, patch_len: int) -> jax.Array:
    """
    Perform a average pooling operation where the length of the side of each patch is patch_len. This is
    equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the
    filter, the stride length is patch_len, and the padding is 'VALID'.

    args:
        D: dimension of data
        image_data: image data, shape (spatial,tensor)
        patch_len: the side length of the patches, must evenly divide the sidelength

    returns:
        the image data after being averaged pooled, shape (spatial, tensor)
    """
    spatial_dims, _ = parse_shape(image_data.shape, D)
    assert functools.reduce(lambda carry, N: carry and (N % patch_len == 0), spatial_dims, True)
    # convolve expects (out_c,in_c,h,w)
    filter_data = (1 / (patch_len**D)) * jnp.ones((1, 1) + (patch_len,) * D)

    # reshape to (1,h,w,tensor) because convolve expects (c,h,w,tensor)
    return convolve(
        D,
        image_data[None, None],
        filter_data,
        False,
        stride=(patch_len,) * D,
        padding="VALID",
    )[0, 0]