Skip to content

Stopping conditions

ginjax.ml.stopping_conditions ¤

StopCondition ¤

Base StopCondition.

Source code in ginjax/ml/stopping_conditions.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class StopCondition:
    """
    Base StopCondition.
    """

    best_model: Optional[eqx.Module]
    verbose: int

    def __init__(self: Self, verbose: int = 0) -> None:
        """
        StopCondition constructor.

        args:
            verbose: verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total
                epochs, and 2 prints every epoch.
        """
        assert verbose in {0, 1, 2}
        self.best_model = None
        self.verbose = verbose

    def stop(
        self: Self,
        model: eqx.Module,
        current_epoch: int,
        train_loss: Optional[jax.Array],
        val_loss: Optional[jax.Array],
        epoch_time: float,
    ) -> bool:
        return True

    def log_status(
        self: Self,
        epoch: int,
        train_loss: Optional[ArrayLike],
        val_loss: Optional[ArrayLike],
        epoch_time: float,
    ) -> None:
        if train_loss is not None:
            if val_loss is not None:
                print(
                    f"Epoch {epoch} Train: {train_loss:.7f} Val: {val_loss:.7f} Epoch time: {epoch_time:.5f}",
                )
            else:
                print(f"Epoch {epoch} Train: {train_loss:.7f} Epoch time: {epoch_time:.5f}")
__init__(verbose: int = 0) -> None ¤

StopCondition constructor.

Parameters:

Name Type Description Default
verbose int

verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total epochs, and 2 prints every epoch.

0
Source code in ginjax/ml/stopping_conditions.py
19
20
21
22
23
24
25
26
27
28
29
def __init__(self: Self, verbose: int = 0) -> None:
    """
    StopCondition constructor.

    args:
        verbose: verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total
            epochs, and 2 prints every epoch.
    """
    assert verbose in {0, 1, 2}
    self.best_model = None
    self.verbose = verbose

EpochStop ¤

Bases: StopCondition

Stop when enough epochs have passed.

