Skip to content

Layers

ginjax.layers ¤

ConvContract ¤

Bases: Module

A layer then performs the convolution followed by contraction.

Source code in ginjax/layers.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class ConvContract(eqx.Module):
    """
    A layer then performs the convolution followed by contraction.
    """

    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]
    bias: dict[tuple[tuple[bool, ...], int], jax.Array]
    invariant_filters: geom.MultiImage

    input_keys: geom.Signature = eqx.field(static=True)
    target_keys: geom.Signature = eqx.field(static=True)
    use_bias: Union[str, bool] = eqx.field(static=True)
    stride: Union[int, tuple[int, ...]] = eqx.field(static=True)
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = eqx.field(static=True)
    lhs_dilation: Optional[tuple[int, ...]] = eqx.field(static=True)
    rhs_dilation: Union[int, tuple[int, ...]] = eqx.field(static=True)
    D: int = eqx.field(static=True)
    fast_mode: bool = eqx.field(static=True)
    missing_filter: bool = eqx.field(static=True)

    def __init__(
        self: Self,
        input_keys: geom.Signature,
        target_keys: geom.Signature,
        invariant_filters: geom.MultiImage,
        use_bias: Union[str, bool] = "auto",
        stride: Union[int, tuple[int, ...]] = 1,
        padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
        lhs_dilation: Optional[tuple[int, ...]] = None,
        rhs_dilation: Union[int, tuple[int, ...]] = 1,
        key: Any = None,
    ):
        """
        Constructor for equivariant tensor convolution then contraction.

        args:
            input_keys: A mapping of (k,p) to an integer representing the input channels
            target_keys: A mapping of (k,p) to an integer representing the output channels
            invariant_filters: A MultiImage of the invariant filters to build the convolution filters
            use_bias: One of 'auto', 'mean', or 'scalar', or True for 'auto' or False for no bias.
                Mean uses a mean scale for every type, scalar uses a regular bias for scalars only
                and auto does regular bias for scalars and mean for non-scalars.
            stride: convolution stride, defaults to (1,)*self.D
            padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
                defaults to 'TORUS' if image.is_torus, else 'SAME'
            lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
            rhs_dilation: amount of dilation to apply to filter in each dimension D
        """
        self.input_keys = input_keys
        self.target_keys = target_keys
        self.invariant_filters = invariant_filters
        self.use_bias = use_bias
        self.stride = stride
        self.padding = padding
        self.lhs_dilation = lhs_dilation
        self.rhs_dilation = rhs_dilation

        self.D = invariant_filters.D
        # if a particular desired convolution for input_keys -> target_keys is missing the needed
        # filter (possibly because an equivariant one doesn't exist), this is set to true
        self.missing_filter = False

        if isinstance(use_bias, bool):
            use_bias = "auto" if use_bias else use_bias
        elif isinstance(use_bias, str):
            assert use_bias in {"auto", "mean", "scalar"}
        else:
            raise ValueError(
                f"ConvContract: bias must be str or bool, but found {type(use_bias)}:{use_bias}"
            )

        self.weights = {}  # presumably some way to jax.lax.scan this?
        self.bias = {}
        all_filter_spatial_dims = []
        for (in_k, in_p), in_c in self.input_keys:
            self.weights[(in_k, in_p)] = {}
            for (out_k, out_p), out_c in self.target_keys:
                key, subkey1, subkey2 = random.split(key, num=3)

                # filters are always contravariant
                filter_key = ((False,) * (len(in_k) + len(out_k)), (in_p + out_p) % 2)
                if filter_key not in self.invariant_filters:
                    self.missing_filter = True
                    continue  # relevant when there isn't an N=3, (0,1) filter

                num_filters = len(self.invariant_filters[filter_key])
                if False and filter_key == ((), 0):
                    # TODO: Currently unused, a work in progress
                    weight_per_ff = []
                    # TODO: jax.lax.scan here instead
                    for conv_filter, tensor_mul in zip(
                        self.invariant_filters[filter_key],
                        [1, (1 + 8 / 9), (1 + 2 / 3)],
                        # [1, 1, 1],
                    ):
                        key, subkey = random.split(key)

                        # number of weights that will appear in a single component output.
                        tensor_mul = scipy.special.comb(jnp.sum(conv_filter), 2, repetition=True)
                        # tensor_mul = jnp.sum(conv_filter**2, axis=tuple(range(self.D))) * tensor_mul
                        bound = jnp.sqrt(1 / (in_c * num_filters * tensor_mul))

                        weight_per_ff.append(
                            random.uniform(subkey, shape=(out_c, in_c), minval=-bound, maxval=bound)
                        )
                    self.weights[(in_k, in_p)][(out_k, out_p)] = jnp.stack(weight_per_ff, axis=-1)

                    # # bound = jnp.sqrt(3 / (0.085 * in_c * num_filters)) # tanh multiplier
                    # bound = jnp.sqrt(3 / (in_c * num_filters))
                    # key, subkey = random.split(key)
                    # rand_weights = random.uniform(
                    #     subkey, shape=(out_c, in_c, num_filters), minval=-bound, maxval=bound
                    # )
                    # self.weights[(in_k, in_p)][(out_k, out_p)] = rand_weights

                else:
                    # Works really well, not sure why?
                    filter_spatial_dims, _ = geom.parse_shape(
                        self.invariant_filters[filter_key].shape[1:], self.D
                    )
                    bound_shape = (in_c,) + filter_spatial_dims + (self.D,) * len(in_k)
                    bound = 1 / jnp.sqrt(math.prod(bound_shape))
                    self.weights[(in_k, in_p)][(out_k, out_p)] = random.uniform(
                        subkey1,
                        shape=(out_c, in_c, len(self.invariant_filters[filter_key])),
                        minval=-bound,
                        maxval=bound,
                    )
                    all_filter_spatial_dims.append(filter_spatial_dims)

                if use_bias:
                    # this may get set multiple times, bound could be different but not a huge issue?
                    self.bias[(out_k, out_p)] = random.uniform(
                        subkey2,
                        shape=(out_c,),
                        minval=-bound,
                        maxval=bound,
                    )

        # If all the in_c match, all out_c match, and all the filter dims match, can use fast_mode
        self.fast_mode = (
            (not self.missing_filter)
            and (len(set([in_c for _, in_c in input_keys])) == 1)
            and (len(set([out_c for _, out_c in target_keys])) == 1)
            and (len(set(all_filter_spatial_dims)) == 1)
        )
        self.fast_mode = False

    def fast_convolve(
        self: Self,
        input_multi_image: geom.MultiImage,
        weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    ) -> geom.MultiImage:
        """
        Convolve when all filter_spatial_dims, in_c, and out_c match, can do a single convolve
        instead of multiple between each type. Sadly, only ~20% speedup.
        """
        # These must all be equal to call fast_convolve
        in_c = self.input_keys[0][1]
        out_c = self.target_keys[0][1]

        one_img = next(iter(input_multi_image.values()))
        spatial_dims, _ = geom.parse_shape(one_img.shape[1:], self.D)
        one_filter = next(iter(self.invariant_filters.values()))
        filter_spatial_dims, _ = geom.parse_shape(one_filter.shape[1:], self.D)

        image_ravel = jnp.zeros(spatial_dims + (0, in_c))
        filter_ravel = jnp.zeros((in_c,) + filter_spatial_dims + (0, out_c))
        for (in_k, in_p), image_block in input_multi_image.items():
            # (in_c,spatial,tensor) -> (spatial,-1,in_c)
            img = jnp.moveaxis(image_block.reshape((in_c,) + spatial_dims + (-1,)), 0, -1)
            image_ravel = jnp.concatenate([image_ravel, img], axis=-2)

            filter_ravel_in = jnp.zeros(
                (in_c,) + filter_spatial_dims + (self.D,) * len(in_k) + (0, out_c)
            )
            for (out_k, out_p), weight_block in weights[(in_k, in_p)].items():
                filter_key = (in_k + out_k, (in_p + out_p) % 2)  # tuple addition for k is right?

                # (out_c,in_c,num_filters),(num, spatial, tensor) -> (out_c,in_c,spatial,tensor)
                filter_block = jnp.einsum(
                    "ijk,k...->ij...",
                    weight_block,
                    jax.lax.stop_gradient(self.invariant_filters[filter_key]),
                )
                # (out_c,in_c,spatial,tensor) -> (in_c,spatial,in_tensor,-1,out_c)
                ff = jnp.moveaxis(
                    filter_block.reshape(
                        (out_c, in_c) + filter_spatial_dims + (self.D,) * len(in_k) + (-1,)
                    ),
                    0,
                    -1,
                )
                filter_ravel_in = jnp.concatenate([filter_ravel_in, ff], axis=-2)

            filter_ravel_in = filter_ravel_in.reshape(
                (in_c,) + filter_spatial_dims + (-1,) + (out_c,)
            )
            filter_ravel = jnp.concatenate([filter_ravel, filter_ravel_in], axis=-2)

        image_ravel = image_ravel.reshape(spatial_dims + (-1,))
        filter_ravel = jnp.moveaxis(filter_ravel, 0, self.D).reshape(
            filter_spatial_dims + (in_c, -1)
        )

        out = geom.convolve_ravel(
            self.D,
            image_ravel[None],  # add batch dim
            filter_ravel,
            input_multi_image.is_torus,
            self.stride,
            self.padding,
            self.lhs_dilation,
            self.rhs_dilation,
        )[0]
        new_spatial_dims = out.shape[: self.D]
        # (spatial,tensor_sum*out_c) -> (out_c,spatial,tensor_sum)
        out = jnp.moveaxis(out.reshape(new_spatial_dims + (-1, out_c)), -1, 0)

        out_k_sum = sum([self.D ** len(out_k) for (out_k, _), _ in self.target_keys])
        idx = 0
        out_multi_image = input_multi_image.empty()
        for in_k, in_p in input_multi_image.keys():
            length = (self.D ** len(in_k)) * out_k_sum
            # break off all the channels related to this particular in_k
            out_per_in = out[..., idx : idx + length].reshape(
                (out_c,) + new_spatial_dims + (self.D,) * len(in_k) + (-1,)
            )

            out_idx = 0
            for (out_k, out_p), _ in self.target_keys:
                out_length = self.D ** len(out_k)
                # separate the different out_k parts for particular in_k
                img_block = out_per_in[..., out_idx : out_idx + out_length]
                img_block = img_block.reshape(
                    (out_c,) + new_spatial_dims + (self.D,) * len(in_k + out_k)
                )
                contracted_img = jnp.sum(img_block, axis=range(1 + self.D, 1 + self.D + len(in_k)))

                if (out_k, out_p) in out_multi_image:  # it already has that key
                    out_multi_image[(out_k, out_p)] = (
                        contracted_img + out_multi_image[(out_k, out_p)]
                    )
                else:
                    out_multi_image.append(out_k, out_p, contracted_img)

                out_idx += out_length

            idx += length

        return out_multi_image

    def individual_convolve(
        self: Self,
        x: geom.MultiImage,
        weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    ) -> geom.MultiImage:
        """
        Function to perform convolve_contract on an entire MultiImage by doing the pairwise convolutions
        individually. This is necessary when filters have unequal sizes, or the in_c or out_c are
        not all equal. Weights is passed as an argument to make it easier to test this function.

        args:
            x: the input
            weights: the weights used to combine the invariant filters

        returns:
            the convolved MultiImage
        """
        if x.metric_tensor is not None and x.metric_tensor_inv is None:
            x.metric_tensor_inv = geom.get_metric_inverse(x.metric_tensor)

        # TODO: metric should only be carried over if the image isn't changing size
        out = x.empty()
        for (in_k, in_p), images_block in x.items():
            for (out_k, out_p), weight_block in weights[(in_k, in_p)].items():
                # filters are always contravariant
                filter_key = ((False,) * (len(in_k) + len(out_k)), (in_p + out_p) % 2)

                # (out_c,in_c,num_inv_filters) (num, spatial, tensor) -> (out_c,in_c,spatial,tensor)
                filter_block = jnp.einsum(
                    "ijk,k...->ij...",
                    weight_block,
                    jax.lax.stop_gradient(self.invariant_filters[filter_key]),
                )

                if x.metric_tensor is not None:
                    assert x.metric_tensor_inv is not None
                    # lower all axes to covariant
                    images_block = geom.raise_lower(
                        images_block,
                        x.metric_tensor.data,
                        x.metric_tensor_inv.data,
                        in_k,
                        (True,) * len(in_k),
                    )
                    # without a metric tensor, we assume that its the flat euclidean metric in
                    # which case lower == upper

                convolve_contracted_imgs = geom.convolve_contract(
                    x.D,
                    images_block[None],  # add batch dim
                    filter_block,
                    x.is_torus,
                    self.stride,
                    self.padding,
                    self.lhs_dilation,
                    self.rhs_dilation,
                )[0]

                if x.metric_tensor is not None:
                    assert x.metric_tensor_inv is not None
                    in_spatial, _ = geom.parse_shape(images_block.shape[1:], x.D)
                    out_spatial, _ = geom.parse_shape(convolve_contracted_imgs.shape[1:], x.D)
                    assert (
                        in_spatial == out_spatial
                    ), f"For convolution with a metric tensor, spatial dimensions cannot change"
                    # restore axes to proper lower/upper
                    convolve_contracted_imgs = geom.raise_lower(
                        convolve_contracted_imgs,
                        x.metric_tensor.data,
                        x.metric_tensor_inv.data,
                        (True,) * len(out_k),
                        out_k,
                    )

                if (out_k, out_p) in out:  # it already has that key
                    out[(out_k, out_p)] = convolve_contracted_imgs + out[(out_k, out_p)]
                else:
                    out.append(out_k, out_p, convolve_contracted_imgs)

        return out

    def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
        """
        The callable, calls either fast_convolve or individual_convolve. Currently fast_convolve
        is not used because it is not much faster.

        args:
            x: the input

        returns:
            the convolved MultiImage, which is a new object
        """
        if self.fast_mode:
            x = self.fast_convolve(x, self.weights)
        else:  # slow mode
            x = self.individual_convolve(x, self.weights)

        if self.use_bias:
            biased_x = x.empty()
            for (k, p), image in x.items():
                broadcast_shape = (len(self.bias[(k, p)]),) + (1,) * (self.D + len(k))
                if (k, p) == ((), 0) and (self.use_bias == "scalar" or self.use_bias == "auto"):
                    biased_x.append(k, p, image + self.bias[(k, p)].reshape(broadcast_shape))
                elif ((k, p) != ((), 0) and self.use_bias == "auto") or self.use_bias == "mean":
                    mean_image = jnp.mean(
                        image, axis=tuple(range(1, 1 + self.invariant_filters.D)), keepdims=True
                    )
                    biased_x.append(
                        k, p, image + mean_image * self.bias[(k, p)].reshape(broadcast_shape)
                    )

            return biased_x
        else:
            return x
