-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Problem with model predictions failing (flatlining) when running on H100s #328
Description
Hi TimesFM team,
I have been noticing a very strange issue when running inference with TimesFM 2.5 on my H100 gpus. I haven't noticed this on my other GPUs. Often, after loading the model and calling model.forecast, I get a flat line prediction that is just the same value repeated over the entire horizon. Sometimes, running exactly the same code, I get the actual reasonable model prediction output. I have noticed this problem only on the H100 gpus and I'm trying to figure out if it's a problem with the torch.compile (i.e. compiled_decode), with specific dependencies, or something else entirely. This seems like a very serious issue, as I am reproducing this error in a jupyter notebook, where quite literally I see the predictions fail almost every time I re-run the notebook (after restarting the kernel). Below is a description of my setup; thanks in advance for considering this issue and please let me know if I should provide any additional information to help us figure it out!
rseed = 123
set_seed(rseed)
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch", device=device)
model.compile(
timesfm.ForecastConfig(
max_context=512,
max_horizon=512,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
per_core_batch_size=32,
)
)
point_forecast, quantile_forecast = model.forecast(
horizon=512,
inputs=[context],
)
NOTE: that while I used max_context=512 and max_horizon=512 in the compile config for this example, previously I have been using max_context=15360 and max_horizon=1024 following the timesfm2.5 notebook in the gift-eval directory, and the error persists regardless of these specific settings. I am using prediction_length = 512 for this example, hard-coded to be more explicit and clear.
print(f"{model.model.device=}")
print(f"{torch.__version__=}")
print(f"{torch.cuda.is_available()=}")
model.model.device='cuda:3'
torch.version='2.7.0+cu126'
torch.cuda.is_available()=True
Plotting
# Setup data
context_timesteps = np.arange(context.shape[-1])
future_timesteps = np.arange(context.shape[-1], context.shape[-1] + prediction_length)
print(f"context shape: {context.shape}")
print(f"future_vals shape: {future_vals.shape}")
print(f"length of context_timesteps: {len(context_timesteps)}")
print(f"length of future_timesteps: {len(future_timesteps)}")
context shape: (469,)
future_vals shape: (18,)
length of context_timesteps: 469
length of future_timesteps: 512
Data split is from Gift-Eval (black is context, blue is labels over prediction horizon). Specifically, the first short-term split of m4_monthly:

fig, ax = plt.subplots(figsize=(6, 2))
# Plot context, ground truth and predictions
ax.plot(context_timesteps, context, color="black", linewidth=1, label="Context")
ax.plot(future_timesteps[:future_vals.shape[-1]], future_vals, color="black", linewidth=1, linestyle="--", label="Ground Truth")
ax.plot(future_timesteps, point_forecast.squeeze(), color="tab:blue", linewidth=2, label="Prediction")
ax.set_xlabel("Timestep", fontweight="bold")
fig.tight_layout()
Failure prediction (happens a lot)
Successful prediction
Dependencies
timesfm 2.0.0
torch 2.7.0
transformers 4.52.4
tokenizers 0.21.4
huggingface_hub 0.35.3
triton 3.3.0
nvidia-cublas-cu12 12.6.4.1
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvrtc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12 9.5.1.17
nvidia-cufft-cu12 11.3.0.4
nvidia-cufile-cu12 1.11.1.6
nvidia-curand-cu12 10.3.7.77
nvidia-cusolver-cu12 11.7.1.2
nvidia-cusparse-cu12 12.5.4.2
nvidia-cusparselt-cu12 0.6.3
nvidia-nccl-cu12 2.26.2
nvidia-nvjitlink-cu12 12.6.85
nvidia-nvshmem-cu12 3.3.20
nvidia-nvtx-cu12 12.6.77
Furthermore, I have even tried clearing the cuda cache before running inference, but the problem persists
import gc
gc.collect()
# Clear CUDA caches specifically for device
if torch.cuda.is_available():
print(f"Current GPU on {device}: {torch.cuda.get_device_name()}")
with torch.cuda.device(device):
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
print(f"Cleared CUDA caches for device {device}")