Skip to content

Data

ginjax.data ¤

time_series_idxs(past_steps: int, future_steps: int, delta_t: int, total_steps: int) -> tuple ¤

Get the input and output indices to split a time series into overlapping sequences of past steps and future steps.

Parameters:

Name Type Description Default
past_steps int

number of historical steps to use in the model

required
future_steps int

number of future steps of the output

required
delta_t int

number of timesteps per model step, applies to past and future steps

required
total_steps int

total number of timesteps that we are batching

required

Returns:

Type Description
tuple

tuple of jnp.arrays of input and output idxs, 1st axis num sequences, 2nd axis actual sequences

Source code in ginjax/data.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def time_series_idxs(past_steps: int, future_steps: int, delta_t: int, total_steps: int) -> tuple:
    """
    Get the input and output indices to split a time series into overlapping sequences of past steps and
    future steps.

    args:
        past_steps: number of historical steps to use in the model
        future_steps: number of future steps of the output
        delta_t: number of timesteps per model step, applies to past and future steps
        total_steps: total number of timesteps that we are batching

    Returns:
        tuple of jnp.arrays of input and output idxs, 1st axis num sequences, 2nd axis actual sequences
    """
    first_start = 0
    last_start = (
        total_steps - future_steps * delta_t - (past_steps - 1) * delta_t
    )  # one past step is included
    assert (
        first_start < last_start
    ), f"time_series_idxs: {total_steps}-{future_steps}*{delta_t} - ({past_steps}-1)*{delta_t}"
    in_idxs = (
        jnp.arange(first_start, last_start)[:, None]
        + jnp.arange(0, past_steps * delta_t, delta_t)[None, :]
    )

    first_start = past_steps * delta_t
    last_start = total_steps - (future_steps - 1) * delta_t
    assert (
        first_start < last_start
    ), f"time_series_idxs: {total_steps}-({future_steps}-1)*{delta_t}, {past_steps}*{delta_t}"
    out_idxs = (
        jnp.arange(first_start, last_start)[:, None]
        + jnp.arange(0, future_steps * delta_t, delta_t)[None, :]
    )
    assert len(in_idxs) == len(out_idxs)

    return in_idxs, out_idxs

batch_time_series(dynamic_fields: geom.MultiImage, constant_fields: geom.MultiImage, total_steps: int, past_steps: int, future_steps: int, skip_initial: int = 0, delta_t: int = 1, downsample: int = 0) -> tuple[geom.MultiImage, geom.MultiImage] ¤

Given time series fields batch an initial batch dimension, convert them to input and output MultiImages based on the number of past steps, future steps, and any subsampling/downsampling.

Parameters:

Name Type Description Default
dynamic_fields MultiImage

the dynamic fields, shape (batch,channels*time,spatial,tensor)

required
constant_fields MultiImage

the constant fields, shape (batch,channels,spatial,tensor)

required
total_steps int

total number of timesteps we are working with

required
past_steps int

number of historical steps to use in the model

required
future_steps int

number of future steps

required
skip_initial int

number of initial time steps to skip

0
delta_t int

number of timesteps per model step

1
downsample int

number of times to downsample the image by average pooling, decreases by a factor of 2

0

Returns:

Type Description
tuple[MultiImage, MultiImage]

tuple of MultiImages multi_image_X and multi_image_Y

Source code in ginjax/data.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def batch_time_series(
    dynamic_fields: geom.MultiImage,
    constant_fields: geom.MultiImage,
    total_steps: int,
    past_steps: int,
    future_steps: int,
    skip_initial: int = 0,
    delta_t: int = 1,
    downsample: int = 0,
) -> tuple[geom.MultiImage, geom.MultiImage]:
    """
    Given time series fields batch an initial batch dimension, convert them to input and output
    MultiImages based on the number of past steps, future steps, and any subsampling/downsampling.

    args:
        dynamic_fields: the dynamic fields, shape (batch,channels*time,spatial,tensor)
        constant_fields: the constant fields, shape (batch,channels,spatial,tensor)
        total_steps: total number of timesteps we are working with
        past_steps: number of historical steps to use in the model
        future_steps: number of future steps
        skip_initial: number of initial time steps to skip
        delta_t: number of timesteps per model step
        downsample: number of times to downsample the image by average pooling, decreases by a factor
            of 2

    returns:
        tuple of MultiImages multi_image_X and multi_image_Y
    """
    vmap_f = jax.vmap(times_series_to_multi_images, in_axes=(0, 0) + (None,) * 6)
    multi_image_x, multi_image_y = vmap_f(
        dynamic_fields,
        constant_fields,
        total_steps,
        past_steps,
        future_steps,
        skip_initial,
        delta_t,
        downsample,
    )
    return multi_image_x.combine_axes((0, 1)), multi_image_y.combine_axes((0, 1))

