-
Notifications
You must be signed in to change notification settings - Fork 1.2k
load local models #340
Copy link
Copy link
Open
Description
import torch
import numpy as np
import timesfm
from timesfm.timesfm_2p5.timesfm_2p5_torch import TimesFM_2p5_200M_torch
# remote download and load
# model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch", torch_compile=True)
# local load
model = TimesFM_2p5_200M_torch.from_pretrained(
pretrained_model_name_or_path="./models/timesfm-2.5-200m-pytorch",
torch_compile=True,
local_files_only=True
)
model.compile(
timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
)
point_forecast, quantile_forecast = model.forecast(
horizon=12,
inputs=[
np.linspace(0, 1, 100),
np.sin(np.linspace(0, 20, 67)),
], # Two dummy inputs
)
print(point_forecast.shape) # (2, 12)
print(quantile_forecast.shape) # (2, 12, 10): mean, then 10th to 90th quantiles.When loading the local model, the remote config keeps downloading. Can you optimize it?
class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, ModelHubMixin):
"""PyTorch implementation of TimesFM 2.5 with 200M parameters."""
model: nn.Module = TimesFM_2p5_200M_torch_module()
@classmethod
def _from_pretrained(
cls,
*,
model_id: str = "google/timesfm-2.5-200m-pytorch",
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool = True,
proxies: Optional[Dict] = None,
resume_download: Optional[bool] = None,
local_files_only: bool,
token: Optional[Union[str, bool]],
**model_kwargs,
):
"""
Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
Face Hub. This method is the backend for the `from_pretrained` class
method provided by `ModelHubMixin`.
"""
# Create an instance of the model wrapper class.
instance = cls(**model_kwargs)
# Download the config file for hf tracking.
_ = hf_hub_download(
repo_id="google/timesfm-2.5-200m-pytorch",
filename="config.json",
force_download=True,
)
print("Downloaded.")
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels