indievish commited on
Commit ·
d628f44
1
Parent(s): e4ccd9a
Use TimesFM 2.5 from GitHub repo (not PyPI)
Browse files- handler.py +17 -21
- requirements.txt +1 -2
handler.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
Custom handler for HuggingFace Inference Endpoints.
|
| 3 |
-
Uses
|
| 4 |
"""
|
| 5 |
|
| 6 |
import numpy as np
|
|
@@ -14,18 +14,19 @@ class EndpointHandler:
|
|
| 14 |
|
| 15 |
torch.set_float32_matmul_precision("high")
|
| 16 |
|
| 17 |
-
self.model = timesfm.
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -37,14 +38,9 @@ class EndpointHandler:
|
|
| 37 |
return {"error": "inputs must be a non-empty list of numbers"}
|
| 38 |
|
| 39 |
input_array = [np.array(inputs, dtype=np.float64)]
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
point, quantiles = self.model.forecast(
|
| 43 |
-
input_array,
|
| 44 |
-
freq=frequency_input,
|
| 45 |
-
)
|
| 46 |
|
| 47 |
return {
|
| 48 |
-
"point_forecast": point[0]
|
| 49 |
-
"quantile_forecast": quantiles[0]
|
| 50 |
}
|
|
|
|
| 1 |
"""
|
| 2 |
Custom handler for HuggingFace Inference Endpoints.
|
| 3 |
+
Uses TimesFM 2.5 (200M) installed from GitHub repo.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import numpy as np
|
|
|
|
| 14 |
|
| 15 |
torch.set_float32_matmul_precision("high")
|
| 16 |
|
| 17 |
+
self.model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
|
| 18 |
+
"google/timesfm-2.5-200m-pytorch"
|
| 19 |
+
)
|
| 20 |
+
self.model.compile(
|
| 21 |
+
timesfm.ForecastConfig(
|
| 22 |
+
max_context=1024,
|
| 23 |
+
max_horizon=128,
|
| 24 |
+
normalize_inputs=True,
|
| 25 |
+
use_continuous_quantile_head=True,
|
| 26 |
+
force_flip_invariance=True,
|
| 27 |
+
infer_is_positive=False,
|
| 28 |
+
fix_quantile_crossing=True,
|
| 29 |
+
)
|
| 30 |
)
|
| 31 |
|
| 32 |
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
|
|
| 38 |
return {"error": "inputs must be a non-empty list of numbers"}
|
| 39 |
|
| 40 |
input_array = [np.array(inputs, dtype=np.float64)]
|
| 41 |
+
point, quantiles = self.model.forecast(horizon=horizon, inputs=input_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
return {
|
| 44 |
+
"point_forecast": point[0].tolist(),
|
| 45 |
+
"quantile_forecast": quantiles[0].tolist(),
|
| 46 |
}
|
requirements.txt
CHANGED
|
@@ -1,3 +1,2 @@
|
|
| 1 |
-
timesfm
|
| 2 |
-
torch
|
| 3 |
numpy
|
|
|
|
| 1 |
+
git+https://github.com/google-research/timesfm.git@master#egg=timesfm[torch]
|
|
|
|
| 2 |
numpy
|