Skip to content

Losses

ginjax.ml.losses ¤

timestep_smse_loss(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, n_steps: int, reduce: Optional[str] = 'mean') -> jax.Array ¤

Returns loss for each timestep. Loss is summed over the channels, and mean over spatial dimensions and the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
n_steps int

number of timesteps, all channels should be a multiple of this

required
reduce Optional[str]

how to reduce over the batch, one of mean or max

'mean'

Returns:

Type Description
Array

the loss array with shape (batch,n_steps) if reduce is None or (n_steps,)

Source code in ginjax/ml/losses.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def timestep_smse_loss(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    n_steps: int,
    reduce: Optional[str] = "mean",
) -> jax.Array:
    """
    Returns loss for each timestep. Loss is summed over the channels, and mean over spatial dimensions
    and the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        n_steps: number of timesteps, all channels should be a multiple of this
        reduce: how to reduce over the batch, one of mean or max

    returns:
        the loss array with shape (batch,n_steps) if reduce is None or (n_steps,)
    """
    reduce_options = {"mean", "max", None}
    assert (
        reduce in reduce_options
    ), f"timestep_smse_loss: reduce={reduce} must be one of {reduce_options}"
    assert (
        multi_image_x.get_n_leading() == multi_image_x.get_n_leading() == 2
    ), "timestep_smse_loss: MultiImages must have batch and channel axes"

    spatial_size = np.multiply.reduce(multi_image_x.get_spatial_dims())
    batch = multi_image_x.get_L()
    loss_per_step = jnp.zeros((batch, n_steps))
    for image_a, image_b in zip(
        multi_image_x.values(), multi_image_y.values()
    ):  # loop over image types
        image_a = image_a.reshape((batch, -1, n_steps) + image_a.shape[2:])
        image_b = image_b.reshape((batch, -1, n_steps) + image_b.shape[2:])
        loss = (
            jnp.sum((image_a - image_b) ** 2, axis=(1,) + tuple(range(3, image_a.ndim)))
            / spatial_size
        )
        loss_per_step = loss_per_step + loss

    if reduce == "mean":
        return jnp.mean(loss_per_step, axis=0)
    elif reduce == "max":
        return loss_per_step[jnp.argmax(jnp.sum(loss_per_step, axis=1))]
    else:
        return loss_per_step

smse_loss(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: Optional[str] = 'mean') -> jax.Array ¤

Sum of mean squared error loss. The sum is over the channels, the mean is over the spatial dimensions. Mean is also taken over batch if reduce == 'mean', or it returns each loss if reduce is None.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce Optional[str]

how to reduce over batch. Either "mean" or None.

'mean'

Returns:

Type Description
Array

the loss value

Source code in ginjax/ml/losses.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def smse_loss(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: Optional[str] = "mean",
) -> jax.Array:
    """
    Sum of mean squared error loss. The sum is over the channels, the mean is over the spatial
    dimensions. Mean is also taken over batch if reduce == 'mean', or it returns each loss if
    reduce is None.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.

    returns:
        the loss value
    """
    reduce_options = {"mean", None}
    assert reduce in reduce_options, f"smse_loss: reduce={reduce} must be one of {reduce_options}"
    assert (
        multi_image_x.get_n_leading() == multi_image_x.get_n_leading() == 2
    ), "smse_loss: MultiImages must have batch and channel axes"

    spatial_size = np.multiply.reduce(multi_image_x.get_spatial_dims())
    loss_per_batch = jnp.zeros(multi_image_x.get_L())
    for image_a, image_b in zip(multi_image_x.values(), multi_image_y.values()):
        loss = jnp.sum((image_a - image_b) ** 2, axis=tuple(range(1, image_a.ndim))) / spatial_size
        loss_per_batch = loss_per_batch + loss

    return jnp.mean(loss_per_batch) if reduce == "mean" else loss_per_batch

normalized_smse_loss(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: str | None = 'mean', eps: float = 1e-05) -> jax.Array ¤

Pointwise normalized loss. We find the norm of each channel at each spatial point of the true value and divide the tensor by that norm. Then we take the l2 loss, mean over the spatial dimensions, sum over the channels, then mean over the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce str | None

