Skip to content

Problem with model predictions failing (flatlining) when running on H100s #328

@abao1999

Description

@abao1999

Hi TimesFM team,

I have been noticing a very strange issue when running inference with TimesFM 2.5 on my H100 gpus. I haven't noticed this on my other GPUs. Often, after loading the model and calling model.forecast, I get a flat line prediction that is just the same value repeated over the entire horizon. Sometimes, running exactly the same code, I get the actual reasonable model prediction output. I have noticed this problem only on the H100 gpus and I'm trying to figure out if it's a problem with the torch.compile (i.e. compiled_decode), with specific dependencies, or something else entirely. This seems like a very serious issue, as I am reproducing this error in a jupyter notebook, where quite literally I see the predictions fail almost every time I re-run the notebook (after restarting the kernel). Below is a description of my setup; thanks in advance for considering this issue and please let me know if I should provide any additional information to help us figure it out!

rseed = 123
set_seed(rseed)

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch", device=device)

model.compile(
    timesfm.ForecastConfig(
        max_context=512,
        max_horizon=512,
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
        per_core_batch_size=32,
    )
)

point_forecast, quantile_forecast = model.forecast(
    horizon=512,
    inputs=[context],
)

NOTE: that while I used max_context=512 and max_horizon=512 in the compile config for this example, previously I have been using max_context=15360 and max_horizon=1024 following the timesfm2.5 notebook in the gift-eval directory, and the error persists regardless of these specific settings. I am using prediction_length = 512 for this example, hard-coded to be more explicit and clear.

print(f"{model.model.device=}")
print(f"{torch.__version__=}")
print(f"{torch.cuda.is_available()=}")

model.model.device='cuda:3'
torch.version='2.7.0+cu126'
torch.cuda.is_available()=True

Plotting

# Setup data
context_timesteps = np.arange(context.shape[-1])
future_timesteps = np.arange(context.shape[-1], context.shape[-1] + prediction_length)
print(f"context shape: {context.shape}")
print(f"future_vals shape: {future_vals.shape}")
print(f"length of context_timesteps: {len(context_timesteps)}")
print(f"length of future_timesteps: {len(future_timesteps)}")
context shape: (469,)
future_vals shape: (18,)
length of context_timesteps: 469
length of future_timesteps: 512

Data split is from Gift-Eval (black is context, blue is labels over prediction horizon). Specifically, the first short-term split of m4_monthly:
Image

fig, ax = plt.subplots(figsize=(6, 2))

# Plot context, ground truth and predictions
ax.plot(context_timesteps, context, color="black", linewidth=1, label="Context")
ax.plot(future_timesteps[:future_vals.shape[-1]], future_vals, color="black", linewidth=1, linestyle="--", label="Ground Truth") 
ax.plot(future_timesteps, point_forecast.squeeze(), color="tab:blue", linewidth=2, label="Prediction")

ax.set_xlabel("Timestep", fontweight="bold")
fig.tight_layout()

Failure prediction (happens a lot)

Image

Successful prediction

Image

Dependencies

timesfm				2.0.0
torch                     		2.7.0
transformers              	4.52.4
tokenizers			0.21.4
huggingface_hub		0.35.3
triton                    		3.3.0

nvidia-cublas-cu12        12.6.4.1
nvidia-cuda-cupti-cu12    12.6.80
nvidia-cuda-nvrtc-cu12    12.6.77
nvidia-cuda-runtime-cu12  12.6.77
nvidia-cudnn-cu12         9.5.1.17
nvidia-cufft-cu12         11.3.0.4
nvidia-cufile-cu12        1.11.1.6
nvidia-curand-cu12        10.3.7.77
nvidia-cusolver-cu12      11.7.1.2
nvidia-cusparse-cu12      12.5.4.2
nvidia-cusparselt-cu12    0.6.3
nvidia-nccl-cu12          2.26.2
nvidia-nvjitlink-cu12     12.6.85
nvidia-nvshmem-cu12       3.3.20
nvidia-nvtx-cu12          12.6.77

Furthermore, I have even tried clearing the cuda cache before running inference, but the problem persists

import gc

gc.collect()
# Clear CUDA caches specifically for device
if torch.cuda.is_available():
    print(f"Current GPU on {device}: {torch.cuda.get_device_name()}")
    with torch.cuda.device(device):
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        torch.cuda.reset_peak_memory_stats()
    print(f"Cleared CUDA caches for device {device}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions