indievish commited on
Commit
d628f44
·
1 Parent(s): e4ccd9a

Use TimesFM 2.5 from GitHub repo (not PyPI)

Browse files
Files changed (2) hide show
  1. handler.py +17 -21
  2. requirements.txt +1 -2
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Custom handler for HuggingFace Inference Endpoints.
3
- Uses timesfm pip package (v1.3.0) with the 2.0 PyTorch model.
4
  """
5
 
6
  import numpy as np
@@ -14,18 +14,19 @@ class EndpointHandler:
14
 
15
  torch.set_float32_matmul_precision("high")
16
 
17
- self.model = timesfm.TimesFm(
18
- hparams=timesfm.TimesFmHparams(
19
- backend="gpu" if torch.cuda.is_available() else "cpu",
20
- per_core_batch_size=32,
21
- horizon_len=128,
22
- num_layers=20,
23
- context_len=1024,
24
- use_positional_embedding=False,
25
- ),
26
- checkpoint=timesfm.TimesFmCheckpoint(
27
- huggingface_repo_id="google/timesfm-2.0-500m-pytorch",
28
- ),
 
29
  )
30
 
31
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
@@ -37,14 +38,9 @@ class EndpointHandler:
37
  return {"error": "inputs must be a non-empty list of numbers"}
38
 
39
  input_array = [np.array(inputs, dtype=np.float64)]
40
- frequency_input = [0] # 0 = no specific frequency
41
-
42
- point, quantiles = self.model.forecast(
43
- input_array,
44
- freq=frequency_input,
45
- )
46
 
47
  return {
48
- "point_forecast": point[0][:horizon].tolist(),
49
- "quantile_forecast": quantiles[0][:horizon].tolist(),
50
  }
 
1
  """
2
  Custom handler for HuggingFace Inference Endpoints.
3
+ Uses TimesFM 2.5 (200M) installed from GitHub repo.
4
  """
5
 
6
  import numpy as np
 
14
 
15
  torch.set_float32_matmul_precision("high")
16
 
17
+ self.model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
18
+ "google/timesfm-2.5-200m-pytorch"
19
+ )
20
+ self.model.compile(
21
+ timesfm.ForecastConfig(
22
+ max_context=1024,
23
+ max_horizon=128,
24
+ normalize_inputs=True,
25
+ use_continuous_quantile_head=True,
26
+ force_flip_invariance=True,
27
+ infer_is_positive=False,
28
+ fix_quantile_crossing=True,
29
+ )
30
  )
31
 
32
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
 
38
  return {"error": "inputs must be a non-empty list of numbers"}
39
 
40
  input_array = [np.array(inputs, dtype=np.float64)]
41
+ point, quantiles = self.model.forecast(horizon=horizon, inputs=input_array)
 
 
 
 
 
42
 
43
  return {
44
+ "point_forecast": point[0].tolist(),
45
+ "quantile_forecast": quantiles[0].tolist(),
46
  }
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
- timesfm==1.3.0
2
- torch
3
  numpy
 
1
+ git+https://github.com/google-research/timesfm.git@master#egg=timesfm[torch]
 
2
  numpy