__init__(input_keys: geom.Signature, target_keys: geom.Signature, invariant_filters: geom.MultiImage, use_bias: Union[str, bool] = 'auto', stride: Union[int, tuple[int, ...]] = 1, padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1, key: Any = None) ¤

Constructor for equivariant tensor convolution then contraction.

Parameters:

Name Type Description Default
input_keys Signature

A mapping of (k,p) to an integer representing the input channels

required
target_keys Signature

A mapping of (k,p) to an integer representing the output channels

required
invariant_filters MultiImage

A MultiImage of the invariant filters to build the convolution filters

required
use_bias Union[str, bool]

One of 'auto', 'mean', or 'scalar', or True for 'auto' or False for no bias. Mean uses a mean scale for every type, scalar uses a regular bias for scalars only and auto does regular bias for scalars and mean for non-scalars.

'auto'
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[Union[str, int, tuple[tuple[int, int], ...]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D

1
Source code in ginjax/layers.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def __init__(
    self: Self,
    input_keys: geom.Signature,
    target_keys: geom.Signature,
    invariant_filters: geom.MultiImage,
    use_bias: Union[str, bool] = "auto",
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
    key: Any = None,
):
    """
    Constructor for equivariant tensor convolution then contraction.

    args:
        input_keys: A mapping of (k,p) to an integer representing the input channels
        target_keys: A mapping of (k,p) to an integer representing the output channels
        invariant_filters: A MultiImage of the invariant filters to build the convolution filters
        use_bias: One of 'auto', 'mean', or 'scalar', or True for 'auto' or False for no bias.
            Mean uses a mean scale for every type, scalar uses a regular bias for scalars only
            and auto does regular bias for scalars and mean for non-scalars.
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D
    """
    self.input_keys = input_keys
    self.target_keys = target_keys
    self.invariant_filters = invariant_filters
    self.use_bias = use_bias
    self.stride = stride
    self.padding = padding
    self.lhs_dilation = lhs_dilation
    self.rhs_dilation = rhs_dilation

    self.D = invariant_filters.D
    # if a particular desired convolution for input_keys -> target_keys is missing the needed
    # filter (possibly because an equivariant one doesn't exist), this is set to true
    self.missing_filter = False

    if isinstance(use_bias, bool):
        use_bias = "auto" if use_bias else use_bias
    elif isinstance(use_bias, str):
        assert use_bias in {"auto", "mean", "scalar"}
    else:
        raise ValueError(
            f"ConvContract: bias must be str or bool, but found {type(use_bias)}:{use_bias}"
        )

    self.weights = {}  # presumably some way to jax.lax.scan this?
    self.bias = {}
    all_filter_spatial_dims = []
    for (in_k, in_p), in_c in self.input_keys:
        self.weights[(in_k, in_p)] = {}
        for (out_k, out_p), out_c in self.target_keys:
            key, subkey1, subkey2 = random.split(key, num=3)

            # filters are always contravariant
            filter_key = ((False,) * (len(in_k) + len(out_k)), (in_p + out_p) % 2)
            if filter_key not in self.invariant_filters:
                self.missing_filter = True
                continue  # relevant when there isn't an N=3, (0,1) filter

            num_filters = len(self.invariant_filters[filter_key])
            if False and filter_key == ((), 0):
                # TODO: Currently unused, a work in progress
                weight_per_ff = []
                # TODO: jax.lax.scan here instead
                for conv_filter, tensor_mul in zip(
                    self.invariant_filters[filter_key],
                    [1, (1 + 8 / 9), (1 + 2 / 3)],
                    # [1, 1, 1],
                ):
                    key, subkey = random.split(key)

                    # number of weights that will appear in a single component output.
                    tensor_mul = scipy.special.comb(jnp.sum(conv_filter), 2, repetition=True)
                    # tensor_mul = jnp.sum(conv_filter**2, axis=tuple(range(self.D))) * tensor_mul
                    bound = jnp.sqrt(1 / (in_c * num_filters * tensor_mul))

                    weight_per_ff.append(
                        random.uniform(subkey, shape=(out_c, in_c), minval=-bound, maxval=bound)
                    )
                self.weights[(in_k, in_p)][(out_k, out_p)] = jnp.stack(weight_per_ff, axis=-1)

                # # bound = jnp.sqrt(3 / (0.085 * in_c * num_filters)) # tanh multiplier
                # bound = jnp.sqrt(3 / (in_c * num_filters))
                # key, subkey = random.split(key)
                # rand_weights = random.uniform(
                #     subkey, shape=(out_c, in_c, num_filters), minval=-bound, maxval=bound
                # )
                # self.weights[(in_k, in_p)][(out_k, out_p)] = rand_weights

            else:
                # Works really well, not sure why?
                filter_spatial_dims, _ = geom.parse_shape(
                    self.invariant_filters[filter_key].shape[1:], self.D
                )
                bound_shape = (in_c,) + filter_spatial_dims + (self.D,) * len(in_k)
                bound = 1 / jnp.sqrt(math.prod(bound_shape))
                self.weights[(in_k, in_p)][(out_k, out_p)] = random.uniform(
                    subkey1,
                    shape=(out_c, in_c, len(self.invariant_filters[filter_key])),
                    minval=-bound,
                    maxval=bound,
                )
                all_filter_spatial_dims.append(filter_spatial_dims)

            if use_bias:
                # this may get set multiple times, bound could be different but not a huge issue?
                self.bias[(out_k, out_p)] = random.uniform(
                    subkey2,
                    shape=(out_c,),
                    minval=-bound,
                    maxval=bound,
                )

    # If all the in_c match, all out_c match, and all the filter dims match, can use fast_mode
    self.fast_mode = (
        (not self.missing_filter)
        and (len(set([in_c for _, in_c in input_keys])) == 1)
        and (len(set([out_c for _, out_c in target_keys])) == 1)
        and (len(set(all_filter_spatial_dims)) == 1)
    )
    self.fast_mode = False
fast_convolve(input_multi_image: geom.MultiImage, weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]) -> geom.MultiImage ¤

Convolve when all filter_spatial_dims, in_c, and out_c match, can do a single convolve instead of multiple between each type. Sadly, only ~20% speedup.

Source code in ginjax/layers.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def fast_convolve(
    self: Self,
    input_multi_image: geom.MultiImage,
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
) -> geom.MultiImage:
    """
    Convolve when all filter_spatial_dims, in_c, and out_c match, can do a single convolve
    instead of multiple between each type. Sadly, only ~20% speedup.
    """
    # These must all be equal to call fast_convolve
    in_c = self.input_keys[0][1]
    out_c = self.target_keys[0][1]

    one_img = next(iter(input_multi_image.values()))
    spatial_dims, _ = geom.parse_shape(one_img.shape[1:], self.D)
    one_filter = next(iter(self.invariant_filters.values()))
    filter_spatial_dims, _ = geom.parse_shape(one_filter.shape[1:], self.D)

    image_ravel = jnp.zeros(spatial_dims + (0, in_c))
    filter_ravel = jnp.zeros((in_c,) + filter_spatial_dims + (0, out_c))
    for (in_k, in_p), image_block in input_multi_image.items():
        # (in_c,spatial,tensor) -> (spatial,-1,in_c)
        img = jnp.moveaxis(image_block.reshape((in_c,) + spatial_dims + (-1,)), 0, -1)
        image_ravel = jnp.concatenate([image_ravel, img], axis=-2)

        filter_ravel_in = jnp.zeros(
            (in_c,) + filter_spatial_dims + (self.D,) * len(in_k) + (0, out_c)
        )
        for (out_k, out_p), weight_block in weights[(in_k, in_p)].items():
            filter_key = (in_k + out_k, (in_p + out_p) % 2)  # tuple addition for k is right?

            # (out_c,in_c,num_filters),(num, spatial, tensor) -> (out_c,in_c,spatial,tensor)
            filter_block = jnp.einsum(
                "ijk,k...->ij...",
                weight_block,
                jax.lax.stop_gradient(self.invariant_filters[filter_key]),
            )
            # (out_c,in_c,spatial,tensor) -> (in_c,spatial,in_tensor,-1,out_c)
            ff = jnp.moveaxis(
                filter_block.reshape(
                    (out_c, in_c) + filter_spatial_dims + (self.D,) * len(in_k) + (-1,)
                ),
                0,
                -1,
            )
            filter_ravel_in = jnp.concatenate([filter_ravel_in, ff], axis=-2)

        filter_ravel_in = filter_ravel_in.reshape(
            (in_c,) + filter_spatial_dims + (-1,) + (out_c,)
        )
        filter_ravel = jnp.concatenate([filter_ravel, filter_ravel_in], axis=-2)

    image_ravel = image_ravel.reshape(spatial_dims + (-1,))
    filter_ravel = jnp.moveaxis(filter_ravel, 0, self.D).reshape(
        filter_spatial_dims + (in_c, -1)
    )

    out = geom.convolve_ravel(
        self.D,
        image_ravel[None],  # add batch dim
        filter_ravel,
        input_multi_image.is_torus,
        self.stride,
        self.padding,
        self.lhs_dilation,
        self.rhs_dilation,
    )[0]
    new_spatial_dims = out.shape[: self.D]
    # (spatial,tensor_sum*out_c) -> (out_c,spatial,tensor_sum)
    out = jnp.moveaxis(out.reshape(new_spatial_dims + (-1, out_c)), -1, 0)

    out_k_sum = sum([self.D ** len(out_k) for (out_k, _), _ in self.target_keys])
    idx = 0
    out_multi_image = input_multi_image.empty()
    for in_k, in_p in input_multi_image.keys():
        length = (self.D ** len(in_k)) * out_k_sum
        # break off all the channels related to this particular in_k
        out_per_in = out[..., idx : idx + length].reshape(
            (out_c,) + new_spatial_dims + (self.D,) * len(in_k) + (-1,)
        )

        out_idx = 0
        for (out_k, out_p), _ in self.target_keys:
            out_length = self.D ** len(out_k)
            # separate the different out_k parts for particular in_k
            img_block = out_per_in[..., out_idx : out_idx + out_length]
            img_block = img_block.reshape(
                (out_c,) + new_spatial_dims + (self.D,) * len(in_k + out_k)
            )
            contracted_img = jnp.sum(img_block, axis=range(1 + self.D, 1 + self.D + len(in_k)))

            if (out_k, out_p) in out_multi_image:  # it already has that key
                out_multi_image[(out_k, out_p)] = (
                    contracted_img + out_multi_image[(out_k, out_p)]
                )
            else:
                out_multi_image.append(out_k, out_p, contracted_img)

            out_idx += out_length

        idx += length

    return out_multi_image
individual_convolve(x: geom.MultiImage, weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]) -> geom.MultiImage ¤

