| """ |
| Custom handler for HuggingFace Inference Endpoints. |
| Uses TimesFM 2.5 (200M) installed from GitHub repo. |
| """ |
|
|
| import numpy as np |
| import torch |
| from typing import Any |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| import timesfm |
| from timesfm.timesfm_2p5.timesfm_2p5_torch import ( |
| TimesFM_2p5_200M_torch_module, |
| TimesFM_2p5_200M_torch, |
| ) |
| from safetensors.torch import load_file |
| from huggingface_hub import hf_hub_download |
| import os |
|
|
| torch.set_float32_matmul_precision("high") |
|
|
| |
| model_id = "google/timesfm-2.5-200m-pytorch" |
| weight_file = hf_hub_download(repo_id=model_id, filename="model.safetensors") |
|
|
| |
| self.tfm = TimesFM_2p5_200M_torch() |
| self.tfm.model = TimesFM_2p5_200M_torch_module() |
| self.tfm.model.load_checkpoint(weight_file, torch_compile=False) |
|
|
| self.tfm.compile( |
| timesfm.ForecastConfig( |
| max_context=1024, |
| max_horizon=128, |
| normalize_inputs=True, |
| use_continuous_quantile_head=True, |
| force_flip_invariance=True, |
| infer_is_positive=False, |
| fix_quantile_crossing=True, |
| ) |
| ) |
|
|
| def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
| inputs = data.get("inputs", []) |
| parameters = data.get("parameters", {}) |
| horizon = min(parameters.get("horizon", 24), 128) |
|
|
| if not inputs or not isinstance(inputs, list): |
| return {"error": "inputs must be a non-empty list of numbers"} |
|
|
| input_array = [np.array(inputs, dtype=np.float64)] |
| point, quantiles = self.tfm.forecast(horizon=horizon, inputs=input_array) |
|
|
| return { |
| "point_forecast": point[0].tolist(), |
| "quantile_forecast": quantiles[0].tolist(), |
| } |
|
|