times_series_to_multi_images(dynamic_fields: geom.MultiImage, constant_fields: geom.MultiImage, total_steps: int, past_steps: int, future_steps: int, skip_initial: int = 0, delta_t: int = 1, downsample: int = 0) -> tuple[geom.MultiImage, geom.MultiImage] ¤

Given time series fields, convert them to input and output MultiImages based on the number of past steps, future steps, and any subsampling/downsampling.

Parameters:

Name Type Description Default
dynamic_fields MultiImage

the dynamic fields, shape (channels*time,spatial,tensor)

required
constant_fields MultiImage

the constant fields, shape (channels,spatial,tensor)

required
total_steps int

total number of timesteps we are working with

required
past_steps int

number of historical steps to use in the model

required
future_steps int

number of future steps

required
skip_initial int

number of initial time steps to skip

0
delta_t int

number of timesteps per model step

1
downsample int

number of times to downsample the image by average pooling, decreases by a factor of 2

0

Returns:

Type Description
tuple[MultiImage, MultiImage]

tuple of MultiImages multi_image_X and multi_image_Y

Source code in ginjax/data.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def times_series_to_multi_images(
    dynamic_fields: geom.MultiImage,
    constant_fields: geom.MultiImage,
    total_steps: int,
    past_steps: int,
    future_steps: int,
    skip_initial: int = 0,
    delta_t: int = 1,
    downsample: int = 0,
) -> tuple[geom.MultiImage, geom.MultiImage]:
    """
    Given time series fields, convert them to input and output MultiImages based on the number of past steps,
    future steps, and any subsampling/downsampling.

    args:
        dynamic_fields: the dynamic fields, shape (channels*time,spatial,tensor)
        constant_fields: the constant fields, shape (channels,spatial,tensor)
        total_steps: total number of timesteps we are working with
        past_steps: number of historical steps to use in the model
        future_steps: number of future steps
        skip_initial: number of initial time steps to skip
        delta_t: number of timesteps per model step
        downsample: number of times to downsample the image by average pooling, decreases by a factor
            of 2

    returns:
        tuple of MultiImages multi_image_X and multi_image_Y
    """
    assert len(dynamic_fields.values()) != 0

    spatial_dims = dynamic_fields.get_spatial_dims()
    D = dynamic_fields.D
    input_idxs, output_idxs = time_series_idxs(
        past_steps, future_steps, delta_t, total_steps - skip_initial
    )

    multi_image_x = dynamic_fields.empty()
    multi_image_y = dynamic_fields.empty()
    for (k, parity), image in dynamic_fields.expand(0, total_steps).items():
        image = image[:, skip_initial:]
        n_channels = len(image)

        input_image = image[:, input_idxs].reshape(
            (n_channels, -1, past_steps) + spatial_dims + (D,) * len(k)
        )
        output_image = image[:, output_idxs].reshape(
            (n_channels, -1, future_steps) + spatial_dims + (D,) * len(k)
        )

        # (c,b,timesteps,spatial,tensor) -> (b,c,timesteps,spatial,tensor)
        multi_image_x.append(k, parity, jnp.moveaxis(input_image, 1, 0))
        multi_image_y.append(k, parity, jnp.moveaxis(output_image, 1, 0))

    multi_image_x = multi_image_x.combine_axes((1, 2))
    multi_image_y = multi_image_y.combine_axes((1, 2))

    batch = len(next(iter(multi_image_x.values())))
    for (k, parity), image in constant_fields.items():
        multi_image_x.append(k, parity, jnp.full((batch,) + image.shape, image), axis=1)

    for _ in range(downsample):
        multi_image_x = multi_image_x.average_pool(2)
        multi_image_y = multi_image_y.average_pool(2)

    return multi_image_x, multi_image_y