how to reduce over batch. Either "mean" or None.

'mean'
eps float

ensure that we aren't dividing by 0 norm

1e-05

Returns:

Type Description
Array

the loss value

Source code in ginjax/ml/losses.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def normalized_smse_loss(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: str | None = "mean",
    eps: float = 1e-5,
) -> jax.Array:
    """
    Pointwise normalized loss. We find the norm of each channel at each spatial point of the true value
    and divide the tensor by that norm. Then we take the l2 loss, mean over the spatial dimensions, sum
    over the channels, then mean over the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.
        eps: ensure that we aren't dividing by 0 norm

    returns:
        the loss value
    """
    spatial_size = np.multiply.reduce(multi_image_x.get_spatial_dims())

    order_loss = jnp.zeros(multi_image_x.get_L())
    for (k, parity), img_block in multi_image_y.items():
        # (b,c,spatial, (1,)*k)
        norm = geom.norm(multi_image_y.D + 2, img_block, keepdims=True) ** 2
        normalized_l2 = ((multi_image_x[(k, parity)] - img_block) ** 2) / (norm + eps)
        # (b,)
        order_loss = order_loss + (
            jnp.sum(normalized_l2, axis=range(1, img_block.ndim)) / spatial_size
        )

    return jnp.mean(order_loss) if reduce == "mean" else order_loss

nrmse_per_pixel_loss(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: str | None = 'mean', eps: float = 0) -> jax.Array ¤

The normalized root mean squared error. The error is relative to the second input per pixel.

The average is taken over each pixel, and channel. If reduce is 'mean' it is also taken over the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce str | None

how to reduce over batch. Either "mean" or None.

'mean'
eps float

epsilon to add to the denominator to avoid divide by zero errors

0

Returns:

Type Description
Array

average root mean squared error with respect to the second input.

Source code in ginjax/ml/losses.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def nrmse_per_pixel_loss(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: str | None = "mean",
    eps: float = 0,
) -> jax.Array:
    """
    The normalized root mean squared error. The error is relative to the second input per pixel.

    The average is taken over each pixel, and channel. If reduce is 'mean' it is also
    taken over the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.
        eps: epsilon to add to the denominator to avoid divide by zero errors

    returns:
        average root mean squared error with respect to the second input.
    """
    reduce_options = {"mean", None}
    assert (
        reduce in reduce_options
    ), f"l1_rel_error: reduce={reduce} must be one of {reduce_options}"
    assert (
        multi_image_x.get_n_leading() == multi_image_y.get_n_leading() == 2
    ), "l1_rel_error: MultiImages must have batch and channel axes"

    batch = multi_image_x.get_L()
    D = multi_image_x.D
    error_per_batch = jnp.zeros((batch, 0))
    for image_a, image_b in zip(multi_image_x.values(), multi_image_y.values()):
        diff_norm = geom.norm(D + 2, image_a - image_b)  # (batch,channels,spatial)
        image_b_norm = geom.norm(D + 2, image_b)  # (batch,channels,spatial)
        rel_error = diff_norm / (image_b_norm + eps)
        # (batch,channels*spatial)
        error_per_batch = jnp.concatenate([error_per_batch, rel_error.reshape((batch, -1))], axis=1)

    error_per_batch = jnp.mean(error_per_batch, axis=1)  # mean over channels, spatial -> (batch,)

    return jnp.mean(error_per_batch) if reduce == "mean" else error_per_batch

nrmse_loss(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: str | None = 'mean', eps: float = 0) -> jax.Array ¤

The normalized root mean squared error. This definition follows the standard one used in literature where the norm is taken over the entire difference image and reference image before doing the division diff / reference. We then take the mean over the channels, and then reduce over the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce str | None

how to reduce over batch. Either "mean" or None.

'mean'
eps float

epsilon to add to the denominator to avoid divide by zero errors

0

Returns:

Type Description
Array

average root mean squared error with respect to the second input.

