|
18 | 18 | import math |
19 | 19 | import os |
20 | 20 | from pathlib import Path |
21 | | -from typing import Dict, Optional, Sequence, Union |
| 21 | +from typing import Optional, Sequence, Union |
22 | 22 |
|
23 | 23 | import numpy as np |
24 | 24 | import torch |
25 | | -from huggingface_hub import ModelHubMixin, hf_hub_download |
| 25 | +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
26 | 26 | from safetensors.torch import load_file, save_file |
27 | 27 | from torch import nn |
28 | 28 |
|
@@ -263,75 +263,90 @@ def forecast_naive( |
263 | 263 | return outputs |
264 | 264 |
|
265 | 265 |
|
266 | | -class TimesFM_2p5_200M_torch(timesfm_2p5_base.TimesFM_2p5, ModelHubMixin): |
| 266 | +class TimesFM_2p5_200M_torch( |
| 267 | + timesfm_2p5_base.TimesFM_2p5, |
| 268 | + PyTorchModelHubMixin, |
| 269 | + library_name="timesfm", |
| 270 | + repo_url="https://github.com/google-research/timesfm", |
| 271 | + paper_url="https://arxiv.org/abs/2310.10688", |
| 272 | + docs_url="https://github.com/google-research/timesfm", |
| 273 | + license="apache-2.0", |
| 274 | + pipeline_tag="time-series-forecasting", |
| 275 | + tags=["pytorch", "timeseries", "forecasting", "timesfm-2.5"], |
| 276 | +): |
267 | 277 | """PyTorch implementation of TimesFM 2.5 with 200M parameters.""" |
268 | 278 |
|
269 | | - model: nn.Module = TimesFM_2p5_200M_torch_module() |
| 279 | + DEFAULT_REPO_ID = "google/timesfm-2.5-200m-pytorch" |
| 280 | + WEIGHTS_FILENAME = "model.safetensors" |
| 281 | + |
| 282 | + def __init__( |
| 283 | + self, |
| 284 | + torch_compile: bool = True, |
| 285 | + config: Optional[dict] = None, |
| 286 | + ): |
| 287 | + self.model = TimesFM_2p5_200M_torch_module() |
| 288 | + self.torch_compile = torch_compile |
| 289 | + if config is not None: |
| 290 | + self._hub_mixin_config = config |
270 | 291 |
|
271 | 292 | @classmethod |
272 | 293 | def _from_pretrained( |
273 | 294 | cls, |
274 | 295 | *, |
275 | | - model_id: str = "google/timesfm-2.5-200m-pytorch", |
| 296 | + model_id: str = DEFAULT_REPO_ID, |
276 | 297 | revision: Optional[str], |
277 | 298 | cache_dir: Optional[Union[str, Path]], |
278 | | - force_download: bool = True, |
279 | | - proxies: Optional[Dict] = None, |
280 | | - resume_download: Optional[bool] = None, |
| 299 | + force_download: bool = False, |
281 | 300 | local_files_only: bool, |
282 | 301 | token: Optional[Union[str, bool]], |
| 302 | + config: Optional[dict] = None, |
283 | 303 | **model_kwargs, |
284 | 304 | ): |
285 | 305 | """ |
286 | 306 | Loads a PyTorch safetensors TimesFM model from a local path or the Hugging |
287 | 307 | Face Hub. This method is the backend for the `from_pretrained` class |
288 | | - method provided by `ModelHubMixin`. |
| 308 | + method provided by `PyTorchModelHubMixin`. |
289 | 309 | """ |
290 | | - # Create an instance of the model wrapper class. |
291 | | - instance = cls(**model_kwargs) |
292 | | - # Download the config file for hf tracking. |
293 | | - _ = hf_hub_download( |
294 | | - repo_id="google/timesfm-2.5-200m-pytorch", |
295 | | - filename="config.json", |
296 | | - force_download=True, |
297 | | - ) |
298 | | - print("Downloaded.") |
299 | | - |
300 | 310 | # Determine the path to the model weights. |
301 | 311 | model_file_path = "" |
302 | 312 | if os.path.isdir(model_id): |
303 | 313 | logging.info("Loading checkpoint from local directory: %s", model_id) |
304 | | - model_file_path = os.path.join(model_id, "model.safetensors") |
| 314 | + model_file_path = os.path.join(model_id, cls.WEIGHTS_FILENAME) |
305 | 315 | if not os.path.exists(model_file_path): |
306 | | - raise FileNotFoundError(f"model.safetensors not found in directory {model_id}") |
| 316 | + raise FileNotFoundError( |
| 317 | + f"{cls.WEIGHTS_FILENAME} not found in directory {model_id}" |
| 318 | + ) |
307 | 319 | else: |
308 | 320 | logging.info("Downloading checkpoint from Hugging Face repo %s", model_id) |
309 | 321 | model_file_path = hf_hub_download( |
310 | 322 | repo_id=model_id, |
311 | | - filename="model.safetensors", |
| 323 | + filename=cls.WEIGHTS_FILENAME, |
312 | 324 | revision=revision, |
313 | 325 | cache_dir=cache_dir, |
314 | 326 | force_download=force_download, |
315 | | - proxies=proxies, |
316 | | - resume_download=resume_download, |
317 | 327 | token=token, |
318 | 328 | local_files_only=local_files_only, |
319 | 329 | ) |
320 | 330 |
|
| 331 | + # Create an instance of the model wrapper class. |
| 332 | + instance = cls(config=config, **model_kwargs) |
| 333 | + |
321 | 334 | logging.info("Loading checkpoint from: %s", model_file_path) |
322 | 335 | # Load the weights into the model. |
323 | | - instance.model.load_checkpoint(model_file_path, **model_kwargs) |
| 336 | + instance.model.load_checkpoint( |
| 337 | + model_file_path, torch_compile=instance.torch_compile |
| 338 | + ) |
324 | 339 | return instance |
325 | 340 |
|
326 | 341 | def _save_pretrained(self, save_directory: Union[str, Path]): |
327 | 342 | """ |
328 | 343 | Saves the model's state dictionary to a safetensors file. This method |
329 | | - is called by the `save_pretrained` method from `ModelHubMixin`. |
| 344 | + is called by the `save_pretrained` method from `PyTorchModelHubMixin`. |
330 | 345 | """ |
331 | 346 | if not os.path.exists(save_directory): |
332 | 347 | os.makedirs(save_directory) |
333 | 348 |
|
334 | | - weights_path = os.path.join(save_directory, "model.safetensors") |
| 349 | + weights_path = os.path.join(save_directory, self.WEIGHTS_FILENAME) |
335 | 350 | save_file(self.model.state_dict(), weights_path) |
336 | 351 |
|
337 | 352 | def compile(self, forecast_config: configs.ForecastConfig, **kwargs) -> None: |
|
0 commit comments