Function to perform convolve_contract on an entire MultiImage by doing the pairwise convolutions individually. This is necessary when filters have unequal sizes, or the in_c or out_c are not all equal. Weights is passed as an argument to make it easier to test this function.

Parameters:

Name Type Description Default
x MultiImage

the input

required
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

the weights used to combine the invariant filters

required

Returns:

Type Description
MultiImage

the convolved MultiImage

Source code in ginjax/layers.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def individual_convolve(
    self: Self,
    x: geom.MultiImage,
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
) -> geom.MultiImage:
    """
    Function to perform convolve_contract on an entire MultiImage by doing the pairwise convolutions
    individually. This is necessary when filters have unequal sizes, or the in_c or out_c are
    not all equal. Weights is passed as an argument to make it easier to test this function.

    args:
        x: the input
        weights: the weights used to combine the invariant filters

    returns:
        the convolved MultiImage
    """
    if x.metric_tensor is not None and x.metric_tensor_inv is None:
        x.metric_tensor_inv = geom.get_metric_inverse(x.metric_tensor)

    # TODO: metric should only be carried over if the image isn't changing size
    out = x.empty()
    for (in_k, in_p), images_block in x.items():
        for (out_k, out_p), weight_block in weights[(in_k, in_p)].items():
            # filters are always contravariant
            filter_key = ((False,) * (len(in_k) + len(out_k)), (in_p + out_p) % 2)

            # (out_c,in_c,num_inv_filters) (num, spatial, tensor) -> (out_c,in_c,spatial,tensor)
            filter_block = jnp.einsum(
                "ijk,k...->ij...",
                weight_block,
                jax.lax.stop_gradient(self.invariant_filters[filter_key]),
            )

            if x.metric_tensor is not None:
                assert x.metric_tensor_inv is not None
                # lower all axes to covariant
                images_block = geom.raise_lower(
                    images_block,
                    x.metric_tensor.data,
                    x.metric_tensor_inv.data,
                    in_k,
                    (True,) * len(in_k),
                )
                # without a metric tensor, we assume that its the flat euclidean metric in
                # which case lower == upper

            convolve_contracted_imgs = geom.convolve_contract(
                x.D,
                images_block[None],  # add batch dim
                filter_block,
                x.is_torus,
                self.stride,
                self.padding,
                self.lhs_dilation,
                self.rhs_dilation,
            )[0]

            if x.metric_tensor is not None:
                assert x.metric_tensor_inv is not None
                in_spatial, _ = geom.parse_shape(images_block.shape[1:], x.D)
                out_spatial, _ = geom.parse_shape(convolve_contracted_imgs.shape[1:], x.D)
                assert (
                    in_spatial == out_spatial
                ), f"For convolution with a metric tensor, spatial dimensions cannot change"
                # restore axes to proper lower/upper
                convolve_contracted_imgs = geom.raise_lower(
                    convolve_contracted_imgs,
                    x.metric_tensor.data,
                    x.metric_tensor_inv.data,
                    (True,) * len(out_k),
                    out_k,
                )

            if (out_k, out_p) in out:  # it already has that key
                out[(out_k, out_p)] = convolve_contracted_imgs + out[(out_k, out_p)]
            else:
                out.append(out_k, out_p, convolve_contracted_imgs)

    return out
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

