Add API support for time-varying scalars (#364)#375
Conversation
Adds a public time-varying scalars (TVS) channel that flows end-to-end from datasets through the encoder's global conditioning vector and is correctly sliced during autoregressive rollouts. Changes - types: new TensorTC alias; Sample/Batch carry optional time_varying_scalars with collation, repeat(), and to() support. - data: HDF5Dataset loads per-trajectory time_varying_scalars and stores a window of n_steps_output + n_tvs_extra_steps slices per sample; new n_tvs_extra_steps dataset/datamodule parameter so rollout has enough future slices. - encoders.base: encode_cond appends the current TVS slice to global_cond and raises a clear error if the buffer is exhausted. - models.encoder_processor_decoder: _clone_batch clones TVS; _advance_batch shifts TVS by stride at every rollout step. - scripts/setup: expose n_time_varying_scalars in logic_stats. - scripts/data, benchmarking/inference: TVS-aware device move and synthetic batch construction. Closes alan-turing-institute#364.
|
I have the time varying fields code change locally, will follow up with that PR after we agree on implementations for the scalar case. |
sgreenbury
left a comment
There was a problem hiding this comment.
Thank you @qiencai for taking a look at this and opening the PR! I added some initial comments below, and the following thoughts to discuss:
- When dataset is not in full trajectory mode, would it make sense to get slices of
time_varying_scalarsthat match the input temporal dim's size? - For full trajectory model, should this provide the full sequence with inputs or on the outputs?
- Remaining aspects include incorporating into the encoder class/latent spaces
sgreenbury
left a comment
There was a problem hiding this comment.
Thanks for the discussion earlier, adding some notes/code here on the ideas:
- Update dataset windowing: we could aim to have the
n_steps_inputwindow for the time varying scalars (so same length of time as input as theinput_fields. - Permuting the scalars in encode
- Add flexibility around
n_steps_inputbeing different ton_time_varying_scalars - Tests: it could work well to add some tests using the
time_varying_scalarslike some of the examples in this test file https://github.com/alan-turing-institute/autocast/blob/3cf33ddcfa80ab78375212809824b9c06595621c/tests/models/test_encoder_processor_decoder.py
Address PR alan-turing-institute#375 review feedback: - Condition on the current input window (input-aligned) instead of the prediction step; encode_cond flattens the last n_tvs_steps steps via rearrange('b t c -> b (t c)'). New encoder attr n_tvs_steps (default 1) selects a past-frame subset of the input window. - Drop n_tvs_extra_steps; the per-sample TVS buffer is sized by the window (input, extended across the rollout horizon in full/subtrajectory mode so the sliding window reaches each frame without ever looking ahead). - Add subtrajectory_mode + subtrajectory_start_idxs to build rollout subtrajectories at explicit starts; datamodule uses it for rollout datasets when set. - Add data/calendar.py:month_start_indices (no-leap default) to generate month-anchored start indices, ported from monthly_init_indices. - inference.py: drop getattr; rename tvs -> time_varying_scalars. - Tests for calendar, dataset windowing/subtrajectory mode, encode_cond rearrange, and _advance_batch striding.
Calendar-agnostic counterpart to month_start_indices: place rollout initialisations on a fixed-period grid (every init_interval steps, with an optional init_offset) so subtrajectory_mode covers regular-cadence environmental forecasts in addition to first-of-month inits. Same input-window-fits / horizon-fits filtering; returns input-window start indices for subtrajectory_start_idxs.
sgreenbury
left a comment
There was a problem hiding this comment.
Thanks for the updates @qiencai, I think this is shaping up really well. Added some initial comments below for discussion.
| normalization_path: str | None = None, | ||
| normalization_stats: dict | DictConfig | None = None, | ||
| subtrajectory_mode: bool = False, | ||
| subtrajectory_start_idxs: list[int] | None = None, |
There was a problem hiding this comment.
Looking at the design again, I think some additional API is needed in order to enable the rollout for each subtrajectory to have sufficient length for the rollout?
| subtrajectory_start_idxs: list[int] | None = None, | |
| subtrajectory_start_idxs: list[int] | None = None, | |
| n_steps_rollout: int | None = None |
| self.all_time_varying_scalars: list[torch.Tensor] = [] | ||
|
|
||
| # Each sample spans `n_steps_input + n_steps_output` steps. | ||
| window_size = self.n_steps_input + self.n_steps_output |
There was a problem hiding this comment.
With the above suggestion re. n_steps_rollout, I think this would be updated to?
| window_size = self.n_steps_input + self.n_steps_output | |
| window_size = self.n_steps_input + self.n_steps_rollout |
| Use `data.calendar.month_start_indices` to generate month-anchored | ||
| starts. Cannot be combined with `full_trajectory_mode` or | ||
| `autoencoder_mode`. Defaults to False. | ||
| subtrajectory_start_idxs: list[int] | None |
There was a problem hiding this comment.
I think we're assuming: subtrajectory_start_idxs: list[int] | None = None for each trajectory with the current API. I think this would ideally a list per trajecrory so possibly could be extended to something like:
subtrajectory_start_idxs: list[int] | dict[int, list[int]] | None
Input-window start indices for `subtrajectory_mode`. A flat list is
applied to every trajectory. A dict maps trajectory index to the
starts for that trajectory, and missing trajectories contribute no
rollout subtrajectories. Required when `subtrajectory_mode` is True.
Defaults to None.
There was a problem hiding this comment.
If we could document and update this to be documented for any time series data where days are the unit, thaat would be great.
| return dates | ||
|
|
||
|
|
||
| def month_start_indices( |
There was a problem hiding this comment.
From discussion with @qiencai, adding a comment here around the idea of enabling this to be potentially more hydra-configurable with something like:
- An abstract base class to describe index generators (
class SubtrajectoryIndexGenerator(abc.ABC)) - Some subclasses (e.g.
class MonthStartIndexGenerator(SubtrajectoryIndexGenerator)
Sketch below that could go in a separate module in e.g. src/autocast/data/subtrajectory.py
# i.e. list or dict of lists but for hydra
SubtrajectoryStartIdxs = Sequence[int] | Mapping[int | str | Sequence[int]]
class SubtrajectoryIndexGenerator(abc.ABC):
"""Generate rollout subtrajectory start indices for a dataset split."""
@abc.abstractmethod
def __call__(
self,
*,
n_trajectories: int,
n_timesteps: int,
n_steps_input: int,
n_steps_rollout: int,
) -> SubtrajectoryStartIdxs:
"""Return flat or per-trajectory rollout start indices."""
@dataclass
class MonthStartIndexGenerator(SubtrajectoryIndexGenerator):
"""Generate starts whose initialization date falls on the first of a month."""
years: list[int]
stride: int = 1
drop_leap_day: bool = True
def __call__(
self,
*,
n_trajectories: int,
n_timesteps: int,
n_steps_input: int,
n_steps_rollout: int,
) -> list[int]:
"""Return month-start initialization starts shared by every trajectory."""
start_idxs, _ = month_start_indices(
test_years=self.years,
n_steps_input=n_steps_input,
max_rollout_steps=n_steps_rollout, # NB. I think this might need changing in current form since max_rollout steps is rollouts not timesteps.
stride=self.stride,
drop_leap_day=self.drop_leap_day,
)
return start_idxs
Adds a time-varying scalars (TVS) channel, end-to-end from datasets through the encoder's global conditioning vector and is correctly sliced during autoregressive rollouts.
Usage
Provide a
time_varying_scalarsarray of shape(n_trajectories, n_timesteps, n_channels)in your HDF5 data file alongside the existing data arrays, the dictionary key should be "time_varying_scalars", and setn_tvs_extra_stepson the datamodule to at leastautoregressive_train_steps * strideso rollout always has a fresh slice; the encoder will then automatically concatenate the current TVS slice ontoglobal_condat every step, no model-side changes required.Changes
Closes #364.