Skip to content

Commit a3beaa8

Browse files
authored
Merge pull request #372 from kashif/modelhubmixin
[HF] use the ModelHubMixin api
2 parents 8a755c9 + d2cb484 commit a3beaa8

File tree

1 file changed

+42
-27
lines changed

1 file changed

+42
-27
lines changed

src/timesfm/timesfm_2p5/timesfm_2p5_torch.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import math
1919
import os
2020
from pathlib import Path
21-
from typing import Dict, Optional, Sequence, Union
21+
from typing import Optional, Sequence, Union
2222

2323
import numpy as np
2424
import torch
25-
from huggingface_hub import ModelHubMixin, hf_hub_download
25+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
2626
from safetensors.torch import load_file, save_file
2727
from torch import nn
2828

@@ -263,75 +263,90 @@ def forecast_naive(
263263
return outputs
264264

265265

266-
class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, ModelHubMixin):
266+
class TimesFM_2p5_200M_torch(
267+
timesfm_2p5_base.TimesFM_2p5,
268+
PyTorchModelHubMixin,
269+
library_name="timesfm",
270+
repo_url="https://github.com/google-research/timesfm",
271+
paper_url="https://arxiv.org/abs/2310.10688",
272+
docs_url="https://github.com/google-research/timesfm",
273+
license="apache-2.0",
274+
pipeline_tag="time-series-forecasting",
275+
tags=["pytorch", "timeseries", "forecasting", "timesfm-2.5"],
276+
):
267277
"""PyTorch implementation of TimesFM 2.5 with 200M parameters."""
268278

269-
model: nn.Module = TimesFM_2p5_200M_torch_module()
279+
DEFAULT_REPO_ID = "google/timesfm-2.5-200m-pytorch"
280+
WEIGHTS_FILENAME = "model.safetensors"
281+
282+
def __init__(
283+
self,
284+
torch_compile: bool = True,
285+
config: Optional[dict] = None,
286+
):
287+
self.model = TimesFM_2p5_200M_torch_module()
288+
self.torch_compile = torch_compile
289+
if config is not None:
290+
self._hub_mixin_config = config
270291

271292
@classmethod
272293
def _from_pretrained(
273294
cls,
274295
*,
275-
model_id: str = "google/timesfm-2.5-200m-pytorch",
296+
model_id: str = DEFAULT_REPO_ID,
276297
revision: Optional[str],
277298
cache_dir: Optional[Union[str, Path]],
278-
force_download: bool = True,
279-
proxies: Optional[Dict] = None,
280-
resume_download: Optional[bool] = None,
299+
force_download: bool = False,
281300
local_files_only: bool,
282301
token: Optional[Union[str, bool]],
302+
config: Optional[dict] = None,
283303
**model_kwargs,
284304
):
285305
"""
286306
Loads a PyTorch safetensors TimesFM model from a local path or the Hugging
287307
Face Hub. This method is the backend for the `from_pretrained` class
288-
method provided by `ModelHubMixin`.
308+
method provided by `PyTorchModelHubMixin`.
289309
"""
290-
# Create an instance of the model wrapper class.
291-
instance = cls(**model_kwargs)
292-
# Download the config file for hf tracking.
293-
_ = hf_hub_download(
294-
repo_id="google/timesfm-2.5-200m-pytorch",
295-
filename="config.json",
296-
force_download=True,
297-
)
298-
print("Downloaded.")
299-
300310
# Determine the path to the model weights.
301311
model_file_path = ""
302312
if os.path.isdir(model_id):
303313
logging.info("Loading checkpoint from local directory: %s", model_id)
304-
model_file_path = os.path.join(model_id, "model.safetensors")
314+
model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME)
305315
if not os.path.exists(model_file_path):
306-
raise FileNotFoundError(f"model.safetensors not found in directory {model_id}")
316+
raise FileNotFoundError(
317+
f"{cls.WEIGHTS_FILENAME} not found in directory {model_id}"
318+
)
307319
else:
308320
logging.info("Downloading checkpoint from Hugging Face repo %s", model_id)
309321
model_file_path = hf_hub_download(
310322
repo_id=model_id,
311-
filename="model.safetensors",
323+
filename=cls.WEIGHTS_FILENAME,
312324
revision=revision,
313325
cache_dir=cache_dir,
314326
force_download=force_download,
315-
proxies=proxies,
316-
resume_download=resume_download,
317327
token=token,
318328
local_files_only=local_files_only,
319329
)
320330

331+
# Create an instance of the model wrapper class.
332+
instance = cls(config=config, **model_kwargs)
333+
321334
logging.info("Loading checkpoint from: %s", model_file_path)
322335
# Load the weights into the model.
323-
instance.model.load_checkpoint(model_file_path, **model_kwargs)
336+
instance.model.load_checkpoint(
337+
model_file_path, torch_compile=instance.torch_compile
338+
)
324339
return instance
325340

326341
def _save_pretrained(self, save_directory: Union[str, Path]):
327342
"""
328343
Saves the model's state dictionary to a safetensors file. This method
329-
is called by the `save_pretrained` method from `ModelHubMixin`.
344+
is called by the `save_pretrained` method from `PyTorchModelHubMixin`.
330345
"""
331346
if not os.path.exists(save_directory):
332347
os.makedirs(save_directory)
333348

334-
weights_path = os.path.join(save_directory, "model.safetensors")
349+
weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME)
335350
save_file(self.model.state_dict(), weights_path)
336351

337352
def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None:

0 commit comments

Comments
 (0)