Skip to content

Add API support for time-varying scalars (#364)#375

Open
qiencai wants to merge 4 commits into
alan-turing-institute:mainfrom
qiencai:feature/time-varying-scalars
Open

Add API support for time-varying scalars (#364)#375
qiencai wants to merge 4 commits into
alan-turing-institute:mainfrom
qiencai:feature/time-varying-scalars

Conversation

@qiencai

@qiencai qiencai commented May 18, 2026

Copy link
Copy Markdown
Collaborator

Adds a time-varying scalars (TVS) channel, end-to-end from datasets through the encoder's global conditioning vector and is correctly sliced during autoregressive rollouts.

Usage

Provide a time_varying_scalars array of shape (n_trajectories, n_timesteps, n_channels) in your HDF5 data file alongside the existing data arrays, the dictionary key should be "time_varying_scalars", and set n_tvs_extra_steps on the datamodule to at least autoregressive_train_steps * stride so rollout always has a fresh slice; the encoder will then automatically concatenate the current TVS slice onto global_cond at every step, no model-side changes required.

Changes

  • types: new TensorTC alias; Sample/Batch carry optional time_varying_scalars with collation, repeat(), and to() support.
  • data: HDF5Dataset loads per-trajectory time_varying_scalars and stores a window of n_steps_output + n_tvs_extra_steps slices per sample(which I set to 365, this is just to slice long enough TVS so the sequence won't be all consumed during rollout, such parameter needs to exist, to make sure the shape is consistent for all samples); new n_tvs_extra_steps dataset/datamodule parameter so rollout has enough future slices.
  • encoders.base: encode_cond appends the current TVS slice to global_cond and raises a clear error if the buffer is exhausted.
  • models.encoder_processor_decoder: _clone_batch clones TVS; _advance_batch shifts TVS by stride at every rollout step.
  • scripts/setup: expose n_time_varying_scalars in logic_stats.
  • scripts/data, benchmarking/inference: TVS-aware device move and synthetic batch construction.

Closes #364.

Adds a public time-varying scalars (TVS) channel that flows end-to-end
from datasets through the encoder's global conditioning vector and is
correctly sliced during autoregressive rollouts.

Changes
- types: new TensorTC alias; Sample/Batch carry optional time_varying_scalars
  with collation, repeat(), and to() support.
- data: HDF5Dataset loads per-trajectory time_varying_scalars and stores a
  window of n_steps_output + n_tvs_extra_steps slices per sample; new
  n_tvs_extra_steps dataset/datamodule parameter so rollout has enough
  future slices.
- encoders.base: encode_cond appends the current TVS slice to global_cond
  and raises a clear error if the buffer is exhausted.
- models.encoder_processor_decoder: _clone_batch clones TVS; _advance_batch
  shifts TVS by stride at every rollout step.
- scripts/setup: expose n_time_varying_scalars in logic_stats.
- scripts/data, benchmarking/inference: TVS-aware device move and synthetic
  batch construction.

Closes alan-turing-institute#364.
@qiencai

qiencai commented May 18, 2026

Copy link
Copy Markdown
Collaborator Author

I have the time varying fields code change locally, will follow up with that PR after we agree on implementations for the scalar case.

@sgreenbury sgreenbury left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @qiencai for taking a look at this and opening the PR! I added some initial comments below, and the following thoughts to discuss:

  • When dataset is not in full trajectory mode, would it make sense to get slices of time_varying_scalars that match the input temporal dim's size?
  • For full trajectory model, should this provide the full sequence with inputs or on the outputs?
  • Remaining aspects include incorporating into the encoder class/latent spaces

Comment thread src/autocast/data/dataset.py Outdated
Comment thread src/autocast/benchmarking/inference.py Outdated
Comment thread src/autocast/encoders/base.py Outdated

@sgreenbury sgreenbury left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion earlier, adding some notes/code here on the ideas:

Comment thread src/autocast/data/dataset.py Outdated
Comment thread src/autocast/encoders/base.py Outdated
qiencai added 2 commits June 5, 2026 15:04
Address PR alan-turing-institute#375 review feedback:
- Condition on the current input window (input-aligned) instead of the
  prediction step; encode_cond flattens the last n_tvs_steps steps via
  rearrange('b t c -> b (t c)'). New encoder attr n_tvs_steps (default 1)
  selects a past-frame subset of the input window.
- Drop n_tvs_extra_steps; the per-sample TVS buffer is sized by the window
  (input, extended across the rollout horizon in full/subtrajectory mode so
  the sliding window reaches each frame without ever looking ahead).
- Add subtrajectory_mode + subtrajectory_start_idxs to build rollout
  subtrajectories at explicit starts; datamodule uses it for rollout
  datasets when set.
- Add data/calendar.py:month_start_indices (no-leap default) to generate
  month-anchored start indices, ported from monthly_init_indices.
- inference.py: drop getattr; rename tvs -> time_varying_scalars.
- Tests for calendar, dataset windowing/subtrajectory mode, encode_cond
  rearrange, and _advance_batch striding.
Calendar-agnostic counterpart to month_start_indices: place rollout
initialisations on a fixed-period grid (every init_interval steps, with an
optional init_offset) so subtrajectory_mode covers regular-cadence
environmental forecasts in addition to first-of-month inits. Same
input-window-fits / horizon-fits filtering; returns input-window start
indices for subtrajectory_start_idxs.

@sgreenbury sgreenbury left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @qiencai, I think this is shaping up really well. Added some initial comments below for discussion.

normalization_path: str | None = None,
normalization_stats: dict | DictConfig | None = None,
subtrajectory_mode: bool = False,
subtrajectory_start_idxs: list[int] | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the design again, I think some additional API is needed in order to enable the rollout for each subtrajectory to have sufficient length for the rollout?

Suggested change
subtrajectory_start_idxs: list[int] | None = None,
subtrajectory_start_idxs: list[int] | None = None,
n_steps_rollout: int | None = None

self.all_time_varying_scalars: list[torch.Tensor] = []

# Each sample spans `n_steps_input + n_steps_output` steps.
window_size = self.n_steps_input + self.n_steps_output

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above suggestion re. n_steps_rollout, I think this would be updated to?

Suggested change
window_size = self.n_steps_input + self.n_steps_output
window_size = self.n_steps_input + self.n_steps_rollout

Use `data.calendar.month_start_indices` to generate month-anchored
starts. Cannot be combined with `full_trajectory_mode` or
`autoencoder_mode`. Defaults to False.
subtrajectory_start_idxs: list[int] | None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're assuming: subtrajectory_start_idxs: list[int] | None = None for each trajectory with the current API. I think this would ideally a list per trajecrory so possibly could be extended to something like:

subtrajectory_start_idxs: list[int] | dict[int, list[int]] | None
            Input-window start indices for `subtrajectory_mode`. A flat list is
            applied to every trajectory. A dict maps trajectory index to the
            starts for that trajectory, and missing trajectories contribute no
            rollout subtrajectories. Required when `subtrajectory_mode` is True.
            Defaults to None.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could document and update this to be documented for any time series data where days are the unit, thaat would be great.

return dates


def month_start_indices(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From discussion with @qiencai, adding a comment here around the idea of enabling this to be potentially more hydra-configurable with something like:

  • An abstract base class to describe index generators (class SubtrajectoryIndexGenerator(abc.ABC))
  • Some subclasses (e.g. class MonthStartIndexGenerator(SubtrajectoryIndexGenerator)

Sketch below that could go in a separate module in e.g. src/autocast/data/subtrajectory.py

# i.e. list or dict of lists but for hydra
SubtrajectoryStartIdxs = Sequence[int] | Mapping[int | str | Sequence[int]]

class SubtrajectoryIndexGenerator(abc.ABC):
    """Generate rollout subtrajectory start indices for a dataset split."""

    @abc.abstractmethod
    def __call__(
        self,
        *,
        n_trajectories: int,
        n_timesteps: int,
        n_steps_input: int,
        n_steps_rollout: int,
    ) -> SubtrajectoryStartIdxs:
        """Return flat or per-trajectory rollout start indices."""

@dataclass
class MonthStartIndexGenerator(SubtrajectoryIndexGenerator):
    """Generate starts whose initialization date falls on the first of a month."""

    years: list[int]
    stride: int = 1
    drop_leap_day: bool = True

    def __call__(
        self,
        *,
        n_trajectories: int,
        n_timesteps: int,
        n_steps_input: int,
        n_steps_rollout: int,
    ) -> list[int]:
        """Return month-start initialization starts shared by every trajectory."""
        start_idxs, _ = month_start_indices(
            test_years=self.years,
            n_steps_input=n_steps_input,
            max_rollout_steps=n_steps_rollout, # NB. I think this might need changing in current form since max_rollout steps is rollouts not timesteps.
            stride=self.stride,
            drop_leap_day=self.drop_leap_day,
        )
        return start_idxs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add API to support time-varying scalars

2 participants