Source code in ginjax/ml/stopping_conditions.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class EpochStop(StopCondition):
    """
    Stop when enough epochs have passed.
    """

    def __init__(self: Self, epochs: int, verbose: int = 0) -> None:
        """
        EpochStop constructor.

        args:
            epochs: epoch limit
            verbose: verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total
                epochs, and 2 prints every epoch.
        """
        super(EpochStop, self).__init__(verbose=verbose)
        self.epochs = epochs

    def stop(
        self: Self,
        model: eqx.Module,
        current_epoch: int,
        train_loss: Optional[jax.Array],
        val_loss: Optional[jax.Array],
        epoch_time: float,
    ) -> bool:
        """
        Stops if current_epoch is greater than or equal to the specified stop epoch, and log_status
        depending on the level of verbose.

        args:
            model: the current model, saved every epoch
            current_epoch: current epoch
            train_loss: current training loss
            val_loss: current valdiation loss
            epoch_time: how long the epoch took

        returns:
            whether to stop
        """
        self.best_model = model

        if self.verbose == 2 or (
            self.verbose == 1 and (current_epoch % (self.epochs // np.min([10, self.epochs])) == 0)
        ):
            self.log_status(current_epoch, train_loss, val_loss, epoch_time)

        return current_epoch >= self.epochs
__init__(epochs: int, verbose: int = 0) -> None ¤

EpochStop constructor.

Parameters:

Name Type Description Default
epochs int

epoch limit

required
verbose int

verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total epochs, and 2 prints every epoch.

0
Source code in ginjax/ml/stopping_conditions.py
62
63
64
65
66
67
68
69
70
71
72
def __init__(self: Self, epochs: int, verbose: int = 0) -> None:
    """
    EpochStop constructor.

    args:
        epochs: epoch limit
        verbose: verbose level, one of 0,1,2. 0 prints nothing, 1 prints every 10% of total
            epochs, and 2 prints every epoch.
    """
    super(EpochStop, self).__init__(verbose=verbose)
    self.epochs = epochs
stop(model: eqx.Module, current_epoch: int, train_loss: Optional[jax.Array], val_loss: Optional[jax.Array], epoch_time: float) -> bool ¤

Stops if current_epoch is greater than or equal to the specified stop epoch, and log_status depending on the level of verbose.

Parameters:

Name Type Description Default
model Module

the current model, saved every epoch

required
current_epoch int

current epoch

required
train_loss Optional[Array]

current training loss

required
val_loss Optional[Array]

current valdiation loss

required
epoch_time float

how long the epoch took

required

Returns:

Type Description
bool

whether to stop

Source code in ginjax/ml/stopping_conditions.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def stop(
    self: Self,
    model: eqx.Module,
    current_epoch: int,
    train_loss: Optional[jax.Array],
    val_loss: Optional[jax.Array],
    epoch_time: float,
) -> bool:
    """
    Stops if current_epoch is greater than or equal to the specified stop epoch, and log_status
    depending on the level of verbose.

    args:
        model: the current model, saved every epoch
        current_epoch: current epoch
        train_loss: current training loss
        val_loss: current valdiation loss
        epoch_time: how long the epoch took

    returns:
        whether to stop
    """
    self.best_model = model

    if self.verbose == 2 or (
        self.verbose == 1 and (current_epoch % (self.epochs // np.min([10, self.epochs])) == 0)
    ):
        self.log_status(current_epoch, train_loss, val_loss, epoch_time)

    return current_epoch >= self.epochs

TrainLoss ¤

Bases: StopCondition

Stop when the training error stops improving after patience number of epochs.

Source code in ginjax/ml/stopping_conditions.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class TrainLoss(StopCondition):
    """
    Stop when the training error stops improving after patience number of epochs.
    """

    def __init__(self: Self, patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None:
        """
        TrainLoss constructor.

        args:
            patience: how many epochs of non-improvement to wait before stopping
            min_delta: the minimum decrease to count as an improvement
            verbose: the verbose level, one of 0,1. 0 don't log, 1 log on improvement.
        """
        super(TrainLoss, self).__init__(verbose=verbose)
        self.patience = patience
        self.min_delta = min_delta
        self.best_train_loss = jnp.inf
        self.epochs_since_best = 0

    def stop(
        self: Self,
        model: eqx.Module,
        current_epoch: int,
        train_loss: Optional[jax.Array],
        val_loss: Optional[jax.Array],
        epoch_time: float,
    ) -> bool:
        """
        Stops if the training loss has not improved for a number of epochs equal to patience, and log_status
        depending on the level of verbose.

        args:
            model: the current model, saved every epoch
            current_epoch: current epoch
            train_loss: current training loss
            val_loss: current valdiation loss
            epoch_time: how long the epoch took

        returns:
            whether to stop
        """
        if train_loss is None:
            return False
        else:
            train_loss = train_loss.astype(float)

        if train_loss < (self.best_train_loss - self.min_delta):
            self.best_train_loss = train_loss
            self.best_model = model
            self.epochs_since_best = 0

            if self.verbose >= 1:
                self.log_status(current_epoch, train_loss, val_loss, epoch_time)
        else:
            self.epochs_since_best += 1

        return self.epochs_since_best > self.patience
__init__(patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None ¤

TrainLoss constructor.

Parameters:

Name Type Description Default
patience int

how many epochs of non-improvement to wait before stopping

0
min_delta float

the minimum decrease to count as an improvement

0
verbose int

the verbose level, one of 0,1. 0 don't log, 1 log on improvement.

0
Source code in ginjax/ml/stopping_conditions.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(self: Self, patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None:
    """
    TrainLoss constructor.

    args:
        patience: how many epochs of non-improvement to wait before stopping
        min_delta: the minimum decrease to count as an improvement
        verbose: the verbose level, one of 0,1. 0 don't log, 1 log on improvement.
    """
    super(TrainLoss, self).__init__(verbose=verbose)
    self.patience = patience
    self.min_delta = min_delta
    self.best_train_loss = jnp.inf
    self.epochs_since_best = 0
stop(model: eqx.Module, current_epoch: int, train_loss: Optional[jax.Array], val_loss: Optional[jax.Array], epoch_time: float) -> bool ¤

Stops if the training loss has not improved for a number of epochs equal to patience, and log_status depending on the level of verbose.

Parameters:

Name Type Description Default
model Module

the current model, saved every epoch

required
current_epoch int

current epoch

required
train_loss Optional[Array]

current training loss

required
val_loss Optional[Array]

current valdiation loss

required
epoch_time float

how long the epoch took

required

Returns:

Type Description
bool

whether to stop

Source code in ginjax/ml/stopping_conditions.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def stop(
    self: Self,
    model: eqx.Module,
    current_epoch: int,
    train_loss: Optional[jax.Array],
    val_loss: Optional[jax.Array],
    epoch_time: float,
) -> bool:
    """
    Stops if the training loss has not improved for a number of epochs equal to patience, and log_status
    depending on the level of verbose.

    args:
        model: the current model, saved every epoch
        current_epoch: current epoch
        train_loss: current training loss
        val_loss: current valdiation loss
        epoch_time: how long the epoch took

    returns:
        whether to stop
    """
    if train_loss is None:
        return False
    else:
        train_loss = train_loss.astype(float)

    if train_loss < (self.best_train_loss - self.min_delta):
        self.best_train_loss = train_loss
        self.best_model = model
        self.epochs_since_best = 0

        if self.verbose >= 1:
            self.log_status(current_epoch, train_loss, val_loss, epoch_time)
    else:
        self.epochs_since_best += 1

    return self.epochs_since_best > self.patience

ValLoss ¤

Bases: StopCondition

Stop when the validation error stops improving after patience number of epochs.

Source code in ginjax/ml/stopping_conditions.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class ValLoss(StopCondition):
    """
    Stop when the validation error stops improving after patience number of epochs.
    """

    def __init__(self: Self, patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None:
        """
        ValLoss constructor.

        args:
            patience: how many epochs of non-improvement to wait before stopping
            min_delta: the minimum decrease to count as an improvement
            verbose: the verbose level, one of 0,1. 0 don't log, 1 log on improvement.
        """
        super(ValLoss, self).__init__(verbose=verbose)
        self.patience = patience
        self.min_delta = min_delta
        self.best_val_loss = jnp.inf
        self.epochs_since_best = 0

    def stop(
        self: Self,
        model: eqx.Module,
        current_epoch: int,
        train_loss: Optional[jax.Array],
        val_loss: Optional[jax.Array],
        epoch_time: float,
    ) -> bool:
        """
        Stops if the val loss has not improved for a number of epochs equal to patience, and log_status
        depending on the level of verbose.

        args:
            model: the current model, saved every epoch
            current_epoch: current epoch
            train_loss: current training loss
            val_loss: current valdiation loss
            epoch_time: how long the epoch took

        returns:
            whether to stop
        """
        if val_loss is None:
            return False
        else:
            val_loss = val_loss.astype(float)

        if val_loss < (self.best_val_loss - self.min_delta):
            self.best_val_loss = val_loss
            self.best_model = model
            self.epochs_since_best = 0

            if self.verbose >= 1:
                self.log_status(current_epoch, train_loss, val_loss, epoch_time)
        else:
            self.epochs_since_best += 1

        return self.epochs_since_best > self.patience
__init__(patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None ¤

ValLoss constructor.

Parameters:

Name Type Description Default
patience int

how many epochs of non-improvement to wait before stopping

0
min_delta float

the minimum decrease to count as an improvement

0
verbose int

the verbose level, one of 0,1. 0 don't log, 1 log on improvement.

0
Source code in ginjax/ml/stopping_conditions.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def __init__(self: Self, patience: int = 0, min_delta: float = 0, verbose: int = 0) -> None:
    """
    ValLoss constructor.

    args:
        patience: how many epochs of non-improvement to wait before stopping
        min_delta: the minimum decrease to count as an improvement
        verbose: the verbose level, one of 0,1. 0 don't log, 1 log on improvement.
    """
    super(ValLoss, self).__init__(verbose=verbose)
    self.patience = patience
    self.min_delta = min_delta
    self.best_val_loss = jnp.inf
    self.epochs_since_best = 0
stop(model: eqx.Module, current_epoch: int, train_loss: Optional[jax.Array], val_loss: Optional[jax.Array], epoch_time: float) -> bool ¤

Stops if the val loss has not improved for a number of epochs equal to patience, and log_status depending on the level of verbose.

Parameters:

Name Type Description Default
model Module

the current model, saved every epoch

required
current_epoch int

current epoch

required
train_loss Optional[Array]

current training loss

required
val_loss Optional[Array]

current valdiation loss

required
epoch_time float

how long the epoch took

required

Returns:

Type Description
bool

whether to stop

Source code in ginjax/ml/stopping_conditions.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def stop(
    self: Self,
    model: eqx.Module,
    current_epoch: int,
    train_loss: Optional[jax.Array],
    val_loss: Optional[jax.Array],
    epoch_time: float,
) -> bool:
    """
    Stops if the val loss has not improved for a number of epochs equal to patience, and log_status
    depending on the level of verbose.

    args:
        model: the current model, saved every epoch
        current_epoch: current epoch
        train_loss: current training loss
        val_loss: current valdiation loss
        epoch_time: how long the epoch took

    returns:
        whether to stop
    """
    if val_loss is None:
        return False
    else:
        val_loss = val_loss.astype(float)

    if val_loss < (self.best_val_loss - self.min_delta):
        self.best_val_loss = val_loss
        self.best_model = model
        self.epochs_since_best = 0

        if self.verbose >= 1:
            self.log_status(current_epoch, train_loss, val_loss, epoch_time)
    else:
        self.epochs_since_best += 1

    return self.epochs_since_best > self.patience

AnyStop ¤

Bases: StopCondition

Combine multiple stopping conditions, and stop when any of them stop. Can be used to implement early stopping. The best model returned is according to the first stop condition, and each prints according to its verbosity.

Source code in ginjax/ml/stopping_conditions.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class AnyStop(StopCondition):
    """
    Combine multiple stopping conditions, and stop when any of them stop. Can be used to implement
    early stopping. The best model returned is according to the first stop condition, and each
    prints according to its verbosity.
    """

    stop_conditions: list[StopCondition]

    def __init__(self: Self, stop_conditions: list[StopCondition]) -> None:
        """
        StopCondition constructor.

        args:
            stop_conditions: a list of all the stopping conditions
        """
        assert len(stop_conditions) > 0
        self.stop_conditions = stop_conditions

    def stop(
        self: Self,
        model: eqx.Module,
        current_epoch: int,
        train_loss: Optional[jax.Array],
        val_loss: Optional[jax.Array],
        epoch_time: float,
    ) -> bool:
        test_all = [
            sc.stop(model, current_epoch, train_loss, val_loss, epoch_time)
            for sc in self.stop_conditions
        ]
        self.best_model = self.stop_conditions[0].best_model
        return functools.reduce(lambda x, y: x or y, test_all, False)
__init__(stop_conditions: list[StopCondition]) -> None ¤

StopCondition constructor.

Parameters:

Name Type Description Default
stop_conditions list[StopCondition]

a list of all the stopping conditions

required
Source code in ginjax/ml/stopping_conditions.py
235
236
237
238
239
240
241
242
243
def __init__(self: Self, stop_conditions: list[StopCondition]) -> None:
    """
    StopCondition constructor.

    args:
        stop_conditions: a list of all the stopping conditions
    """
    assert len(stop_conditions) > 0
    self.stop_conditions = stop_conditions