Skip to content

Models

ginjax.models ¤

MultiImageModule ¤

Bases: Module

A model that takes as input and output a MultiImage and aux_data. The models that inherit from this class will also take and return aux_data even if they do not use it.

Source code in ginjax/models.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class MultiImageModule(eqx.Module):
    """
    A model that takes as input and output a MultiImage and aux_data. The models that inherit from
    this class will also take and return aux_data even if they do not use it.
    """

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Layer callable

        args:
            x: the input
            aux_data: data used for stuff like batch norm

        returns:
            the output MultiImage and aux_data
        """
        return x, aux_data
__call__(x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Layer callable

Parameters:

Name Type Description Default
x MultiImage

the input

required
aux_data Optional[State]

data used for stuff like batch norm

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage and aux_data

Source code in ginjax/models.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def __call__(
    self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Layer callable

    args:
        x: the input
        aux_data: data used for stuff like batch norm

    returns:
        the output MultiImage and aux_data
    """
    return x, aux_data

AnyDimensionalModel ¤

Bases: MultiImageModule

A MultiImage model that implements a convertD function that can convert to work on different dimensional input. This also provides the helper functions transfer_weights to get this done.

Source code in ginjax/models.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
class AnyDimensionalModel(MultiImageModule):
    """
    A MultiImage model that implements a convertD function that can convert to work on different
    dimensional input. This also provides the helper functions transfer_weights to get this done.
    """

    @staticmethod
    def _extend_weights(
        old_weights_block: jax.Array,
        filter_key: tuple[tuple[bool, ...], int],
        old_filters: geom.MultiImage,
        new_filters: geom.MultiImage,
    ) -> jax.Array:
        """
        Given a set of weights associated with old_filters, extend the weights to new_filters.
        For offcenter weights (associated with a set of filters that has a center filter) and for
        balanced weights (associated with a set of filters which has no center filter), the new
        weights are the average of the old weights.

        args:
            old_weights_block: the old weights, shape (out_c,in_c,n_filters)
            filter_key: the key for the filters we are extending weights for
            old_filters: the old filters
            new_filters: the new filters

        returns:
            the weights associated with the new filters
        """
        k = len(filter_key[0])
        if k not in {0, 1, 2}:
            raise NotImplementedError()

        n_add_unbalanced = 0
        n_add_balanced = 0
        center_weight = None
        offcenter_old_weights = None
        balanced_weights = None
        if k == 0:
            center_weight = old_weights_block[:, :, :1]
            offcenter_old_weights = old_weights_block[:, :, 1:]
            n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
        elif k == 1:
            balanced_weights = old_weights_block
            n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
        elif k == 2:
            # for k==2, the first set of filters follows the scalar filters
            assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
            n_old_unbalanced = len(old_filters[(), 0])
            center_weight = old_weights_block[:, :, :1]
            offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
            n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

            balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
            # gap between new filters and (old filters plus the additional unbalanced filter)
            n_add_balanced = len(new_filters[filter_key]) - (
                len(old_filters[filter_key]) + n_add_unbalanced
            )

        assert n_add_unbalanced >= 0
        assert n_add_balanced >= 0

        new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
        if center_weight is not None and offcenter_old_weights is not None:
            # TODO: check what happens when n_add_unbalanced = 0
            additional_weights = jnp.full(
                old_weights_block.shape[:2] + (n_add_unbalanced,),
                jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
            )

            new_unbalanced_weights = jnp.concatenate(
                [center_weight, offcenter_old_weights, additional_weights], axis=2
            )

        new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
        if balanced_weights is not None:
            assert balanced_weights is not None
            additional_weights = jnp.full(
                old_weights_block.shape[:2] + (n_add_balanced,),
                jnp.mean(balanced_weights, axis=2, keepdims=True),
            )

            new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

        return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)

    @staticmethod
    def volume_rescale_weights(
        old_filter_triple: tuple[jax.Array, jax.Array, int],
        new_filter_triple: tuple[jax.Array, jax.Array, int],
        verbose: bool = False,
    ) -> jax.Array:
        """
        Rescale the weights so that the sum of the weights times the filters add up to the same
        value for the old filters and the new filters (which are likely a higher dimension).

        args:
            old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the old filters, and the old dimension
            new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the new filters, and the new dimension
            verbose: whether to print the old weights and ratios

        return:
            jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
        """
        old_filters, old_weights, old_D = old_filter_triple
        new_filters, new_weights, new_D = new_filter_triple

        # both are (out_c,in_c)
        old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
        new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

        # Dont rescale filters that always sum to 0.
        # (n_filters,tensor)
        spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
        # (n_filters,)
        spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
        nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

        # (out_c,in_c)
        ratios = old_weights_sum / (new_weights_sum + geom.TINY)
        # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
        ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

        if verbose:
            print("old weights", old_weights.shape, old_weights)
            print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

        return new_weights * ratios

    @staticmethod
    def compat_flex_rescale_weights(
        old_filter_triple: tuple[jax.Array, jax.Array, int],
        new_filter_triple: tuple[jax.Array, jax.Array, int],
        verbose: bool = False,
    ) -> jax.Array:
        """
        Do compatibility rescaling, now with one extra free parameter. For now this is only defined
        for sidelength 3 filters for D=1 to D=2.

        args:
            old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the old filters, and the old dimension
            new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the new filters, and the new dimension
            verbose: whether to print the old weights and ratios

        return:
            jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
        """
        old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
        new_filters, new_weights, new_D = new_filter_triple
        k = old_filters.ndim - (1 + old_D)
        assert k == new_filters.ndim - (
            1 + new_D
        ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

        D_increase = new_D - old_D
        assert D_increase == 1

        if (
            old_filters.shape[1 : 1 + old_D] == (3,) * old_D
            and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
        ):
            if old_D == 1 and new_D == 2:
                assert old_weights.shape[2] == 2  # should be 2 filters
                ratio = 1 / 3

                alpha_prime = jnp.stack(
                    [
                        old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                        (1 - 2 * ratio) * old_weights[..., 1],
                        ratio * old_weights[..., 1],
                    ],
                    axis=-1,
                )
            elif old_D == 2 and new_D == 3:
                # need to get first 4 new_weights from first 3 old_weights

                z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

                alpha_prime = jnp.stack(
                    [
                        old_weights[..., 0]
                        - 2 * old_weights[..., 1]
                        + 4 * old_weights[..., 2]
                        - 8 * z,
                        old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                        old_weights[..., 2] - 2 * z,
                        z,
                    ],
                    axis=-1,
                )

                # filters are in flipped order for some reason
                symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
                along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

                alpha_prime = jnp.concatenate(
                    [alpha_prime, symmetric_traceless, along_trace], axis=-1
                )
            else:
                raise ValueError()
        elif (
            old_filters.shape[1 : 1 + old_D] == (2,) * old_D
            and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
        ):
            alpha_prime = old_weights / (2**D_increase)
        else:
            raise ValueError()

        # TODO: I could check that the condition holds?

        return alpha_prime

    @staticmethod
    def compatibility_norm_rescale_weights(
        old_filter_triple: tuple[jax.Array, jax.Array, int],
        new_filter_triple: tuple[jax.Array, jax.Array, int],
        verbose: bool = False,
    ) -> jax.Array:
        """
        Rescale the weight coefficients so that they are compatible with the particular embedding.
        This algorithm has an implicit assumption that we are using orthoplex filters.

        WARNING: This is the old version which works on the norms of the tensors.

        args:
            old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the old filters, and the old dimension
            new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the new filters, and the new dimension
            verbose: whether to print the old weights and ratios

        return:
            jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
        """
        old_filters, old_weights, old_D = old_filter_triple
        new_filters, new_weights, new_D = new_filter_triple

        # Convert filters to the norm of the filters. This assumes 2 things:
        # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
        # 2. the sign of the filters are positive
        old_filters = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )
        new_filters = jnp.linalg.norm(
            new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # assert the filters are already in ascending order by number of pixels.
        # So for orthoplex, this means innermost to outermost
        filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
        assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

        D_increase = new_D - old_D
        assert D_increase > 0

        # first, reduce the filters to the nonempty pixel filters.
        nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
            new_filters.shape[: 1 + new_D]
        )
        # (n_filters,old_spatial)
        collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

        # (n_filters,old_spatial)
        collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
        # (n_filters,old_spatial)

        # (out_c,in_c,spatial)
        old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
        for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

            # get the outermost pixel of collapsed filter i
            # (old_spatial_size,) true/falses whether the pixel is nonempty
            nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
            farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

            # with current weight for filter i and collapsed sum of updated_weights,
            # calculate new weight to equal old weight
            updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
            # (out_c,in_c,n_filters,old_spatial)
            scaled_collapsed_ff = get_scaled_filters(
                old_D, collapsed_ff, jnp.array(updated_weights)
            )
            # (out_c,in_c,old_spatial)
            collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
            # (out_c,in_c)
            collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
                :, :, farthest_pixel_idx
            ]
            # assume that old_weights_val = new_weights_val. The old weight and new weight are
            # the same at this point, otherwise filter value could be different, but it wont be
            # for normalize and gaussian at least.
            old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
                :, :, farthest_pixel_idx
            ]
            # this should really be new_ff_val, assume they are equal, see above
            old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

            # set updated weights
            updated_weights[:, :, i] = (
                -(collapsed_val - old_weights_val) + old_weights_val
            ) / old_norm_ff_val

        updated_weights = jnp.array(updated_weights)

        # now we check that we did it right
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
        # (out_c,in_c,old_spatial)
        scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

        # (n_filters,old_spatial)
        old_norm_ff = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
            axis=-1,
        )

        # (out_c,in_c,n_filters,old_spatial)
        old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
        # (out_c,in_c,old_spatial)
        old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

        diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
        diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

        assert jnp.allclose(
            scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
        ), diff_message

        if verbose:
            print("new weights:", new_weights)
            print("updated weights:", updated_weights)

        return updated_weights

    @staticmethod
    def compatibility_rescale_weights(
        old_filter_triple: tuple[jax.Array, jax.Array, int],
        new_filter_triple: tuple[jax.Array, jax.Array, int],
        verbose: bool = False,
    ) -> jax.Array:
        """
        Rescale the weight coefficients so that they are compatible with the particular embedding.
        This algorithm has an implicit assumption that we are using orthoplex filters. This
        implements Algorithm 1: Orthoplex filter weight scaling.

        args:
            old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the old filters, and the old dimension
            new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
                the new filters, and the new dimension
            verbose: whether to print the old weights and ratios

        return:
            jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
        """
        old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
        new_filters, new_weights, new_D = new_filter_triple
        k = old_filters.ndim - (1 + old_D)
        assert k == new_filters.ndim - (
            1 + new_D
        ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

        D_increase = new_D - old_D
        assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

        # old/new_filters shape (n_filters,spatial,tensor)

        # we have filters ell=0,1,...,L
        # same number of filters
        assert len(old_filters) == len(
            new_filters
        ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
        L = len(old_filters) - 1
        L_plus = len(old_filters)  # more useful for iterating

        new_filters_proj_tensors = (
            new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
        )

        # currently special case N=2 because its so different
        if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
            assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
            assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

            alpha_prime = old_weights / (2**D_increase)

        else:  # filters are odd, and in particular 2L + 1 square
            # largest filter goes up to the border
            assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
            assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

            # (n_filters,new_spatial)
            new_filters_proj_norm = jnp.linalg.norm(
                new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
            )

            # (n_filters,old_spatial)
            old_filters_norm = jnp.linalg.norm(
                old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
            )

            # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
            alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
            for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
                j_d_centered = (z,) + (0,) * (old_D - 1)
                j_dplus_centered = (z,) + (0,) * (new_D - 1)

                j_d = tuple(x + L for x in j_d_centered)
                j_dplus = tuple(x + L for x in j_dplus_centered)

                # (out_c,in_c,n_filters,new_spatial)
                scaled_new_filters = (
                    alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
                )
                # sum over filters, spatial dims (out_c,in_c,old_spatial)
                # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
                collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

                # alpha_prime = (alpha * C_z - sum) / (C'_z)
                alpha_prime[:, :, z] = (
                    old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
                ) / new_filters_proj_norm[z, *j_dplus]

            alpha_prime = jnp.array(alpha_prime)

        # now we check that we did it right
        # (out_c,in_c,n_filters,new_spatial,proj_tensor)
        scaled_new_filters = (
            alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
        )
        # (out_c,in_c,old_spatial,proj_tensor)
        collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

        # (out_c,in_c,n_filters,old_spatial,tensor)
        scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
        # (out_c,in_c,old_spatial,tensor)
        scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

        diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
        diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

        assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

        if verbose:
            print("old weights:", old_weights)
            print("updated weights:", alpha_prime)

        return alpha_prime

    @staticmethod
    def _transfer_conv_weights(
        weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
        old_filters: geom.MultiImage,
        new_filters: geom.MultiImage,
        rescale: geom.Rescaling,
        verbose: bool = False,
    ) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
        """
        Transfer the conv weights from old filters to new filters of possibly a different dimension.
        If rescale is true, then scale the weights so that the sum of the filter basis of a particular
        order scaled by the weights is equal for the old filters and the new.

        args:
            weights: a weights dictionary from a layers.ConvContract layer
            old_filters: the old filters that the weights came from
            new_filters: the new filters that we will be using the weights for
            rescale: type of rescaling to perform on the weights
            verbose: print the ratio of the squared sum of filters new/old after transfering the
                weights, default to False.

        returns:
            a new weights dictionary
        """
        new_weights = {}

        for (in_k, in_p), in_weights in weights.items():
            new_weights[(in_k, in_p)] = {}
            for (out_k, out_p), old_weights_block in in_weights.items():
                filter_k = in_k + out_k
                filter_key = (filter_k, (in_p + out_p) % 2)

                new_weights_block = AnyDimensionalModel._extend_weights(
                    old_weights_block, filter_key, old_filters, new_filters
                )

                old_filter_block = old_filters[filter_key]
                new_filter_block = new_filters[filter_key]

                if rescale is geom.Rescaling.VOLUME:
                    pos_weights = AnyDimensionalModel.volume_rescale_weights(
                        (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                        (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                        verbose,
                    )
                    neg_weights = AnyDimensionalModel.volume_rescale_weights(
                        (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                        (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                        verbose,
                    )
                    scaled_weights_block = pos_weights + neg_weights
                elif rescale is geom.Rescaling.COMPATIBILITY:
                    # Dont rescale filters that always sum to 0.
                    # (n_filters,tensor)
                    spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                    # (n_filters,)
                    spatial_sum_norm = jnp.linalg.norm(
                        spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                    )
                    nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                    updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                        (
                            old_filter_block[nonzero_mask],
                            old_weights_block[:, :, nonzero_mask],
                            old_filters.D,
                        ),
                        (
                            new_filter_block[nonzero_mask],
                            new_weights_block[:, :, nonzero_mask],
                            new_filters.D,
                        ),
                        verbose,
                    )

                    scaled_weights_block = new_weights_block
                    scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                        updated_weights_block
                    )
                elif rescale is geom.Rescaling.COMPAT_FLEX:
                    scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                        (old_filter_block, old_weights_block, old_filters.D),
                        (new_filter_block, new_weights_block, new_filters.D),
                        verbose,
                    )
                else:
                    scaled_weights_block = new_weights_block

                new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

        return new_weights

    def transfer_weights(
        self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
    ) -> Self:
        """
        Transfer the weights and biases from an old model to a new model. This allows converting
        between dimensions as well. This works by copying all jax arrays from the old model to the new
        model, then resetting the new models conv filters to the new conv filters, then doing any
        conv filter related weight scaling.

        In the future, it may make sense for the updates to be defined on the individual layers, and
        then the tree_at recursively calls those functions.

        args:
            old_model: the old model
            new_model: the new model
            old_conv_filters: the convolution filters used in the old model
            conv_filters: the convolution filters to use in the new model, can have different D
            rescale: type of rescaling to perform on the weights
            verbose: print the ratio of the squared sum of filters new/old after transfering the
                weights, default to False.

        returns:
            a new model with the old weights except conv weights which are adjusted, and new filters
        """
        # get the new filters
        is_conv = lambda n: isinstance(n, layers.ConvContract)
        get_filters = lambda m: [
            x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
        ]
        new_filters = get_filters(new_model)

        # now replace all jax arrays
        get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
        new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

        # now reset the proper conv filters
        new_model = eqx.tree_at(get_filters, new_model, new_filters)

        # now set the proper weights
        get_conv_weights = lambda m: [
            x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
        ]
        conv_weights = get_conv_weights(self)
        new_weights = [
            AnyDimensionalModel._transfer_conv_weights(
                weight, old_filter, new_filter, rescale, verbose
            )
            for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
        ]
        new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

        return new_model

    def convertD(
        self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
    ) -> Self:
        """
        Placeholder function, must be overwritten by the inheriting class.

        Construct a new model with filters in a higher dimension. This only works for equivariant
        models.

        args:
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: type of rescaling to perform on the weights
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        raise NotImplementedError(
            f"AnyDimensionalModel::convertD: derived class {self.__class__} does not implement convertD."
        )
__call__(x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Layer callable

Parameters:

Name Type Description Default
x MultiImage

the input

required
aux_data Optional[State]

data used for stuff like batch norm

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage and aux_data

Source code in ginjax/models.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def __call__(
    self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Layer callable

    args:
        x: the input
        aux_data: data used for stuff like batch norm

    returns:
        the output MultiImage and aux_data
    """
    return x, aux_data
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Placeholder function, must be overwritten by the inheriting class.

Construct a new model with filters in a higher dimension. This only works for equivariant models.

Parameters:

Name Type Description Default
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

type of rescaling to perform on the weights

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
def convertD(
    self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
) -> Self:
    """
    Placeholder function, must be overwritten by the inheriting class.

    Construct a new model with filters in a higher dimension. This only works for equivariant
    models.

    args:
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: type of rescaling to perform on the weights
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    raise NotImplementedError(
        f"AnyDimensionalModel::convertD: derived class {self.__class__} does not implement convertD."
    )

ConvBlock ¤

Bases: MultiImageModule

A convolution block consisting of a convolution, a nonlinearity, and a GroupNorm/BatchNorm. Can be equivariant or not, in typical order or in preactivation order.

Source code in ginjax/models.py
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
class ConvBlock(MultiImageModule):
    """
    A convolution block consisting of a convolution, a nonlinearity, and a GroupNorm/BatchNorm.
    Can be equivariant or not, in typical order or in preactivation order.
    """

    conv: layers.ConvContract | layers.LayerWrapper
    group_norm: layers.GroupNorm | layers.LayerWrapper | None
    batch_norm: layers.LayerWrapperAux | None
    nonlinearity: layers.VectorNeuronNonlinear | layers.LayerWrapper | Callable

    D: int = eqx.field(static=True)
    equivariant: bool = eqx.field(static=True)
    use_batch_norm: bool = eqx.field(static=True)
    use_group_norm: bool = eqx.field(static=True)
    preactivation_order: bool = eqx.field(static=True)

    def __init__(
        self: Self,
        D: int,
        input_keys: geom.Signature,
        output_keys: geom.Signature,
        use_bias: Union[bool, str] = "auto",
        activation_f: Optional[Union[Callable, str]] = jax.nn.gelu,
        equivariant: bool = True,
        conv_filters: Optional[geom.MultiImage] = None,
        kernel_size: Optional[Union[int, Sequence[int]]] = None,
        use_group_norm: bool = False,
        use_batch_norm: bool = False,
        preactivation_order: bool = False,
        key: Any = None,
        **conv_kwargs: Any,
    ) -> None:
        """
        Constructor for ConvBlock

        args:
            D: the dimension of the space
            input_keys: MultiImage Signature of input
            output_keys: MultiImage Signature of output
            use_bias: whether to use a bias
            activation_f: the type of activation function
            equivariant: whether it is equivariant
            conv_filters: the invariant filters if it is equivariant
            kernel_size: sidelength(s) of the kernel if not equivariant
            use_group_norm: whether to use GroupNorm
            use_batch_norm: whether to use BatchNorm, can only be for non-equivariant
            preactivation_order: whether to use preactivation order
            key: jax.random key
            conv_kwargs: further key word args that will be passed to the convolution
        """
        self.D = D
        self.equivariant = equivariant
        self.use_group_norm = use_group_norm
        self.use_batch_norm = use_batch_norm
        self.preactivation_order = preactivation_order

        subkey1, subkey2 = random.split(key)
        self.conv = make_conv(
            self.D,
            input_keys,
            output_keys,
            use_bias,
            equivariant,
            conv_filters,
            kernel_size,
            key=subkey1,
            **conv_kwargs,
        )

        if use_group_norm:
            if self.equivariant:
                self.group_norm = layers.LayerNorm(output_keys, self.D)
            else:
                self.group_norm = layers.LayerWrapper(
                    eqx.nn.GroupNorm(1, output_keys[0][1]), output_keys
                )
        else:
            self.group_norm = None

        if use_batch_norm:
            self.batch_norm = layers.LayerWrapperAux(
                eqx.nn.BatchNorm(output_keys[0][1], axis_name=["pmap_batch", "batch"]), output_keys
            )
        else:
            self.batch_norm = None

        self.nonlinearity = handle_activation(
            activation_f, self.equivariant, output_keys, self.D, subkey2
        )

    def __call__(
        self: Self, x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Layer callable

        args:
            x: the input
            batch_stats: data for batch norm

        returns:
            the output MultiImage and batch stats
        """
        if self.preactivation_order:
            if self.use_group_norm:
                assert self.group_norm is not None
                x = self.group_norm(x)
            elif self.use_batch_norm:
                assert self.batch_norm is not None
                x, batch_stats = self.batch_norm(x, batch_stats)

            x = self.nonlinearity(x)
            x = self.conv(x)
        else:
            x = self.conv(x)
            if self.use_group_norm:
                assert self.group_norm is not None
                x = self.group_norm(x)
            elif self.use_batch_norm:
                assert self.batch_norm is not None
                x, batch_stats = self.batch_norm(x, batch_stats)

            x = self.nonlinearity(x)

        return x, batch_stats
__init__(D: int, input_keys: geom.Signature, output_keys: geom.Signature, use_bias: Union[bool, str] = 'auto', activation_f: Optional[Union[Callable, str]] = jax.nn.gelu, equivariant: bool = True, conv_filters: Optional[geom.MultiImage] = None, kernel_size: Optional[Union[int, Sequence[int]]] = None, use_group_norm: bool = False, use_batch_norm: bool = False, preactivation_order: bool = False, key: Any = None, **conv_kwargs: Any) -> None ¤

Constructor for ConvBlock

Parameters:

Name Type Description Default
D int

the dimension of the space

required
input_keys Signature

MultiImage Signature of input

required
output_keys Signature

MultiImage Signature of output

required
use_bias Union[bool, str]

whether to use a bias

'auto'
activation_f Optional[Union[Callable, str]]

the type of activation function

gelu
equivariant bool

whether it is equivariant

True
conv_filters Optional[MultiImage]

the invariant filters if it is equivariant

None
kernel_size Optional[Union[int, Sequence[int]]]

sidelength(s) of the kernel if not equivariant

None
use_group_norm bool

whether to use GroupNorm

False
use_batch_norm bool

whether to use BatchNorm, can only be for non-equivariant

False
preactivation_order bool

whether to use preactivation order

False
key Any

jax.random key

None
conv_kwargs Any

further key word args that will be passed to the convolution

{}
Source code in ginjax/models.py
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
def __init__(
    self: Self,
    D: int,
    input_keys: geom.Signature,
    output_keys: geom.Signature,
    use_bias: Union[bool, str] = "auto",
    activation_f: Optional[Union[Callable, str]] = jax.nn.gelu,
    equivariant: bool = True,
    conv_filters: Optional[geom.MultiImage] = None,
    kernel_size: Optional[Union[int, Sequence[int]]] = None,
    use_group_norm: bool = False,
    use_batch_norm: bool = False,
    preactivation_order: bool = False,
    key: Any = None,
    **conv_kwargs: Any,
) -> None:
    """
    Constructor for ConvBlock

    args:
        D: the dimension of the space
        input_keys: MultiImage Signature of input
        output_keys: MultiImage Signature of output
        use_bias: whether to use a bias
        activation_f: the type of activation function
        equivariant: whether it is equivariant
        conv_filters: the invariant filters if it is equivariant
        kernel_size: sidelength(s) of the kernel if not equivariant
        use_group_norm: whether to use GroupNorm
        use_batch_norm: whether to use BatchNorm, can only be for non-equivariant
        preactivation_order: whether to use preactivation order
        key: jax.random key
        conv_kwargs: further key word args that will be passed to the convolution
    """
    self.D = D
    self.equivariant = equivariant
    self.use_group_norm = use_group_norm
    self.use_batch_norm = use_batch_norm
    self.preactivation_order = preactivation_order

    subkey1, subkey2 = random.split(key)
    self.conv = make_conv(
        self.D,
        input_keys,
        output_keys,
        use_bias,
        equivariant,
        conv_filters,
        kernel_size,
        key=subkey1,
        **conv_kwargs,
    )

    if use_group_norm:
        if self.equivariant:
            self.group_norm = layers.LayerNorm(output_keys, self.D)
        else:
            self.group_norm = layers.LayerWrapper(
                eqx.nn.GroupNorm(1, output_keys[0][1]), output_keys
            )
    else:
        self.group_norm = None

    if use_batch_norm:
        self.batch_norm = layers.LayerWrapperAux(
            eqx.nn.BatchNorm(output_keys[0][1], axis_name=["pmap_batch", "batch"]), output_keys
        )
    else:
        self.batch_norm = None

    self.nonlinearity = handle_activation(
        activation_f, self.equivariant, output_keys, self.D, subkey2
    )
__call__(x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Layer callable

Parameters:

Name Type Description Default
x MultiImage

the input

required
batch_stats Optional[State]

data for batch norm

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage and batch stats

Source code in ginjax/models.py
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
def __call__(
    self: Self, x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Layer callable

    args:
        x: the input
        batch_stats: data for batch norm

    returns:
        the output MultiImage and batch stats
    """
    if self.preactivation_order:
        if self.use_group_norm:
            assert self.group_norm is not None
            x = self.group_norm(x)
        elif self.use_batch_norm:
            assert self.batch_norm is not None
            x, batch_stats = self.batch_norm(x, batch_stats)

        x = self.nonlinearity(x)
        x = self.conv(x)
    else:
        x = self.conv(x)
        if self.use_group_norm:
            assert self.group_norm is not None
            x = self.group_norm(x)
        elif self.use_batch_norm:
            assert self.batch_norm is not None
            x, batch_stats = self.batch_norm(x, batch_stats)

        x = self.nonlinearity(x)

    return x, batch_stats

UNet ¤

Bases: AnyDimensionalModel

Implementation of the UNet: https://arxiv.org/abs/1505.04597. This model defaults to the equivariant version, but can also be the non-equivariant version.

Source code in ginjax/models.py
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
class UNet(AnyDimensionalModel):
    """
    Implementation of the UNet: https://arxiv.org/abs/1505.04597.
    This model defaults to the equivariant version, but can also be the non-equivariant version.
    """

    embedding: list[ConvBlock]
    downsample_blocks: list[tuple[layers.MaxNormPool, list[ConvBlock]]]
    upsample_blocks: list[tuple[layers.ConvContract | layers.LayerWrapper, list[ConvBlock]]]
    decode: layers.ConvContract | layers.LayerWrapper

    D: int = eqx.field(static=True)
    equivariant: bool = eqx.field(static=True)
    use_bias: bool | str = eqx.field(static=True)
    activation_f: Callable | str | None = eqx.field(static=True)
    use_group_norm: bool = eqx.field(static=True)
    use_batch_norm: bool = eqx.field(static=True)
    input_keys: geom.Signature = eqx.field(static=True)
    output_keys: geom.Signature = eqx.field(static=True)
    mid_keys: geom.Signature = eqx.field(static=True)
    padding_mode: str = eqx.field(static=True)

    def __init__(
        self: Self,
        D: int,
        input_keys: geom.Signature,
        output_keys: geom.Signature,
        depth: int,
        num_downsamples: int = 4,
        num_conv: int = 2,
        use_bias: Union[bool, str] = "auto",
        activation_f: Callable | str | None = jax.nn.gelu,
        equivariant: bool = True,
        conv_filters: Optional[geom.MultiImage] = None,
        upsample_filters: Optional[geom.MultiImage] = None,
        kernel_size: Optional[Union[int, Sequence[int]]] = None,
        use_group_norm: bool = False,
        use_batch_norm: bool = False,
        mid_keys: Optional[geom.Signature] = None,
        padding_mode: str = "ZEROS",
        key: Any = None,
    ) -> None:
        """
        Constructor for the UNet.

        args:
            D: the dimension of the space
            input_keys: the MultiImage Signature for the input
            output_keys: the MultiImage Signature for the output
            depth: the number of channels at the highest level of the unet. This is overwritten if
                mid_keys is provided
            num_downsamples: number of convolution blocks followed by a max pool
            num_conv: number of convolutions per level
            use_bias: whether to use a bias
            activation_f: the activation function
            equivariant: whether to be equivariant
            conv_filters: the invariant filters for the equivariant version
            kernel_size: sidelength(s) for the non-equivariant version
            use_group_norm: whether to use GroupNorm
            use_batch_norm: whether to use the BatchNorm, only for non-equivariant version
            mid_keys: types of images and number of channels for the mid layers, as a baseline
            padding_mode: used for non-equivariant models, padding mode to pass to convolutions
            key: jax.random key
        """
        assert num_conv > 0
        assert key is not None

        self.input_keys = input_keys
        self.output_keys = output_keys
        if equivariant:
            if mid_keys is None:
                mid_keys = geom.signature_union(input_keys, output_keys, depth)

            assert not use_batch_norm, "UNet::init Batch Norm cannot be used with equivariant model"
        else:
            if mid_keys is None:
                mid_keys = geom.Signature(((((), 0), depth),))

            # use these keys along the way, then for the final output use self.output_keys
            input_keys_size = sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)
            input_keys = geom.Signature(((((), 0), input_keys_size),))
            output_key_size = sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)
            output_keys = geom.Signature(((((), 0), output_key_size),))

        self.D = D
        self.equivariant = equivariant
        self.use_bias = use_bias
        self.activation_f = activation_f
        self.use_group_norm = use_group_norm
        self.use_batch_norm = use_batch_norm
        self.mid_keys = mid_keys
        self.padding_mode = padding_mode

        # embedding layers
        self.embedding = []
        for conv_idx in range(num_conv):
            in_keys = input_keys if conv_idx == 0 else mid_keys
            key, subkey = random.split(key)
            self.embedding.append(
                ConvBlock(
                    self.D,
                    in_keys,
                    mid_keys,
                    use_bias,
                    activation_f,
                    equivariant,
                    conv_filters,
                    kernel_size,
                    use_group_norm,
                    use_batch_norm,
                    padding_mode=padding_mode,
                    key=subkey,
                )
            )

        self.downsample_blocks = []
        for downsample in range(1, num_downsamples + 1):
            down_layers = (layers.MaxNormPool(2, equivariant), [])

            for conv_idx in range(num_conv):
                out_keys = geom.Signature(
                    tuple((k_p, _depth * (2**downsample)) for k_p, _depth in mid_keys)
                )
                if conv_idx == 0:
                    in_keys = geom.Signature(
                        tuple((k_p, _depth * (2 ** (downsample - 1))) for k_p, _depth in mid_keys)
                    )
                else:
                    in_keys = out_keys

                key, subkey = random.split(key)
                down_layers[1].append(
                    ConvBlock(
                        self.D,
                        in_keys,
                        out_keys,
                        use_bias,
                        activation_f,
                        equivariant,
                        conv_filters,
                        kernel_size,
                        use_group_norm,
                        use_batch_norm,
                        padding_mode=padding_mode,
                        key=subkey,
                    )
                )

            self.downsample_blocks.append(down_layers)

        self.upsample_blocks = []
        for upsample in reversed(range(num_downsamples)):
            in_keys = geom.Signature(
                tuple((k_p, _depth * (2 ** (upsample + 1))) for k_p, _depth in mid_keys)
            )
            out_keys = geom.Signature(
                tuple((k_p, _depth * (2**upsample)) for k_p, _depth in mid_keys)
            )
            key, subkey = random.split(key)
            # perform the transposed convolution. For non-equivariant, padding and stride should
            # instead be the padding and stride for the forward direction convolution.
            if equivariant:
                padding = ((1, 1),) * self.D
                stride = (1,) * self.D
                upsample_kernel_size = None  # ignored for equivariant
            else:
                padding = "VALID"
                stride = (2,) * self.D
                upsample_kernel_size = (2,) * self.D  # kernel size of the downsample

            up_layers = (
                make_conv(
                    self.D,
                    in_keys,
                    out_keys,
                    use_bias,
                    equivariant,
                    upsample_filters,
                    upsample_kernel_size,
                    stride,
                    padding,
                    (2,) * self.D,  # lhs_dilation
                    padding_mode=padding_mode,
                    key=subkey,
                ),
                [],
            )

            for conv_idx in range(num_conv):
                out_keys = geom.Signature(
                    tuple((k_p, _depth * (2**upsample)) for k_p, _depth in mid_keys)
                )
                if conv_idx == 0:  # due to adding the residual layer back, in_c is doubled again
                    in_keys = geom.Signature(
                        tuple((k_p, _depth * (2 ** (upsample + 1))) for k_p, _depth in mid_keys)
                    )
                else:
                    in_keys = out_keys

                key, subkey = random.split(key)
                up_layers[1].append(
                    ConvBlock(
                        self.D,
                        in_keys,
                        out_keys,
                        use_bias,
                        activation_f,
                        equivariant,
                        conv_filters,
                        kernel_size,
                        use_group_norm,
                        use_batch_norm,
                        padding_mode=padding_mode,
                        key=subkey,
                    )
                )

            self.upsample_blocks.append(up_layers)

        key, subkey = random.split(key)

        self.decode = make_conv(
            self.D,
            mid_keys,
            output_keys,
            use_bias,
            equivariant,
            conv_filters,
            kernel_size,
            padding_mode=padding_mode,
            key=subkey,
        )

    def convertD(
        self: Self,
        conv_filters: geom.MultiImage,
        rescale: geom.Rescaling,
        key: jax.Array,
        **kwargs,
    ) -> Self:
        """
        Construct a new model with filters in a higher dimension. This only works for equivariant
        models.

        args:
            old_conv_filters: the current conv filters for the model
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: whether to force the sum of the filters in the new dimension to be equal
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        assert self.equivariant
        assert "upsample_filters" in kwargs
        new_model = self.__class__(
            conv_filters.D,
            self.input_keys,
            self.output_keys,
            0,  # ignored since mid_keys is provided
            len(self.downsample_blocks),
            len(self.embedding),
            self.use_bias,
            self.activation_f,
            self.equivariant,
            conv_filters,
            kwargs["upsample_filters"],
            0,  # ignored for equivariant model
            self.use_group_norm,
            self.use_batch_norm,
            self.mid_keys,
            self.padding_mode,
            key,
        )

        return self.transfer_weights(new_model, rescale)

    def __call__(
        self: Self, x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Callable function for UNet

        args:
            x: the input MultiImage
            batch_stats: batch stats for BatchNorm if present

        returns:
            the output MultiImage and batch_stats
        """
        if not self.equivariant:
            x = x.to_scalar_multi_image()

        for layer in self.embedding:
            x, batch_stats = layer(x, batch_stats)

        residual_multi_images = []
        for max_pool_layer, conv_blocks in self.downsample_blocks:
            residual_multi_images.append(x)
            x = max_pool_layer(x)
            for layer in conv_blocks:
                x, batch_stats = layer(x, batch_stats)

        for (upsample_layer, conv_blocks), residual_multi_image in zip(
            self.upsample_blocks, reversed(residual_multi_images)
        ):
            upsample_x = upsample_layer(x)
            x = upsample_x.concat(residual_multi_image)
            for layer in conv_blocks:
                x, batch_stats = layer(x, batch_stats)

        x = self.decode(x)
        if self.equivariant:
            out = x
        else:
            out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

        return out, batch_stats
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
__init__(D: int, input_keys: geom.Signature, output_keys: geom.Signature, depth: int, num_downsamples: int = 4, num_conv: int = 2, use_bias: Union[bool, str] = 'auto', activation_f: Callable | str | None = jax.nn.gelu, equivariant: bool = True, conv_filters: Optional[geom.MultiImage] = None, upsample_filters: Optional[geom.MultiImage] = None, kernel_size: Optional[Union[int, Sequence[int]]] = None, use_group_norm: bool = False, use_batch_norm: bool = False, mid_keys: Optional[geom.Signature] = None, padding_mode: str = 'ZEROS', key: Any = None) -> None ¤

Constructor for the UNet.

Parameters:

Name Type Description Default
D int

the dimension of the space

required
input_keys Signature

the MultiImage Signature for the input

required
output_keys Signature

the MultiImage Signature for the output

required
depth int

the number of channels at the highest level of the unet. This is overwritten if mid_keys is provided

required
num_downsamples int

number of convolution blocks followed by a max pool

4
num_conv int

number of convolutions per level

2
use_bias Union[bool, str]

whether to use a bias

'auto'
activation_f Callable | str | None

the activation function

gelu
equivariant bool

whether to be equivariant

True
conv_filters Optional[MultiImage]

the invariant filters for the equivariant version

None
kernel_size Optional[Union[int, Sequence[int]]]

sidelength(s) for the non-equivariant version

None
use_group_norm bool

whether to use GroupNorm

False
use_batch_norm bool

whether to use the BatchNorm, only for non-equivariant version

False
mid_keys Optional[Signature]

types of images and number of channels for the mid layers, as a baseline

None
padding_mode str

used for non-equivariant models, padding mode to pass to convolutions

'ZEROS'
key Any

jax.random key

None
Source code in ginjax/models.py
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
def __init__(
    self: Self,
    D: int,
    input_keys: geom.Signature,
    output_keys: geom.Signature,
    depth: int,
    num_downsamples: int = 4,
    num_conv: int = 2,
    use_bias: Union[bool, str] = "auto",
    activation_f: Callable | str | None = jax.nn.gelu,
    equivariant: bool = True,
    conv_filters: Optional[geom.MultiImage] = None,
    upsample_filters: Optional[geom.MultiImage] = None,
    kernel_size: Optional[Union[int, Sequence[int]]] = None,
    use_group_norm: bool = False,
    use_batch_norm: bool = False,
    mid_keys: Optional[geom.Signature] = None,
    padding_mode: str = "ZEROS",
    key: Any = None,
) -> None:
    """
    Constructor for the UNet.

    args:
        D: the dimension of the space
        input_keys: the MultiImage Signature for the input
        output_keys: the MultiImage Signature for the output
        depth: the number of channels at the highest level of the unet. This is overwritten if
            mid_keys is provided
        num_downsamples: number of convolution blocks followed by a max pool
        num_conv: number of convolutions per level
        use_bias: whether to use a bias
        activation_f: the activation function
        equivariant: whether to be equivariant
        conv_filters: the invariant filters for the equivariant version
        kernel_size: sidelength(s) for the non-equivariant version
        use_group_norm: whether to use GroupNorm
        use_batch_norm: whether to use the BatchNorm, only for non-equivariant version
        mid_keys: types of images and number of channels for the mid layers, as a baseline
        padding_mode: used for non-equivariant models, padding mode to pass to convolutions
        key: jax.random key
    """
    assert num_conv > 0
    assert key is not None

    self.input_keys = input_keys
    self.output_keys = output_keys
    if equivariant:
        if mid_keys is None:
            mid_keys = geom.signature_union(input_keys, output_keys, depth)

        assert not use_batch_norm, "UNet::init Batch Norm cannot be used with equivariant model"
    else:
        if mid_keys is None:
            mid_keys = geom.Signature(((((), 0), depth),))

        # use these keys along the way, then for the final output use self.output_keys
        input_keys_size = sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)
        input_keys = geom.Signature(((((), 0), input_keys_size),))
        output_key_size = sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)
        output_keys = geom.Signature(((((), 0), output_key_size),))

    self.D = D
    self.equivariant = equivariant
    self.use_bias = use_bias
    self.activation_f = activation_f
    self.use_group_norm = use_group_norm
    self.use_batch_norm = use_batch_norm
    self.mid_keys = mid_keys
    self.padding_mode = padding_mode

    # embedding layers
    self.embedding = []
    for conv_idx in range(num_conv):
        in_keys = input_keys if conv_idx == 0 else mid_keys
        key, subkey = random.split(key)
        self.embedding.append(
            ConvBlock(
                self.D,
                in_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                kernel_size,
                use_group_norm,
                use_batch_norm,
                padding_mode=padding_mode,
                key=subkey,
            )
        )

    self.downsample_blocks = []
    for downsample in range(1, num_downsamples + 1):
        down_layers = (layers.MaxNormPool(2, equivariant), [])

        for conv_idx in range(num_conv):
            out_keys = geom.Signature(
                tuple((k_p, _depth * (2**downsample)) for k_p, _depth in mid_keys)
            )
            if conv_idx == 0:
                in_keys = geom.Signature(
                    tuple((k_p, _depth * (2 ** (downsample - 1))) for k_p, _depth in mid_keys)
                )
            else:
                in_keys = out_keys

            key, subkey = random.split(key)
            down_layers[1].append(
                ConvBlock(
                    self.D,
                    in_keys,
                    out_keys,
                    use_bias,
                    activation_f,
                    equivariant,
                    conv_filters,
                    kernel_size,
                    use_group_norm,
                    use_batch_norm,
                    padding_mode=padding_mode,
                    key=subkey,
                )
            )

        self.downsample_blocks.append(down_layers)

    self.upsample_blocks = []
    for upsample in reversed(range(num_downsamples)):
        in_keys = geom.Signature(
            tuple((k_p, _depth * (2 ** (upsample + 1))) for k_p, _depth in mid_keys)
        )
        out_keys = geom.Signature(
            tuple((k_p, _depth * (2**upsample)) for k_p, _depth in mid_keys)
        )
        key, subkey = random.split(key)
        # perform the transposed convolution. For non-equivariant, padding and stride should
        # instead be the padding and stride for the forward direction convolution.
        if equivariant:
            padding = ((1, 1),) * self.D
            stride = (1,) * self.D
            upsample_kernel_size = None  # ignored for equivariant
        else:
            padding = "VALID"
            stride = (2,) * self.D
            upsample_kernel_size = (2,) * self.D  # kernel size of the downsample

        up_layers = (
            make_conv(
                self.D,
                in_keys,
                out_keys,
                use_bias,
                equivariant,
                upsample_filters,
                upsample_kernel_size,
                stride,
                padding,
                (2,) * self.D,  # lhs_dilation
                padding_mode=padding_mode,
                key=subkey,
            ),
            [],
        )

        for conv_idx in range(num_conv):
            out_keys = geom.Signature(
                tuple((k_p, _depth * (2**upsample)) for k_p, _depth in mid_keys)
            )
            if conv_idx == 0:  # due to adding the residual layer back, in_c is doubled again
                in_keys = geom.Signature(
                    tuple((k_p, _depth * (2 ** (upsample + 1))) for k_p, _depth in mid_keys)
                )
            else:
                in_keys = out_keys

            key, subkey = random.split(key)
            up_layers[1].append(
                ConvBlock(
                    self.D,
                    in_keys,
                    out_keys,
                    use_bias,
                    activation_f,
                    equivariant,
                    conv_filters,
                    kernel_size,
                    use_group_norm,
                    use_batch_norm,
                    padding_mode=padding_mode,
                    key=subkey,
                )
            )

        self.upsample_blocks.append(up_layers)

    key, subkey = random.split(key)

    self.decode = make_conv(
        self.D,
        mid_keys,
        output_keys,
        use_bias,
        equivariant,
        conv_filters,
        kernel_size,
        padding_mode=padding_mode,
        key=subkey,
    )
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Construct a new model with filters in a higher dimension. This only works for equivariant models.

Parameters:

Name Type Description Default
old_conv_filters

the current conv filters for the model

required
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

whether to force the sum of the filters in the new dimension to be equal

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
def convertD(
    self: Self,
    conv_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    key: jax.Array,
    **kwargs,
) -> Self:
    """
    Construct a new model with filters in a higher dimension. This only works for equivariant
    models.

    args:
        old_conv_filters: the current conv filters for the model
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: whether to force the sum of the filters in the new dimension to be equal
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    assert self.equivariant
    assert "upsample_filters" in kwargs
    new_model = self.__class__(
        conv_filters.D,
        self.input_keys,
        self.output_keys,
        0,  # ignored since mid_keys is provided
        len(self.downsample_blocks),
        len(self.embedding),
        self.use_bias,
        self.activation_f,
        self.equivariant,
        conv_filters,
        kwargs["upsample_filters"],
        0,  # ignored for equivariant model
        self.use_group_norm,
        self.use_batch_norm,
        self.mid_keys,
        self.padding_mode,
        key,
    )

    return self.transfer_weights(new_model, rescale)
__call__(x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Callable function for UNet

Parameters:

Name Type Description Default
x MultiImage

the input MultiImage

required
batch_stats Optional[State]

batch stats for BatchNorm if present

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage and batch_stats

Source code in ginjax/models.py
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def __call__(
    self: Self, x: geom.MultiImage, batch_stats: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Callable function for UNet

    args:
        x: the input MultiImage
        batch_stats: batch stats for BatchNorm if present

    returns:
        the output MultiImage and batch_stats
    """
    if not self.equivariant:
        x = x.to_scalar_multi_image()

    for layer in self.embedding:
        x, batch_stats = layer(x, batch_stats)

    residual_multi_images = []
    for max_pool_layer, conv_blocks in self.downsample_blocks:
        residual_multi_images.append(x)
        x = max_pool_layer(x)
        for layer in conv_blocks:
            x, batch_stats = layer(x, batch_stats)

    for (upsample_layer, conv_blocks), residual_multi_image in zip(
        self.upsample_blocks, reversed(residual_multi_images)
    ):
        upsample_x = upsample_layer(x)
        x = upsample_x.concat(residual_multi_image)
        for layer in conv_blocks:
            x, batch_stats = layer(x, batch_stats)

    x = self.decode(x)
    if self.equivariant:
        out = x
    else:
        out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

    return out, batch_stats

DilResNet ¤

Bases: AnyDimensionalModel

The Dilated ResNet from https://arxiv.org/abs/2112.15275.

Source code in ginjax/models.py
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
class DilResNet(AnyDimensionalModel):
    """
    The Dilated ResNet from https://arxiv.org/abs/2112.15275.
    """

    encoder: list[ConvBlock]
    blocks: list[list[ConvBlock]]
    decoder: list[ConvBlock]

    D: int = eqx.field(static=True)
    output_keys: geom.Signature = eqx.field(static=True)
    input_keys: geom.Signature = eqx.field(static=True)
    use_bias: bool | str = eqx.field(static=True)
    activation_f: Callable | str | None = eqx.field(static=True)
    equivariant: bool = eqx.field(static=True)
    use_group_norm: bool = eqx.field(static=True)
    mid_keys: geom.Signature = eqx.field(static=True)
    padding_mode: str = eqx.field(static=True)

    def __init__(
        self: Self,
        D: int,
        input_keys: geom.Signature,
        output_keys: geom.Signature,
        depth: int,
        num_blocks: int = 4,
        use_bias: bool | str = "auto",
        activation_f: Callable | str | None = jax.nn.relu,
        equivariant: bool = True,
        conv_filters: geom.MultiImage | None = None,
        kernel_size: int | Sequence[int] | None = None,
        use_group_norm: bool = False,
        mid_keys: geom.Signature | None = None,
        padding_mode: str = "ZEROS",
        key: Any = None,
    ) -> None:
        """
        Constructor for the DilatedResNet

        args:
            D: the dimension of the space
            input_keys: the MultiImage Signature for the input
            output_keys: the MultiImage Signature for the output
            depth: the number of channelsat the highest level of the unet
            num_blocks: number of resnet blocks
            use_bias: whether to use a bias
            activation_f: the activation function
            equivariant: whether to be equivariant
            conv_filters: the invariant filters for the equivariant version
            kernel_size: sidelength(s) for the non-equivariant version
            use_group_norm: whether to use GroupNorm
            mid_keys: types of images and number of channels for the mid layers, as a baseline
            padding_mode: used for non-equivariant models, padding mode to pass to convolutions
            key: jax.random key
        """
        self.D = D
        self.equivariant = equivariant
        self.output_keys = output_keys
        self.input_keys = input_keys

        if equivariant:
            if mid_keys is None:
                mid_keys = geom.signature_union(input_keys, output_keys, depth)
        else:
            if mid_keys is None:
                mid_keys = geom.Signature(((((), 0), depth),))

            # use these keys along the way, then for the final output use self.output_keys
            input_keys = geom.Signature(
                ((((), 0), sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)),)
            )
            output_keys = geom.Signature(
                ((((), 0), sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)),)
            )

        self.use_bias = use_bias
        self.activation_f = activation_f
        self.use_group_norm = use_group_norm
        self.mid_keys = mid_keys
        self.padding_mode = padding_mode

        # encoder
        key, subkey1, subkey2 = random.split(key, num=3)
        self.encoder = [
            ConvBlock(
                D,
                input_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey1,
            ),
            ConvBlock(
                D,
                mid_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey2,
            ),
        ]

        self.blocks = []
        for _ in range(num_blocks):
            # dCNN block
            dilation_block = []
            for dilation in [1, 2, 4, 8, 4, 2, 1]:
                key, subkey = random.split(key)
                dilation_block.append(
                    ConvBlock(
                        D,
                        mid_keys,
                        mid_keys,
                        use_bias,
                        activation_f,
                        equivariant,
                        conv_filters,
                        kernel_size,
                        use_group_norm,
                        rhs_dilation=(dilation,) * D,
                        padding_mode=padding_mode,
                        key=subkey,
                    )
                )

            self.blocks.append(dilation_block)

        key, subkey1, subkey2 = random.split(key, num=3)
        self.decoder = [
            ConvBlock(
                D,
                mid_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey1,
            ),
            ConvBlock(
                D,
                mid_keys,
                output_keys,
                use_bias,
                None,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey2,
            ),
        ]

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Callable for this layer

        args:
            x: the input MultiImage
            aux_data: unused, needed for compliance

        returns:
            the output MultiImage, aux_data
        """
        if not self.equivariant:
            x = x.to_scalar_multi_image()

        for layer in self.encoder:
            x, _ = layer(x)

        for dilation_block in self.blocks:
            residual_x = x.copy()

            for layer in dilation_block:
                x, _ = layer(x)

            x = x + residual_x

        for layer in self.decoder:
            x, _ = layer(x)

        if self.equivariant:
            out = x
        else:
            out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

        return out, aux_data

    def convertD(
        self: Self,
        conv_filters: geom.MultiImage,
        rescale: geom.Rescaling,
        key: jax.Array,
        **kwargs,
    ) -> Self:
        """
        Construct a new model with filters in a higher dimension. This only works for equivariant
        models.

        args:
            old_conv_filters: the current conv filters for the model
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: whether to force the sum of the filters in the new dimension to be equal
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        assert self.equivariant

        new_model = self.__class__(
            conv_filters.D,
            self.input_keys,
            self.output_keys,
            0,  # ignored since mid_keys is provided
            len(self.blocks),
            self.use_bias,
            self.activation_f,
            self.equivariant,
            conv_filters,
            0,  # ignored for equivariant model
            self.use_group_norm,
            self.mid_keys,
            self.padding_mode,
            key,
        )

        return self.transfer_weights(new_model, rescale)
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
__init__(D: int, input_keys: geom.Signature, output_keys: geom.Signature, depth: int, num_blocks: int = 4, use_bias: bool | str = 'auto', activation_f: Callable | str | None = jax.nn.relu, equivariant: bool = True, conv_filters: geom.MultiImage | None = None, kernel_size: int | Sequence[int] | None = None, use_group_norm: bool = False, mid_keys: geom.Signature | None = None, padding_mode: str = 'ZEROS', key: Any = None) -> None ¤

Constructor for the DilatedResNet

Parameters:

Name Type Description Default
D int

the dimension of the space

required
input_keys Signature

the MultiImage Signature for the input

required
output_keys Signature

the MultiImage Signature for the output

required
depth int

the number of channelsat the highest level of the unet

required
num_blocks int

number of resnet blocks

4
use_bias bool | str

whether to use a bias

'auto'
activation_f Callable | str | None

the activation function

relu
equivariant bool

whether to be equivariant

True
conv_filters MultiImage | None

the invariant filters for the equivariant version

None
kernel_size int | Sequence[int] | None

sidelength(s) for the non-equivariant version

None
use_group_norm bool

whether to use GroupNorm

False
mid_keys Signature | None

types of images and number of channels for the mid layers, as a baseline

None
padding_mode str

used for non-equivariant models, padding mode to pass to convolutions

'ZEROS'
key Any

jax.random key

None
Source code in ginjax/models.py
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
def __init__(
    self: Self,
    D: int,
    input_keys: geom.Signature,
    output_keys: geom.Signature,
    depth: int,
    num_blocks: int = 4,
    use_bias: bool | str = "auto",
    activation_f: Callable | str | None = jax.nn.relu,
    equivariant: bool = True,
    conv_filters: geom.MultiImage | None = None,
    kernel_size: int | Sequence[int] | None = None,
    use_group_norm: bool = False,
    mid_keys: geom.Signature | None = None,
    padding_mode: str = "ZEROS",
    key: Any = None,
) -> None:
    """
    Constructor for the DilatedResNet

    args:
        D: the dimension of the space
        input_keys: the MultiImage Signature for the input
        output_keys: the MultiImage Signature for the output
        depth: the number of channelsat the highest level of the unet
        num_blocks: number of resnet blocks
        use_bias: whether to use a bias
        activation_f: the activation function
        equivariant: whether to be equivariant
        conv_filters: the invariant filters for the equivariant version
        kernel_size: sidelength(s) for the non-equivariant version
        use_group_norm: whether to use GroupNorm
        mid_keys: types of images and number of channels for the mid layers, as a baseline
        padding_mode: used for non-equivariant models, padding mode to pass to convolutions
        key: jax.random key
    """
    self.D = D
    self.equivariant = equivariant
    self.output_keys = output_keys
    self.input_keys = input_keys

    if equivariant:
        if mid_keys is None:
            mid_keys = geom.signature_union(input_keys, output_keys, depth)
    else:
        if mid_keys is None:
            mid_keys = geom.Signature(((((), 0), depth),))

        # use these keys along the way, then for the final output use self.output_keys
        input_keys = geom.Signature(
            ((((), 0), sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)),)
        )
        output_keys = geom.Signature(
            ((((), 0), sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)),)
        )

    self.use_bias = use_bias
    self.activation_f = activation_f
    self.use_group_norm = use_group_norm
    self.mid_keys = mid_keys
    self.padding_mode = padding_mode

    # encoder
    key, subkey1, subkey2 = random.split(key, num=3)
    self.encoder = [
        ConvBlock(
            D,
            input_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey1,
        ),
        ConvBlock(
            D,
            mid_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey2,
        ),
    ]

    self.blocks = []
    for _ in range(num_blocks):
        # dCNN block
        dilation_block = []
        for dilation in [1, 2, 4, 8, 4, 2, 1]:
            key, subkey = random.split(key)
            dilation_block.append(
                ConvBlock(
                    D,
                    mid_keys,
                    mid_keys,
                    use_bias,
                    activation_f,
                    equivariant,
                    conv_filters,
                    kernel_size,
                    use_group_norm,
                    rhs_dilation=(dilation,) * D,
                    padding_mode=padding_mode,
                    key=subkey,
                )
            )

        self.blocks.append(dilation_block)

    key, subkey1, subkey2 = random.split(key, num=3)
    self.decoder = [
        ConvBlock(
            D,
            mid_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey1,
        ),
        ConvBlock(
            D,
            mid_keys,
            output_keys,
            use_bias,
            None,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey2,
        ),
    ]
__call__(x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Callable for this layer

Parameters:

Name Type Description Default
x MultiImage

the input MultiImage

required
aux_data Optional[State]

unused, needed for compliance

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage, aux_data

Source code in ginjax/models.py
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
def __call__(
    self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Callable for this layer

    args:
        x: the input MultiImage
        aux_data: unused, needed for compliance

    returns:
        the output MultiImage, aux_data
    """
    if not self.equivariant:
        x = x.to_scalar_multi_image()

    for layer in self.encoder:
        x, _ = layer(x)

    for dilation_block in self.blocks:
        residual_x = x.copy()

        for layer in dilation_block:
            x, _ = layer(x)

        x = x + residual_x

    for layer in self.decoder:
        x, _ = layer(x)

    if self.equivariant:
        out = x
    else:
        out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

    return out, aux_data
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Construct a new model with filters in a higher dimension. This only works for equivariant models.

Parameters:

Name Type Description Default
old_conv_filters

the current conv filters for the model

required
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

whether to force the sum of the filters in the new dimension to be equal

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
def convertD(
    self: Self,
    conv_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    key: jax.Array,
    **kwargs,
) -> Self:
    """
    Construct a new model with filters in a higher dimension. This only works for equivariant
    models.

    args:
        old_conv_filters: the current conv filters for the model
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: whether to force the sum of the filters in the new dimension to be equal
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    assert self.equivariant

    new_model = self.__class__(
        conv_filters.D,
        self.input_keys,
        self.output_keys,
        0,  # ignored since mid_keys is provided
        len(self.blocks),
        self.use_bias,
        self.activation_f,
        self.equivariant,
        conv_filters,
        0,  # ignored for equivariant model
        self.use_group_norm,
        self.mid_keys,
        self.padding_mode,
        key,
    )

    return self.transfer_weights(new_model, rescale)

ResNet ¤

Bases: AnyDimensionalModel

A typical ResNet.

Source code in ginjax/models.py
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
class ResNet(AnyDimensionalModel):
    """
    A typical ResNet.
    """

    encoder: list[ConvBlock]
    blocks: list[list[ConvBlock]]
    decoder: list[ConvBlock]

    D: int = eqx.field(static=True)
    equivariant: bool = eqx.field(static=True)
    output_keys: geom.Signature = eqx.field(static=True)
    input_keys: geom.Signature = eqx.field(static=True)
    use_bias: bool | str = eqx.field(static=True)
    activation_f: Callable | str = eqx.field(static=True)
    use_group_norm: bool = eqx.field(static=True)
    preactivation_order: bool = eqx.field(static=True)
    input_keys: geom.Signature = eqx.field(static=True)
    output_keys: geom.Signature = eqx.field(static=True)
    mid_keys: geom.Signature = eqx.field(static=True)
    padding_mode: str = eqx.field(static=True)

    def __init__(
        self: Self,
        D: int,
        input_keys: geom.Signature,
        output_keys: geom.Signature,
        depth: int,
        num_blocks: int = 8,
        num_conv: int = 2,
        use_bias: bool | str = "auto",
        activation_f: Callable | str = jax.nn.gelu,
        equivariant: bool = True,
        conv_filters: geom.MultiImage | None = None,
        kernel_size: int | Sequence[int] | None = None,
        use_group_norm: bool = True,
        preactivation_order: bool = True,
        mid_keys: geom.Signature | None = None,
        padding_mode: str = "ZEROS",
        key: Any = None,
    ) -> None:
        """
        Constructor for the ResNet

        args:
            D: the dimension of the space
            input_keys: the MultiImage Signature for the input
            output_keys: the MultiImage Signature for the output
            depth: the number of channelsat the highest level of the unet
            num_blocks: number of resnet blocks
            num_conv: number of convolutions per block
            use_bias: whether to use a bias
            activation_f: the activation function
            equivariant: whether to be equivariant
            conv_filters: the invariant filters for the equivariant version
            kernel_size: sidelength(s) for the non-equivariant version
            use_group_norm: whether to use GroupNorm
            preactivation_order: whether to use preactivation order
            mid_keys: types of images and number of channels for the mid layers, as a baseline
            padding_mode: for non-equivariant, pass 'TOROIDAL' if all sides are toroidal
            key: jax.random key
        """
        self.D = D
        self.equivariant = equivariant
        self.output_keys = output_keys
        self.input_keys = input_keys

        if equivariant:
            if mid_keys is None:
                mid_keys = geom.signature_union(input_keys, output_keys, depth)
        else:
            if mid_keys is None:
                mid_keys = geom.Signature(((((), 0), depth),))

            # use these keys along the way, then for the final output use self.output_keys
            input_keys = geom.Signature(
                ((((), 0), sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)),)
            )
            output_keys = geom.Signature(
                ((((), 0), sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)),)
            )

        self.use_bias = use_bias
        self.activation_f = activation_f
        self.use_group_norm = use_group_norm
        self.preactivation_order = preactivation_order
        self.mid_keys = mid_keys
        self.padding_mode = padding_mode

        # encoder
        key, subkey1, subkey2 = random.split(key, num=3)
        self.encoder = [
            ConvBlock(
                D,
                input_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey1,
            ),
            ConvBlock(
                D,
                mid_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey2,
            ),
        ]

        self.blocks = []
        for _ in range(num_blocks):
            # dCNN block
            block = []
            for _ in range(num_conv):
                key, subkey = random.split(key)
                block.append(
                    ConvBlock(
                        D,
                        mid_keys,
                        mid_keys,
                        use_bias,
                        activation_f,
                        equivariant,
                        conv_filters,
                        kernel_size,
                        use_group_norm,
                        preactivation_order=preactivation_order,
                        padding_mode=padding_mode,
                        key=subkey,
                    )
                )

            self.blocks.append(block)

        key, subkey1, subkey2 = random.split(key, num=3)
        self.decoder = [
            ConvBlock(
                D,
                mid_keys,
                mid_keys,
                use_bias,
                activation_f,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey1,
            ),
            ConvBlock(
                D,
                mid_keys,
                output_keys,
                use_bias,
                None,
                equivariant,
                conv_filters,
                1,
                padding_mode=padding_mode,
                key=subkey2,
            ),
        ]

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        """
        Callable for this layer

        args:
            x: the input MultiImage
            aux_data: unused, needed for compliance

        returns:
            the output MultiImage and aux_data
        """
        if not self.equivariant:
            x = x.to_scalar_multi_image()

        for layer in self.encoder:
            x, _ = layer(x)

        for block in self.blocks:
            residual_x = x.copy()

            for layer in block:
                x, _ = layer(x)

            x = x + residual_x

        for layer in self.decoder:
            x, _ = layer(x)

        if self.equivariant:
            out = x
        else:
            out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

        return out, aux_data

    def convertD(
        self: Self,
        conv_filters: geom.MultiImage,
        rescale: geom.Rescaling,
        key: jax.Array,
        **kwargs,
    ) -> Self:
        """
        Construct a new model with filters in a higher dimension. This only works for equivariant
        models.

        args:
            old_conv_filters: the current conv filters for the model
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: whether to force the sum of the filters in the new dimension to be equal
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        assert self.equivariant

        new_model = self.__class__(
            conv_filters.D,
            self.input_keys,
            self.output_keys,
            0,  # ignored since mid_keys is provided
            len(self.blocks),
            len(self.blocks[0]),
            self.use_bias,
            self.activation_f,
            self.equivariant,
            conv_filters,
            0,  # ignored for equivariant model
            self.use_group_norm,
            self.preactivation_order,
            self.mid_keys,
            self.padding_mode,
            key,
        )

        return self.transfer_weights(new_model, rescale)
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
__init__(D: int, input_keys: geom.Signature, output_keys: geom.Signature, depth: int, num_blocks: int = 8, num_conv: int = 2, use_bias: bool | str = 'auto', activation_f: Callable | str = jax.nn.gelu, equivariant: bool = True, conv_filters: geom.MultiImage | None = None, kernel_size: int | Sequence[int] | None = None, use_group_norm: bool = True, preactivation_order: bool = True, mid_keys: geom.Signature | None = None, padding_mode: str = 'ZEROS', key: Any = None) -> None ¤

Constructor for the ResNet

Parameters:

Name Type Description Default
D int

the dimension of the space

required
input_keys Signature

the MultiImage Signature for the input

required
output_keys Signature

the MultiImage Signature for the output

required
depth int

the number of channelsat the highest level of the unet

required
num_blocks int

number of resnet blocks

8
num_conv int

number of convolutions per block

2
use_bias bool | str

whether to use a bias

'auto'
activation_f Callable | str

the activation function

gelu
equivariant bool

whether to be equivariant

True
conv_filters MultiImage | None

the invariant filters for the equivariant version

None
kernel_size int | Sequence[int] | None

sidelength(s) for the non-equivariant version

None
use_group_norm bool

whether to use GroupNorm

True
preactivation_order bool

whether to use preactivation order

True
mid_keys Signature | None

types of images and number of channels for the mid layers, as a baseline

None
padding_mode str

for non-equivariant, pass 'TOROIDAL' if all sides are toroidal

'ZEROS'
key Any

jax.random key

None
Source code in ginjax/models.py
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
def __init__(
    self: Self,
    D: int,
    input_keys: geom.Signature,
    output_keys: geom.Signature,
    depth: int,
    num_blocks: int = 8,
    num_conv: int = 2,
    use_bias: bool | str = "auto",
    activation_f: Callable | str = jax.nn.gelu,
    equivariant: bool = True,
    conv_filters: geom.MultiImage | None = None,
    kernel_size: int | Sequence[int] | None = None,
    use_group_norm: bool = True,
    preactivation_order: bool = True,
    mid_keys: geom.Signature | None = None,
    padding_mode: str = "ZEROS",
    key: Any = None,
) -> None:
    """
    Constructor for the ResNet

    args:
        D: the dimension of the space
        input_keys: the MultiImage Signature for the input
        output_keys: the MultiImage Signature for the output
        depth: the number of channelsat the highest level of the unet
        num_blocks: number of resnet blocks
        num_conv: number of convolutions per block
        use_bias: whether to use a bias
        activation_f: the activation function
        equivariant: whether to be equivariant
        conv_filters: the invariant filters for the equivariant version
        kernel_size: sidelength(s) for the non-equivariant version
        use_group_norm: whether to use GroupNorm
        preactivation_order: whether to use preactivation order
        mid_keys: types of images and number of channels for the mid layers, as a baseline
        padding_mode: for non-equivariant, pass 'TOROIDAL' if all sides are toroidal
        key: jax.random key
    """
    self.D = D
    self.equivariant = equivariant
    self.output_keys = output_keys
    self.input_keys = input_keys

    if equivariant:
        if mid_keys is None:
            mid_keys = geom.signature_union(input_keys, output_keys, depth)
    else:
        if mid_keys is None:
            mid_keys = geom.Signature(((((), 0), depth),))

        # use these keys along the way, then for the final output use self.output_keys
        input_keys = geom.Signature(
            ((((), 0), sum(in_c * (D ** len(k)) for (k, _), in_c in input_keys)),)
        )
        output_keys = geom.Signature(
            ((((), 0), sum(out_c * (D ** len(k)) for (k, _), out_c in output_keys)),)
        )

    self.use_bias = use_bias
    self.activation_f = activation_f
    self.use_group_norm = use_group_norm
    self.preactivation_order = preactivation_order
    self.mid_keys = mid_keys
    self.padding_mode = padding_mode

    # encoder
    key, subkey1, subkey2 = random.split(key, num=3)
    self.encoder = [
        ConvBlock(
            D,
            input_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey1,
        ),
        ConvBlock(
            D,
            mid_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey2,
        ),
    ]

    self.blocks = []
    for _ in range(num_blocks):
        # dCNN block
        block = []
        for _ in range(num_conv):
            key, subkey = random.split(key)
            block.append(
                ConvBlock(
                    D,
                    mid_keys,
                    mid_keys,
                    use_bias,
                    activation_f,
                    equivariant,
                    conv_filters,
                    kernel_size,
                    use_group_norm,
                    preactivation_order=preactivation_order,
                    padding_mode=padding_mode,
                    key=subkey,
                )
            )

        self.blocks.append(block)

    key, subkey1, subkey2 = random.split(key, num=3)
    self.decoder = [
        ConvBlock(
            D,
            mid_keys,
            mid_keys,
            use_bias,
            activation_f,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey1,
        ),
        ConvBlock(
            D,
            mid_keys,
            output_keys,
            use_bias,
            None,
            equivariant,
            conv_filters,
            1,
            padding_mode=padding_mode,
            key=subkey2,
        ),
    ]
__call__(x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None) -> tuple[geom.MultiImage, Optional[eqx.nn.State]] ¤

Callable for this layer

Parameters:

Name Type Description Default
x MultiImage

the input MultiImage

required
aux_data Optional[State]

unused, needed for compliance

None

Returns:

Type Description
tuple[MultiImage, Optional[State]]

the output MultiImage and aux_data

Source code in ginjax/models.py
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
def __call__(
    self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
    """
    Callable for this layer

    args:
        x: the input MultiImage
        aux_data: unused, needed for compliance

    returns:
        the output MultiImage and aux_data
    """
    if not self.equivariant:
        x = x.to_scalar_multi_image()

    for layer in self.encoder:
        x, _ = layer(x)

    for block in self.blocks:
        residual_x = x.copy()

        for layer in block:
            x, _ = layer(x)

        x = x + residual_x

    for layer in self.decoder:
        x, _ = layer(x)

    if self.equivariant:
        out = x
    else:
        out = geom.MultiImage.from_scalar_multi_image(x, self.output_keys)

    return out, aux_data
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Construct a new model with filters in a higher dimension. This only works for equivariant models.

Parameters:

Name Type Description Default
old_conv_filters

the current conv filters for the model

required
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

whether to force the sum of the filters in the new dimension to be equal

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
def convertD(
    self: Self,
    conv_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    key: jax.Array,
    **kwargs,
) -> Self:
    """
    Construct a new model with filters in a higher dimension. This only works for equivariant
    models.

    args:
        old_conv_filters: the current conv filters for the model
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: whether to force the sum of the filters in the new dimension to be equal
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    assert self.equivariant

    new_model = self.__class__(
        conv_filters.D,
        self.input_keys,
        self.output_keys,
        0,  # ignored since mid_keys is provided
        len(self.blocks),
        len(self.blocks[0]),
        self.use_bias,
        self.activation_f,
        self.equivariant,
        conv_filters,
        0,  # ignored for equivariant model
        self.use_group_norm,
        self.preactivation_order,
        self.mid_keys,
        self.padding_mode,
        key,
    )

    return self.transfer_weights(new_model, rescale)

ModelWrapper ¤

Bases: MultiImageModule

This wraps a typical CNN so that it is a MultiImage model. This model will take an input MultiImage, convert it to a jax array, feed it through the model, then convert it to the appropriate output MultiImage at the end.

Source code in ginjax/models.py
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
class ModelWrapper(MultiImageModule):
    """
    This wraps a typical CNN so that it is a MultiImage model. This model will take an input
    MultiImage, convert it to a jax array, feed it through the model, then convert it to the
    appropriate output MultiImage at the end.
    """

    model: eqx.Module

    D: int = eqx.field(static=True)
    output_keys: geom.Signature = eqx.field(static=True)
    output_is_torus: Union[bool, tuple[bool, ...]] = eqx.field(static=True)
    pass_aux_data: bool = eqx.field(static=True)

    def __init__(
        self: Self,
        D: int,
        model: eqx.Module,
        output_keys: geom.Signature,
        output_is_torus: Union[bool, tuple[bool, ...]],
        pass_aux_data: bool = False,
    ) -> None:
        """
        Construct the model wrapper.

        args:
            D: the dimension of the space
            model: a vanilla cnn model, should input and output images of shape (channels,spatial)
            output_keys: signature for the output MultiImage
            output_is_torus: toroidal structure of the output MultiImage
            pass_aux_data: whether the model expects and outputs aux_data
        """
        self.D = D
        assert callable(model)
        self.model = model
        self.output_keys = output_keys
        self.output_is_torus = output_is_torus
        self.pass_aux_data = pass_aux_data  # pass the AUX, bro

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        x_array = x.to_scalar_multi_image()[((), 0)]
        assert callable(self.model)
        if self.pass_aux_data:
            out, aux_data = self.model(x_array, aux_data)
        else:
            out = self.model(x_array)

        out_multi_image = geom.MultiImage(
            {(0, 0): out},
            self.D,
            self.output_is_torus,
        ).from_scalar_multi_image(self.output_keys)

        return out_multi_image, aux_data
__init__(D: int, model: eqx.Module, output_keys: geom.Signature, output_is_torus: Union[bool, tuple[bool, ...]], pass_aux_data: bool = False) -> None ¤

Construct the model wrapper.

Parameters:

Name Type Description Default
D int

the dimension of the space

required
model Module

a vanilla cnn model, should input and output images of shape (channels,spatial)

required
output_keys Signature

signature for the output MultiImage

required
output_is_torus Union[bool, tuple[bool, ...]]

toroidal structure of the output MultiImage

required
pass_aux_data bool

whether the model expects and outputs aux_data

False
Source code in ginjax/models.py
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
def __init__(
    self: Self,
    D: int,
    model: eqx.Module,
    output_keys: geom.Signature,
    output_is_torus: Union[bool, tuple[bool, ...]],
    pass_aux_data: bool = False,
) -> None:
    """
    Construct the model wrapper.

    args:
        D: the dimension of the space
        model: a vanilla cnn model, should input and output images of shape (channels,spatial)
        output_keys: signature for the output MultiImage
        output_is_torus: toroidal structure of the output MultiImage
        pass_aux_data: whether the model expects and outputs aux_data
    """
    self.D = D
    assert callable(model)
    self.model = model
    self.output_keys = output_keys
    self.output_is_torus = output_is_torus
    self.pass_aux_data = pass_aux_data  # pass the AUX, bro

GroupAverage ¤

Bases: MultiImageModule

Model that takes in a different model and peforms group averaging to make it an equivariant model. Can either always average, so that it is equivariant during training as well, or only average at inference time to test whether training a non-equivariant model, then group averaging helps. This will reveal whether to data set is indeed an equivariant data set.

Source code in ginjax/models.py
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
class GroupAverage(MultiImageModule):
    """
    Model that takes in a different model and peforms group averaging to make it an equivariant
    model. Can either always average, so that it is equivariant during training as well, or only
    average at inference time to test whether training a non-equivariant model, then group
    averaging helps. This will reveal whether to data set is indeed an equivariant data set.
    """

    model: MultiImageModule
    inference: bool

    # static to prevent this from being converted to a traced jax array
    operators: list[np.ndarray] = eqx.field(static=True)
    always_average: bool = eqx.field(static=True)

    def __init__(
        self: Self,
        model: MultiImageModule,
        operators: list[np.ndarray],
        always_average: bool = False,
        inference: bool = False,
    ) -> None:
        self.model = model
        self.operators = operators
        self.always_average = always_average
        self.inference = inference

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:

        if (self.always_average or self.inference) and len(self.operators) > 0:
            sum_image = None
            out_aux = None
            for gg in self.operators:
                out_image, out_aux = self.model(x.times_group_element(gg), aux_data)
                rot_out_image = out_image.times_group_element(gg.T)
                sum_image = rot_out_image if sum_image is None else sum_image + rot_out_image

            assert sum_image is not None
            return sum_image / len(self.operators), out_aux

        else:
            return self.model(x, aux_data)

LastStepIdentity ¤

Bases: AnyDimensionalModel

Source code in ginjax/models.py
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
class LastStepIdentity(AnyDimensionalModel):

    residual: bool = eqx.field(static=True)

    def __init__(self: Self, residual: bool = False):
        self.residual = residual

    def convertD(
        self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
    ) -> Self:
        """
        Convert model to a different dimension.

        args:
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: whether to force the sum of the filters in the new dimension to be equal
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        return self.__class__(self.residual)

    def __call__(
        self: Self, x: geom.MultiImage, batch_stats: eqx.nn.State | None = None
    ) -> tuple[geom.MultiImage, eqx.nn.State | None]:
        """
        Callable function.

        args:
            x: the input MultiImage
            batch_stats: batch stats for BatchNorm if present

        returns:
            the output MultiImage and batch_stats
        """

        out = x.empty()
        for (k, parity), img_block in x.items():
            # If it is a residual model, make it all zeros to add it
            out_img_block = jnp.zeros_like(img_block[-1:]) if self.residual else img_block[-1:]
            out.append(k, parity, out_img_block)

        return out, batch_stats
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Convert model to a different dimension.

Parameters:

Name Type Description Default
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

whether to force the sum of the filters in the new dimension to be equal

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
def convertD(
    self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
) -> Self:
    """
    Convert model to a different dimension.

    args:
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: whether to force the sum of the filters in the new dimension to be equal
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    return self.__class__(self.residual)
__call__(x: geom.MultiImage, batch_stats: eqx.nn.State | None = None) -> tuple[geom.MultiImage, eqx.nn.State | None] ¤

Callable function.

Parameters:

Name Type Description Default
x MultiImage

the input MultiImage

required
batch_stats State | None

batch stats for BatchNorm if present

None

Returns:

Type Description
tuple[MultiImage, State | None]

the output MultiImage and batch_stats

Source code in ginjax/models.py
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
def __call__(
    self: Self, x: geom.MultiImage, batch_stats: eqx.nn.State | None = None
) -> tuple[geom.MultiImage, eqx.nn.State | None]:
    """
    Callable function.

    args:
        x: the input MultiImage
        batch_stats: batch stats for BatchNorm if present

    returns:
        the output MultiImage and batch_stats
    """

    out = x.empty()
    for (k, parity), img_block in x.items():
        # If it is a residual model, make it all zeros to add it
        out_img_block = jnp.zeros_like(img_block[-1:]) if self.residual else img_block[-1:]
        out.append(k, parity, out_img_block)

    return out, batch_stats

SimpleConvSeries ¤

Bases: AnyDimensionalModel

Simple convolution model consisting of a series of ConvBlocks, with all but the last with a gelu vector neuron nonlinearity.

Source code in ginjax/models.py
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
class SimpleConvSeries(AnyDimensionalModel):
    """
    Simple convolution model consisting of a series of ConvBlocks, with all but the last with a
    gelu vector neuron nonlinearity.
    """

    layers: list[ConvBlock]

    D: int = eqx.field(static=True)
    input_keys: geom.Signature = eqx.field(static=True)
    target_keys: geom.Signature = eqx.field(static=True)
    width: int = eqx.field(static=True)
    depth: int = eqx.field(static=True)
    use_bias: bool | str = eqx.field(static=True)
    activation_f: Callable | str | None = eqx.field(static=True)

    def __init__(
        self: Self,
        input_keys: geom.Signature,
        target_keys: geom.Signature,
        conv_filters: geom.MultiImage,
        width: int,
        depth: int,
        use_bias: bool | str,
        activation_f: Callable | str | None,
        key: jax.Array,
    ) -> None:
        assert depth >= 1
        self.D = conv_filters.D
        self.input_keys = input_keys
        self.target_keys = target_keys
        self.width = width
        self.depth = depth
        self.use_bias = use_bias
        self.activation_f = activation_f

        mid_keys = geom.signature_union(input_keys, target_keys, width) if depth > 1 else input_keys

        subkey_last, *subkeys = random.split(key, num=depth)
        self.layers = []
        for subkey in subkeys:
            self.layers.append(
                ConvBlock(
                    self.D,
                    input_keys,
                    mid_keys,
                    use_bias,
                    activation_f,
                    True,
                    conv_filters,
                    key=subkey,
                )
            )

        self.layers.append(
            ConvBlock(
                self.D, mid_keys, target_keys, use_bias, None, True, conv_filters, key=subkey_last
            )
        )

    def convertD(
        self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
    ) -> Self:
        """
        Construct a new model with filters in a higher dimension.

        args:
            conv_filters: the new conv filters we are swapping to, probably in a higher dimension
            rescale: how to rescale the filter weights
            key: key to initialize the weights, since they are overruled it won't matter

        returns:
            a new model with new filters but the old weights
        """
        new_model = self.__class__(
            self.input_keys,
            self.target_keys,
            conv_filters,
            self.width,
            self.depth,
            self.use_bias,
            self.activation_f,
            key,
        )

        return self.transfer_weights(new_model, rescale, verbose=False)

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: eqx.nn.State | None = None
    ) -> tuple[geom.MultiImage, eqx.nn.State | None]:
        for layer in self.layers:
            x, aux_data = layer(x, aux_data)

        return x, aux_data
_extend_weights(old_weights_block: jax.Array, filter_key: tuple[tuple[bool, ...], int], old_filters: geom.MultiImage, new_filters: geom.MultiImage) -> jax.Array staticmethod ¤

Given a set of weights associated with old_filters, extend the weights to new_filters. For offcenter weights (associated with a set of filters that has a center filter) and for balanced weights (associated with a set of filters which has no center filter), the new weights are the average of the old weights.

Parameters:

Name Type Description Default
old_weights_block Array

the old weights, shape (out_c,in_c,n_filters)

required
filter_key tuple[tuple[bool, ...], int]

the key for the filters we are extending weights for

required
old_filters MultiImage

the old filters

required
new_filters MultiImage

the new filters

required

Returns:

Type Description
Array

the weights associated with the new filters

Source code in ginjax/models.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@staticmethod
def _extend_weights(
    old_weights_block: jax.Array,
    filter_key: tuple[tuple[bool, ...], int],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
) -> jax.Array:
    """
    Given a set of weights associated with old_filters, extend the weights to new_filters.
    For offcenter weights (associated with a set of filters that has a center filter) and for
    balanced weights (associated with a set of filters which has no center filter), the new
    weights are the average of the old weights.

    args:
        old_weights_block: the old weights, shape (out_c,in_c,n_filters)
        filter_key: the key for the filters we are extending weights for
        old_filters: the old filters
        new_filters: the new filters

    returns:
        the weights associated with the new filters
    """
    k = len(filter_key[0])
    if k not in {0, 1, 2}:
        raise NotImplementedError()

    n_add_unbalanced = 0
    n_add_balanced = 0
    center_weight = None
    offcenter_old_weights = None
    balanced_weights = None
    if k == 0:
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:]
        n_add_unbalanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 1:
        balanced_weights = old_weights_block
        n_add_balanced = len(new_filters[filter_key]) - len(old_filters[filter_key])
    elif k == 2:
        # for k==2, the first set of filters follows the scalar filters
        assert ((), 0) in old_filters, "_extend_weights needs k=0 filters if it includes k=2"
        n_old_unbalanced = len(old_filters[(), 0])
        center_weight = old_weights_block[:, :, :1]
        offcenter_old_weights = old_weights_block[:, :, 1:n_old_unbalanced]
        n_add_unbalanced = len(new_filters[(), 0]) - n_old_unbalanced

        balanced_weights = old_weights_block[:, :, n_old_unbalanced:]
        # gap between new filters and (old filters plus the additional unbalanced filter)
        n_add_balanced = len(new_filters[filter_key]) - (
            len(old_filters[filter_key]) + n_add_unbalanced
        )

    assert n_add_unbalanced >= 0
    assert n_add_balanced >= 0

    new_unbalanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if center_weight is not None and offcenter_old_weights is not None:
        # TODO: check what happens when n_add_unbalanced = 0
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_unbalanced,),
            jnp.mean(offcenter_old_weights, axis=2, keepdims=True),
        )

        new_unbalanced_weights = jnp.concatenate(
            [center_weight, offcenter_old_weights, additional_weights], axis=2
        )

    new_balanced_weights = jnp.zeros(old_weights_block.shape[:2] + (0,))
    if balanced_weights is not None:
        assert balanced_weights is not None
        additional_weights = jnp.full(
            old_weights_block.shape[:2] + (n_add_balanced,),
            jnp.mean(balanced_weights, axis=2, keepdims=True),
        )

        new_balanced_weights = jnp.concatenate([balanced_weights, additional_weights], axis=2)

    return jnp.concatenate([new_unbalanced_weights, new_balanced_weights], axis=2)
volume_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weights so that the sum of the weights times the filters add up to the same value for the old filters and the new filters (which are likely a higher dimension).

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@staticmethod
def volume_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weights so that the sum of the weights times the filters add up to the same
    value for the old filters and the new filters (which are likely a higher dimension).

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # both are (out_c,in_c)
    old_weights_sum = get_filter_sum(old_D, old_filters, old_weights)
    new_weights_sum = get_filter_sum(new_D, new_filters, new_weights)

    # Dont rescale filters that always sum to 0.
    # (n_filters,tensor)
    spatial_sum = jnp.sum(old_filters, axis=tuple(range(1, 1 + old_D)))
    # (n_filters,)
    spatial_sum_norm = jnp.linalg.norm(spatial_sum.reshape((len(spatial_sum), -1)), axis=1)
    nonzero_filter_mask = (spatial_sum_norm != 0)[None, None]  # (1,1,n_filters)

    # (out_c,in_c)
    ratios = old_weights_sum / (new_weights_sum + geom.TINY)
    # Scale nonzero by ratios, scale the others by 1 (out_c,in_c,n_filters)
    ratios = nonzero_filter_mask * ratios[..., None] + (~nonzero_filter_mask)

    if verbose:
        print("old weights", old_weights.shape, old_weights)
        print("ratios", ratios.shape, ratios)  # (out_c,in_c,n_filters)

    return new_weights * ratios
compat_flex_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Do compatibility rescaling, now with one extra free parameter. For now this is only defined for sidelength 3 filters for D=1 to D=2.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@staticmethod
def compat_flex_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Do compatibility rescaling, now with one extra free parameter. For now this is only defined
    for sidelength 3 filters for D=1 to D=2.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compat_flex_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase == 1

    if (
        old_filters.shape[1 : 1 + old_D] == (3,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (3,) * new_D
    ):
        if old_D == 1 and new_D == 2:
            assert old_weights.shape[2] == 2  # should be 2 filters
            ratio = 1 / 3

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0] + (-2 + 4 * ratio) * old_weights[..., 1],
                    (1 - 2 * ratio) * old_weights[..., 1],
                    ratio * old_weights[..., 1],
                ],
                axis=-1,
            )
        elif old_D == 2 and new_D == 3:
            # need to get first 4 new_weights from first 3 old_weights

            z = (old_weights[..., 2] * 4 - old_weights[..., 1]) / 9

            alpha_prime = jnp.stack(
                [
                    old_weights[..., 0]
                    - 2 * old_weights[..., 1]
                    + 4 * old_weights[..., 2]
                    - 8 * z,
                    old_weights[..., 1] - 2 * old_weights[..., 2] + 4 * z,
                    old_weights[..., 2] - 2 * z,
                    z,
                ],
                axis=-1,
            )

            # filters are in flipped order for some reason
            symmetric_traceless = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 4:5]
            along_trace = jnp.ones_like(old_weights[..., :2]) * old_weights[..., 3:4]

            alpha_prime = jnp.concatenate(
                [alpha_prime, symmetric_traceless, along_trace], axis=-1
            )
        else:
            raise ValueError()
    elif (
        old_filters.shape[1 : 1 + old_D] == (2,) * old_D
        and new_filters.shape[1 : 1 + new_D] == (2,) * new_D
    ):
        alpha_prime = old_weights / (2**D_increase)
    else:
        raise ValueError()

    # TODO: I could check that the condition holds?

    return alpha_prime
compatibility_norm_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters.

WARNING: This is the old version which works on the norms of the tensors.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
@staticmethod
def compatibility_norm_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters.

    WARNING: This is the old version which works on the norms of the tensors.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple
    new_filters, new_weights, new_D = new_filter_triple

    # Convert filters to the norm of the filters. This assumes 2 things:
    # 1. tensors in each pixel differ only by norm. True for nonzero filters of a single irrep
    # 2. the sign of the filters are positive
    old_filters = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
    )
    new_filters = jnp.linalg.norm(
        new_filters.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
    )

    # assert the filters are already in ascending order by number of pixels.
    # So for orthoplex, this means innermost to outermost
    filter_raw_sum = jnp.sum(1 * geom.nonempty_pixels(new_D, new_filters, 1), axis=-1)
    assert sorted(list(filter_raw_sum)) == list(filter_raw_sum)

    D_increase = new_D - old_D
    assert D_increase > 0

    # first, reduce the filters to the nonempty pixel filters.
    nonempty_pixel_filter = 1 * geom.nonempty_pixels(new_D, new_filters, 1).reshape(
        new_filters.shape[: 1 + new_D]
    )
    # (n_filters,old_spatial)
    collapsed_nonempty_ff = jnp.sum(nonempty_pixel_filter, axis=tuple(range(1, 1 + D_increase)))

    # (n_filters,old_spatial)
    collapsed_ff = jnp.sum(new_filters, axis=tuple(range(1, 1 + D_increase)))
    # (n_filters,old_spatial)

    # (out_c,in_c,spatial)
    old_scaled_ff = jnp.sum(get_scaled_filters(old_D, old_filters, old_weights), axis=2)

    # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
    updated_weights = np.zeros(new_weights.shape[:2] + (len(new_filters),))
    for i in reversed(range(len(filter_raw_sum))):  # starting with the outermost filter...

        # get the outermost pixel of collapsed filter i
        # (old_spatial_size,) true/falses whether the pixel is nonempty
        nonempty_pixels = geom.nonempty_pixels(old_D, collapsed_nonempty_ff[i]).ravel()
        farthest_pixel_idx = jnp.max(jnp.arange(len(nonempty_pixels))[nonempty_pixels])

        # with current weight for filter i and collapsed sum of updated_weights,
        # calculate new weight to equal old weight
        updated_weights[:, :, i] = new_weights[:, :, i]  # temp set weight to current weight
        # (out_c,in_c,n_filters,old_spatial)
        scaled_collapsed_ff = get_scaled_filters(
            old_D, collapsed_ff, jnp.array(updated_weights)
        )
        # (out_c,in_c,old_spatial)
        collapsed_sum = jnp.sum(scaled_collapsed_ff, axis=2)
        # (out_c,in_c)
        collapsed_val = collapsed_sum.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # assume that old_weights_val = new_weights_val. The old weight and new weight are
        # the same at this point, otherwise filter value could be different, but it wont be
        # for normalize and gaussian at least.
        old_weights_val = old_scaled_ff.reshape(collapsed_sum.shape[:2] + (-1,))[
            :, :, farthest_pixel_idx
        ]
        # this should really be new_ff_val, assume they are equal, see above
        old_norm_ff_val = old_filters[i].ravel()[farthest_pixel_idx]

        # set updated weights
        updated_weights[:, :, i] = (
            -(collapsed_val - old_weights_val) + old_weights_val
        ) / old_norm_ff_val

    updated_weights = jnp.array(updated_weights)

    # now we check that we did it right
    # (out_c,in_c,n_filters,old_spatial)
    scaled_collapsed_ff = get_scaled_filters(old_D, collapsed_ff, updated_weights)
    # (out_c,in_c,old_spatial)
    scaled_collapsed_ff = jnp.sum(scaled_collapsed_ff, axis=2)

    # (n_filters,old_spatial)
    old_norm_ff = jnp.linalg.norm(
        old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)),
        axis=-1,
    )

    # (out_c,in_c,n_filters,old_spatial)
    old_scaled_filters = get_scaled_filters(old_D, old_norm_ff, old_weights)
    # (out_c,in_c,old_spatial)
    old_scaled_filters = jnp.sum(old_scaled_filters, axis=2)

    diff = jnp.max(jnp.abs(scaled_collapsed_ff - old_scaled_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(
        scaled_collapsed_ff, old_scaled_filters, rtol=1e-3, atol=1e-3
    ), diff_message

    if verbose:
        print("new weights:", new_weights)
        print("updated weights:", updated_weights)

    return updated_weights
compatibility_rescale_weights(old_filter_triple: tuple[jax.Array, jax.Array, int], new_filter_triple: tuple[jax.Array, jax.Array, int], verbose: bool = False) -> jax.Array staticmethod ¤

Rescale the weight coefficients so that they are compatible with the particular embedding. This algorithm has an implicit assumption that we are using orthoplex filters. This implements Algorithm 1: Orthoplex filter weight scaling.

Parameters:

Name Type Description Default
old_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the old filters, and the old dimension

required
new_filter_triple tuple[Array, Array, int]

tuple of weights (shape (out_channels,in_channels,num_filters)), the new filters, and the new dimension

required
verbose bool

whether to print the old weights and ratios

False
return

jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling

Source code in ginjax/models.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
@staticmethod
def compatibility_rescale_weights(
    old_filter_triple: tuple[jax.Array, jax.Array, int],
    new_filter_triple: tuple[jax.Array, jax.Array, int],
    verbose: bool = False,
) -> jax.Array:
    """
    Rescale the weight coefficients so that they are compatible with the particular embedding.
    This algorithm has an implicit assumption that we are using orthoplex filters. This
    implements Algorithm 1: Orthoplex filter weight scaling.

    args:
        old_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the old filters, and the old dimension
        new_filter_triple: tuple of weights (shape (out_channels,in_channels,num_filters)),
            the new filters, and the new dimension
        verbose: whether to print the old weights and ratios

    return:
        jax array of rescaled weights (out_channels,in_channels,num_filters) after rescaling
    """
    old_filters, old_weights, old_D = old_filter_triple  # old weights are alpha
    new_filters, new_weights, new_D = new_filter_triple
    k = old_filters.ndim - (1 + old_D)
    assert k == new_filters.ndim - (
        1 + new_D
    ), f"compatibility_rescale_weights: old_filters k={k}, new_filters k={new_filters.ndim - (1 + new_D)}"

    D_increase = new_D - old_D
    assert D_increase > 0, f"compatibility_rescale_weights: D_increase={D_increase}"

    # old/new_filters shape (n_filters,spatial,tensor)

    # we have filters ell=0,1,...,L
    # same number of filters
    assert len(old_filters) == len(
        new_filters
    ), f"compatibility_rescale_weights: len old_filters={len(old_filters)}, len new_filters={len(new_filters)}"
    L = len(old_filters) - 1
    L_plus = len(old_filters)  # more useful for iterating

    new_filters_proj_tensors = (
        new_filters[..., (slice(0, old_D),) * k] if k > 0 else new_filters
    )

    # currently special case N=2 because its so different
    if old_filters.shape[1] == 2 or new_filters.shape[1] == 2:
        assert (2,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert (2,) * new_D == new_filters.shape[1 : 1 + new_D]

        alpha_prime = old_weights / (2**D_increase)

    else:  # filters are odd, and in particular 2L + 1 square
        # largest filter goes up to the border
        assert ((2 * L) + 1,) * old_D == old_filters.shape[1 : 1 + old_D]
        assert ((2 * L) + 1,) * new_D == new_filters.shape[1 : 1 + new_D]

        # (n_filters,new_spatial)
        new_filters_proj_norm = jnp.linalg.norm(
            new_filters_proj_tensors.reshape(new_filters.shape[: 1 + new_D] + (-1,)), axis=-1
        )

        # (n_filters,old_spatial)
        old_filters_norm = jnp.linalg.norm(
            old_filters.reshape(old_filters.shape[: 1 + old_D] + (-1,)), axis=-1
        )

        # use np so we can easily edit it (out_c,in_c,n_nonzero_filters)
        alpha_prime = np.zeros(new_weights.shape[:2] + (L_plus,))
        for z in reversed(range(L_plus)):  # iterates from L,L-1,...,0
            j_d_centered = (z,) + (0,) * (old_D - 1)
            j_dplus_centered = (z,) + (0,) * (new_D - 1)

            j_d = tuple(x + L for x in j_d_centered)
            j_dplus = tuple(x + L for x in j_dplus_centered)

            # (out_c,in_c,n_filters,new_spatial)
            scaled_new_filters = (
                alpha_prime[..., *((None,) * new_D)] * new_filters_proj_norm[None, None]
            )
            # sum over filters, spatial dims (out_c,in_c,old_spatial)
            # since alpha_prime are only nonzero for z+1, this is the proper sum over ell=z+1 to L
            collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

            # alpha_prime = (alpha * C_z - sum) / (C'_z)
            alpha_prime[:, :, z] = (
                old_weights[:, :, z] * old_filters_norm[z, *j_d] - collapsed_ff[:, :, *j_d]
            ) / new_filters_proj_norm[z, *j_dplus]

        alpha_prime = jnp.array(alpha_prime)

    # now we check that we did it right
    # (out_c,in_c,n_filters,new_spatial,proj_tensor)
    scaled_new_filters = (
        alpha_prime[..., *((None,) * (new_D + k))] * new_filters_proj_tensors[None, None]
    )
    # (out_c,in_c,old_spatial,proj_tensor)
    collapsed_ff = jnp.sum(scaled_new_filters, axis=tuple(range(2, 2 + 1 + D_increase)))

    # (out_c,in_c,n_filters,old_spatial,tensor)
    scaled_old_filters = old_weights[..., *((None,) * (old_D + k))] * old_filters[None, None]
    # (out_c,in_c,old_spatial,tensor)
    scaled_old_filters = jnp.sum(scaled_old_filters, axis=2)

    diff = jnp.max(jnp.abs(collapsed_ff - scaled_old_filters))
    diff_message = f"AnyDimensionalModel::compatibility_rescale_weights: Diff is {diff}"

    assert jnp.allclose(collapsed_ff, scaled_old_filters, rtol=1e-3, atol=1e-3), diff_message

    if verbose:
        print("old weights:", old_weights)
        print("updated weights:", alpha_prime)

    return alpha_prime
_transfer_conv_weights(weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]], old_filters: geom.MultiImage, new_filters: geom.MultiImage, rescale: geom.Rescaling, verbose: bool = False) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]] staticmethod ¤

Transfer the conv weights from old filters to new filters of possibly a different dimension. If rescale is true, then scale the weights so that the sum of the filter basis of a particular order scaled by the weights is equal for the old filters and the new.

Parameters:

Name Type Description Default
weights dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a weights dictionary from a layers.ConvContract layer

required
old_filters MultiImage

the old filters that the weights came from

required
new_filters MultiImage

the new filters that we will be using the weights for

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], Array]]

a new weights dictionary

Source code in ginjax/models.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@staticmethod
def _transfer_conv_weights(
    weights: dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]],
    old_filters: geom.MultiImage,
    new_filters: geom.MultiImage,
    rescale: geom.Rescaling,
    verbose: bool = False,
) -> dict[tuple[tuple[bool, ...], int], dict[tuple[tuple[bool, ...], int], jax.Array]]:
    """
    Transfer the conv weights from old filters to new filters of possibly a different dimension.
    If rescale is true, then scale the weights so that the sum of the filter basis of a particular
    order scaled by the weights is equal for the old filters and the new.

    args:
        weights: a weights dictionary from a layers.ConvContract layer
        old_filters: the old filters that the weights came from
        new_filters: the new filters that we will be using the weights for
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new weights dictionary
    """
    new_weights = {}

    for (in_k, in_p), in_weights in weights.items():
        new_weights[(in_k, in_p)] = {}
        for (out_k, out_p), old_weights_block in in_weights.items():
            filter_k = in_k + out_k
            filter_key = (filter_k, (in_p + out_p) % 2)

            new_weights_block = AnyDimensionalModel._extend_weights(
                old_weights_block, filter_key, old_filters, new_filters
            )

            old_filter_block = old_filters[filter_key]
            new_filter_block = new_filters[filter_key]

            if rescale is geom.Rescaling.VOLUME:
                pos_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, jax.nn.relu(old_weights_block), old_filters.D),
                    (new_filter_block, jax.nn.relu(new_weights_block), new_filters.D),
                    verbose,
                )
                neg_weights = AnyDimensionalModel.volume_rescale_weights(
                    (old_filter_block, -jax.nn.relu(-old_weights_block), old_filters.D),
                    (new_filter_block, -jax.nn.relu(-new_weights_block), new_filters.D),
                    verbose,
                )
                scaled_weights_block = pos_weights + neg_weights
            elif rescale is geom.Rescaling.COMPATIBILITY:
                # Dont rescale filters that always sum to 0.
                # (n_filters,tensor)
                spatial_sum = jnp.sum(new_filter_block, axis=tuple(range(1, 1 + new_filters.D)))
                # (n_filters,)
                spatial_sum_norm = jnp.linalg.norm(
                    spatial_sum.reshape((len(spatial_sum), -1)), axis=1
                )
                nonzero_mask = spatial_sum_norm != 0  # (n_filters,)

                updated_weights_block = AnyDimensionalModel.compatibility_rescale_weights(
                    (
                        old_filter_block[nonzero_mask],
                        old_weights_block[:, :, nonzero_mask],
                        old_filters.D,
                    ),
                    (
                        new_filter_block[nonzero_mask],
                        new_weights_block[:, :, nonzero_mask],
                        new_filters.D,
                    ),
                    verbose,
                )

                scaled_weights_block = new_weights_block
                scaled_weights_block = scaled_weights_block.at[:, :, nonzero_mask].set(
                    updated_weights_block
                )
            elif rescale is geom.Rescaling.COMPAT_FLEX:
                scaled_weights_block = AnyDimensionalModel.compat_flex_rescale_weights(
                    (old_filter_block, old_weights_block, old_filters.D),
                    (new_filter_block, new_weights_block, new_filters.D),
                    verbose,
                )
            else:
                scaled_weights_block = new_weights_block

            new_weights[(in_k, in_p)][(out_k, out_p)] = scaled_weights_block

    return new_weights
transfer_weights(new_model: Self, rescale: geom.Rescaling, verbose: bool = False) -> Self ¤

Transfer the weights and biases from an old model to a new model. This allows converting between dimensions as well. This works by copying all jax arrays from the old model to the new model, then resetting the new models conv filters to the new conv filters, then doing any conv filter related weight scaling.

In the future, it may make sense for the updates to be defined on the individual layers, and then the tree_at recursively calls those functions.

Parameters:

Name Type Description Default
old_model

the old model

required
new_model Self

the new model

required
old_conv_filters

the convolution filters used in the old model

required
conv_filters

the convolution filters to use in the new model, can have different D

required
rescale Rescaling

type of rescaling to perform on the weights

required
verbose bool

print the ratio of the squared sum of filters new/old after transfering the weights, default to False.

False

Returns:

Type Description
Self

a new model with the old weights except conv weights which are adjusted, and new filters

Source code in ginjax/models.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def transfer_weights(
    self: Self, new_model: Self, rescale: geom.Rescaling, verbose: bool = False
) -> Self:
    """
    Transfer the weights and biases from an old model to a new model. This allows converting
    between dimensions as well. This works by copying all jax arrays from the old model to the new
    model, then resetting the new models conv filters to the new conv filters, then doing any
    conv filter related weight scaling.

    In the future, it may make sense for the updates to be defined on the individual layers, and
    then the tree_at recursively calls those functions.

    args:
        old_model: the old model
        new_model: the new model
        old_conv_filters: the convolution filters used in the old model
        conv_filters: the convolution filters to use in the new model, can have different D
        rescale: type of rescaling to perform on the weights
        verbose: print the ratio of the squared sum of filters new/old after transfering the
            weights, default to False.

    returns:
        a new model with the old weights except conv weights which are adjusted, and new filters
    """
    # get the new filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    new_filters = get_filters(new_model)

    # now replace all jax arrays
    get_all_weights = lambda m: jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array)
    new_model = eqx.tree_at(get_all_weights, new_model, get_all_weights(self))

    # now reset the proper conv filters
    new_model = eqx.tree_at(get_filters, new_model, new_filters)

    # now set the proper weights
    get_conv_weights = lambda m: [
        x.weights for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    conv_weights = get_conv_weights(self)
    new_weights = [
        AnyDimensionalModel._transfer_conv_weights(
            weight, old_filter, new_filter, rescale, verbose
        )
        for weight, old_filter, new_filter in zip(conv_weights, get_filters(self), new_filters)
    ]
    new_model = eqx.tree_at(get_conv_weights, new_model, new_weights)

    return new_model
convertD(conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs) -> Self ¤

Construct a new model with filters in a higher dimension.

Parameters:

Name Type Description Default
conv_filters MultiImage

the new conv filters we are swapping to, probably in a higher dimension

required
rescale Rescaling

how to rescale the filter weights

required
key Array

key to initialize the weights, since they are overruled it won't matter

required

Returns:

Type Description
Self

a new model with new filters but the old weights

Source code in ginjax/models.py
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
def convertD(
    self: Self, conv_filters: geom.MultiImage, rescale: geom.Rescaling, key: jax.Array, **kwargs
) -> Self:
    """
    Construct a new model with filters in a higher dimension.

    args:
        conv_filters: the new conv filters we are swapping to, probably in a higher dimension
        rescale: how to rescale the filter weights
        key: key to initialize the weights, since they are overruled it won't matter

    returns:
        a new model with new filters but the old weights
    """
    new_model = self.__class__(
        self.input_keys,
        self.target_keys,
        conv_filters,
        self.width,
        self.depth,
        self.use_bias,
        self.activation_f,
        key,
    )

    return self.transfer_weights(new_model, rescale, verbose=False)

handle_activation(activation_f: Optional[Union[Callable, str]], equivariant: bool, input_keys: geom.Signature, D: int, key: ArrayLike) -> Callable[[Any], geom.MultiImage] ¤

Parse what activation function to use, return the appropriate callable

Parameters:

Name Type Description Default
activation_f Optional[Union[Callable, str]]

the type of activation, either a callable or a string name from ACTIVATION_REGISTRY

required
equivariant bool

whether to use an equivariant activation

required
input_keys Signature

the layers input keys

required
D int

dimension of the model

required
key ArrayLike

jax.random key

required

Returns:

Type Description
Callable[[Any], MultiImage]

A layer that performs the specified activation function

Source code in ginjax/models.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def handle_activation(
    activation_f: Optional[Union[Callable, str]],
    equivariant: bool,
    input_keys: geom.Signature,
    D: int,
    key: ArrayLike,
) -> Callable[[Any], geom.MultiImage]:
    """
    Parse what activation function to use, return the appropriate callable

    args:
        activation_f: the type of activation, either a callable or a string name from ACTIVATION_REGISTRY
        equivariant: whether to use an equivariant activation
        input_keys: the layers input keys
        D: dimension of the model
        key: jax.random key

    returns:
        A layer that performs the specified activation function
    """
    if equivariant:
        if activation_f is None:
            return lambda x: x
        elif isinstance(activation_f, str):
            assert activation_f in ACTIVATION_REGISTRY
            return layers.VectorNeuronNonlinear(
                input_keys, D, ACTIVATION_REGISTRY[activation_f], key=key
            )
        else:
            return layers.VectorNeuronNonlinear(input_keys, D, activation_f, key=key)
    else:
        if activation_f is None:
            return layers.LayerWrapper(eqx.nn.Identity(), input_keys)
        elif isinstance(activation_f, str):
            assert activation_f in ACTIVATION_REGISTRY
            return layers.LayerWrapper(ACTIVATION_REGISTRY[activation_f], input_keys)
        else:
            return layers.LayerWrapper(activation_f, input_keys)

make_conv(D: int, input_keys: geom.Signature, target_keys: geom.Signature, use_bias: Union[str, bool], equivariant: bool, invariant_filters: Optional[geom.MultiImage] = None, kernel_size: Optional[Union[int, Sequence[int]]] = None, stride: Union[tuple[int, ...], int] = 1, padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None, lhs_dilation: Optional[tuple[int, ...]] = None, rhs_dilation: Union[int, tuple[int, ...]] = 1, padding_mode: str = 'ZEROS', key: Any = None) -> Union[layers.ConvContract, layers.LayerWrapper] ¤

Factory for convolution layer which makes ConvContract if equivariant and makes a regular conv otherwise.

Parameters:

Name Type Description Default
D int

dimension of the space

required
input_keys Signature

MultiImage Signature of input

required
target_keys Signature

MultiImage Signature of output

required
use_bias Union[str, bool]

whether to use a bias

required
equivariant bool

whether to use an equivariant layer or normal layer

required
invariant_filters Optional[MultiImage]

filters used for equivariant layer

None
kernel_size Optional[Union[int, Sequence[int]]]

sidelength(s) of kernel, only used for non-equivariant layer

None
stride Union[tuple[int, ...], int]

convolution stride

1
padding Optional[Union[str, int, tuple[tuple[int, int], ...]]]

convolution padding

None
lhs_dilation Optional[tuple[int, ...]]

left hand side dilation for transpose convolution

None
rhs_dilation Union[int, tuple[int, ...]]

right hand side dilation for dilated convolutions

1
padding_mode str

for non-equivariant convolutions, define padding mode that is passed to conv. For equivariant, this is a variable of the input

'ZEROS'
key Any

jax.random key

None

Returns:

Type Description
Union[ConvContract, LayerWrapper]

either ConvContract or a LayerWrapper around an equinox convolution

Source code in ginjax/models.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def make_conv(
    D: int,
    input_keys: geom.Signature,
    target_keys: geom.Signature,
    use_bias: Union[str, bool],
    equivariant: bool,
    invariant_filters: Optional[geom.MultiImage] = None,
    kernel_size: Optional[Union[int, Sequence[int]]] = None,
    stride: Union[tuple[int, ...], int] = 1,
    padding: Optional[Union[str, int, tuple[tuple[int, int], ...]]] = None,
    lhs_dilation: Optional[tuple[int, ...]] = None,
    rhs_dilation: Union[int, tuple[int, ...]] = 1,
    padding_mode: str = "ZEROS",
    key: Any = None,  # any instead of arraylike because split cannot handle None
) -> Union[layers.ConvContract, layers.LayerWrapper]:
    """
    Factory for convolution layer which makes ConvContract if equivariant and makes a regular conv
    otherwise.

    args:
        D: dimension of the space
        input_keys: MultiImage Signature of input
        target_keys: MultiImage Signature of output
        use_bias: whether to use a bias
        equivariant: whether to use an equivariant layer or normal layer
        invariant_filters: filters used for equivariant layer
        kernel_size: sidelength(s) of kernel, only used for non-equivariant layer
        stride: convolution stride
        padding: convolution padding
        lhs_dilation: left hand side dilation for transpose convolution
        rhs_dilation: right hand side dilation for dilated convolutions
        padding_mode: for non-equivariant convolutions, define padding mode that is passed to conv.
            For equivariant, this is a variable of the input
        key: jax.random key

    returns:
        either ConvContract or a LayerWrapper around an equinox convolution
    """
    if equivariant:
        assert invariant_filters is not None
        return layers.ConvContract(
            input_keys,
            target_keys,
            invariant_filters,
            use_bias,
            stride,
            padding,
            lhs_dilation,
            rhs_dilation,
            key,
        )
    else:
        assert kernel_size is not None
        assert len(input_keys) == len(target_keys) == 1
        assert input_keys[0][0] == target_keys[0][0] == ((), 0)
        padding = "SAME" if padding is None else padding
        padding_mode = padding_mode if padding == "SAME" else "ZEROS"  # only implemented for SAME
        use_bias = True if use_bias == "auto" else use_bias
        assert isinstance(use_bias, bool)
        if lhs_dilation is None:
            return layers.LayerWrapper(
                eqx.nn.Conv(
                    D,
                    input_keys[0][1],
                    target_keys[0][1],
                    kernel_size,
                    stride,
                    padding,
                    rhs_dilation,
                    use_bias=use_bias,
                    padding_mode=padding_mode,
                    key=key,
                ),
                input_keys,
            )
        else:
            # if there is lhs_dilation, assume its a transpose convolution
            return layers.LayerWrapper(
                eqx.nn.ConvTranspose(
                    D,
                    input_keys[0][1],
                    target_keys[0][1],
                    kernel_size,
                    stride,
                    padding,
                    dilation=rhs_dilation,
                    use_bias=use_bias,
                    padding_mode=padding_mode,
                    key=key,
                ),
                input_keys,
            )

count_params(model: eqx.Module) -> int ¤

Count the number of parameters in the model

Parameters:

Name Type Description Default
model Module

model to measure

required

Returns:

Type Description
int

number of parameters

Source code in ginjax/models.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def count_params(model: eqx.Module) -> int:
    """
    Count the number of parameters in the model

    args:
        model: model to measure

    returns:
        number of parameters
    """
    # get the filters
    is_conv = lambda n: isinstance(n, layers.ConvContract)
    get_filters = lambda m: [
        x.invariant_filters for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv) if is_conv(x)
    ]
    get_array_sizes = lambda m: [
        x.size for x in jax.tree_util.tree_leaves(m, is_leaf=eqx.is_array) if eqx.is_array(x)
    ]

    # filters are arrays, but they aren't params so we subtract them from the total array size
    total_size = sum(get_array_sizes(model))
    filter_size = sum(get_array_sizes(get_filters(model)))
    return total_size - filter_size

get_scaled_filters(D: int, filter_block: jax.Array, weights: jax.Array) -> jax.Array ¤

For a set of filters and a block of weights, scale the filters by the weights.

Consider writing a filters subclass of MultiImage that has these functions defined on it.

Parameters:

Name Type Description Default
D int

the dimension

required
filters

the geometric filters data block, (n_filters,spatial,tensor)

required
weights Array

the block of weights, shape (out_c,in_c,n_filters)

required

Returns:

Type Description
Array

array of filter sums, shape (out_c,in_c,n_filters,spatial,tensor)

Source code in ginjax/models.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def get_scaled_filters(D: int, filter_block: jax.Array, weights: jax.Array) -> jax.Array:
    """
    For a set of filters and a block of weights, scale the filters by the weights.

    Consider writing a filters subclass of MultiImage that has these functions defined on it.

    args:
        D: the dimension
        filters: the geometric filters data block, (n_filters,spatial,tensor)
        weights: the block of weights, shape (out_c,in_c,n_filters)

    returns:
        array of filter sums, shape (out_c,in_c,n_filters,spatial,tensor)
    """
    _, len_k = geom.parse_shape(filter_block.shape[1:], D)
    # (out_c, in_c, n_filters) -> (out_c,in_c,n_filters,(1,)*D,(1,)*k)
    weights_mul = weights.reshape(weights.shape + (1,) * D + (1,) * len_k)

    # (out_c,in_c,n_filters,spatial,tensor)
    return filter_block[None, None] * weights_mul

get_filter_sum(D: int, filter_block: jax.Array, weights: jax.Array) -> jax.Array ¤

For a set of filters and possibly a block of weights, calculate the sum of the filters scaled by the weights, then take the tensor norm.

Consider writing a filters subclass of MultiImage that has these functions defined on it.

Parameters:

Name Type Description Default
D int

dimension of the space

required
filters

the geometric filters as a MultiImage

required
weights Array

the block of weights, shape (out_c,in_c,n_filters)

required

Returns:

Type Description
Array

array of filter sums, shape (out_c,in_c)

Source code in ginjax/models.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def get_filter_sum(D: int, filter_block: jax.Array, weights: jax.Array) -> jax.Array:
    """
    For a set of filters and possibly a block of weights, calculate the sum of the filters scaled
    by the weights, then take the tensor norm.

    Consider writing a filters subclass of MultiImage that has these functions defined on it.

    args:
        D: dimension of the space
        filters: the geometric filters as a MultiImage
        weights: the block of weights, shape (out_c,in_c,n_filters)

    returns:
        array of filter sums, shape (out_c,in_c)
    """
    # (out_c,in_c,n_filters,spatial,tensor)
    scaled_filters = get_scaled_filters(D, filter_block, weights)

    # (out_c,in_c,tensor)
    weights_sum = jnp.sum(scaled_filters, axis=tuple(range(2, 3 + D)))

    # flatten tensor, then get its Frobenius norm.
    # (out_c,in_c,tensor_size)
    weights_sum_flat_tensor = weights_sum.reshape(weights_sum.shape[:2] + (-1,))

    # (out_c,in_c,tensor) -> (out_c,in_c)
    return jnp.linalg.norm(weights_sum_flat_tensor, axis=-1)