Skip to content

Geometric image

ginjax.geometric.geometric_image ¤

GeometricImage ¤

One of the main classes of the package. This class is a single geometric image, a.k.a. an image where every pixel is a k,p tensor. This class is primarily used for simple operations on geometric images and plotting.

Source code in ginjax/geometric/geometric_image.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
@register_pytree_node_class
class GeometricImage:
    """
    One of the main classes of the package. This class is a single geometric image, a.k.a. an image
    where every pixel is a k,p tensor. This class is primarily used for simple operations on
    geometric images and plotting.
    """

    D: int
    spatial_dims: tuple[int, ...]
    k: int
    covariant_axes: tuple[bool, ...]  # can be () for k==0
    data: jax.Array
    parity: int
    is_torus: tuple[bool, ...]

    # Constructors

    @classmethod
    def zeros(
        cls,
        N: Union[int, tuple[int, ...]],
        k: int,
        parity: int,
        D: int,
        is_torus: Union[bool, tuple[bool]] = True,
        covariant_axes: Union[bool, tuple[bool, ...]] = False,
    ) -> Self:
        """
        Zero constructor for GeometricImage.

        args:
            N: length of all sides if an int, otherwise a tuple of the side lengths
            k: the order of the tensor in each pixel, i.e. 0 (scalar), 1 (vector), 2 (matrix), etc.
            parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
            D: dimension of the image, and length of vectors or side length of matrices or tensors.
            is_torus: whether the datablock is a torus, used for convolutions
            covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
                with the coordinate change. False for typical vectors, true for gradients.

        returns:
            constructed GeometricImage
        """
        spatial_dims = N if isinstance(N, tuple) else (N,) * D
        assert len(spatial_dims) == D
        return cls(jnp.zeros(spatial_dims + (D,) * k), parity, D, is_torus, covariant_axes)

    @classmethod
    def fill(
        cls,
        N: Union[int, tuple[int, ...]],
        parity: int,
        D: int,
        fill: Union[jax.Array, float],
        is_torus: Union[bool, tuple[bool, ...]] = True,
        covariant_axes: Union[bool, tuple[bool, ...]] = False,
    ) -> Self:
        """
        Fill constructor to construct a geometric image every pixel as fill

        args:
            N: length of all sides if an int, otherwise a tuple of the side lengths
            parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
            D: dimension of the image, and length of vectors or side length of matrices or tensors.
            fill: tensor to fill the image with
            is_torus: whether the datablock is a torus, used for convolutions. Defaults to true.
            covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
                with the coordinate change. False for typical vectors, true for gradients.

        returns:
            Constructed GeometricImage
        """
        spatial_dims = N if isinstance(N, tuple) else (N,) * D
        assert len(spatial_dims) == D

        k = (
            len(fill.shape)
            if (isinstance(fill, jnp.ndarray) or isinstance(fill, np.ndarray))
            else 0
        )
        data = jnp.stack([fill for _ in range(np.multiply.reduce(spatial_dims))]).reshape(
            spatial_dims + (D,) * k
        )
        return cls(data, parity, D, is_torus, covariant_axes)

    def __init__(
        self: Self,
        data: jnp.ndarray,
        parity: int,
        D: int,
        is_torus: Union[bool, tuple[bool, ...]] = True,
        covariant_axes: Union[bool, tuple[bool, ...]] = False,
    ) -> None:
        """
        Constructor for GeometricImage. It will be (N^D x D^k), so if N=100, D=2, k=1, then it's
        (100 x 100 x 2). The spatial dimensions don't have to be square.

        args:
            data: image data, shape (spatial,tensor)
            parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
            D: dimension of the image, and length of vectors or side length of matrices or tensors.
            is_torus: whether the datablock is a torus, used for convolutions.
                Takes either a tuple of bools of length D specifying whether each dimension is toroidal,
                or simply True or False which sets all dimensions to that value.
            covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
                with the coordinate change. False for typical vectors, true for gradients. You
                can only take a contraction between 1 covariant axis and 1 contravariant axis,
                but for a flat Euclidean metric these vectors are numerically identical, so we will
                not enforce this.
        """
        self.D = D
        self.spatial_dims, self.k = parse_shape(data.shape, D)
        assert data.shape[D:] == self.k * (
            self.D,
        ), "GeometricImage: each pixel must be D cross D, k times"

        if isinstance(covariant_axes, bool):
            covariant_axes = (covariant_axes,) * self.k

        assert len(covariant_axes) == self.k

        self.covariant_axes = covariant_axes
        self.parity = parity % 2

        assert (isinstance(is_torus, tuple) and (len(is_torus) == D)) or isinstance(is_torus, bool)
        if isinstance(is_torus, bool):
            is_torus = (is_torus,) * D

        self.is_torus = is_torus

        self.data = jnp.copy(
            data
        )  # TODO: don't need to copy if data is already an immutable jnp array

    def copy(self: Self) -> Self:
        """
        Copy the geometric image.
        """
        return self.__class__(self.data, self.parity, self.D, self.is_torus, self.covariant_axes)

    # Getters, setters, basic info

    def hash(self: Self, indices: ArrayLike) -> tuple[jax.Array, ...]:
        """
        Converts an array of indices to their pixels on the torus by modding the indices with the
        spatial dimensions.

        args:
            indices: array of indices, shape (num_idx, D) to apply the remainder to

        returns:
            the pixel indices as a d-tuple of jax arrays
        """
        return hash(self.D, self.spatial_dims, indices)

    def __getitem__(self: Self, key: Any) -> jax.Array:
        """
        Accessor for data values. Now you can do image[key] where k are indices or array slices and it will just work
        Note that JAX does not throw errors for indexing out of bounds

        args:
            key: JAX/numpy indexer, i.e. "0", "0,1,3", "4:, 2:3, 0" etc.

        returns:
            data from the specified index or slice.
        """
        return self.data[key]

    def __setitem__(self: Self, key: Any, val: Any) -> Self:
        """
        Set the jax array data to the specified value. Jax arrays are immutable, so this
        reconstructs the data object with copying, and is potentially slow.

        args:
            key: index or slice to access data
            val: value to set the data to

        returns:
            the geometric image
        """
        self.data = self.data.at[key].set(val)
        return self

    def shape(self: Self) -> tuple[int, ...]:
        """
        Return the full shape of the data block

        returns:
            The shape of the data block
        """
        return self.data.shape

    def image_shape(self: Self, plus_Ns: Optional[tuple[int, ...]] = None) -> tuple[int, ...]:
        """
        Return the shape of the data block that is not the ktensor shape, but what comes before that.

        args:
            plus_Ns: d-length tuple, N to add to each spatial dim

        returns:
            the shape of the image, modified by plus_Ns
        """
        plus_Ns = (0,) * self.D if (plus_Ns is None) else plus_Ns
        return tuple(N + plus_N for N, plus_N in zip(self.spatial_dims, plus_Ns))

    def image_size(self: Self) -> int:
        """
        Return the total number of pixels in the image.
        """
        return functools.reduce(lambda c, v: c * v, self.image_shape(), 1)

    def pixel_shape(self: Self) -> tuple[int, ...]:
        """
        Return the shape of the data block that is the ktensor, aka the pixel of the image.

        returns:
            the shape of the pixel
        """
        return self.k * (self.D,)

    def pixel_size(self: Self) -> int:
        """
        Get the size of the pixel shape, i.e. (D,D,D) = D**3

        returns:
            the size of the pixels
        """
        return self.D**self.k

    def __str__(self: Self) -> str:
        """
        returns:
            the string representation of the GeometricImage
        """
        return "<{} object in D={} with spatial_dims={}, k={}, parity={}, is_torus={}, covariant_axes={}>".format(
            self.__class__,
            self.D,
            self.spatial_dims,
            self.k,
            self.parity,
            self.is_torus,
            self.covariant_axes,
        )

    # itertools does not have type hints, but it will be a product[tuple[int,...]]
    def keys(self: Self) -> Any:
        """
        Iterate over the keys of GeometricImage
        """
        return it.product(*list(range(N) for N in self.spatial_dims))

    def key_array(self: Self) -> jax.Array:
        """
        returns:
            the pixel indices as a jax array
        """
        # equivalent to the old pixels function
        return jnp.array([key for key in self.keys()], dtype=int)

    def pixels(self: Self) -> Generator[jax.Array]:
        """
        Iterate over the pixels of GeometricImage.

        returns:
            a generator of the pixels
        """
        for key in self.keys():
            yield self[key]

    def items(self: Self) -> Generator[tuple[Any, jax.Array]]:
        """
        Iterate over the key, pixel pairs of GeometricImage.

        returns:
            a generator of pairs of the pixel index and its pixel
        """
        for key in self.keys():
            yield (key, self[key])

    # Binary Operators, Complicated functions

    def __eq__(self: Self, other: object, rtol: float = TINY, atol: float = TINY) -> bool:
        """
        Equality operator, must have same shape, parity, and data within the TINY=1e-5 tolerance.

        args:
            other: an object to compare to this GeometricImage
            rtol: relative tolerance, passed to jnp.allclose
            atol: absolute tolerance, passed to jnp.allclose

        returns:
            true if they are equal, false otherwise
        """
        if isinstance(other, GeometricImage):
            return (
                self.D == other.D
                and self.spatial_dims == other.spatial_dims
                and self.k == other.k
                and self.parity == other.parity
                and self.is_torus == other.is_torus
                and self.covariant_axes == other.covariant_axes
                and self.data.shape == other.data.shape
                and bool(jnp.allclose(self.data, other.data, rtol, atol))
            )
        else:
            return False

    def __add__(self: Self, other: Self) -> Self:
        """
        Addition operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

        args:
            other: other image to add the the first one

        returns:
            a new GeometricImage that is the sum of this one and the other one
        """
        assert self.D == other.D
        assert self.spatial_dims == other.spatial_dims
        assert self.k == other.k
        assert self.parity == other.parity
        assert self.is_torus == other.is_torus
        assert self.covariant_axes == other.covariant_axes
        assert self.data.shape == other.data.shape
        return self.__class__(
            self.data + other.data, self.parity, self.D, self.is_torus, self.covariant_axes
        )

    def __sub__(self: Self, other: Self) -> Self:
        """
        Subtraction operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

        args:
            other: other image to add the the first one

        returns:
            a new GeometricImage that is the difference of this GeometricImage and the other one
        """
        assert self.D == other.D
        assert self.spatial_dims == other.spatial_dims
        assert self.k == other.k
        assert self.parity == other.parity
        assert self.is_torus == other.is_torus
        assert self.covariant_axes == other.covariant_axes
        assert self.data.shape == other.data.shape
        return self.__class__(
            self.data - other.data, self.parity, self.D, self.is_torus, self.covariant_axes
        )

    def __mul__(self: Self, other: Union[Self, float, int]) -> Self:
        """
        If other is a scalar, do scalar multiplication of the data. If it is another GeometricImage, do the tensor
        product at each pixel. Return the result as a new GeometricImage.

        args:
            other (GeometricImage or number): scalar or image to multiply by

        returns:
            a new GeometricImage that is the product of this GeometricImage with other
        """
        if isinstance(other, GeometricImage):
            assert self.D == other.D
            assert self.spatial_dims == other.spatial_dims
            assert self.is_torus == other.is_torus
            return self.__class__(
                mul(self.D, self.data, other.data),
                self.parity + other.parity,
                self.D,
                self.is_torus,
                self.covariant_axes + other.covariant_axes,
            )
        else:  # its an integer or a float, or something that can we can multiply a Jax array by (like a DeviceArray)
            return self.__class__(
                self.data * other, self.parity, self.D, self.is_torus, self.covariant_axes
            )

    def __rmul__(self: Self, other: Union[Self, float, int]) -> Self:
        """
        If other is a scalar, multiply the data by the scalar. This is necessary for doing scalar * image, and it
        should only be called in that case.

        args:
            other (GeometricImage or number): scalar or image to multiply by

        returns:
            a new GeometricImage that is the product of this GeometricImage with other
        """
        return self * other

    def transpose(self: Self, axes_permutation: Sequence[int]) -> Self:
        """
        Transposes the axes of the tensor, keeping the image axes in the front the same

        args:
            axes_permutation: new axes order

        returns:
            a new GeometricImage that has been transposed
        """
        idx_shift = len(self.image_shape())
        new_indices = tuple(
            tuple(range(idx_shift)) + tuple(axis + idx_shift for axis in axes_permutation)
        )
        new_covariant_axes = tuple(self.covariant_axes[axis] for axis in axes_permutation)
        return self.__class__(
            jnp.transpose(self.data, new_indices),
            self.parity,
            self.D,
            self.is_torus,
            new_covariant_axes,
        )

    @functools.partial(jax.jit, static_argnums=[2, 3, 4, 5])
    def convolve_with(
        self: Self,
        filter_image: Self,
        stride: Union[int, tuple[int, ...]] = 1,
        padding: Optional[tuple[tuple[int, int]]] = None,
        lhs_dilation: Optional[tuple[int, ...]] = None,
        rhs_dilation: Union[int, tuple[int, ...]] = 1,
    ) -> Self:
        """
        See [convolve](functional_geometric_image.md#ginjax.geometric.functional_geometric_image.convolve)
        for a description of this function.

        args:
            filter_image: the convolution filter, shape (out_c,in_c,spatial,tensor)
            stride: convolution stride, defaults to (1,)*self.D
            padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
                defaults to 'TORUS' if image.is_torus, else 'SAME'
            lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
            rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1

        returns:
            convolved_image of shape (batch,out_c,spatial,tensor)
        """
        convolved_array = convolve(
            self.D,
            self.data[None, None],  # add batch, in_channels axes
            filter_image.data[None, None],  # add out_channels, in_channels axes
            self.is_torus,
            stride,
            padding,
            lhs_dilation,
            rhs_dilation,
        )
        return self.__class__(
            convolved_array[0, 0],  # ignore batch, out_channels axes
            self.parity + filter_image.parity,
            self.D,
            self.is_torus,
            self.covariant_axes + filter_image.covariant_axes,
        )

    def max_pool(self: Self, patch_len: int, use_norm: bool = True) -> Self:
        """
        Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined
        by the norm of the pixel when use_norm is True. Note that for scalars, this will be the absolute value of
        the pixel. If you want to use the max instead, set use_norm to False (requires scalar images).

        args:
            patch_len: the side length of the patches, must evenly divide all spatial dims
            use_norm: whether to use norm to calculate the max

        returns:
            a new GeometricImage with the max pool applied
        """
        return self.__class__(
            max_pool(self.D, self.data, patch_len, use_norm),
            self.parity,
            self.D,
            self.is_torus,
            self.covariant_axes,
        )

    @functools.partial(jax.jit, static_argnums=1)
    def average_pool(self: Self, patch_len: int) -> Self:
        """
        Perform a average pooling operation where the length of the side of each patch is patch_len. This is
        equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the
        filter, the stride length is patch_len, and the padding is 'VALID'.

        args:
            patch_len: the side length of the patches, must evenly divide self.N

        returns:
            a new GeometricImage with the average pool applied
        """
        return self.__class__(
            average_pool(self.D, self.data, patch_len),
            self.parity,
            self.D,
            self.is_torus,
            self.covariant_axes,
        )

    @functools.partial(jax.jit, static_argnums=1)
    def unpool(self: Self, patch_len: int) -> Self:
        """
        Each pixel turns into a (patch_len,)*self.D patch of that pixel. Also called
        "Nearest Neighbor" unpooling.

        args:
            patch_len: side length of the patch of our unpooled images

        returns:
            a new GeometricImage with the unpool applied
        """
        grow_filter = GeometricImage(jnp.ones((patch_len,) * self.D), 0, self.D)
        return self.convolve_with(
            grow_filter,
            padding=((patch_len - 1,) * 2,) * self.D,
            lhs_dilation=(patch_len,) * self.D,
        )

    def times_scalar(self: Self, scalar: float) -> Self:
        """
        Scale the data by a scalar, returning a new GeometricImage object. Alias of the multiplication operator.

        args:
            scalar: number to scale everything by

        returns:
            a new GeometricImage scaled by the scalar
        """
        return self * scalar

    @jax.jit
    def norm(self: Self) -> Self:
        """
        Calculate the norm pixel-wise. This becomes a scalar image.

        returns:
            a new GeoemtricImage of all the pixels normed.
        """
        return self.__class__(norm(self.D, self.data), 0, self.D, self.is_torus)

    def normalize(self: Self) -> Self:
        """
        Normalize so that the max norm of each pixel is 1, and all other tensors are scaled appropriately

        returns:
            a new GeometricImage scaled by the max norm
        """
        max_norm = float(jnp.max(self.norm().data))
        if max_norm > TINY:
            return self.times_scalar(1.0 / max_norm)
        else:
            return self.times_scalar(1.0)

    def activation_function(self: Self, function: Callable[[jnp.ndarray], jnp.ndarray]) -> Self:
        """
        Apply the specified activation function to the GeometricImage

        args:
            function: the activation function

        returns:
            a new GeometricImage with the activation function applied
        """
        assert (
            self.k == 0
        ), "Activation functions only implemented for k=0 tensors due to equivariance"
        return self.__class__(
            function(self.data), self.parity, self.D, self.is_torus, self.covariant_axes
        )

    def contract(self: Self, i: int, j: int) -> Self:
        """
        Use einsum to perform a kronecker contraction on two dimensions of the tensor

        args:
            i: first index of tensor
            j: second index of tensor

        returns:
            a new GeometricImage contracted by those indices
        """
        assert self.k >= 2
        idx_shift = len(self.image_shape())

        first, second = min(i, j), max(i, j)
        axes_ls = self.covariant_axes
        new_covariant_axes = axes_ls[:first] + axes_ls[first + 1 : second] + axes_ls[second + 1 :]
        return self.__class__(
            multicontract(self.data, ((i, j),), idx_shift),
            self.parity,
            self.D,
            self.is_torus,
            new_covariant_axes,
        )

    def multicontract(self: Self, indices: tuple[tuple[int, int], ...]) -> Self:
        """
        Use einsum to perform a kronecker contraction on two dimensions of the tensor

        args:
            indices: indices to contract

        returns:
            a new GeometricImage contracted by those indices
        """
        assert self.k >= 2
        idx_shift = len(self.image_shape())
        sorted_idxs = sorted(list(sum(indices, ())))
        new_cov_axes = tuple(
            self.covariant_axes[prev + 1 : next]
            for prev, next in zip([-1] + sorted_idxs, sorted_idxs + [self.k])
        )
        return self.__class__(
            multicontract(self.data, indices, idx_shift),
            self.parity,
            self.D,
            self.is_torus,
            sum(new_cov_axes, ()),
        )

    def levi_civita_contract(self: Self, indices: Union[tuple[int, ...], int]) -> Self:
        """
        Perform the Levi-Civita contraction. Outer product with the Levi-Civita Symbol, then perform D-1 contractions.
        Resulting image has k= self.k - self.D + 2

        args:
            indices: indices of tensor to perform contractions on

        returns:
            a new GeometricImage contracted by those indices
        """
        assert self.k >= (
            self.D - 1
        )  # so we have enough indices to work on since we perform D-1 contractions
        if not isinstance(indices, tuple):
            indices = (indices,)
        assert len(indices) == self.D - 1

        levi_civita = LeviCivitaSymbol.get(self.D)
        outer = jnp.tensordot(self.data, levi_civita, axes=0)

        # make contraction index pairs with one of specified indices, and index (in order) from the levi_civita symbol
        idx_shift = len(self.image_shape())
        zipped_indices = tuple(
            (i + idx_shift, j + idx_shift)
            for i, j in zip(indices, range(self.k, self.k + len(indices)))
        )
        return self.__class__(
            multicontract(outer, zipped_indices),
            self.parity + 1,
            self.D,
            self.is_torus,
            self.covariant_axes[: self.k - self.D + 2],  # right length, but maybe wrong
        )

    def raise_lower(
        self: Self,
        metric_tensor: Self,
        metric_tensor_inv: Self,
        axes: tuple[bool, ...],
        precision: Optional[jax.lax.Precision] = None,
    ) -> Self:
        """
        Raise or lower the axes of the tensor according the the metric tensor and axes.

        args:
            metric_tensor: the metric tensor g_ij, must be same spatial shape as this
            metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
            axes: desired covariant axes
            precision: precision used for einsum

        returns:
            new GeometricImage with correct axes
        """
        return self.__class__(
            raise_lower(
                self.data,
                metric_tensor.data,
                metric_tensor_inv.data,
                self.covariant_axes,
                axes,
                precision,
            ),
            self.parity,
            self.D,
            self.is_torus,
            axes,
        )

    def raise_lower_precise(
        self: Self, metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...]
    ) -> Self:
        """
        Raise or lower the axes of the tensor according the the metric tensor and axes using the
        highest precision for einsum.

        args:
            metric_tensor: the metric tensor g_ij, must be same spatial shape as this
            metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
            axes: desired covariant axes

        returns:
            new GeometricImage with correct axes
        """
        return self.raise_lower(metric_tensor, metric_tensor_inv, axes, jax.lax.Precision.HIGHEST)

    def times_group_element(
        self: Self,
        gg: np.ndarray,
        precision: Optional[jax.lax.Precision] = None,
    ) -> Self:
        """
        Apply a group element of O(d) to the geometric image. First apply the action to the location
        of the pixels, then apply the action to the pixels themselves. The group element provided
        is the one that acts on contravariant axes, will be inverted to apply to covariant axes as
        well.

        args:
            gg: a DxD matrix that rotates a contravariant vector gg @ v
            precision: precision level for einsum, for equality tests use Precision.HIGHEST

        returns:
            a new GeometricImage that has been rotated
        """
        assert self.k < 14
        assert gg.shape == (self.D, self.D)

        return self.__class__(
            times_group_element(self.D, self.data, self.parity, gg, self.covariant_axes, precision),
            self.parity,
            self.D,
            rotate_is_torus(self.is_torus, gg),
            self.covariant_axes,
        )

    def times_gg_precise(self: Self, gg: np.ndarray) -> Self:
        """
        Apply a group element of O(d) to the geometric image using the highest precision einsum.
        See times_group_element for more details.

        args:
            gg: a DxD matrix that rotates a contravariant vector gg @ v

        returns:
            a new GeometricImage that has been rotated
        """
        return self.times_group_element(gg, jax.lax.Precision.HIGHEST)

    def translate(self: Self, tau: jax.Array) -> Self:
        """
        Translate the image on the torus. Translations on the data matrix are ij ordering. For
        example, a translation of [1,-1] moves the down one row, then to the left one column.

        args:
            tau: the translation vector, length D

        returns:
            a geometric image that has been translated
        """
        assert (
            self.is_torus == (True,) * self.D
        ), f"GeometricImage::translate: Image must be a torus, but got {self.is_torus}"
        assert (
            len(tau) == self.D
        ), f"GeometricImage::translate: {self.D}D image received {len(tau)}D translation"

        return self.__class__(
            translate(self.D, self.data, tau, 0),
            self.parity,
            self.D,
            self.is_torus,
            self.covariant_axes,
        )

    def plot(
        self: Self,
        ax: Optional[matplotlib.axes.Axes] = None,
        title: str = "",
        boxes: bool = False,
        fill: bool = True,
        symbols: bool = False,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        colorbar: bool = False,
        cmap: matplotlib.colors.Colormap | str | None = None,
        vector_scaling: float = 0.5,
    ) -> None:
        """
        Plot the geometric image.

        args:
            ax: matplotlib.pyplot Axes to plot this geometric image one
            title: title of the plot
            boxes: whether to plot boxes around each pixel
            fill: whether to fill the pixels with an appropriate color
            symbols: whether to fill the pixels with a symbol
            vmin: min value to plot, everything below this is cut off. If none, will use actual min
            vmax: max value to plot, everything above this is cut off. If none, will use actual max
            colorbar: whether to plot a colorbar
            cmap: a colormap or string for the pixel fill, scalars and vectors have their defaults
            vector_scaling: how much to scale the vectors
        """
        # plot functions should fail gracefully
        if self.k > 2:
            print(
                f"GeometricImage::plot: Can only plot tensor order 0,1, or 2 images, but got k={self.k}"
            )
            return
        if self.k == 2 and self.D == 3:
            print(f"GeometricImage::plot: Cannot plot D=3, k=2 geometric images.")
            return

        ax = utils.setup_plot() if ax is None else ax

        if self.D == 1:
            # convert image to a 2D image that is N,1
            data_2d = self.data.reshape((len(self.data), 1) + (1,) * self.k)
            mul_img = 1
            if self.k == 1:
                mul_img = jnp.concatenate(
                    [jnp.ones_like(data_2d), jnp.zeros_like(data_2d)], axis=-1
                )
            elif self.k == 2 and self.parity == 0:  # kronecker delta coefficient
                mul_img = jnp.full((self.D, 1) + (2, 2), jnp.eye(2)[None, None])
            elif self.k == 2 and self.parity == 1:  # levi civita coefficient
                mul_img = jnp.full((self.D, 1) + (2, 2), LeviCivitaSymbol.get(2)[None, None])
            elif self.k > 2:
                print(f"GeometricImage::plot: Not implemented for D=1, k={self.k}")
                return

            # GeometricFilters must be square, so make it a GeometricImage
            image_2d = GeometricImage(
                data_2d * mul_img, self.parity, 2, self.is_torus[0], self.covariant_axes
            )
            image_2d.plot(
                ax, title, boxes, fill, symbols, vmin, vmax, colorbar, cmap, vector_scaling
            )
            return

        # This was breaking earlier with jax arrays, not sure why. I really don't want plotting to break,
        # so I am will swap to numpy arrays just in case.
        key_array_transpose = np.array(self.key_array()).T  # (D,N**D)
        xs = key_array_transpose[0]
        ys = key_array_transpose[1]
        zs = key_array_transpose[2:]
        if self.D == 3:
            xs = xs + utils.XOFF * zs
            ys = ys + utils.YOFF * zs

        pixels = np.array(list(self.pixels()))

        if self.k == 0:
            vmin = np.min(pixels) if vmin is None else vmin
            vmax = np.max(pixels) if vmax is None else vmax
            utils.plot_scalars(
                ax,
                self.spatial_dims,
                xs,
                ys,
                pixels,
                boxes=boxes,
                fill=fill,
                symbols=symbols,
                vmin=vmin,
                vmax=vmax,
                cmap=cmap,
                colorbar=colorbar,
            )
        elif self.k == 1:
            vmin = 0.0 if vmin is None else vmin
            vmax = 2.0 if vmax is None else vmax
            utils.plot_vectors(
                ax,
                xs,
                ys,
                pixels,
                boxes=boxes,
                fill=fill,
                vmin=vmin,
                vmax=vmax,
                cmap=cmap,
                scaling=vector_scaling,
            )
        else:  # self.k == 2
            utils.plot_tensors(ax, xs, ys, pixels, boxes=boxes)

        utils.finish_plot(ax, title, xs, ys, self.D)

    def tree_flatten(
        self: Self,
    ) -> tuple[tuple[jnp.ndarray], dict[str, Union[int, Union[bool, tuple[bool]]]]]:
        """
        Helper function to define GeometricImage as a pytree so jax.jit handles it correctly. Children and aux_data
        must contain all the variables that are passed in __init__()
        """
        children = (self.data,)  # arrays / dynamic values
        aux_data = {
            "D": self.D,
            "parity": self.parity,
            "is_torus": self.is_torus,
            "covariant_axes": self.covariant_axes,
        }  # static values
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """
        Helper function to define GeometricImage as a pytree so jax.jit handles it correctly.
        """
        return cls(*children, **aux_data)
