Skip to content

fix(v1): respect batch_size in train_gen when permute=False#364

Open
staticpayload wants to merge 1 commit intogoogle-research:masterfrom
staticpayload:fix/v1-train-gen-batching-359
Open

fix(v1): respect batch_size in train_gen when permute=False#364
staticpayload wants to merge 1 commit intogoogle-research:masterfrom
staticpayload:fix/v1-train-gen-batching-359

Conversation

@staticpayload
Copy link
Copy Markdown

Summary

  • fix TimeSeriesdata.train_gen() in v1/src/timesfm/data_loader.py so permute=False yields deterministic batch windows instead of all series every step
  • keep existing permute=True sampling behavior unchanged
  • add regression tests that validate index windows, tensor leading dimensions, and no extra batch on even splits

Bug

Issue #359 reports that train_gen ignores batch_size when permute=False by setting tsidx = np.arange(num_ts) every iteration.

This makes non-permuted training load all series in each generated sample, which breaks expected batching semantics and inflates memory/compute.

Fix

For permute=False, iterate explicit batch starts (range(0, num_ts, self.batch_size)) and slice tsidx as:
np.arange(batch_idx, min(batch_idx + self.batch_size, num_ts)).

Tests

Added v1/tests/test_data_loader.py with two regression tests:

  1. test_train_gen_non_permute_respects_batch_windows
  2. test_train_gen_non_permute_no_extra_batches_on_even_split

Local run:

python3 -m pytest -q v1/tests/test_data_loader.py
# 2 passed

Fixes #359

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

train_gen loads indices of all time series instead of batch-size if self.permute is False

1 participant