The callable, calls either fast_convolve or individual_convolve. Currently fast_convolve is not used because it is not much faster.

Parameters:

Name Type Description Default
x MultiImage

the input

required

Returns:

Type Description
MultiImage

the convolved MultiImage, which is a new object

Source code in ginjax/layers.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    The callable, calls either fast_convolve or individual_convolve. Currently fast_convolve
    is not used because it is not much faster.

    args:
        x: the input

    returns:
        the convolved MultiImage, which is a new object
    """
    if self.fast_mode:
        x = self.fast_convolve(x, self.weights)
    else:  # slow mode
        x = self.individual_convolve(x, self.weights)

    if self.use_bias:
        biased_x = x.empty()
        for (k, p), image in x.items():
            broadcast_shape = (len(self.bias[(k, p)]),) + (1,) * (self.D + len(k))
            if (k, p) == ((), 0) and (self.use_bias == "scalar" or self.use_bias == "auto"):
                biased_x.append(k, p, image + self.bias[(k, p)].reshape(broadcast_shape))
            elif ((k, p) != ((), 0) and self.use_bias == "auto") or self.use_bias == "mean":
                mean_image = jnp.mean(
                    image, axis=tuple(range(1, 1 + self.invariant_filters.D)), keepdims=True
                )
                biased_x.append(
                    k, p, image + mean_image * self.bias[(k, p)].reshape(broadcast_shape)
                )

        return biased_x
    else:
        return x

GroupNorm ¤

Bases: Module

Implementation of GroupNorm for equivariant and non-equivariant models.

Source code in ginjax/layers.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
class GroupNorm(eqx.Module):
    """
    Implementation of GroupNorm for equivariant and non-equivariant models.
    """

    scale: dict[tuple[tuple[bool, ...], int], jax.Array]
    bias: dict[tuple[tuple[bool, ...], int], jax.Array]
    vanilla_norm: dict[tuple[tuple[bool, ...], int], eqx.nn.GroupNorm]

    D: int = eqx.field(static=True)
    groups: int = eqx.field(static=True)
    eps: float = eqx.field(static=True)

    def __init__(
        self: Self,
        input_keys: geom.Signature,
        D: int,
        groups: int,
        eps: float = 1e-5,
    ) -> None:
        """
        Constructor for GroupNorm. When num_groups=num_channels, this is equivalent to instance_norm. When
        num_groups=1, this is equivalent to layer_norm.

        args:
            input_keys: input key signature
            D: dimension
            groups: the number of channel groups for group_norm
            eps: number to add to variance so we aren't dividing by 0
        """
        self.D = D
        self.groups = groups
        self.eps = eps

        self.scale = {}
        self.bias = {}
        self.vanilla_norm = {}  # for scalars, can use basic implementation of GroupNorm
        for (k, p), in_c in input_keys:
            assert (
                in_c % groups
            ) == 0, f"group_norm: Groups must evenly divide channels, but got groups={groups}, channels={in_c}."

            if len(k) == 0:
                self.vanilla_norm[(k, p)] = eqx.nn.GroupNorm(groups, in_c, eps)
            elif len(k) == 1:
                self.scale[(k, p)] = jnp.ones(in_c)
                self.bias[(k, p)] = jnp.zeros(in_c)
            elif len(k) > 1:
                raise NotImplementedError(
                    f"ml::group_norm: Equivariant group_norm not implemented for k>1, but k={k}",
                )

    def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
        """
        Callable for GroupNorm,

        args:
            x: input MultiImage

        returns:
            the output normed MultiImage
        """
        out_x = x.empty()
        for (k, p), image_block in x.items():
            if len(k) == 0:
                whitened_data = self.vanilla_norm[(k, p)](image_block)  # normal norm
            elif len(k) == 1:
                # save mean vec, allows for un-mean centering (?)
                mean_vec = jnp.mean(image_block, axis=tuple(range(1, 1 + self.D)), keepdims=True)
                assert mean_vec.shape == (image_block.shape[0],) + (1,) * self.D + (self.D,) * len(
                    k
                )
                whitened_data = _group_norm_K1(self.D, image_block, self.groups, eps=self.eps)
                scale_mul = self.scale[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
                bias_mul = self.bias[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
                whitened_data = whitened_data * scale_mul + mean_vec * bias_mul
            else:  # k > 1
                raise NotImplementedError(
                    f"ml::group_norm: Equivariant group_norm not implemented for k>1, but k={k}",
                )

            out_x.append(k, p, whitened_data)

        return out_x
__init__(input_keys: geom.Signature, D: int, groups: int, eps: float = 1e-05) -> None ¤

Constructor for GroupNorm. When num_groups=num_channels, this is equivalent to instance_norm. When num_groups=1, this is equivalent to layer_norm.

Parameters:

Name Type Description Default
input_keys Signature

input key signature

required
D int

dimension

required
groups int

the number of channel groups for group_norm

required
eps float

number to add to variance so we aren't dividing by 0

1e-05
Source code in ginjax/layers.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def __init__(
    self: Self,
    input_keys: geom.Signature,
    D: int,
    groups: int,
    eps: float = 1e-5,
) -> None:
    """
    Constructor for GroupNorm. When num_groups=num_channels, this is equivalent to instance_norm. When
    num_groups=1, this is equivalent to layer_norm.

    args:
        input_keys: input key signature
        D: dimension
        groups: the number of channel groups for group_norm
        eps: number to add to variance so we aren't dividing by 0
    """
    self.D = D
    self.groups = groups
    self.eps = eps

    self.scale = {}
    self.bias = {}
    self.vanilla_norm = {}  # for scalars, can use basic implementation of GroupNorm
    for (k, p), in_c in input_keys:
        assert (
            in_c % groups
        ) == 0, f"group_norm: Groups must evenly divide channels, but got groups={groups}, channels={in_c}."

        if len(k) == 0:
            self.vanilla_norm[(k, p)] = eqx.nn.GroupNorm(groups, in_c, eps)
        elif len(k) == 1:
            self.scale[(k, p)] = jnp.ones(in_c)
            self.bias[(k, p)] = jnp.zeros(in_c)
        elif len(k) > 1:
            raise NotImplementedError(
                f"ml::group_norm: Equivariant group_norm not implemented for k>1, but k={k}",
            )
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

Callable for GroupNorm,

Parameters:

Name Type Description Default
x MultiImage

input MultiImage

required

Returns:

Type Description
MultiImage

the output normed MultiImage

Source code in ginjax/layers.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    Callable for GroupNorm,

    args:
        x: input MultiImage

    returns:
        the output normed MultiImage
    """
    out_x = x.empty()
    for (k, p), image_block in x.items():
        if len(k) == 0:
            whitened_data = self.vanilla_norm[(k, p)](image_block)  # normal norm
        elif len(k) == 1:
            # save mean vec, allows for un-mean centering (?)
            mean_vec = jnp.mean(image_block, axis=tuple(range(1, 1 + self.D)), keepdims=True)
            assert mean_vec.shape == (image_block.shape[0],) + (1,) * self.D + (self.D,) * len(
                k
            )
            whitened_data = _group_norm_K1(self.D, image_block, self.groups, eps=self.eps)
            scale_mul = self.scale[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
            bias_mul = self.bias[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
            whitened_data = whitened_data * scale_mul + mean_vec * bias_mul
        else:  # k > 1
            raise NotImplementedError(
                f"ml::group_norm: Equivariant group_norm not implemented for k>1, but k={k}",
            )

        out_x.append(k, p, whitened_data)

    return out_x

LayerNorm ¤

Bases: GroupNorm

LayerNorm, which is GroupNorm with a single group.

Source code in ginjax/layers.py
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
class LayerNorm(GroupNorm):
    """
    LayerNorm, which is GroupNorm with a single group.
    """

    def __init__(self: Self, input_keys: geom.Signature, D: int, eps: float = 1e-5) -> None:
        """
        Constructor for LayerNorm.

        args:
            input_keys: the input signature
            D: the dimension
            eps: number to add to variance so we aren't dividing by 0
        """
        super(LayerNorm, self).__init__(input_keys, D, 1, eps)
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

Callable for GroupNorm,

Parameters:

Name Type Description Default
x MultiImage

input MultiImage

required

Returns:

Type Description
MultiImage

the output normed MultiImage

Source code in ginjax/layers.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    Callable for GroupNorm,

    args:
        x: input MultiImage

    returns:
        the output normed MultiImage
    """
    out_x = x.empty()
    for (k, p), image_block in x.items():
        if len(k) == 0:
            whitened_data = self.vanilla_norm[(k, p)](image_block)  # normal norm
        elif len(k) == 1:
            # save mean vec, allows for un-mean centering (?)
            mean_vec = jnp.mean(image_block, axis=tuple(range(1, 1 + self.D)), keepdims=True)
            assert mean_vec.shape == (image_block.shape[0],) + (1,) * self.D + (self.D,) * len(
                k
            )
            whitened_data = _group_norm_K1(self.D, image_block, self.groups, eps=self.eps)
            scale_mul = self.scale[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
            bias_mul = self.bias[(k, p)].reshape((-1,) + (1,) * (self.D + len(k)))
            whitened_data = whitened_data * scale_mul + mean_vec * bias_mul
        else:  # k > 1
            raise NotImplementedError(
                f"ml::group_norm: Equivariant group_norm not implemented for k>1, but k={k}",
            )

        out_x.append(k, p, whitened_data)

    return out_x
__init__(input_keys: geom.Signature, D: int, eps: float = 1e-05) -> None ¤

Constructor for LayerNorm.

Parameters:

Name Type Description Default
input_keys Signature

the input signature

required
D int

the dimension

required
eps float

number to add to variance so we aren't dividing by 0

1e-05
Source code in ginjax/layers.py
539
540
541
542
543
544
545
546
547
548
def __init__(self: Self, input_keys: geom.Signature, D: int, eps: float = 1e-5) -> None:
    """
    Constructor for LayerNorm.

    args:
        input_keys: the input signature
        D: the dimension
        eps: number to add to variance so we aren't dividing by 0
    """
    super(LayerNorm, self).__init__(input_keys, D, 1, eps)

VectorNeuronNonlinear ¤

Bases: Module

The vector nonlinearity in the Vector Neurons paper: https://arxiv.org/pdf/2104.12229.pdf Basically use the channels of a vector to get a direction vector. Use the direction vector to get an inner product with the input vector. The inner product is like the input to a typical nonlinear activation, and it is used to scale the non-orthogonal part of the input vector.

Source code in ginjax/layers.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
class VectorNeuronNonlinear(eqx.Module):
    """
    The vector nonlinearity in the Vector Neurons paper: https://arxiv.org/pdf/2104.12229.pdf
    Basically use the channels of a vector to get a direction vector. Use the direction vector
    to get an inner product with the input vector. The inner product is like the input to a
    typical nonlinear activation, and it is used to scale the non-orthogonal part of the input
    vector.
    """

    weights: dict[tuple[tuple[bool, ...], int], jax.Array]

    eps: float = eqx.field(static=True)
    D: int = eqx.field(static=True)
    scalar_activation: Callable = eqx.field(static=True)

    def __init__(
        self: Self,
        input_keys: geom.Signature,
        D: int,
        scalar_activation: Callable[[ArrayLike], jax.Array] = jax.nn.relu,
        eps: float = 1e-5,
        key: Any = None,
    ) -> None:
        """
        Constructor for VectorNeuronNonlinear.

        args:
            input_keys: the signature of the input MultiImage
            D: the dimension
            scalar_activation: nonlinearity used for scalars
            eps: small value to avoid dividing by zero if the k_vec is close to 0
            key: jax.random key
        """
        self.eps = eps
        self.D = D
        self.scalar_activation = scalar_activation

        self.weights = {}
        for (k, p), in_c in input_keys:
            if (k, p) != ((), 0):  # initialization?
                bound = 1.0 / jnp.sqrt(in_c)
                key, subkey = random.split(key, num=2)
                self.weights[(k, p)] = random.uniform(
                    subkey, shape=(in_c, in_c), minval=-bound, maxval=bound
                )

    def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
        """
        Callable for VectorNeuronNonlinearity

        args:
            x: the input

        returns:
            a new MultiImage output
        """
        out_x = x.empty()
        for (k, p), img_block in x.items():

            if (k, p) == ((), 0):
                out_x.append(k, p, self.scalar_activation(img_block))
            else:
                # -> (out_c,spatial,tensor)
                k_vec = jnp.einsum("ij,j...->i...", self.weights[(k, p)], img_block)
                k_vec_normed = k_vec / (geom.norm(1 + self.D, k_vec, keepdims=True) + self.eps)

                inner_prod = jnp.einsum(
                    f"...{geom.LETTERS[:len(k)]},...{geom.LETTERS[:len(k)]}->...",
                    img_block,
                    k_vec_normed,
                )

                # split the vector into a parallel section and a perpendicular section
                v_parallel = jnp.einsum(
                    f"...,...{geom.LETTERS[:len(k)]}->...{geom.LETTERS[:len(k)]}",
                    inner_prod,
                    k_vec_normed,
                )
                v_perp = img_block - v_parallel
                h = self.scalar_activation(inner_prod) / (jnp.abs(inner_prod) + self.eps)

                scaled_parallel = jnp.einsum(
                    f"...,...{geom.LETTERS[:len(k)]}->...{geom.LETTERS[:len(k)]}", h, v_parallel
                )
                out_x.append(k, p, scaled_parallel + v_perp)

        return out_x
__init__(input_keys: geom.Signature, D: int, scalar_activation: Callable[[ArrayLike], jax.Array] = jax.nn.relu, eps: float = 1e-05, key: Any = None) -> None ¤

Constructor for VectorNeuronNonlinear.

Parameters:

Name Type Description Default
input_keys Signature

the signature of the input MultiImage

required
D int

the dimension

required
scalar_activation Callable[[ArrayLike], Array]

nonlinearity used for scalars

relu
eps float

small value to avoid dividing by zero if the k_vec is close to 0

1e-05
key Any

jax.random key

None
Source code in ginjax/layers.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def __init__(
    self: Self,
    input_keys: geom.Signature,
    D: int,
    scalar_activation: Callable[[ArrayLike], jax.Array] = jax.nn.relu,
    eps: float = 1e-5,
    key: Any = None,
) -> None:
    """
    Constructor for VectorNeuronNonlinear.

    args:
        input_keys: the signature of the input MultiImage
        D: the dimension
        scalar_activation: nonlinearity used for scalars
        eps: small value to avoid dividing by zero if the k_vec is close to 0
        key: jax.random key
    """
    self.eps = eps
    self.D = D
    self.scalar_activation = scalar_activation

    self.weights = {}
    for (k, p), in_c in input_keys:
        if (k, p) != ((), 0):  # initialization?
            bound = 1.0 / jnp.sqrt(in_c)
            key, subkey = random.split(key, num=2)
            self.weights[(k, p)] = random.uniform(
                subkey, shape=(in_c, in_c), minval=-bound, maxval=bound
            )
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

Callable for VectorNeuronNonlinearity

Parameters:

Name Type Description Default
x MultiImage

the input

required

Returns:

Type Description
MultiImage

a new MultiImage output

Source code in ginjax/layers.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    Callable for VectorNeuronNonlinearity

    args:
        x: the input

    returns:
        a new MultiImage output
    """
    out_x = x.empty()
    for (k, p), img_block in x.items():

        if (k, p) == ((), 0):
            out_x.append(k, p, self.scalar_activation(img_block))
        else:
            # -> (out_c,spatial,tensor)
            k_vec = jnp.einsum("ij,j...->i...", self.weights[(k, p)], img_block)
            k_vec_normed = k_vec / (geom.norm(1 + self.D, k_vec, keepdims=True) + self.eps)

            inner_prod = jnp.einsum(
                f"...{geom.LETTERS[:len(k)]},...{geom.LETTERS[:len(k)]}->...",
                img_block,
                k_vec_normed,
            )

            # split the vector into a parallel section and a perpendicular section
            v_parallel = jnp.einsum(
                f"...,...{geom.LETTERS[:len(k)]}->...{geom.LETTERS[:len(k)]}",
                inner_prod,
                k_vec_normed,
            )
            v_perp = img_block - v_parallel
            h = self.scalar_activation(inner_prod) / (jnp.abs(inner_prod) + self.eps)

            scaled_parallel = jnp.einsum(
                f"...,...{geom.LETTERS[:len(k)]}->...{geom.LETTERS[:len(k)]}", h, v_parallel
            )
            out_x.append(k, p, scaled_parallel + v_perp)

    return out_x

MaxNormPool ¤

Bases: Module

Layer that performs that MaxPool based on the norm of the tensor.

Source code in ginjax/layers.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
class MaxNormPool(eqx.Module):
    """
    Layer that performs that MaxPool based on the norm of the tensor.
    """

    patch_len: int = eqx.field(static=True)
    use_norm: bool = eqx.field(static=True)

    def __init__(self: Self, patch_len: int, use_norm: bool = True) -> None:
        """
        Constructor for MaxNormPool.

        args:
            patch_len: sidelength of the patch
            use_norm: whether to use norm to calculate the max
        """
        self.patch_len = patch_len
        self.use_norm = use_norm

    def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
        """
        Callable for MaxNormPool.

        args:
            x: the input to the layer

        returns:
            a new max normed output MultiImage
        """
        in_axes = (None, 0, None, None)
        vmap_max_pool = jax.vmap(geom.max_pool, in_axes=in_axes)

        out_x = x.empty()
        for (k, p), image_block in x.items():
            out_x.append(k, p, vmap_max_pool(x.D, image_block, self.patch_len, self.use_norm))

        return out_x
__init__(patch_len: int, use_norm: bool = True) -> None ¤

Constructor for MaxNormPool.

Parameters:

Name Type Description Default
patch_len int

sidelength of the patch

required
use_norm bool

whether to use norm to calculate the max

True
Source code in ginjax/layers.py
648
649
650
651
652
653
654
655
656
657
def __init__(self: Self, patch_len: int, use_norm: bool = True) -> None:
    """
    Constructor for MaxNormPool.

    args:
        patch_len: sidelength of the patch
        use_norm: whether to use norm to calculate the max
    """
    self.patch_len = patch_len
    self.use_norm = use_norm
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

Callable for MaxNormPool.

Parameters:

Name Type Description Default
x MultiImage

the input to the layer

required

Returns:

Type Description
MultiImage

a new max normed output MultiImage

Source code in ginjax/layers.py
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    Callable for MaxNormPool.

    args:
        x: the input to the layer

    returns:
        a new max normed output MultiImage
    """
    in_axes = (None, 0, None, None)
    vmap_max_pool = jax.vmap(geom.max_pool, in_axes=in_axes)

    out_x = x.empty()
    for (k, p), image_block in x.items():
        out_x.append(k, p, vmap_max_pool(x.D, image_block, self.patch_len, self.use_norm))

    return out_x

LayerWrapper ¤

Bases: Module

Wrapper class for any module which takes an image and converts it to taking and producing a MultiImage.

Source code in ginjax/layers.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
class LayerWrapper(eqx.Module):
    """
    Wrapper class for any module which takes an image and converts it to taking and producing a
    MultiImage.
    """

    modules: dict[tuple[tuple[bool, ...], int], Callable[..., Any]]

    def __init__(self: Self, module: Callable[..., Any], input_keys: geom.Signature) -> None:
        """
        Perform the module or callable (e.g., activation) on each layer of the input MultiImage.
        Since we only take input_keys, module should preserve the shape/tensor order and parity.

        args:
            module: module should have as input/output an image of shape (channels, spatial)
            input_keys: actual input (and output) signature this module will process
        """
        self.modules = {}
        for (k, p), _ in input_keys:
            # I believe this *should* duplicate so they are independent, per the description in
            # https://docs.kidger.site/equinox/api/nn/shared/. However, it may not. In the scalar
            # case this should be perfectly fine though.
            self.modules[(k, p)] = module

    def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
        """
        Callable for LayerWrapper.

        args:
            x: the input

        returns:
            a new MultiImage
        """
        out = x.__class__({}, x.D, x.is_torus)
        for (k, p), image in x.items():
            out.append(k, p, self.modules[(k, p)](image))

        return out
__init__(module: Callable[..., Any], input_keys: geom.Signature) -> None ¤

Perform the module or callable (e.g., activation) on each layer of the input MultiImage. Since we only take input_keys, module should preserve the shape/tensor order and parity.

Parameters:

Name Type Description Default
module Callable[..., Any]

module should have as input/output an image of shape (channels, spatial)

required
input_keys Signature

actual input (and output) signature this module will process

required
Source code in ginjax/layers.py
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
def __init__(self: Self, module: Callable[..., Any], input_keys: geom.Signature) -> None:
    """
    Perform the module or callable (e.g., activation) on each layer of the input MultiImage.
    Since we only take input_keys, module should preserve the shape/tensor order and parity.

    args:
        module: module should have as input/output an image of shape (channels, spatial)
        input_keys: actual input (and output) signature this module will process
    """
    self.modules = {}
    for (k, p), _ in input_keys:
        # I believe this *should* duplicate so they are independent, per the description in
        # https://docs.kidger.site/equinox/api/nn/shared/. However, it may not. In the scalar
        # case this should be perfectly fine though.
        self.modules[(k, p)] = module
__call__(x: geom.MultiImage) -> geom.MultiImage ¤

Callable for LayerWrapper.

Parameters:

Name Type Description Default
x MultiImage

the input

required

Returns:

Type Description
MultiImage

a new MultiImage

Source code in ginjax/layers.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
def __call__(self: Self, x: geom.MultiImage) -> geom.MultiImage:
    """
    Callable for LayerWrapper.

    args:
        x: the input

    returns:
        a new MultiImage
    """
    out = x.__class__({}, x.D, x.is_torus)
    for (k, p), image in x.items():
        out.append(k, p, self.modules[(k, p)](image))

    return out

LayerWrapperAux ¤

Bases: Module

Wrapper class for any module which takes an image and aux data and converts it to taking and producing a MultiImage and aux data.

Source code in ginjax/layers.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
class LayerWrapperAux(eqx.Module):
    """
    Wrapper class for any module which takes an image and aux data and converts it to taking and
    producing a MultiImage and aux data.
    """

    modules: dict[tuple[tuple[bool, ...], int], Callable[..., Any]]

    def __init__(self: Self, module: Callable[..., Any], input_keys: geom.Signature):
        """
        Perform the module or callable (e.g., activation) on each layer of the input MultiImage.
        Since we only take input_keys, module should preserve the shape/tensor order and parity.

        args:
            module: module should have as input/output an image of shape (channels, spatial) and
                aux data (likely batch_stats for BatchNorm).
            input_keys: actual input (and output) signature this module will process
        """
        self.modules = {}
        for (k, p), _ in input_keys:
            # I believe this *should* duplicate so they are independent, per the description in
            # https://docs.kidger.site/equinox/api/nn/shared/. However, it may not. In the scalar
            # case this should be perfectly fine though.
            self.modules[(k, p)] = module

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State]
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Callable for LayerWrapperAux.

        args:
            x: the input
            aux_data: the aux_data, e.g. for BatchNorm

        returns:
            a new MultiImage and the aux_data
        """
        out = x.__class__({}, x.D, x.is_torus)
        for (k, p), image in x.items():
            out_image, aux_data = self.modules[(k, p)](image, aux_data)
            out.append(k, p, out_image)

        return out, aux_data
__init__(module: Callable[..., Any], input_keys: geom.Signature) ¤

Perform the module or callable (e.g., activation) on each layer of the input MultiImage. Since we only take input_keys, module should preserve the shape/tensor order and parity.

Parameters:

Name Type Description Default
module Callable[..., Any]

module should have as input/output an image of shape (channels, spatial) and aux data (likely batch_stats for BatchNorm).

required
input_keys Signature

actual input (and output) signature this module will process

required
Source code in ginjax/layers.py
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
def __init__(self: Self, module: Callable[..., Any], input_keys: geom.Signature):
    """
    Perform the module or callable (e.g., activation) on each layer of the input MultiImage.
    Since we only take input_keys, module should preserve the shape/tensor order and parity.

    args:
        module: module should have as input/output an image of shape (channels, spatial) and
            aux data (likely batch_stats for BatchNorm).
        input_keys: actual input (and output) signature this module will process
    """
    self.modules = {}
    for (k, p), _ in input_keys:
        # I believe this *should* duplicate so they are independent, per the description in
        # https://docs.kidger.site/equinox/api/nn/shared/. However, it may not. In the scalar
        # case this should be perfectly fine though.
        self.modules[(k, p)] = module
__call__(x: geom.MultiImage, aux_data: Optional[eqx.nn.State]) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Callable for LayerWrapperAux.

Parameters:

Name Type Description Default
x MultiImage

the input

required
aux_data Optional[State]

the aux_data, e.g. for BatchNorm

required

Returns:

Type Description
tuple[MultiImage, Optional[State]]

a new MultiImage and the aux_data

Source code in ginjax/layers.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def __call__(
    self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State]
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Callable for LayerWrapperAux.

    args:
        x: the input
        aux_data: the aux_data, e.g. for BatchNorm

    returns:
        a new MultiImage and the aux_data
    """
    out = x.__class__({}, x.D, x.is_torus)
    for (k, p), image in x.items():
        out_image, aux_data = self.modules[(k, p)](image, aux_data)
        out.append(k, p, out_image)

    return out, aux_data

_group_norm_K1(D: int, image_block: jax.Array, groups: int, method: str = 'eigh', eps: float = 1e-05) -> jax.Array ¤

Perform the layer norm whitening on a vector image block. This is somewhat based on the Clifford Layers Batch norm, link below. However, this differs in that we use eigh rather than cholesky because cholesky is not invariant to all the elements of our group. https://github.com/microsoft/cliffordlayers/blob/main/cliffordlayers/nn/functional/batchnorm.py

Parameters:

Name Type Description Default
D int

the dimension of the space

required
image_block Array

data block of shape (channels,spatial,tensor)

required
groups int

the number of channel groups, must evenly divide channels

required
method str

method used for the whitening, either 'eigh', or 'cholesky'. Note that 'cholesky' is not equivariant.

'eigh'
eps float

to avoid non-invertible matrices, added to the covariance matrix

1e-05

Returns:

Type Description
Array

the whitened data, shape (channels,spatial,tensor)

Source code in ginjax/layers.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _group_norm_K1(
    D: int, image_block: jax.Array, groups: int, method: str = "eigh", eps: float = 1e-5
) -> jax.Array:
    """
    Perform the layer norm whitening on a vector image block. This is somewhat based on the Clifford
    Layers Batch norm, link below. However, this differs in that we use eigh rather than cholesky because
    cholesky is not invariant to all the elements of our group.
    https://github.com/microsoft/cliffordlayers/blob/main/cliffordlayers/nn/functional/batchnorm.py

    args:
        D: the dimension of the space
        image_block: data block of shape (channels,spatial,tensor)
        groups: the number of channel groups, must evenly divide channels
        method: method used for the whitening, either 'eigh', or 'cholesky'. Note that
            'cholesky' is not equivariant.
        eps: to avoid non-invertible matrices, added to the covariance matrix

    returns:
        the whitened data, shape (channels,spatial,tensor)
    """
    in_c = len(image_block)
    spatial_dims, k = geom.parse_shape(image_block.shape[1:], D)
    assert (
        k == 1
    ), f"ml::_group_norm_K1: Equivariant group_norm is not implemented for k>1, but k={k}"
    assert (in_c % groups) == 0  # groups must evenly divide the number of channels
    channels_per_group = in_c // groups

    image_grouped = image_block.reshape((groups, channels_per_group) + spatial_dims + (D,))

    mean = jnp.mean(image_grouped, axis=tuple(range(1, 2 + D)), keepdims=True)  # (G,1,(1,)*D,D)
    centered_img = image_grouped - mean  # (G,in_c//G,spatial,tensor)

    X = centered_img.reshape((groups, -1, D))  # (G,spatial*in_c//G,D)
    cov = jnp.einsum("...ij,...ik->...jk", X, X) / X.shape[-2]  # biased cov, (G,D,D)

    if method == "eigh":
        # symmetrize_input=True seems to cause issues with autograd, and cov is already symmetric
        eigvals, eigvecs = jnp.linalg.eigh(cov, symmetrize_input=False)  # (G,D), (G,D,D)
        eigvals_invhalf = jnp.sqrt(1.0 / (eigvals + eps))  # (G,D)
        S_diag = jax.vmap(lambda S: jnp.diag(S))(eigvals_invhalf).reshape((groups, D, D))
        # do U S U^T, and multiply each vector in centered_img by the resulting matrix
        whitened_data = jnp.einsum(
            "...ij,...jk,...kl,...ml->...mi",
            eigvecs,
            S_diag,
            eigvecs.transpose((0, 2, 1)),
            centered_img.reshape((groups, -1, D)),
        )
    elif method == "cholesky":
        L = jax.lax.linalg.cholesky(cov, symmetrize_input=False)  # (groups,D,D)
        L = L + eps * jnp.eye(D).reshape((1, D, D))
        whitened_data = jax.lax.linalg.triangular_solve(
            L,
            centered_img.reshape((groups, -1, D)),
            left_side=False,
            lower=True,
        )
    else:
        raise NotImplementedError(f"ml::_group_norm_K1: method {method} not implemented.")

    return whitened_data.reshape(image_block.shape)