Source code in ginjax/ml/losses.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def nrmse_loss(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: str | None = "mean",
    eps: float = 0,
) -> jax.Array:
    """
    The normalized root mean squared error. This definition follows the standard one used in
    literature where the norm is taken over the entire difference image and reference image
    before doing the division diff / reference. We then take the mean over the channels,
    and then reduce over the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.
        eps: epsilon to add to the denominator to avoid divide by zero errors

    returns:
        average root mean squared error with respect to the second input.
    """
    reduce_options = {"mean", None}
    assert (
        reduce in reduce_options
    ), f"l1_rel_error: reduce={reduce} must be one of {reduce_options}"
    assert (
        multi_image_x.get_n_leading() == multi_image_y.get_n_leading() == 2
    ), "l1_rel_error: MultiImages must have batch and channel axes"

    batch = multi_image_x.get_L()

    error_per_batch = jnp.zeros((batch, 0))
    for image_a, image_b in zip(multi_image_x.values(), multi_image_y.values()):
        # reshape to (batch,channels,spatial*tensor)
        image_a = image_a.reshape(image_a.shape[:2] + (-1,))
        image_b = image_b.reshape(image_b.shape[:2] + (-1,))
        diff_norms = jnp.linalg.norm(image_a - image_b, axis=2)
        target_norms = jnp.linalg.norm(image_b, axis=2)
        error_per_batch = jnp.concatenate(
            [error_per_batch, diff_norms / (target_norms + eps)], axis=1
        )

    error_per_batch = jnp.mean(error_per_batch, axis=1)  # (batch,channels) -> (batch,)

    return jnp.mean(error_per_batch) if reduce == "mean" else error_per_batch

l2_rel_error(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: str | None = 'mean', eps: float = 0) -> jax.Array ¤

The relative error, taken as a norm over the entire difference image divided by the norm over the entire reference image. We then take the mean over the image types, and then reduce over the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce str | None

how to reduce over batch. Either "mean" or None.

'mean'
eps float

epsilon to add to the denominator to avoid divide by zero errors

0

Returns:

Type Description
Array

average percent relative error with respect to the second input.

Source code in ginjax/ml/losses.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def l2_rel_error(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: str | None = "mean",
    eps: float = 0,
) -> jax.Array:
    """
    The relative error, taken as a norm over the entire difference image divided by the norm over
    the entire reference image. We then take the mean over the image types, and then reduce over
    the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.
        eps: epsilon to add to the denominator to avoid divide by zero errors

    returns:
        average percent relative error with respect to the second input.
    """
    return nrmse_loss(multi_image_x, multi_image_y, reduce, eps) * 100  # convert to percent

l2_per_pixel_rel_error(multi_image_x: geom.MultiImage, multi_image_y: geom.MultiImage, reduce: str | None = 'mean', eps: float = 0) -> jax.Array ¤

Average per tensor relative error as a percentage. The error is relative to the second input per pixel.

The average is taken over each pixel, and channel. If reduce is 'mean' it is also taken over the batch.

Parameters:

Name Type Description Default
multi_image_x MultiImage

predicted data, image_blocks are shape (batch,channels,spatial,tensor)

required
multi_image_y MultiImage

target data, image_blocks are shape (batch,channels,spatial,tensor)

required
reduce str | None

how to reduce over batch. Either "mean" or None.

'mean'
eps float

epsilon to add to the denominator to avoid divide by zero errors

0

Returns:

Type Description
Array

average percent relative error with respect to the second input.

Source code in ginjax/ml/losses.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def l2_per_pixel_rel_error(
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    reduce: str | None = "mean",
    eps: float = 0,
) -> jax.Array:
    """
    Average per tensor relative error as a percentage. The error is relative to the second input per pixel.

    The average is taken over each pixel, and channel. If reduce is 'mean' it is also
    taken over the batch.

    args:
        multi_image_x: predicted data, image_blocks are shape (batch,channels,spatial,tensor)
        multi_image_y: target data, image_blocks are shape (batch,channels,spatial,tensor)
        reduce: how to reduce over batch. Either "mean" or None.
        eps: epsilon to add to the denominator to avoid divide by zero errors

    returns:
        average percent relative error with respect to the second input.
    """
    return (
        nrmse_per_pixel_loss(multi_image_x, multi_image_y, reduce, eps) * 100
    )  # convert to percent