zeros(N: Union[int, tuple[int, ...]], k: int, parity: int, D: int, is_torus: Union[bool, tuple[bool]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> Self classmethod ¤

Zero constructor for GeometricImage.

Parameters:

Name Type Description Default
N Union[int, tuple[int, ...]]

length of all sides if an int, otherwise a tuple of the side lengths

required
k int

the order of the tensor in each pixel, i.e. 0 (scalar), 1 (vector), 2 (matrix), etc.

required
parity int

0 or 1, 0 is normal vectors, 1 is pseudovectors

required
D int

dimension of the image, and length of vectors or side length of matrices or tensors.

required
is_torus Union[bool, tuple[bool]]

whether the datablock is a torus, used for convolutions

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly with the coordinate change. False for typical vectors, true for gradients.

False

Returns:

Type Description
Self

constructed GeometricImage

Source code in ginjax/geometric/geometric_image.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@classmethod
def zeros(
    cls,
    N: Union[int, tuple[int, ...]],
    k: int,
    parity: int,
    D: int,
    is_torus: Union[bool, tuple[bool]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> Self:
    """
    Zero constructor for GeometricImage.

    args:
        N: length of all sides if an int, otherwise a tuple of the side lengths
        k: the order of the tensor in each pixel, i.e. 0 (scalar), 1 (vector), 2 (matrix), etc.
        parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
        D: dimension of the image, and length of vectors or side length of matrices or tensors.
        is_torus: whether the datablock is a torus, used for convolutions
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            with the coordinate change. False for typical vectors, true for gradients.

    returns:
        constructed GeometricImage
    """
    spatial_dims = N if isinstance(N, tuple) else (N,) * D
    assert len(spatial_dims) == D
    return cls(jnp.zeros(spatial_dims + (D,) * k), parity, D, is_torus, covariant_axes)
fill(N: Union[int, tuple[int, ...]], parity: int, D: int, fill: Union[jax.Array, float], is_torus: Union[bool, tuple[bool, ...]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> Self classmethod ¤

Fill constructor to construct a geometric image every pixel as fill

Parameters:

Name Type Description Default
N Union[int, tuple[int, ...]]

length of all sides if an int, otherwise a tuple of the side lengths

required
parity int

0 or 1, 0 is normal vectors, 1 is pseudovectors

required
D int

dimension of the image, and length of vectors or side length of matrices or tensors.

required
fill Union[Array, float]

tensor to fill the image with

required
is_torus Union[bool, tuple[bool, ...]]

whether the datablock is a torus, used for convolutions. Defaults to true.

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly with the coordinate change. False for typical vectors, true for gradients.

False

Returns:

Type Description
Self

Constructed GeometricImage

Source code in ginjax/geometric/geometric_image.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@classmethod
def fill(
    cls,
    N: Union[int, tuple[int, ...]],
    parity: int,
    D: int,
    fill: Union[jax.Array, float],
    is_torus: Union[bool, tuple[bool, ...]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> Self:
    """
    Fill constructor to construct a geometric image every pixel as fill

    args:
        N: length of all sides if an int, otherwise a tuple of the side lengths
        parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
        D: dimension of the image, and length of vectors or side length of matrices or tensors.
        fill: tensor to fill the image with
        is_torus: whether the datablock is a torus, used for convolutions. Defaults to true.
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            with the coordinate change. False for typical vectors, true for gradients.

    returns:
        Constructed GeometricImage
    """
    spatial_dims = N if isinstance(N, tuple) else (N,) * D
    assert len(spatial_dims) == D

    k = (
        len(fill.shape)
        if (isinstance(fill, jnp.ndarray) or isinstance(fill, np.ndarray))
        else 0
    )
    data = jnp.stack([fill for _ in range(np.multiply.reduce(spatial_dims))]).reshape(
        spatial_dims + (D,) * k
    )
    return cls(data, parity, D, is_torus, covariant_axes)
__init__(data: jnp.ndarray, parity: int, D: int, is_torus: Union[bool, tuple[bool, ...]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> None ¤

Constructor for GeometricImage. It will be (N^D x D^k), so if N=100, D=2, k=1, then it's (100 x 100 x 2). The spatial dimensions don't have to be square.

Parameters:

Name Type Description Default
data ndarray

image data, shape (spatial,tensor)

required
parity int

0 or 1, 0 is normal vectors, 1 is pseudovectors

required
D int

dimension of the image, and length of vectors or side length of matrices or tensors.

required
is_torus Union[bool, tuple[bool, ...]]

whether the datablock is a torus, used for convolutions. Takes either a tuple of bools of length D specifying whether each dimension is toroidal, or simply True or False which sets all dimensions to that value.

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly with the coordinate change. False for typical vectors, true for gradients. You can only take a contraction between 1 covariant axis and 1 contravariant axis, but for a flat Euclidean metric these vectors are numerically identical, so we will not enforce this.

False
Source code in ginjax/geometric/geometric_image.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def __init__(
    self: Self,
    data: jnp.ndarray,
    parity: int,
    D: int,
    is_torus: Union[bool, tuple[bool, ...]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> None:
    """
    Constructor for GeometricImage. It will be (N^D x D^k), so if N=100, D=2, k=1, then it's
    (100 x 100 x 2). The spatial dimensions don't have to be square.

    args:
        data: image data, shape (spatial,tensor)
        parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
        D: dimension of the image, and length of vectors or side length of matrices or tensors.
        is_torus: whether the datablock is a torus, used for convolutions.
            Takes either a tuple of bools of length D specifying whether each dimension is toroidal,
            or simply True or False which sets all dimensions to that value.
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            with the coordinate change. False for typical vectors, true for gradients. You
            can only take a contraction between 1 covariant axis and 1 contravariant axis,
            but for a flat Euclidean metric these vectors are numerically identical, so we will
            not enforce this.
    """
    self.D = D
    self.spatial_dims, self.k = parse_shape(data.shape, D)
    assert data.shape[D:] == self.k * (
        self.D,
    ), "GeometricImage: each pixel must be D cross D, k times"

    if isinstance(covariant_axes, bool):
        covariant_axes = (covariant_axes,) * self.k

    assert len(covariant_axes) == self.k

    self.covariant_axes = covariant_axes
    self.parity = parity % 2

    assert (isinstance(is_torus, tuple) and (len(is_torus) == D)) or isinstance(is_torus, bool)
    if isinstance(is_torus, bool):
        is_torus = (is_torus,) * D

    self.is_torus = is_torus

    self.data = jnp.copy(
        data
    )  # TODO: don't need to copy if data is already an immutable jnp array
copy() -> Self ¤

Copy the geometric image.

Source code in ginjax/geometric/geometric_image.py
167
168
169
170
171
def copy(self: Self) -> Self:
    """
    Copy the geometric image.
    """
    return self.__class__(self.data, self.parity, self.D, self.is_torus, self.covariant_axes)
hash(indices: ArrayLike) -> tuple[jax.Array, ...] ¤

Converts an array of indices to their pixels on the torus by modding the indices with the spatial dimensions.

Parameters:

Name Type Description Default
indices ArrayLike

array of indices, shape (num_idx, D) to apply the remainder to

required

Returns:

Type Description
tuple[Array, ...]

the pixel indices as a d-tuple of jax arrays

Source code in ginjax/geometric/geometric_image.py
175
176
177
178
179
180
181
182
183
184
185
186
def hash(self: Self, indices: ArrayLike) -> tuple[jax.Array, ...]:
    """
    Converts an array of indices to their pixels on the torus by modding the indices with the
    spatial dimensions.

    args:
        indices: array of indices, shape (num_idx, D) to apply the remainder to

    returns:
        the pixel indices as a d-tuple of jax arrays
    """
    return hash(self.D, self.spatial_dims, indices)
__getitem__(key: Any) -> jax.Array ¤

Accessor for data values. Now you can do image[key] where k are indices or array slices and it will just work Note that JAX does not throw errors for indexing out of bounds

Parameters:

Name Type Description Default
key Any

JAX/numpy indexer, i.e. "0", "0,1,3", "4:, 2:3, 0" etc.

required

Returns:

Type Description
Array

data from the specified index or slice.

Source code in ginjax/geometric/geometric_image.py
188
189
190
191
192
193
194
195
196
197
198
199
def __getitem__(self: Self, key: Any) -> jax.Array:
    """
    Accessor for data values. Now you can do image[key] where k are indices or array slices and it will just work
    Note that JAX does not throw errors for indexing out of bounds

    args:
        key: JAX/numpy indexer, i.e. "0", "0,1,3", "4:, 2:3, 0" etc.

    returns:
        data from the specified index or slice.
    """
    return self.data[key]
__setitem__(key: Any, val: Any) -> Self ¤

Set the jax array data to the specified value. Jax arrays are immutable, so this reconstructs the data object with copying, and is potentially slow.

Parameters:

Name Type Description Default
key Any

index or slice to access data

required
val Any

value to set the data to

required

Returns:

Type Description
Self

the geometric image

Source code in ginjax/geometric/geometric_image.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def __setitem__(self: Self, key: Any, val: Any) -> Self:
    """
    Set the jax array data to the specified value. Jax arrays are immutable, so this
    reconstructs the data object with copying, and is potentially slow.

    args:
        key: index or slice to access data
        val: value to set the data to

    returns:
        the geometric image
    """
    self.data = self.data.at[key].set(val)
    return self
shape() -> tuple[int, ...] ¤

Return the full shape of the data block

Returns:

Type Description
tuple[int, ...]

The shape of the data block

Source code in ginjax/geometric/geometric_image.py
216
217
218
219
220
221
222
223
def shape(self: Self) -> tuple[int, ...]:
    """
    Return the full shape of the data block

    returns:
        The shape of the data block
    """
    return self.data.shape
image_shape(plus_Ns: Optional[tuple[int, ...]] = None) -> tuple[int, ...] ¤

Return the shape of the data block that is not the ktensor shape, but what comes before that.

Parameters:

Name Type Description Default
plus_Ns Optional[tuple[int, ...]]

d-length tuple, N to add to each spatial dim

None

Returns:

Type Description
tuple[int, ...]

the shape of the image, modified by plus_Ns

Source code in ginjax/geometric/geometric_image.py
225
226
227
228
229
230
231
232
233
234
235
236
def image_shape(self: Self, plus_Ns: Optional[tuple[int, ...]] = None) -> tuple[int, ...]:
    """
    Return the shape of the data block that is not the ktensor shape, but what comes before that.

    args:
        plus_Ns: d-length tuple, N to add to each spatial dim

    returns:
        the shape of the image, modified by plus_Ns
    """
    plus_Ns = (0,) * self.D if (plus_Ns is None) else plus_Ns
    return tuple(N + plus_N for N, plus_N in zip(self.spatial_dims, plus_Ns))
image_size() -> int ¤

Return the total number of pixels in the image.

Source code in ginjax/geometric/geometric_image.py
238
239
240
241
242
def image_size(self: Self) -> int:
    """
    Return the total number of pixels in the image.
    """
    return functools.reduce(lambda c, v: c * v, self.image_shape(), 1)
pixel_shape() -> tuple[int, ...] ¤

Return the shape of the data block that is the ktensor, aka the pixel of the image.

Returns:

Type Description
tuple[int, ...]

the shape of the pixel

Source code in ginjax/geometric/geometric_image.py
244
245
246
247
248
249
250
251
def pixel_shape(self: Self) -> tuple[int, ...]:
    """
    Return the shape of the data block that is the ktensor, aka the pixel of the image.

    returns:
        the shape of the pixel
    """
    return self.k * (self.D,)
pixel_size() -> int ¤

Get the size of the pixel shape, i.e. (D,D,D) = D**3

Returns:

Type Description
int

the size of the pixels

Source code in ginjax/geometric/geometric_image.py
253
254
255
256
257
258
259
260
def pixel_size(self: Self) -> int:
    """
    Get the size of the pixel shape, i.e. (D,D,D) = D**3

    returns:
        the size of the pixels
    """
    return self.D**self.k
__str__() -> str ¤

Returns:

Type Description
str

the string representation of the GeometricImage

Source code in ginjax/geometric/geometric_image.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def __str__(self: Self) -> str:
    """
    returns:
        the string representation of the GeometricImage
    """
    return "<{} object in D={} with spatial_dims={}, k={}, parity={}, is_torus={}, covariant_axes={}>".format(
        self.__class__,
        self.D,
        self.spatial_dims,
        self.k,
        self.parity,
        self.is_torus,
        self.covariant_axes,
    )
keys() -> Any ¤

Iterate over the keys of GeometricImage

Source code in ginjax/geometric/geometric_image.py
278
279
280
281
282
def keys(self: Self) -> Any:
    """
    Iterate over the keys of GeometricImage
    """
    return it.product(*list(range(N) for N in self.spatial_dims))
key_array() -> jax.Array ¤

Returns:

Type Description
Array

the pixel indices as a jax array

Source code in ginjax/geometric/geometric_image.py
284
285
286
287
288
289
290
def key_array(self: Self) -> jax.Array:
    """
    returns:
        the pixel indices as a jax array
    """
    # equivalent to the old pixels function
    return jnp.array([key for key in self.keys()], dtype=int)
pixels() -> Generator[jax.Array] ¤

Iterate over the pixels of GeometricImage.

Returns:

Type Description
Generator[Array]

a generator of the pixels

Source code in ginjax/geometric/geometric_image.py
292
293
294
295
296
297
298
299
300
def pixels(self: Self) -> Generator[jax.Array]:
    """
    Iterate over the pixels of GeometricImage.

    returns:
        a generator of the pixels
    """
    for key in self.keys():
        yield self[key]
items() -> Generator[tuple[Any, jax.Array]] ¤

Iterate over the key, pixel pairs of GeometricImage.

Returns:

Type Description
Generator[tuple[Any, Array]]

a generator of pairs of the pixel index and its pixel

Source code in ginjax/geometric/geometric_image.py
302
303
304
305
306
307
308
309
310
def items(self: Self) -> Generator[tuple[Any, jax.Array]]:
    """
    Iterate over the key, pixel pairs of GeometricImage.

    returns:
        a generator of pairs of the pixel index and its pixel
    """
    for key in self.keys():
        yield (key, self[key])
__eq__(other: object, rtol: float = TINY, atol: float = TINY) -> bool ¤

Equality operator, must have same shape, parity, and data within the TINY=1e-5 tolerance.

Parameters:

Name Type Description Default
other object

an object to compare to this GeometricImage

required
rtol float

relative tolerance, passed to jnp.allclose

TINY
atol float

absolute tolerance, passed to jnp.allclose

TINY

Returns:

Type Description
bool

true if they are equal, false otherwise

Source code in ginjax/geometric/geometric_image.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def __eq__(self: Self, other: object, rtol: float = TINY, atol: float = TINY) -> bool:
    """
    Equality operator, must have same shape, parity, and data within the TINY=1e-5 tolerance.

    args:
        other: an object to compare to this GeometricImage
        rtol: relative tolerance, passed to jnp.allclose
        atol: absolute tolerance, passed to jnp.allclose

    returns:
        true if they are equal, false otherwise
    """
    if isinstance(other, GeometricImage):
        return (
            self.D == other.D
            and self.spatial_dims == other.spatial_dims
            and self.k == other.k
            and self.parity == other.parity
            and self.is_torus == other.is_torus
            and self.covariant_axes == other.covariant_axes
            and self.data.shape == other.data.shape
            and bool(jnp.allclose(self.data, other.data, rtol, atol))
        )
    else:
        return False
__add__(other: Self) -> Self ¤

Addition operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

Parameters:

Name Type Description Default
other Self

other image to add the the first one

required

Returns:

Type Description
Self

a new GeometricImage that is the sum of this one and the other one

Source code in ginjax/geometric/geometric_image.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def __add__(self: Self, other: Self) -> Self:
    """
    Addition operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

    args:
        other: other image to add the the first one

    returns:
        a new GeometricImage that is the sum of this one and the other one
    """
    assert self.D == other.D
    assert self.spatial_dims == other.spatial_dims
    assert self.k == other.k
    assert self.parity == other.parity
    assert self.is_torus == other.is_torus
    assert self.covariant_axes == other.covariant_axes
    assert self.data.shape == other.data.shape
    return self.__class__(
        self.data + other.data, self.parity, self.D, self.is_torus, self.covariant_axes
    )
__sub__(other: Self) -> Self ¤

Subtraction operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

Parameters:

Name Type Description Default
other Self

other image to add the the first one

required

Returns:

Type Description
Self

a new GeometricImage that is the difference of this GeometricImage and the other one

Source code in ginjax/geometric/geometric_image.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def __sub__(self: Self, other: Self) -> Self:
    """
    Subtraction operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

    args:
        other: other image to add the the first one

    returns:
        a new GeometricImage that is the difference of this GeometricImage and the other one
    """
    assert self.D == other.D
    assert self.spatial_dims == other.spatial_dims
    assert self.k == other.k
    assert self.parity == other.parity
    assert self.is_torus == other.is_torus
    assert self.covariant_axes == other.covariant_axes
    assert self.data.shape == other.data.shape
    return self.__class__(
        self.data - other.data, self.parity, self.D, self.is_torus, self.covariant_axes
    )
__mul__(other: Union[Self, float, int]) -> Self ¤

If other is a scalar, do scalar multiplication of the data. If it is another GeometricImage, do the tensor product at each pixel. Return the result as a new GeometricImage.

Parameters:

Name Type Description Default
other GeometricImage or number

scalar or image to multiply by

required

Returns:

Type Description
Self

a new GeometricImage that is the product of this GeometricImage with other

Source code in ginjax/geometric/geometric_image.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def __mul__(self: Self, other: Union[Self, float, int]) -> Self:
    """
    If other is a scalar, do scalar multiplication of the data. If it is another GeometricImage, do the tensor
    product at each pixel. Return the result as a new GeometricImage.

    args:
        other (GeometricImage or number): scalar or image to multiply by

    returns:
        a new GeometricImage that is the product of this GeometricImage with other
    """
    if isinstance(other, GeometricImage):
        assert self.D == other.D
        assert self.spatial_dims == other.spatial_dims
        assert self.is_torus == other.is_torus
        return self.__class__(
            mul(self.D, self.data, other.data),
            self.parity + other.parity,
            self.D,
            self.is_torus,
            self.covariant_axes + other.covariant_axes,
        )
    else:  # its an integer or a float, or something that can we can multiply a Jax array by (like a DeviceArray)
        return self.__class__(
            self.data * other, self.parity, self.D, self.is_torus, self.covariant_axes
        )
__rmul__(other: Union[Self, float, int]) -> Self ¤

If other is a scalar, multiply the data by the scalar. This is necessary for doing scalar * image, and it should only be called in that case.

Parameters:

Name Type Description Default
other GeometricImage or number

scalar or image to multiply by

required

Returns:

Type Description
Self

a new GeometricImage that is the product of this GeometricImage with other

Source code in ginjax/geometric/geometric_image.py
409
410
411
412
413
414
415
416
417
418
419
420
def __rmul__(self: Self, other: Union[Self, float, int]) -> Self:
    """
    If other is a scalar, multiply the data by the scalar. This is necessary for doing scalar * image, and it
    should only be called in that case.

    args:
        other (GeometricImage or number): scalar or image to multiply by

    returns:
        a new GeometricImage that is the product of this GeometricImage with other
    """
    return self * other
transpose(axes_permutation: Sequence[int]) -> Self ¤

Transposes the axes of the tensor, keeping the image axes in the front the same

Parameters:

Name Type Description Default
axes_permutation Sequence[int]

new axes order

required

Returns:

Type Description
Self

a new GeometricImage that has been transposed

Source code in ginjax/geometric/geometric_image.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def transpose(self: Self, axes_permutation: Sequence[int]) -> Self:
    """
    Transposes the axes of the tensor, keeping the image axes in the front the same

    args:
        axes_permutation: new axes order

    returns:
        a new GeometricImage that has been transposed
    """
    idx_shift = len(self.image_shape())
    new_indices = tuple(
        tuple(range(idx_shift)) + tuple(axis + idx_shift for axis in axes_permutation)
    )
    new_covariant_axes = tuple(self.covariant_axes[axis] for axis in axes_permutation)
    return self.__class__(
        jnp.transpose(self.data, new_indices),
        self.parity,
        self.D,
        self.is_torus,
        new_covariant_axes,
    )
convolve_with(filter_image: Self, stride: Union[int, tuple[int, ...]] = 1, padding: Optional[tuple[tuple[int, int]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1) -> Self ¤

See convolve for a description of this function.

Parameters:

Name Type Description Default
filter_image Self

the convolution filter, shape (out_c,in_c,spatial,tensor)

required
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[tuple[tuple[int, int]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D, defaults to 1

1

Returns:

Type Description
Self

convolved_image of shape (batch,out_c,spatial,tensor)

Source code in ginjax/geometric/geometric_image.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@functools.partial(jax.jit, static_argnums=[2, 3, 4, 5])
def convolve_with(
    self: Self,
    filter_image: Self,
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[tuple[tuple[int, int]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
) -> Self:
    """
    See [convolve](functional_geometric_image.md#ginjax.geometric.functional_geometric_image.convolve)
    for a description of this function.

    args:
        filter_image: the convolution filter, shape (out_c,in_c,spatial,tensor)
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1

    returns:
        convolved_image of shape (batch,out_c,spatial,tensor)
    """
    convolved_array = convolve(
        self.D,
        self.data[None, None],  # add batch, in_channels axes
        filter_image.data[None, None],  # add out_channels, in_channels axes
        self.is_torus,
        stride,
        padding,
        lhs_dilation,
        rhs_dilation,
    )
    return self.__class__(
        convolved_array[0, 0],  # ignore batch, out_channels axes
        self.parity + filter_image.parity,
        self.D,
        self.is_torus,
        self.covariant_axes + filter_image.covariant_axes,
    )
max_pool(patch_len: int, use_norm: bool = True) -> Self ¤

Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined by the norm of the pixel when use_norm is True. Note that for scalars, this will be the absolute value of the pixel. If you want to use the max instead, set use_norm to False (requires scalar images).

Parameters:

Name Type Description Default
patch_len int

the side length of the patches, must evenly divide all spatial dims

required
use_norm bool

whether to use norm to calculate the max

True

Returns:

Type Description
Self

a new GeometricImage with the max pool applied

Source code in ginjax/geometric/geometric_image.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def max_pool(self: Self, patch_len: int, use_norm: bool = True) -> Self:
    """
    Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined
    by the norm of the pixel when use_norm is True. Note that for scalars, this will be the absolute value of
    the pixel. If you want to use the max instead, set use_norm to False (requires scalar images).

    args:
        patch_len: the side length of the patches, must evenly divide all spatial dims
        use_norm: whether to use norm to calculate the max

    returns:
        a new GeometricImage with the max pool applied
    """
    return self.__class__(
        max_pool(self.D, self.data, patch_len, use_norm),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
average_pool(patch_len: int) -> Self ¤

Perform a average pooling operation where the length of the side of each patch is patch_len. This is equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the filter, the stride length is patch_len, and the padding is 'VALID'.

Parameters:

Name Type Description Default
patch_len int

the side length of the patches, must evenly divide self.N

required

Returns:

Type Description
Self

a new GeometricImage with the average pool applied

Source code in ginjax/geometric/geometric_image.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
@functools.partial(jax.jit, static_argnums=1)
def average_pool(self: Self, patch_len: int) -> Self:
    """
    Perform a average pooling operation where the length of the side of each patch is patch_len. This is
    equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the
    filter, the stride length is patch_len, and the padding is 'VALID'.

    args:
        patch_len: the side length of the patches, must evenly divide self.N

    returns:
        a new GeometricImage with the average pool applied
    """
    return self.__class__(
        average_pool(self.D, self.data, patch_len),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
unpool(patch_len: int) -> Self ¤

Each pixel turns into a (patch_len,)*self.D patch of that pixel. Also called "Nearest Neighbor" unpooling.

Parameters:

Name Type Description Default
patch_len int

side length of the patch of our unpooled images

required

Returns:

Type Description
Self

a new GeometricImage with the unpool applied

Source code in ginjax/geometric/geometric_image.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
@functools.partial(jax.jit, static_argnums=1)
def unpool(self: Self, patch_len: int) -> Self:
    """
    Each pixel turns into a (patch_len,)*self.D patch of that pixel. Also called
    "Nearest Neighbor" unpooling.

    args:
        patch_len: side length of the patch of our unpooled images

    returns:
        a new GeometricImage with the unpool applied
    """
    grow_filter = GeometricImage(jnp.ones((patch_len,) * self.D), 0, self.D)
    return self.convolve_with(
        grow_filter,
        padding=((patch_len - 1,) * 2,) * self.D,
        lhs_dilation=(patch_len,) * self.D,
    )
times_scalar(scalar: float) -> Self ¤

Scale the data by a scalar, returning a new GeometricImage object. Alias of the multiplication operator.

Parameters:

Name Type Description Default
scalar float

number to scale everything by

required

Returns:

Type Description
Self

a new GeometricImage scaled by the scalar

Source code in ginjax/geometric/geometric_image.py
548
549
550
551
552
553
554
555
556
557
558
def times_scalar(self: Self, scalar: float) -> Self:
    """
    Scale the data by a scalar, returning a new GeometricImage object. Alias of the multiplication operator.

    args:
        scalar: number to scale everything by

    returns:
        a new GeometricImage scaled by the scalar
    """
    return self * scalar
norm() -> Self ¤

Calculate the norm pixel-wise. This becomes a scalar image.

Returns:

Type Description
Self

a new GeoemtricImage of all the pixels normed.

Source code in ginjax/geometric/geometric_image.py
560
561
562
563
564
565
566
567
568
@jax.jit
def norm(self: Self) -> Self:
    """
    Calculate the norm pixel-wise. This becomes a scalar image.

    returns:
        a new GeoemtricImage of all the pixels normed.
    """
    return self.__class__(norm(self.D, self.data), 0, self.D, self.is_torus)
normalize() -> Self ¤

Normalize so that the max norm of each pixel is 1, and all other tensors are scaled appropriately

Returns:

Type Description
Self

a new GeometricImage scaled by the max norm

Source code in ginjax/geometric/geometric_image.py
570
571
572
573
574
575
576
577
578
579
580
581
def normalize(self: Self) -> Self:
    """
    Normalize so that the max norm of each pixel is 1, and all other tensors are scaled appropriately

    returns:
        a new GeometricImage scaled by the max norm
    """
    max_norm = float(jnp.max(self.norm().data))
    if max_norm > TINY:
        return self.times_scalar(1.0 / max_norm)
    else:
        return self.times_scalar(1.0)
activation_function(function: Callable[[jnp.ndarray], jnp.ndarray]) -> Self ¤

Apply the specified activation function to the GeometricImage

Parameters:

Name Type Description Default
function Callable[[ndarray], ndarray]

the activation function

required

Returns:

Type Description
Self

a new GeometricImage with the activation function applied

Source code in ginjax/geometric/geometric_image.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def activation_function(self: Self, function: Callable[[jnp.ndarray], jnp.ndarray]) -> Self:
    """
    Apply the specified activation function to the GeometricImage

    args:
        function: the activation function

    returns:
        a new GeometricImage with the activation function applied
    """
    assert (
        self.k == 0
    ), "Activation functions only implemented for k=0 tensors due to equivariance"
    return self.__class__(
        function(self.data), self.parity, self.D, self.is_torus, self.covariant_axes
    )
contract(i: int, j: int) -> Self ¤

Use einsum to perform a kronecker contraction on two dimensions of the tensor

Parameters:

Name Type Description Default
i int

first index of tensor

required
j int

second index of tensor

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def contract(self: Self, i: int, j: int) -> Self:
    """
    Use einsum to perform a kronecker contraction on two dimensions of the tensor

    args:
        i: first index of tensor
        j: second index of tensor

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= 2
    idx_shift = len(self.image_shape())

    first, second = min(i, j), max(i, j)
    axes_ls = self.covariant_axes
    new_covariant_axes = axes_ls[:first] + axes_ls[first + 1 : second] + axes_ls[second + 1 :]
    return self.__class__(
        multicontract(self.data, ((i, j),), idx_shift),
        self.parity,
        self.D,
        self.is_torus,
        new_covariant_axes,
    )
multicontract(indices: tuple[tuple[int, int], ...]) -> Self ¤

Use einsum to perform a kronecker contraction on two dimensions of the tensor

Parameters:

Name Type Description Default
indices tuple[tuple[int, int], ...]

indices to contract

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def multicontract(self: Self, indices: tuple[tuple[int, int], ...]) -> Self:
    """
    Use einsum to perform a kronecker contraction on two dimensions of the tensor

    args:
        indices: indices to contract

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= 2
    idx_shift = len(self.image_shape())
    sorted_idxs = sorted(list(sum(indices, ())))
    new_cov_axes = tuple(
        self.covariant_axes[prev + 1 : next]
        for prev, next in zip([-1] + sorted_idxs, sorted_idxs + [self.k])
    )
    return self.__class__(
        multicontract(self.data, indices, idx_shift),
        self.parity,
        self.D,
        self.is_torus,
        sum(new_cov_axes, ()),
    )
levi_civita_contract(indices: Union[tuple[int, ...], int]) -> Self ¤

Perform the Levi-Civita contraction. Outer product with the Levi-Civita Symbol, then perform D-1 contractions. Resulting image has k= self.k - self.D + 2

Parameters:

Name Type Description Default
indices Union[tuple[int, ...], int]

indices of tensor to perform contractions on

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def levi_civita_contract(self: Self, indices: Union[tuple[int, ...], int]) -> Self:
    """
    Perform the Levi-Civita contraction. Outer product with the Levi-Civita Symbol, then perform D-1 contractions.
    Resulting image has k= self.k - self.D + 2

    args:
        indices: indices of tensor to perform contractions on

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= (
        self.D - 1
    )  # so we have enough indices to work on since we perform D-1 contractions
    if not isinstance(indices, tuple):
        indices = (indices,)
    assert len(indices) == self.D - 1

    levi_civita = LeviCivitaSymbol.get(self.D)
    outer = jnp.tensordot(self.data, levi_civita, axes=0)

    # make contraction index pairs with one of specified indices, and index (in order) from the levi_civita symbol
    idx_shift = len(self.image_shape())
    zipped_indices = tuple(
        (i + idx_shift, j + idx_shift)
        for i, j in zip(indices, range(self.k, self.k + len(indices)))
    )
    return self.__class__(
        multicontract(outer, zipped_indices),
        self.parity + 1,
        self.D,
        self.is_torus,
        self.covariant_axes[: self.k - self.D + 2],  # right length, but maybe wrong
    )
raise_lower(metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...], precision: Optional[jax.lax.Precision] = None) -> Self ¤

Raise or lower the axes of the tensor according the the metric tensor and axes.

Parameters:

Name Type Description Default
metric_tensor Self

the metric tensor g_ij, must be same spatial shape as this

required
metric_tensor_inv Self

the inverse metric tensor, g^ij. Must be same spatial shape as this

required
axes tuple[bool, ...]

desired covariant axes

required
precision Optional[Precision]

precision used for einsum

None

Returns:

Type Description
Self

new GeometricImage with correct axes

Source code in ginjax/geometric/geometric_image.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
def raise_lower(
    self: Self,
    metric_tensor: Self,
    metric_tensor_inv: Self,
    axes: tuple[bool, ...],
    precision: Optional[jax.lax.Precision] = None,
) -> Self:
    """
    Raise or lower the axes of the tensor according the the metric tensor and axes.

    args:
        metric_tensor: the metric tensor g_ij, must be same spatial shape as this
        metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
        axes: desired covariant axes
        precision: precision used for einsum

    returns:
        new GeometricImage with correct axes
    """
    return self.__class__(
        raise_lower(
            self.data,
            metric_tensor.data,
            metric_tensor_inv.data,
            self.covariant_axes,
            axes,
            precision,
        ),
        self.parity,
        self.D,
        self.is_torus,
        axes,
    )
raise_lower_precise(metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...]) -> Self ¤

Raise or lower the axes of the tensor according the the metric tensor and axes using the highest precision for einsum.

Parameters:

Name Type Description Default
metric_tensor Self

the metric tensor g_ij, must be same spatial shape as this

required
metric_tensor_inv Self

the inverse metric tensor, g^ij. Must be same spatial shape as this

required
axes tuple[bool, ...]

desired covariant axes

required

Returns:

Type Description
Self

new GeometricImage with correct axes

Source code in ginjax/geometric/geometric_image.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
def raise_lower_precise(
    self: Self, metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...]
) -> Self:
    """
    Raise or lower the axes of the tensor according the the metric tensor and axes using the
    highest precision for einsum.

    args:
        metric_tensor: the metric tensor g_ij, must be same spatial shape as this
        metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
        axes: desired covariant axes

    returns:
        new GeometricImage with correct axes
    """
    return self.raise_lower(metric_tensor, metric_tensor_inv, axes, jax.lax.Precision.HIGHEST)
times_group_element(gg: np.ndarray, precision: Optional[jax.lax.Precision] = None) -> Self ¤

Apply a group element of O(d) to the geometric image. First apply the action to the location of the pixels, then apply the action to the pixels themselves. The group element provided is the one that acts on contravariant axes, will be inverted to apply to covariant axes as well.

Parameters:

Name Type Description Default
gg ndarray

a DxD matrix that rotates a contravariant vector gg @ v

required
precision Optional[Precision]

precision level for einsum, for equality tests use Precision.HIGHEST

None

Returns:

Type Description
Self

a new GeometricImage that has been rotated

Source code in ginjax/geometric/geometric_image.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def times_group_element(
    self: Self,
    gg: np.ndarray,
    precision: Optional[jax.lax.Precision] = None,
) -> Self:
    """
    Apply a group element of O(d) to the geometric image. First apply the action to the location
    of the pixels, then apply the action to the pixels themselves. The group element provided
    is the one that acts on contravariant axes, will be inverted to apply to covariant axes as
    well.

    args:
        gg: a DxD matrix that rotates a contravariant vector gg @ v
        precision: precision level for einsum, for equality tests use Precision.HIGHEST

    returns:
        a new GeometricImage that has been rotated
    """
    assert self.k < 14
    assert gg.shape == (self.D, self.D)

    return self.__class__(
        times_group_element(self.D, self.data, self.parity, gg, self.covariant_axes, precision),
        self.parity,
        self.D,
        rotate_is_torus(self.is_torus, gg),
        self.covariant_axes,
    )
times_gg_precise(gg: np.ndarray) -> Self ¤

Apply a group element of O(d) to the geometric image using the highest precision einsum. See times_group_element for more details.

Parameters:

Name Type Description Default
gg ndarray

a DxD matrix that rotates a contravariant vector gg @ v

required

Returns:

Type Description
Self

a new GeometricImage that has been rotated

Source code in ginjax/geometric/geometric_image.py
765
766
767
768
769
770
771
772
773
774
775
776
def times_gg_precise(self: Self, gg: np.ndarray) -> Self:
    """
    Apply a group element of O(d) to the geometric image using the highest precision einsum.
    See times_group_element for more details.

    args:
        gg: a DxD matrix that rotates a contravariant vector gg @ v

    returns:
        a new GeometricImage that has been rotated
    """
    return self.times_group_element(gg, jax.lax.Precision.HIGHEST)
translate(tau: jax.Array) -> Self ¤

Translate the image on the torus. Translations on the data matrix are ij ordering. For example, a translation of [1,-1] moves the down one row, then to the left one column.

Parameters:

Name Type Description Default
tau Array

the translation vector, length D

required

Returns:

Type Description
Self

a geometric image that has been translated

Source code in ginjax/geometric/geometric_image.py
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
def translate(self: Self, tau: jax.Array) -> Self:
    """
    Translate the image on the torus. Translations on the data matrix are ij ordering. For
    example, a translation of [1,-1] moves the down one row, then to the left one column.

    args:
        tau: the translation vector, length D

    returns:
        a geometric image that has been translated
    """
    assert (
        self.is_torus == (True,) * self.D
    ), f"GeometricImage::translate: Image must be a torus, but got {self.is_torus}"
    assert (
        len(tau) == self.D
    ), f"GeometricImage::translate: {self.D}D image received {len(tau)}D translation"

    return self.__class__(
        translate(self.D, self.data, tau, 0),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
plot(ax: Optional[matplotlib.axes.Axes] = None, title: str = '', boxes: bool = False, fill: bool = True, symbols: bool = False, vmin: Optional[float] = None, vmax: Optional[float] = None, colorbar: bool = False, cmap: matplotlib.colors.Colormap | str | None = None, vector_scaling: float = 0.5) -> None ¤

Plot the geometric image.

Parameters:

Name Type Description Default
ax Optional[Axes]

matplotlib.pyplot Axes to plot this geometric image one

None
title str

title of the plot

''
boxes bool

whether to plot boxes around each pixel

False
fill bool

whether to fill the pixels with an appropriate color

True
symbols bool

whether to fill the pixels with a symbol

False
vmin Optional[float]

min value to plot, everything below this is cut off. If none, will use actual min

None
vmax Optional[float]

max value to plot, everything above this is cut off. If none, will use actual max

None
colorbar bool

whether to plot a colorbar

False
cmap Colormap | str | None

a colormap or string for the pixel fill, scalars and vectors have their defaults

None
vector_scaling float

how much to scale the vectors

0.5
Source code in ginjax/geometric/geometric_image.py
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
def plot(
    self: Self,
    ax: Optional[matplotlib.axes.Axes] = None,
    title: str = "",
    boxes: bool = False,
    fill: bool = True,
    symbols: bool = False,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    colorbar: bool = False,
    cmap: matplotlib.colors.Colormap | str | None = None,
    vector_scaling: float = 0.5,
) -> None:
    """
    Plot the geometric image.

    args:
        ax: matplotlib.pyplot Axes to plot this geometric image one
        title: title of the plot
        boxes: whether to plot boxes around each pixel
        fill: whether to fill the pixels with an appropriate color
        symbols: whether to fill the pixels with a symbol
        vmin: min value to plot, everything below this is cut off. If none, will use actual min
        vmax: max value to plot, everything above this is cut off. If none, will use actual max
        colorbar: whether to plot a colorbar
        cmap: a colormap or string for the pixel fill, scalars and vectors have their defaults
        vector_scaling: how much to scale the vectors
    """
    # plot functions should fail gracefully
    if self.k > 2:
        print(
            f"GeometricImage::plot: Can only plot tensor order 0,1, or 2 images, but got k={self.k}"
        )
        return
    if self.k == 2 and self.D == 3:
        print(f"GeometricImage::plot: Cannot plot D=3, k=2 geometric images.")
        return

    ax = utils.setup_plot() if ax is None else ax

    if self.D == 1:
        # convert image to a 2D image that is N,1
        data_2d = self.data.reshape((len(self.data), 1) + (1,) * self.k)
        mul_img = 1
        if self.k == 1:
            mul_img = jnp.concatenate(
                [jnp.ones_like(data_2d), jnp.zeros_like(data_2d)], axis=-1
            )
        elif self.k == 2 and self.parity == 0:  # kronecker delta coefficient
            mul_img = jnp.full((self.D, 1) + (2, 2), jnp.eye(2)[None, None])
        elif self.k == 2 and self.parity == 1:  # levi civita coefficient
            mul_img = jnp.full((self.D, 1) + (2, 2), LeviCivitaSymbol.get(2)[None, None])
        elif self.k > 2:
            print(f"GeometricImage::plot: Not implemented for D=1, k={self.k}")
            return

        # GeometricFilters must be square, so make it a GeometricImage
        image_2d = GeometricImage(
            data_2d * mul_img, self.parity, 2, self.is_torus[0], self.covariant_axes
        )
        image_2d.plot(
            ax, title, boxes, fill, symbols, vmin, vmax, colorbar, cmap, vector_scaling
        )
        return

    # This was breaking earlier with jax arrays, not sure why. I really don't want plotting to break,
    # so I am will swap to numpy arrays just in case.
    key_array_transpose = np.array(self.key_array()).T  # (D,N**D)
    xs = key_array_transpose[0]
    ys = key_array_transpose[1]
    zs = key_array_transpose[2:]
    if self.D == 3:
        xs = xs + utils.XOFF * zs
        ys = ys + utils.YOFF * zs

    pixels = np.array(list(self.pixels()))

    if self.k == 0:
        vmin = np.min(pixels) if vmin is None else vmin
        vmax = np.max(pixels) if vmax is None else vmax
        utils.plot_scalars(
            ax,
            self.spatial_dims,
            xs,
            ys,
            pixels,
            boxes=boxes,
            fill=fill,
            symbols=symbols,
            vmin=vmin,
            vmax=vmax,
            cmap=cmap,
            colorbar=colorbar,
        )
    elif self.k == 1:
        vmin = 0.0 if vmin is None else vmin
        vmax = 2.0 if vmax is None else vmax
        utils.plot_vectors(
            ax,
            xs,
            ys,
            pixels,
            boxes=boxes,
            fill=fill,
            vmin=vmin,
            vmax=vmax,
            cmap=cmap,
            scaling=vector_scaling,
        )
    else:  # self.k == 2
        utils.plot_tensors(ax, xs, ys, pixels, boxes=boxes)

    utils.finish_plot(ax, title, xs, ys, self.D)
tree_flatten() -> tuple[tuple[jnp.ndarray], dict[str, Union[int, Union[bool, tuple[bool]]]]] ¤

Helper function to define GeometricImage as a pytree so jax.jit handles it correctly. Children and aux_data must contain all the variables that are passed in init()

Source code in ginjax/geometric/geometric_image.py
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
def tree_flatten(
    self: Self,
) -> tuple[tuple[jnp.ndarray], dict[str, Union[int, Union[bool, tuple[bool]]]]]:
    """
    Helper function to define GeometricImage as a pytree so jax.jit handles it correctly. Children and aux_data
    must contain all the variables that are passed in __init__()
    """
    children = (self.data,)  # arrays / dynamic values
    aux_data = {
        "D": self.D,
        "parity": self.parity,
        "is_torus": self.is_torus,
        "covariant_axes": self.covariant_axes,
    }  # static values
    return (children, aux_data)
tree_unflatten(aux_data, children) classmethod ¤

Helper function to define GeometricImage as a pytree so jax.jit handles it correctly.

Source code in ginjax/geometric/geometric_image.py
934
935
936
937
938
939
@classmethod
def tree_unflatten(cls, aux_data, children):
    """
    Helper function to define GeometricImage as a pytree so jax.jit handles it correctly.
    """
    return cls(*children, **aux_data)

GeometricFilter ¤

Bases: GeometricImage

A subclass of GeometricImage that enforces square, odd spatial dimensions.

Source code in ginjax/geometric/geometric_image.py
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
@functools.total_ordering
@register_pytree_node_class
class GeometricFilter(GeometricImage):
    """
    A subclass of GeometricImage that enforces square, odd spatial dimensions.
    """

    def __init__(
        self: Self,
        data: jnp.ndarray,
        parity: int,
        D: int,
        is_torus: Union[bool, tuple[bool, ...]] = True,
        covariant_axes: Union[bool, tuple[bool, ...]] = False,
    ) -> None:
        """
        Constructor for GeometricFilter.

        args:
            data: the image data of shape (spatial,tensor). Spatial dimensions must be square, odd
            parity: parity of tensor, 0 for vector, 1 for pseudo-vector
            D: dimension of the image
            is_torus: which dimensions are toroidal
            covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
                of the coordinate change. False for typical vectors, true for gradients. You
                can only take a contraction between 1 covariant axis and 1 contravariant axis,
                but for a flat Euclidean metric these vectors are numerically identical, so we will
                not enforce this.
        """
        super().__init__(data, parity, D, is_torus, covariant_axes)
        assert (
            self.spatial_dims == (self.spatial_dims[0],) * self.D
        ), "GeometricFilter: Filters must be square."  # I could remove  this requirement in the future

    @classmethod
    def from_image(cls, geometric_image: GeometricImage) -> Self:
        """
        Constructor that copies a GeometricImage and returns a GeometricFilter

        args:
            geometric_image: the GeometricImage to copy

        returns:
            a new GeometricFilter copy
        """
        return cls(
            geometric_image.data,
            geometric_image.parity,
            geometric_image.D,
            geometric_image.is_torus,
            geometric_image.covariant_axes,
        )

    def bigness(self: Self) -> float:
        """
        Gives an idea of size for a filter, sparser filters are smaller while less sparse filters are larger

        returns:
            the bigness value
        """
        norms = self.norm().data
        numerator = 0.0
        for key in self.key_array():
            numerator += jnp.linalg.norm(key * norms[tuple(key)], ord=2)

        denominator = float(jnp.sum(norms))
        return numerator / denominator

    def nonempty_pixels(self: Self) -> jax.Array:
        """
        Get the nonempty pixels as a true/false array.

        returns:
            a true/false array of flattened shape (image_size,)
        """
        return nonempty_pixels(self.D, self.data)

    def nonempty_pixel_idxs(self: Self) -> jax.Array:
        """
        Get the centered idxs of nonempty pixels, ordered in the flattened image order.

        returns:
            Nonempty pixels idxs, shape (num_pixels,D)
        """
        idxs = pixel_idxs(self.image_shape())
        idxs_centered = idxs - ((jnp.array(self.image_shape()).reshape((-1, self.D)) - 1) / 2)

        return idxs_centered[self.nonempty_pixels()]

    def __lt__(self: Self, other: Self) -> bool:
        """
        Compare two GeometricFilters on "bigness". The resulting definition may be slightly
        different, but I think its a better definition. The order of comparisons is D, image shape
        in order, k, parity, distance of nonempty pixels from the center, and finally total pixel
        norms.

        args:
            other: the other GeometricFilter to compare to this one

        returns:
            returns self < other
        """
        if self.D != other.D:
            return self.D < other.D

        if self.image_shape() != other.image_shape():
            for N1, N2 in zip(self.image_shape(), other.image_shape()):
                if N1 != N2:
                    return N1 < N2

        if self.k != other.k:
            return self.k < other.k

        if self.parity != other.parity:
            return self.parity < other.parity

        self_sym_total = other_sym_total = 0
        self_antisym_total = other_antisym_total = 0
        self_trace_total = other_trace_total = 0
        if self.k == 2:  # works for D=2,3, and should work for higher D

            self_pixels = self.data.reshape((-1,) + (self.D,) * self.k)
            self_trace = jnp.einsum("...ii", self_pixels) / self.D
            self_trace_matrix = self_trace[:, None, None] * (jnp.eye(self.D)[None])
            self_antisym = (self_pixels - jnp.transpose(self_pixels, (0, 2, 1))) / 2
            self_sym = (self_pixels + jnp.transpose(self_pixels, (0, 2, 1))) / 2 - self_trace_matrix

            other_pixels = other.data.reshape((-1,) + (other.D,) * other.k)
            other_trace = jnp.einsum("...ii", other_pixels) / other.D
            other_trace_matrix = other_trace[:, None, None] * (jnp.eye(other.D)[None])
            other_antisym = (other_pixels - jnp.transpose(other_pixels, (0, 2, 1))) / 2
            other_sym = (
                other_pixels + jnp.transpose(other_pixels, (0, 2, 1))
            ) / 2 - other_trace_matrix

            # norm of elements along diagonal (except for last) and above
            sym_norm_f = jax.vmap(lambda x: jnp.linalg.norm(x[jnp.triu_indices(self.D)][:-1]))
            self_sym_total = float(jnp.sum(sym_norm_f(self_sym)))
            other_sym_total = float(jnp.sum(sym_norm_f(other_sym)))

            # norm of elements above the main diagonal
            antisym_norm_f = jax.vmap(lambda x: jnp.linalg.norm(x[jnp.triu_indices(self.D, 1)]))
            self_antisym_total = float(jnp.sum(antisym_norm_f(self_antisym)))
            other_antisym_total = float(jnp.sum(antisym_norm_f(other_antisym)))

            self_trace_total = float(jnp.sum(jnp.abs(self_trace)))
            other_trace_total = float(jnp.sum(jnp.abs(other_trace)))

            self_min_component = (
                int(self_trace_total != 0)
                + int(self_antisym_total != 0) * 10
                + int(self_sym_total != 0) * 100
            )
            other_min_component = (
                int(other_trace_total != 0)
                + int(other_antisym_total != 0) * 10
                + int(other_sym_total != 0) * 100
            )

            # check whether the filters are in the same irrep
            if self_min_component != other_min_component:
                return self_min_component < other_min_component

        self_nonempty_l1 = float(jnp.max(jnp.sum(jnp.abs(self.nonempty_pixel_idxs()), axis=1)))
        other_nonempty_l1 = float(jnp.max(jnp.sum(jnp.abs(other.nonempty_pixel_idxs()), axis=1)))

        self_nonempty_l2 = float(jnp.max(jnp.linalg.norm(self.nonempty_pixel_idxs(), axis=1)))
        other_nonempty_l2 = float(jnp.max(jnp.linalg.norm(other.nonempty_pixel_idxs(), axis=1)))

        # sort by l1 distance, then l2 distance
        if abs(self_nonempty_l1 - other_nonempty_l1) > TINY:
            return self_nonempty_l1 < other_nonempty_l1

        if abs(self_nonempty_l2 - other_nonempty_l2) > TINY:
            return self_nonempty_l2 < other_nonempty_l2

        if self.k == 2:
            if abs(self_sym_total - other_sym_total) > TINY:
                return self_sym_total < other_sym_total

            if abs(self_antisym_total - other_antisym_total) > TINY:
                return self_antisym_total < other_antisym_total

            if abs(self_trace_total - other_trace_total) > TINY:
                return self_trace_total < other_trace_total

        return float(jnp.sum(self.norm().data)) < float(jnp.sum(other.norm().data))

    def rectify(self: Self) -> Self:
        """
        Filters form an equivalence class up to multiplication by a scalar, so if its negative we want to flip the sign

        returns:
            a new GeometricImage that has been scaled
        """
        if self.k == 0:
            if jnp.sum(self.data) < 0:
                return self.times_scalar(-1)
        elif self.k == 1:
            if self.parity % 2 == 0:
                if (
                    jnp.sum(
                        jnp.einsum("...i,...i", self.key_array().reshape(self.shape()), self.data)
                    )
                    < 0
                ):
                    return self.times_scalar(-1)
            elif self.D == 2:
                if jnp.sum(jnp.cross(self.key_array().reshape(self.shape()), self.data)) < 0:
                    return self.times_scalar(-1)
        return self

    def plot(
        self: Self,
        ax: Optional[Any] = None,
        title: str = "",
        boxes: bool = True,
        fill: bool = True,
        symbols: bool = True,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        colorbar: bool = False,
        cmap: matplotlib.colors.Colormap | str | None = None,
        vector_scaling: float = 0.33,
    ) -> None:
        """
        Plot the geometric filter. Has different default vmin, vmax, vector_scalings than
        GeometricImage.

        args:
            ax: matplotlib.pyplot Axes to plot this geometric filter one
            title: title of the plot
            boxes: whether to plot boxes around each pixel
            fill: whether to fill the pixels with an appropriate color
            symbols: whether to fill the pixels with a symbol
            vmin: min value to plot, everything below this is cut off. If none, will use -3 for
                scalars and 0 otherwise.
            vmax: max value to plot, everything above this is cut off. If none, will use 3
            colorbar: whether to plot a colorbar
            vector_scaling: how much to scale the vectors
        """
        if self.k == 0:
            vmin = -3.0 if vmin is None else vmin
            vmax = 3.0 if vmax is None else vmax
        else:
            vmin = 0.0 if vmin is None else vmin
            vmax = 3.0 if vmax is None else vmax

        super().plot(ax, title, boxes, fill, symbols, vmin, vmax, colorbar, cmap, vector_scaling)
zeros(N: Union[int, tuple[int, ...]], k: int, parity: int, D: int, is_torus: Union[bool, tuple[bool]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> Self classmethod ¤

Zero constructor for GeometricImage.

Parameters:

Name Type Description Default
N Union[int, tuple[int, ...]]

length of all sides if an int, otherwise a tuple of the side lengths

required
k int

the order of the tensor in each pixel, i.e. 0 (scalar), 1 (vector), 2 (matrix), etc.

required
parity int

0 or 1, 0 is normal vectors, 1 is pseudovectors

required
D int

dimension of the image, and length of vectors or side length of matrices or tensors.

required
is_torus Union[bool, tuple[bool]]

whether the datablock is a torus, used for convolutions

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly with the coordinate change. False for typical vectors, true for gradients.

False

Returns:

Type Description
Self

constructed GeometricImage

Source code in ginjax/geometric/geometric_image.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@classmethod
def zeros(
    cls,
    N: Union[int, tuple[int, ...]],
    k: int,
    parity: int,
    D: int,
    is_torus: Union[bool, tuple[bool]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> Self:
    """
    Zero constructor for GeometricImage.

    args:
        N: length of all sides if an int, otherwise a tuple of the side lengths
        k: the order of the tensor in each pixel, i.e. 0 (scalar), 1 (vector), 2 (matrix), etc.
        parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
        D: dimension of the image, and length of vectors or side length of matrices or tensors.
        is_torus: whether the datablock is a torus, used for convolutions
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            with the coordinate change. False for typical vectors, true for gradients.

    returns:
        constructed GeometricImage
    """
    spatial_dims = N if isinstance(N, tuple) else (N,) * D
    assert len(spatial_dims) == D
    return cls(jnp.zeros(spatial_dims + (D,) * k), parity, D, is_torus, covariant_axes)
fill(N: Union[int, tuple[int, ...]], parity: int, D: int, fill: Union[jax.Array, float], is_torus: Union[bool, tuple[bool, ...]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> Self classmethod ¤

Fill constructor to construct a geometric image every pixel as fill

Parameters:

Name Type Description Default
N Union[int, tuple[int, ...]]

length of all sides if an int, otherwise a tuple of the side lengths

required
parity int

0 or 1, 0 is normal vectors, 1 is pseudovectors

required
D int

dimension of the image, and length of vectors or side length of matrices or tensors.

required
fill Union[Array, float]

tensor to fill the image with

required
is_torus Union[bool, tuple[bool, ...]]

whether the datablock is a torus, used for convolutions. Defaults to true.

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly with the coordinate change. False for typical vectors, true for gradients.

False

Returns:

Type Description
Self

Constructed GeometricImage

Source code in ginjax/geometric/geometric_image.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@classmethod
def fill(
    cls,
    N: Union[int, tuple[int, ...]],
    parity: int,
    D: int,
    fill: Union[jax.Array, float],
    is_torus: Union[bool, tuple[bool, ...]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> Self:
    """
    Fill constructor to construct a geometric image every pixel as fill

    args:
        N: length of all sides if an int, otherwise a tuple of the side lengths
        parity: 0 or 1, 0 is normal vectors, 1 is pseudovectors
        D: dimension of the image, and length of vectors or side length of matrices or tensors.
        fill: tensor to fill the image with
        is_torus: whether the datablock is a torus, used for convolutions. Defaults to true.
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            with the coordinate change. False for typical vectors, true for gradients.

    returns:
        Constructed GeometricImage
    """
    spatial_dims = N if isinstance(N, tuple) else (N,) * D
    assert len(spatial_dims) == D

    k = (
        len(fill.shape)
        if (isinstance(fill, jnp.ndarray) or isinstance(fill, np.ndarray))
        else 0
    )
    data = jnp.stack([fill for _ in range(np.multiply.reduce(spatial_dims))]).reshape(
        spatial_dims + (D,) * k
    )
    return cls(data, parity, D, is_torus, covariant_axes)
copy() -> Self ¤

Copy the geometric image.

Source code in ginjax/geometric/geometric_image.py
167
168
169
170
171
def copy(self: Self) -> Self:
    """
    Copy the geometric image.
    """
    return self.__class__(self.data, self.parity, self.D, self.is_torus, self.covariant_axes)
hash(indices: ArrayLike) -> tuple[jax.Array, ...] ¤

Converts an array of indices to their pixels on the torus by modding the indices with the spatial dimensions.

Parameters:

Name Type Description Default
indices ArrayLike

array of indices, shape (num_idx, D) to apply the remainder to

required

Returns:

Type Description
tuple[Array, ...]

the pixel indices as a d-tuple of jax arrays

Source code in ginjax/geometric/geometric_image.py
175
176
177
178
179
180
181
182
183
184
185
186
def hash(self: Self, indices: ArrayLike) -> tuple[jax.Array, ...]:
    """
    Converts an array of indices to their pixels on the torus by modding the indices with the
    spatial dimensions.

    args:
        indices: array of indices, shape (num_idx, D) to apply the remainder to

    returns:
        the pixel indices as a d-tuple of jax arrays
    """
    return hash(self.D, self.spatial_dims, indices)
__getitem__(key: Any) -> jax.Array ¤

Accessor for data values. Now you can do image[key] where k are indices or array slices and it will just work Note that JAX does not throw errors for indexing out of bounds

Parameters:

Name Type Description Default
key Any

JAX/numpy indexer, i.e. "0", "0,1,3", "4:, 2:3, 0" etc.

required

Returns:

Type Description
Array

data from the specified index or slice.

Source code in ginjax/geometric/geometric_image.py
188
189
190
191
192
193
194
195
196
197
198
199
def __getitem__(self: Self, key: Any) -> jax.Array:
    """
    Accessor for data values. Now you can do image[key] where k are indices or array slices and it will just work
    Note that JAX does not throw errors for indexing out of bounds

    args:
        key: JAX/numpy indexer, i.e. "0", "0,1,3", "4:, 2:3, 0" etc.

    returns:
        data from the specified index or slice.
    """
    return self.data[key]
__setitem__(key: Any, val: Any) -> Self ¤

Set the jax array data to the specified value. Jax arrays are immutable, so this reconstructs the data object with copying, and is potentially slow.

Parameters:

Name Type Description Default
key Any

index or slice to access data

required
val Any

value to set the data to

required

Returns:

Type Description
Self

the geometric image

Source code in ginjax/geometric/geometric_image.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def __setitem__(self: Self, key: Any, val: Any) -> Self:
    """
    Set the jax array data to the specified value. Jax arrays are immutable, so this
    reconstructs the data object with copying, and is potentially slow.

    args:
        key: index or slice to access data
        val: value to set the data to

    returns:
        the geometric image
    """
    self.data = self.data.at[key].set(val)
    return self
shape() -> tuple[int, ...] ¤

Return the full shape of the data block

Returns:

Type Description
tuple[int, ...]

The shape of the data block

Source code in ginjax/geometric/geometric_image.py
216
217
218
219
220
221
222
223
def shape(self: Self) -> tuple[int, ...]:
    """
    Return the full shape of the data block

    returns:
        The shape of the data block
    """
    return self.data.shape
image_shape(plus_Ns: Optional[tuple[int, ...]] = None) -> tuple[int, ...] ¤

Return the shape of the data block that is not the ktensor shape, but what comes before that.

Parameters:

Name Type Description Default
plus_Ns Optional[tuple[int, ...]]

d-length tuple, N to add to each spatial dim

None

Returns:

Type Description
tuple[int, ...]

the shape of the image, modified by plus_Ns

Source code in ginjax/geometric/geometric_image.py
225
226
227
228
229
230
231
232
233
234
235
236
def image_shape(self: Self, plus_Ns: Optional[tuple[int, ...]] = None) -> tuple[int, ...]:
    """
    Return the shape of the data block that is not the ktensor shape, but what comes before that.

    args:
        plus_Ns: d-length tuple, N to add to each spatial dim

    returns:
        the shape of the image, modified by plus_Ns
    """
    plus_Ns = (0,) * self.D if (plus_Ns is None) else plus_Ns
    return tuple(N + plus_N for N, plus_N in zip(self.spatial_dims, plus_Ns))
image_size() -> int ¤

Return the total number of pixels in the image.

Source code in ginjax/geometric/geometric_image.py
238
239
240
241
242
def image_size(self: Self) -> int:
    """
    Return the total number of pixels in the image.
    """
    return functools.reduce(lambda c, v: c * v, self.image_shape(), 1)
pixel_shape() -> tuple[int, ...] ¤

Return the shape of the data block that is the ktensor, aka the pixel of the image.

Returns:

Type Description
tuple[int, ...]

the shape of the pixel

Source code in ginjax/geometric/geometric_image.py
244
245
246
247
248
249
250
251
def pixel_shape(self: Self) -> tuple[int, ...]:
    """
    Return the shape of the data block that is the ktensor, aka the pixel of the image.

    returns:
        the shape of the pixel
    """
    return self.k * (self.D,)
pixel_size() -> int ¤

Get the size of the pixel shape, i.e. (D,D,D) = D**3

Returns:

Type Description
int

the size of the pixels

Source code in ginjax/geometric/geometric_image.py
253
254
255
256
257
258
259
260
def pixel_size(self: Self) -> int:
    """
    Get the size of the pixel shape, i.e. (D,D,D) = D**3

    returns:
        the size of the pixels
    """
    return self.D**self.k
__str__() -> str ¤

Returns:

Type Description
str

the string representation of the GeometricImage

Source code in ginjax/geometric/geometric_image.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def __str__(self: Self) -> str:
    """
    returns:
        the string representation of the GeometricImage
    """
    return "<{} object in D={} with spatial_dims={}, k={}, parity={}, is_torus={}, covariant_axes={}>".format(
        self.__class__,
        self.D,
        self.spatial_dims,
        self.k,
        self.parity,
        self.is_torus,
        self.covariant_axes,
    )
keys() -> Any ¤

Iterate over the keys of GeometricImage

Source code in ginjax/geometric/geometric_image.py
278
279
280
281
282
def keys(self: Self) -> Any:
    """
    Iterate over the keys of GeometricImage
    """
    return it.product(*list(range(N) for N in self.spatial_dims))
key_array() -> jax.Array ¤

Returns:

Type Description
Array

the pixel indices as a jax array

Source code in ginjax/geometric/geometric_image.py
284
285
286
287
288
289
290
def key_array(self: Self) -> jax.Array:
    """
    returns:
        the pixel indices as a jax array
    """
    # equivalent to the old pixels function
    return jnp.array([key for key in self.keys()], dtype=int)
pixels() -> Generator[jax.Array] ¤

Iterate over the pixels of GeometricImage.

Returns:

Type Description
Generator[Array]

a generator of the pixels

Source code in ginjax/geometric/geometric_image.py
292
293
294
295
296
297
298
299
300
def pixels(self: Self) -> Generator[jax.Array]:
    """
    Iterate over the pixels of GeometricImage.

    returns:
        a generator of the pixels
    """
    for key in self.keys():
        yield self[key]
items() -> Generator[tuple[Any, jax.Array]] ¤

Iterate over the key, pixel pairs of GeometricImage.

Returns:

Type Description
Generator[tuple[Any, Array]]

a generator of pairs of the pixel index and its pixel

Source code in ginjax/geometric/geometric_image.py
302
303
304
305
306
307
308
309
310
def items(self: Self) -> Generator[tuple[Any, jax.Array]]:
    """
    Iterate over the key, pixel pairs of GeometricImage.

    returns:
        a generator of pairs of the pixel index and its pixel
    """
    for key in self.keys():
        yield (key, self[key])
__eq__(other: object, rtol: float = TINY, atol: float = TINY) -> bool ¤

Equality operator, must have same shape, parity, and data within the TINY=1e-5 tolerance.

Parameters:

Name Type Description Default
other object

an object to compare to this GeometricImage

required
rtol float

relative tolerance, passed to jnp.allclose

TINY
atol float

absolute tolerance, passed to jnp.allclose

TINY

Returns:

Type Description
bool

true if they are equal, false otherwise

Source code in ginjax/geometric/geometric_image.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def __eq__(self: Self, other: object, rtol: float = TINY, atol: float = TINY) -> bool:
    """
    Equality operator, must have same shape, parity, and data within the TINY=1e-5 tolerance.

    args:
        other: an object to compare to this GeometricImage
        rtol: relative tolerance, passed to jnp.allclose
        atol: absolute tolerance, passed to jnp.allclose

    returns:
        true if they are equal, false otherwise
    """
    if isinstance(other, GeometricImage):
        return (
            self.D == other.D
            and self.spatial_dims == other.spatial_dims
            and self.k == other.k
            and self.parity == other.parity
            and self.is_torus == other.is_torus
            and self.covariant_axes == other.covariant_axes
            and self.data.shape == other.data.shape
            and bool(jnp.allclose(self.data, other.data, rtol, atol))
        )
    else:
        return False
__add__(other: Self) -> Self ¤

Addition operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

Parameters:

Name Type Description Default
other Self

other image to add the the first one

required

Returns:

Type Description
Self

a new GeometricImage that is the sum of this one and the other one

Source code in ginjax/geometric/geometric_image.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def __add__(self: Self, other: Self) -> Self:
    """
    Addition operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

    args:
        other: other image to add the the first one

    returns:
        a new GeometricImage that is the sum of this one and the other one
    """
    assert self.D == other.D
    assert self.spatial_dims == other.spatial_dims
    assert self.k == other.k
    assert self.parity == other.parity
    assert self.is_torus == other.is_torus
    assert self.covariant_axes == other.covariant_axes
    assert self.data.shape == other.data.shape
    return self.__class__(
        self.data + other.data, self.parity, self.D, self.is_torus, self.covariant_axes
    )
__sub__(other: Self) -> Self ¤

Subtraction operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

Parameters:

Name Type Description Default
other Self

other image to add the the first one

required

Returns:

Type Description
Self

a new GeometricImage that is the difference of this GeometricImage and the other one

Source code in ginjax/geometric/geometric_image.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def __sub__(self: Self, other: Self) -> Self:
    """
    Subtraction operator for GeometricImages. Both must be the same size and parity. Returns a new GeometricImage.

    args:
        other: other image to add the the first one

    returns:
        a new GeometricImage that is the difference of this GeometricImage and the other one
    """
    assert self.D == other.D
    assert self.spatial_dims == other.spatial_dims
    assert self.k == other.k
    assert self.parity == other.parity
    assert self.is_torus == other.is_torus
    assert self.covariant_axes == other.covariant_axes
    assert self.data.shape == other.data.shape
    return self.__class__(
        self.data - other.data, self.parity, self.D, self.is_torus, self.covariant_axes
    )
__mul__(other: Union[Self, float, int]) -> Self ¤

If other is a scalar, do scalar multiplication of the data. If it is another GeometricImage, do the tensor product at each pixel. Return the result as a new GeometricImage.

Parameters:

Name Type Description Default
other GeometricImage or number

scalar or image to multiply by

required

Returns:

Type Description
Self

a new GeometricImage that is the product of this GeometricImage with other

Source code in ginjax/geometric/geometric_image.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def __mul__(self: Self, other: Union[Self, float, int]) -> Self:
    """
    If other is a scalar, do scalar multiplication of the data. If it is another GeometricImage, do the tensor
    product at each pixel. Return the result as a new GeometricImage.

    args:
        other (GeometricImage or number): scalar or image to multiply by

    returns:
        a new GeometricImage that is the product of this GeometricImage with other
    """
    if isinstance(other, GeometricImage):
        assert self.D == other.D
        assert self.spatial_dims == other.spatial_dims
        assert self.is_torus == other.is_torus
        return self.__class__(
            mul(self.D, self.data, other.data),
            self.parity + other.parity,
            self.D,
            self.is_torus,
            self.covariant_axes + other.covariant_axes,
        )
    else:  # its an integer or a float, or something that can we can multiply a Jax array by (like a DeviceArray)
        return self.__class__(
            self.data * other, self.parity, self.D, self.is_torus, self.covariant_axes
        )
__rmul__(other: Union[Self, float, int]) -> Self ¤

If other is a scalar, multiply the data by the scalar. This is necessary for doing scalar * image, and it should only be called in that case.

Parameters:

Name Type Description Default
other GeometricImage or number

scalar or image to multiply by

required

Returns:

Type Description
Self

a new GeometricImage that is the product of this GeometricImage with other

Source code in ginjax/geometric/geometric_image.py
409
410
411
412
413
414
415
416
417
418
419
420
def __rmul__(self: Self, other: Union[Self, float, int]) -> Self:
    """
    If other is a scalar, multiply the data by the scalar. This is necessary for doing scalar * image, and it
    should only be called in that case.

    args:
        other (GeometricImage or number): scalar or image to multiply by

    returns:
        a new GeometricImage that is the product of this GeometricImage with other
    """
    return self * other
transpose(axes_permutation: Sequence[int]) -> Self ¤

Transposes the axes of the tensor, keeping the image axes in the front the same

Parameters:

Name Type Description Default
axes_permutation Sequence[int]

new axes order

required

Returns:

Type Description
Self

a new GeometricImage that has been transposed

Source code in ginjax/geometric/geometric_image.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def transpose(self: Self, axes_permutation: Sequence[int]) -> Self:
    """
    Transposes the axes of the tensor, keeping the image axes in the front the same

    args:
        axes_permutation: new axes order

    returns:
        a new GeometricImage that has been transposed
    """
    idx_shift = len(self.image_shape())
    new_indices = tuple(
        tuple(range(idx_shift)) + tuple(axis + idx_shift for axis in axes_permutation)
    )
    new_covariant_axes = tuple(self.covariant_axes[axis] for axis in axes_permutation)
    return self.__class__(
        jnp.transpose(self.data, new_indices),
        self.parity,
        self.D,
        self.is_torus,
        new_covariant_axes,
    )
convolve_with(filter_image: Self, stride: Union[int, tuple[int, ...]] = 1, padding: Optional[tuple[tuple[int, int]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1) -> Self ¤

See convolve for a description of this function.

Parameters:

Name Type Description Default
filter_image Self

the convolution filter, shape (out_c,in_c,spatial,tensor)

required
stride Union[int, tuple[int, ...]]

convolution stride, defaults to (1,)*self.D

1
padding Optional[tuple[tuple[int, int]]]

either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs, defaults to 'TORUS' if image.is_torus, else 'SAME'

None
lhs_dilation Optional[tuple[int, ...]]

amount of dilation to apply to image in each dimension D, also transposed conv

None
rhs_dilation Union[int, tuple[int, ...]]

amount of dilation to apply to filter in each dimension D, defaults to 1

1

Returns:

Type Description
Self

convolved_image of shape (batch,out_c,spatial,tensor)

Source code in ginjax/geometric/geometric_image.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@functools.partial(jax.jit, static_argnums=[2, 3, 4, 5])
def convolve_with(
    self: Self,
    filter_image: Self,
    stride: Union[int, tuple[int, ...]] = 1,
    padding: Optional[tuple[tuple[int, int]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
) -> Self:
    """
    See [convolve](functional_geometric_image.md#ginjax.geometric.functional_geometric_image.convolve)
    for a description of this function.

    args:
        filter_image: the convolution filter, shape (out_c,in_c,spatial,tensor)
        stride: convolution stride, defaults to (1,)*self.D
        padding: either 'TORUS','VALID', 'SAME', or D length tuple of (upper,lower) pairs,
            defaults to 'TORUS' if image.is_torus, else 'SAME'
        lhs_dilation: amount of dilation to apply to image in each dimension D, also transposed conv
        rhs_dilation: amount of dilation to apply to filter in each dimension D, defaults to 1

    returns:
        convolved_image of shape (batch,out_c,spatial,tensor)
    """
    convolved_array = convolve(
        self.D,
        self.data[None, None],  # add batch, in_channels axes
        filter_image.data[None, None],  # add out_channels, in_channels axes
        self.is_torus,
        stride,
        padding,
        lhs_dilation,
        rhs_dilation,
    )
    return self.__class__(
        convolved_array[0, 0],  # ignore batch, out_channels axes
        self.parity + filter_image.parity,
        self.D,
        self.is_torus,
        self.covariant_axes + filter_image.covariant_axes,
    )
max_pool(patch_len: int, use_norm: bool = True) -> Self ¤

Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined by the norm of the pixel when use_norm is True. Note that for scalars, this will be the absolute value of the pixel. If you want to use the max instead, set use_norm to False (requires scalar images).

Parameters:

Name Type Description Default
patch_len int

the side length of the patches, must evenly divide all spatial dims

required
use_norm bool

whether to use norm to calculate the max

True

Returns:

Type Description
Self

a new GeometricImage with the max pool applied

Source code in ginjax/geometric/geometric_image.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def max_pool(self: Self, patch_len: int, use_norm: bool = True) -> Self:
    """
    Perform a max pooling operation where the length of the side of each patch is patch_len. Max is determined
    by the norm of the pixel when use_norm is True. Note that for scalars, this will be the absolute value of
    the pixel. If you want to use the max instead, set use_norm to False (requires scalar images).

    args:
        patch_len: the side length of the patches, must evenly divide all spatial dims
        use_norm: whether to use norm to calculate the max

    returns:
        a new GeometricImage with the max pool applied
    """
    return self.__class__(
        max_pool(self.D, self.data, patch_len, use_norm),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
average_pool(patch_len: int) -> Self ¤

Perform a average pooling operation where the length of the side of each patch is patch_len. This is equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the filter, the stride length is patch_len, and the padding is 'VALID'.

Parameters:

Name Type Description Default
patch_len int

the side length of the patches, must evenly divide self.N

required

Returns:

Type Description
Self

a new GeometricImage with the average pool applied

Source code in ginjax/geometric/geometric_image.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
@functools.partial(jax.jit, static_argnums=1)
def average_pool(self: Self, patch_len: int) -> Self:
    """
    Perform a average pooling operation where the length of the side of each patch is patch_len. This is
    equivalent to doing a convolution where each element of the filter is 1 over the number of pixels in the
    filter, the stride length is patch_len, and the padding is 'VALID'.

    args:
        patch_len: the side length of the patches, must evenly divide self.N

    returns:
        a new GeometricImage with the average pool applied
    """
    return self.__class__(
        average_pool(self.D, self.data, patch_len),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
unpool(patch_len: int) -> Self ¤

Each pixel turns into a (patch_len,)*self.D patch of that pixel. Also called "Nearest Neighbor" unpooling.

Parameters:

Name Type Description Default
patch_len int

side length of the patch of our unpooled images

required

Returns:

Type Description
Self

a new GeometricImage with the unpool applied

Source code in ginjax/geometric/geometric_image.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
@functools.partial(jax.jit, static_argnums=1)
def unpool(self: Self, patch_len: int) -> Self:
    """
    Each pixel turns into a (patch_len,)*self.D patch of that pixel. Also called
    "Nearest Neighbor" unpooling.

    args:
        patch_len: side length of the patch of our unpooled images

    returns:
        a new GeometricImage with the unpool applied
    """
    grow_filter = GeometricImage(jnp.ones((patch_len,) * self.D), 0, self.D)
    return self.convolve_with(
        grow_filter,
        padding=((patch_len - 1,) * 2,) * self.D,
        lhs_dilation=(patch_len,) * self.D,
    )
times_scalar(scalar: float) -> Self ¤

Scale the data by a scalar, returning a new GeometricImage object. Alias of the multiplication operator.

Parameters:

Name Type Description Default
scalar float

number to scale everything by

required

Returns:

Type Description
Self

a new GeometricImage scaled by the scalar

Source code in ginjax/geometric/geometric_image.py
548
549
550
551
552
553
554
555
556
557
558
def times_scalar(self: Self, scalar: float) -> Self:
    """
    Scale the data by a scalar, returning a new GeometricImage object. Alias of the multiplication operator.

    args:
        scalar: number to scale everything by

    returns:
        a new GeometricImage scaled by the scalar
    """
    return self * scalar
norm() -> Self ¤

Calculate the norm pixel-wise. This becomes a scalar image.

Returns:

Type Description
Self

a new GeoemtricImage of all the pixels normed.

Source code in ginjax/geometric/geometric_image.py
560
561
562
563
564
565
566
567
568
@jax.jit
def norm(self: Self) -> Self:
    """
    Calculate the norm pixel-wise. This becomes a scalar image.

    returns:
        a new GeoemtricImage of all the pixels normed.
    """
    return self.__class__(norm(self.D, self.data), 0, self.D, self.is_torus)
normalize() -> Self ¤

Normalize so that the max norm of each pixel is 1, and all other tensors are scaled appropriately

Returns:

Type Description
Self

a new GeometricImage scaled by the max norm

Source code in ginjax/geometric/geometric_image.py
570
571
572
573
574
575
576
577
578
579
580
581
def normalize(self: Self) -> Self:
    """
    Normalize so that the max norm of each pixel is 1, and all other tensors are scaled appropriately

    returns:
        a new GeometricImage scaled by the max norm
    """
    max_norm = float(jnp.max(self.norm().data))
    if max_norm > TINY:
        return self.times_scalar(1.0 / max_norm)
    else:
        return self.times_scalar(1.0)
activation_function(function: Callable[[jnp.ndarray], jnp.ndarray]) -> Self ¤

Apply the specified activation function to the GeometricImage

Parameters:

Name Type Description Default
function Callable[[ndarray], ndarray]

the activation function

required

Returns:

Type Description
Self

a new GeometricImage with the activation function applied

Source code in ginjax/geometric/geometric_image.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
def activation_function(self: Self, function: Callable[[jnp.ndarray], jnp.ndarray]) -> Self:
    """
    Apply the specified activation function to the GeometricImage

    args:
        function: the activation function

    returns:
        a new GeometricImage with the activation function applied
    """
    assert (
        self.k == 0
    ), "Activation functions only implemented for k=0 tensors due to equivariance"
    return self.__class__(
        function(self.data), self.parity, self.D, self.is_torus, self.covariant_axes
    )
contract(i: int, j: int) -> Self ¤

Use einsum to perform a kronecker contraction on two dimensions of the tensor

Parameters:

Name Type Description Default
i int

first index of tensor

required
j int

second index of tensor

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def contract(self: Self, i: int, j: int) -> Self:
    """
    Use einsum to perform a kronecker contraction on two dimensions of the tensor

    args:
        i: first index of tensor
        j: second index of tensor

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= 2
    idx_shift = len(self.image_shape())

    first, second = min(i, j), max(i, j)
    axes_ls = self.covariant_axes
    new_covariant_axes = axes_ls[:first] + axes_ls[first + 1 : second] + axes_ls[second + 1 :]
    return self.__class__(
        multicontract(self.data, ((i, j),), idx_shift),
        self.parity,
        self.D,
        self.is_torus,
        new_covariant_axes,
    )
multicontract(indices: tuple[tuple[int, int], ...]) -> Self ¤

Use einsum to perform a kronecker contraction on two dimensions of the tensor

Parameters:

Name Type Description Default
indices tuple[tuple[int, int], ...]

indices to contract

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def multicontract(self: Self, indices: tuple[tuple[int, int], ...]) -> Self:
    """
    Use einsum to perform a kronecker contraction on two dimensions of the tensor

    args:
        indices: indices to contract

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= 2
    idx_shift = len(self.image_shape())
    sorted_idxs = sorted(list(sum(indices, ())))
    new_cov_axes = tuple(
        self.covariant_axes[prev + 1 : next]
        for prev, next in zip([-1] + sorted_idxs, sorted_idxs + [self.k])
    )
    return self.__class__(
        multicontract(self.data, indices, idx_shift),
        self.parity,
        self.D,
        self.is_torus,
        sum(new_cov_axes, ()),
    )
levi_civita_contract(indices: Union[tuple[int, ...], int]) -> Self ¤

Perform the Levi-Civita contraction. Outer product with the Levi-Civita Symbol, then perform D-1 contractions. Resulting image has k= self.k - self.D + 2

Parameters:

Name Type Description Default
indices Union[tuple[int, ...], int]

indices of tensor to perform contractions on

required

Returns:

Type Description
Self

a new GeometricImage contracted by those indices

Source code in ginjax/geometric/geometric_image.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def levi_civita_contract(self: Self, indices: Union[tuple[int, ...], int]) -> Self:
    """
    Perform the Levi-Civita contraction. Outer product with the Levi-Civita Symbol, then perform D-1 contractions.
    Resulting image has k= self.k - self.D + 2

    args:
        indices: indices of tensor to perform contractions on

    returns:
        a new GeometricImage contracted by those indices
    """
    assert self.k >= (
        self.D - 1
    )  # so we have enough indices to work on since we perform D-1 contractions
    if not isinstance(indices, tuple):
        indices = (indices,)
    assert len(indices) == self.D - 1

    levi_civita = LeviCivitaSymbol.get(self.D)
    outer = jnp.tensordot(self.data, levi_civita, axes=0)

    # make contraction index pairs with one of specified indices, and index (in order) from the levi_civita symbol
    idx_shift = len(self.image_shape())
    zipped_indices = tuple(
        (i + idx_shift, j + idx_shift)
        for i, j in zip(indices, range(self.k, self.k + len(indices)))
    )
    return self.__class__(
        multicontract(outer, zipped_indices),
        self.parity + 1,
        self.D,
        self.is_torus,
        self.covariant_axes[: self.k - self.D + 2],  # right length, but maybe wrong
    )
raise_lower(metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...], precision: Optional[jax.lax.Precision] = None) -> Self ¤

Raise or lower the axes of the tensor according the the metric tensor and axes.

Parameters:

Name Type Description Default
metric_tensor Self

the metric tensor g_ij, must be same spatial shape as this

required
metric_tensor_inv Self

the inverse metric tensor, g^ij. Must be same spatial shape as this

required
axes tuple[bool, ...]

desired covariant axes

required
precision Optional[Precision]

precision used for einsum

None

Returns:

Type Description
Self

new GeometricImage with correct axes

Source code in ginjax/geometric/geometric_image.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
def raise_lower(
    self: Self,
    metric_tensor: Self,
    metric_tensor_inv: Self,
    axes: tuple[bool, ...],
    precision: Optional[jax.lax.Precision] = None,
) -> Self:
    """
    Raise or lower the axes of the tensor according the the metric tensor and axes.

    args:
        metric_tensor: the metric tensor g_ij, must be same spatial shape as this
        metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
        axes: desired covariant axes
        precision: precision used for einsum

    returns:
        new GeometricImage with correct axes
    """
    return self.__class__(
        raise_lower(
            self.data,
            metric_tensor.data,
            metric_tensor_inv.data,
            self.covariant_axes,
            axes,
            precision,
        ),
        self.parity,
        self.D,
        self.is_torus,
        axes,
    )
raise_lower_precise(metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...]) -> Self ¤

Raise or lower the axes of the tensor according the the metric tensor and axes using the highest precision for einsum.

Parameters:

Name Type Description Default
metric_tensor Self

the metric tensor g_ij, must be same spatial shape as this

required
metric_tensor_inv Self

the inverse metric tensor, g^ij. Must be same spatial shape as this

required
axes tuple[bool, ...]

desired covariant axes

required

Returns:

Type Description
Self

new GeometricImage with correct axes

Source code in ginjax/geometric/geometric_image.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
def raise_lower_precise(
    self: Self, metric_tensor: Self, metric_tensor_inv: Self, axes: tuple[bool, ...]
) -> Self:
    """
    Raise or lower the axes of the tensor according the the metric tensor and axes using the
    highest precision for einsum.

    args:
        metric_tensor: the metric tensor g_ij, must be same spatial shape as this
        metric_tensor_inv: the inverse metric tensor, g^ij. Must be same spatial shape as this
        axes: desired covariant axes

    returns:
        new GeometricImage with correct axes
    """
    return self.raise_lower(metric_tensor, metric_tensor_inv, axes, jax.lax.Precision.HIGHEST)
times_group_element(gg: np.ndarray, precision: Optional[jax.lax.Precision] = None) -> Self ¤

Apply a group element of O(d) to the geometric image. First apply the action to the location of the pixels, then apply the action to the pixels themselves. The group element provided is the one that acts on contravariant axes, will be inverted to apply to covariant axes as well.

Parameters:

Name Type Description Default
gg ndarray

a DxD matrix that rotates a contravariant vector gg @ v

required
precision Optional[Precision]

precision level for einsum, for equality tests use Precision.HIGHEST

None

Returns:

Type Description
Self

a new GeometricImage that has been rotated

Source code in ginjax/geometric/geometric_image.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def times_group_element(
    self: Self,
    gg: np.ndarray,
    precision: Optional[jax.lax.Precision] = None,
) -> Self:
    """
    Apply a group element of O(d) to the geometric image. First apply the action to the location
    of the pixels, then apply the action to the pixels themselves. The group element provided
    is the one that acts on contravariant axes, will be inverted to apply to covariant axes as
    well.

    args:
        gg: a DxD matrix that rotates a contravariant vector gg @ v
        precision: precision level for einsum, for equality tests use Precision.HIGHEST

    returns:
        a new GeometricImage that has been rotated
    """
    assert self.k < 14
    assert gg.shape == (self.D, self.D)

    return self.__class__(
        times_group_element(self.D, self.data, self.parity, gg, self.covariant_axes, precision),
        self.parity,
        self.D,
        rotate_is_torus(self.is_torus, gg),
        self.covariant_axes,
    )
times_gg_precise(gg: np.ndarray) -> Self ¤

Apply a group element of O(d) to the geometric image using the highest precision einsum. See times_group_element for more details.

Parameters:

Name Type Description Default
gg ndarray

a DxD matrix that rotates a contravariant vector gg @ v

required

Returns:

Type Description
Self

a new GeometricImage that has been rotated

Source code in ginjax/geometric/geometric_image.py
765
766
767
768
769
770
771
772
773
774
775
776
def times_gg_precise(self: Self, gg: np.ndarray) -> Self:
    """
    Apply a group element of O(d) to the geometric image using the highest precision einsum.
    See times_group_element for more details.

    args:
        gg: a DxD matrix that rotates a contravariant vector gg @ v

    returns:
        a new GeometricImage that has been rotated
    """
    return self.times_group_element(gg, jax.lax.Precision.HIGHEST)
translate(tau: jax.Array) -> Self ¤

Translate the image on the torus. Translations on the data matrix are ij ordering. For example, a translation of [1,-1] moves the down one row, then to the left one column.

Parameters:

Name Type Description Default
tau Array

the translation vector, length D

required

Returns:

Type Description
Self

a geometric image that has been translated

Source code in ginjax/geometric/geometric_image.py
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
def translate(self: Self, tau: jax.Array) -> Self:
    """
    Translate the image on the torus. Translations on the data matrix are ij ordering. For
    example, a translation of [1,-1] moves the down one row, then to the left one column.

    args:
        tau: the translation vector, length D

    returns:
        a geometric image that has been translated
    """
    assert (
        self.is_torus == (True,) * self.D
    ), f"GeometricImage::translate: Image must be a torus, but got {self.is_torus}"
    assert (
        len(tau) == self.D
    ), f"GeometricImage::translate: {self.D}D image received {len(tau)}D translation"

    return self.__class__(
        translate(self.D, self.data, tau, 0),
        self.parity,
        self.D,
        self.is_torus,
        self.covariant_axes,
    )
tree_flatten() -> tuple[tuple[jnp.ndarray], dict[str, Union[int, Union[bool, tuple[bool]]]]] ¤

Helper function to define GeometricImage as a pytree so jax.jit handles it correctly. Children and aux_data must contain all the variables that are passed in init()

Source code in ginjax/geometric/geometric_image.py
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
def tree_flatten(
    self: Self,
) -> tuple[tuple[jnp.ndarray], dict[str, Union[int, Union[bool, tuple[bool]]]]]:
    """
    Helper function to define GeometricImage as a pytree so jax.jit handles it correctly. Children and aux_data
    must contain all the variables that are passed in __init__()
    """
    children = (self.data,)  # arrays / dynamic values
    aux_data = {
        "D": self.D,
        "parity": self.parity,
        "is_torus": self.is_torus,
        "covariant_axes": self.covariant_axes,
    }  # static values
    return (children, aux_data)
tree_unflatten(aux_data, children) classmethod ¤

Helper function to define GeometricImage as a pytree so jax.jit handles it correctly.

Source code in ginjax/geometric/geometric_image.py
934
935
936
937
938
939
@classmethod
def tree_unflatten(cls, aux_data, children):
    """
    Helper function to define GeometricImage as a pytree so jax.jit handles it correctly.
    """
    return cls(*children, **aux_data)
__init__(data: jnp.ndarray, parity: int, D: int, is_torus: Union[bool, tuple[bool, ...]] = True, covariant_axes: Union[bool, tuple[bool, ...]] = False) -> None ¤

Constructor for GeometricFilter.

Parameters:

Name Type Description Default
data ndarray

the image data of shape (spatial,tensor). Spatial dimensions must be square, odd

required
parity int

parity of tensor, 0 for vector, 1 for pseudo-vector

required
D int

dimension of the image

required
is_torus Union[bool, tuple[bool, ...]]

which dimensions are toroidal

True
covariant_axes Union[bool, tuple[bool, ...]]

which of k tensor axes are covariant, i.e. they rotate covariantly of the coordinate change. False for typical vectors, true for gradients. You can only take a contraction between 1 covariant axis and 1 contravariant axis, but for a flat Euclidean metric these vectors are numerically identical, so we will not enforce this.

False
Source code in ginjax/geometric/geometric_image.py
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
def __init__(
    self: Self,
    data: jnp.ndarray,
    parity: int,
    D: int,
    is_torus: Union[bool, tuple[bool, ...]] = True,
    covariant_axes: Union[bool, tuple[bool, ...]] = False,
) -> None:
    """
    Constructor for GeometricFilter.

    args:
        data: the image data of shape (spatial,tensor). Spatial dimensions must be square, odd
        parity: parity of tensor, 0 for vector, 1 for pseudo-vector
        D: dimension of the image
        is_torus: which dimensions are toroidal
        covariant_axes: which of k tensor axes are covariant, i.e. they rotate covariantly
            of the coordinate change. False for typical vectors, true for gradients. You
            can only take a contraction between 1 covariant axis and 1 contravariant axis,
            but for a flat Euclidean metric these vectors are numerically identical, so we will
            not enforce this.
    """
    super().__init__(data, parity, D, is_torus, covariant_axes)
    assert (
        self.spatial_dims == (self.spatial_dims[0],) * self.D
    ), "GeometricFilter: Filters must be square."  # I could remove  this requirement in the future
from_image(geometric_image: GeometricImage) -> Self classmethod ¤

Constructor that copies a GeometricImage and returns a GeometricFilter

Parameters:

Name Type Description Default
geometric_image GeometricImage

the GeometricImage to copy

required

Returns:

Type Description
Self

a new GeometricFilter copy

Source code in ginjax/geometric/geometric_image.py
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
@classmethod
def from_image(cls, geometric_image: GeometricImage) -> Self:
    """
    Constructor that copies a GeometricImage and returns a GeometricFilter

    args:
        geometric_image: the GeometricImage to copy

    returns:
        a new GeometricFilter copy
    """
    return cls(
        geometric_image.data,
        geometric_image.parity,
        geometric_image.D,
        geometric_image.is_torus,
        geometric_image.covariant_axes,
    )
bigness() -> float ¤

Gives an idea of size for a filter, sparser filters are smaller while less sparse filters are larger

Returns:

Type Description
float

the bigness value

Source code in ginjax/geometric/geometric_image.py
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def bigness(self: Self) -> float:
    """
    Gives an idea of size for a filter, sparser filters are smaller while less sparse filters are larger

    returns:
        the bigness value
    """
    norms = self.norm().data
    numerator = 0.0
    for key in self.key_array():
        numerator += jnp.linalg.norm(key * norms[tuple(key)], ord=2)

    denominator = float(jnp.sum(norms))
    return numerator / denominator
nonempty_pixels() -> jax.Array ¤

Get the nonempty pixels as a true/false array.

Returns:

Type Description
Array

a true/false array of flattened shape (image_size,)

Source code in ginjax/geometric/geometric_image.py
1010
1011
1012
1013
1014
1015
1016
1017
def nonempty_pixels(self: Self) -> jax.Array:
    """
    Get the nonempty pixels as a true/false array.

    returns:
        a true/false array of flattened shape (image_size,)
    """
    return nonempty_pixels(self.D, self.data)
nonempty_pixel_idxs() -> jax.Array ¤

Get the centered idxs of nonempty pixels, ordered in the flattened image order.

Returns:

Type Description
Array

Nonempty pixels idxs, shape (num_pixels,D)

Source code in ginjax/geometric/geometric_image.py
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def nonempty_pixel_idxs(self: Self) -> jax.Array:
    """
    Get the centered idxs of nonempty pixels, ordered in the flattened image order.

    returns:
        Nonempty pixels idxs, shape (num_pixels,D)
    """
    idxs = pixel_idxs(self.image_shape())
    idxs_centered = idxs - ((jnp.array(self.image_shape()).reshape((-1, self.D)) - 1) / 2)

    return idxs_centered[self.nonempty_pixels()]
__lt__(other: Self) -> bool ¤

Compare two GeometricFilters on "bigness". The resulting definition may be slightly different, but I think its a better definition. The order of comparisons is D, image shape in order, k, parity, distance of nonempty pixels from the center, and finally total pixel norms.

Parameters:

Name Type Description Default
other Self

the other GeometricFilter to compare to this one

required

Returns:

Type Description
bool

returns self < other

Source code in ginjax/geometric/geometric_image.py
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
def __lt__(self: Self, other: Self) -> bool:
    """
    Compare two GeometricFilters on "bigness". The resulting definition may be slightly
    different, but I think its a better definition. The order of comparisons is D, image shape
    in order, k, parity, distance of nonempty pixels from the center, and finally total pixel
    norms.

    args:
        other: the other GeometricFilter to compare to this one

    returns:
        returns self < other
    """
    if self.D != other.D:
        return self.D < other.D

    if self.image_shape() != other.image_shape():
        for N1, N2 in zip(self.image_shape(), other.image_shape()):
            if N1 != N2:
                return N1 < N2

    if self.k != other.k:
        return self.k < other.k

    if self.parity != other.parity:
        return self.parity < other.parity

    self_sym_total = other_sym_total = 0
    self_antisym_total = other_antisym_total = 0
    self_trace_total = other_trace_total = 0
    if self.k == 2:  # works for D=2,3, and should work for higher D

        self_pixels = self.data.reshape((-1,) + (self.D,) * self.k)
        self_trace = jnp.einsum("...ii", self_pixels) / self.D
        self_trace_matrix = self_trace[:, None, None] * (jnp.eye(self.D)[None])
        self_antisym = (self_pixels - jnp.transpose(self_pixels, (0, 2, 1))) / 2
        self_sym = (self_pixels + jnp.transpose(self_pixels, (0, 2, 1))) / 2 - self_trace_matrix

        other_pixels = other.data.reshape((-1,) + (other.D,) * other.k)
        other_trace = jnp.einsum("...ii", other_pixels) / other.D
        other_trace_matrix = other_trace[:, None, None] * (jnp.eye(other.D)[None])
        other_antisym = (other_pixels - jnp.transpose(other_pixels, (0, 2, 1))) / 2
        other_sym = (
            other_pixels + jnp.transpose(other_pixels, (0, 2, 1))
        ) / 2 - other_trace_matrix

        # norm of elements along diagonal (except for last) and above
        sym_norm_f = jax.vmap(lambda x: jnp.linalg.norm(x[jnp.triu_indices(self.D)][:-1]))
        self_sym_total = float(jnp.sum(sym_norm_f(self_sym)))
        other_sym_total = float(jnp.sum(sym_norm_f(other_sym)))

        # norm of elements above the main diagonal
        antisym_norm_f = jax.vmap(lambda x: jnp.linalg.norm(x[jnp.triu_indices(self.D, 1)]))
        self_antisym_total = float(jnp.sum(antisym_norm_f(self_antisym)))
        other_antisym_total = float(jnp.sum(antisym_norm_f(other_antisym)))

        self_trace_total = float(jnp.sum(jnp.abs(self_trace)))
        other_trace_total = float(jnp.sum(jnp.abs(other_trace)))

        self_min_component = (
            int(self_trace_total != 0)
            + int(self_antisym_total != 0) * 10
            + int(self_sym_total != 0) * 100
        )
        other_min_component = (
            int(other_trace_total != 0)
            + int(other_antisym_total != 0) * 10
            + int(other_sym_total != 0) * 100
        )

        # check whether the filters are in the same irrep
        if self_min_component != other_min_component:
            return self_min_component < other_min_component

    self_nonempty_l1 = float(jnp.max(jnp.sum(jnp.abs(self.nonempty_pixel_idxs()), axis=1)))
    other_nonempty_l1 = float(jnp.max(jnp.sum(jnp.abs(other.nonempty_pixel_idxs()), axis=1)))

    self_nonempty_l2 = float(jnp.max(jnp.linalg.norm(self.nonempty_pixel_idxs(), axis=1)))
    other_nonempty_l2 = float(jnp.max(jnp.linalg.norm(other.nonempty_pixel_idxs(), axis=1)))

    # sort by l1 distance, then l2 distance
    if abs(self_nonempty_l1 - other_nonempty_l1) > TINY:
        return self_nonempty_l1 < other_nonempty_l1

    if abs(self_nonempty_l2 - other_nonempty_l2) > TINY:
        return self_nonempty_l2 < other_nonempty_l2

    if self.k == 2:
        if abs(self_sym_total - other_sym_total) > TINY:
            return self_sym_total < other_sym_total

        if abs(self_antisym_total - other_antisym_total) > TINY:
            return self_antisym_total < other_antisym_total

        if abs(self_trace_total - other_trace_total) > TINY:
            return self_trace_total < other_trace_total

    return float(jnp.sum(self.norm().data)) < float(jnp.sum(other.norm().data))
rectify() -> Self ¤

Filters form an equivalence class up to multiplication by a scalar, so if its negative we want to flip the sign

Returns:

Type Description
Self

a new GeometricImage that has been scaled

Source code in ginjax/geometric/geometric_image.py
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def rectify(self: Self) -> Self:
    """
    Filters form an equivalence class up to multiplication by a scalar, so if its negative we want to flip the sign

    returns:
        a new GeometricImage that has been scaled
    """
    if self.k == 0:
        if jnp.sum(self.data) < 0:
            return self.times_scalar(-1)
    elif self.k == 1:
        if self.parity % 2 == 0:
            if (
                jnp.sum(
                    jnp.einsum("...i,...i", self.key_array().reshape(self.shape()), self.data)
                )
                < 0
            ):
                return self.times_scalar(-1)
        elif self.D == 2:
            if jnp.sum(jnp.cross(self.key_array().reshape(self.shape()), self.data)) < 0:
                return self.times_scalar(-1)
    return self
plot(ax: Optional[Any] = None, title: str = '', boxes: bool = True, fill: bool = True, symbols: bool = True, vmin: Optional[float] = None, vmax: Optional[float] = None, colorbar: bool = False, cmap: matplotlib.colors.Colormap | str | None = None, vector_scaling: float = 0.33) -> None ¤

Plot the geometric filter. Has different default vmin, vmax, vector_scalings than GeometricImage.

Parameters:

Name Type Description Default
ax Optional[Any]

matplotlib.pyplot Axes to plot this geometric filter one

None
title str

title of the plot

''
boxes bool

whether to plot boxes around each pixel

True
fill bool

whether to fill the pixels with an appropriate color

True
symbols bool

whether to fill the pixels with a symbol

True
vmin Optional[float]

min value to plot, everything below this is cut off. If none, will use -3 for scalars and 0 otherwise.

None
vmax Optional[float]

max value to plot, everything above this is cut off. If none, will use 3

None
colorbar bool

whether to plot a colorbar

False
vector_scaling float

how much to scale the vectors

0.33
Source code in ginjax/geometric/geometric_image.py
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
def plot(
    self: Self,
    ax: Optional[Any] = None,
    title: str = "",
    boxes: bool = True,
    fill: bool = True,
    symbols: bool = True,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    colorbar: bool = False,
    cmap: matplotlib.colors.Colormap | str | None = None,
    vector_scaling: float = 0.33,
) -> None:
    """
    Plot the geometric filter. Has different default vmin, vmax, vector_scalings than
    GeometricImage.

    args:
        ax: matplotlib.pyplot Axes to plot this geometric filter one
        title: title of the plot
        boxes: whether to plot boxes around each pixel
        fill: whether to fill the pixels with an appropriate color
        symbols: whether to fill the pixels with a symbol
        vmin: min value to plot, everything below this is cut off. If none, will use -3 for
            scalars and 0 otherwise.
        vmax: max value to plot, everything above this is cut off. If none, will use 3
        colorbar: whether to plot a colorbar
        vector_scaling: how much to scale the vectors
    """
    if self.k == 0:
        vmin = -3.0 if vmin is None else vmin
        vmax = 3.0 if vmax is None else vmax
    else:
        vmin = 0.0 if vmin is None else vmin
        vmax = 3.0 if vmax is None else vmax

    super().plot(ax, title, boxes, fill, symbols, vmin, vmax, colorbar, cmap, vector_scaling)

get_kronecker_delta_image(N: int, D: int) -> GeometricImage ¤

Get an image with a Kronecker Delta in every pixel.

Parameters:

Name Type Description Default
N int

the sidelength of the image

required
D int

the dimension of the image

required

Returns:

Type Description
GeometricImage

a new GeometricImage.

Source code in ginjax/geometric/geometric_image.py
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
def get_kronecker_delta_image(N: int, D: int) -> GeometricImage:
    """
    Get an image with a Kronecker Delta in every pixel.

    args:
        N: the sidelength of the image
        D: the dimension of the image

    returns:
        a new GeometricImage.
    """
    return GeometricImage(
        jnp.stack([KroneckerDeltaSymbol.get(D, 2) for _ in range(N**D)]).reshape(
            ((N,) * D + (D,) * 2)
        ),
        0,
        D,
        covariant_axes=(True, False),  # could also be False,True, its symmetric.
    )

get_metric_inverse(metric_tensor: GeometricImage, eps: float = TINY) -> GeometricImage ¤

Given a metric tensor image, invert the matrix in each pixel to get the inverse metric tensor. This converts g_ij -> g^ij.

Parameters:

Name Type Description Default
metric_tensor GeometricImage

the current metric tensor image

required
eps float

to prevent dividing by zero, add eps to the denominator.

TINY

Returns:

Type Description
GeometricImage

the inverse metric tensor image

Source code in ginjax/geometric/geometric_image.py
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
def get_metric_inverse(metric_tensor: GeometricImage, eps: float = TINY) -> GeometricImage:
    """
    Given a metric tensor image, invert the matrix in each pixel to get the inverse metric tensor.
    This converts g_ij -> g^ij.

    args:
        metric_tensor: the current metric tensor image
        eps: to prevent dividing by zero, add eps to the denominator.

    returns:
        the inverse metric tensor image
    """
    D = metric_tensor.D
    # (..., D, D) -> (..., D), (..., D, D)
    eigvals, eigvecs = jnp.linalg.eigh(metric_tensor.data, symmetrize_input=False)

    eigvals_inv = 1.0 / (eigvals + eps)  # (...,D)
    S_diag = jax.vmap(jnp.diag)(eigvals_inv.reshape((-1, D))).reshape(eigvals.shape + (D,))
    # do U S U^T, and multiply each vector in centered_img by the resulting matrix

    inverse_data = jnp.einsum(
        "...ij,...jk,...kl->...il", eigvecs, S_diag, jnp.moveaxis(eigvecs, -1, -2)
    )
    return GeometricImage(inverse_data, 0, D, metric_tensor.is_torus, (False, False))