Premchan369 commited on
Commit
aab4bbb
·
verified ·
1 Parent(s): ff8e6b2

Add online learning: per-symbol adaptive models with meta-learning, concept drift adaptation

Browse files
Files changed (1) hide show
  1. online_learning.py +436 -98
online_learning.py CHANGED
@@ -1,109 +1,447 @@
1
- """Online Learning - Adaptive model updates with drift detection"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
- import torch
5
- import torch.nn as nn
6
- from typing import Dict, Optional
7
- from scipy.stats import ks_2samp
8
 
9
 
10
- class DriftDetector:
11
- """Detect data drift using statistical tests"""
 
 
 
 
 
12
 
13
- def __init__(self, significance=0.05, window=252):
14
- self.significance = significance
15
- self.window = window
16
- self.reference_stats = {}
17
- self.drift_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def set_reference(self, data: np.ndarray, name: str = 'default'):
20
- self.reference_stats[name] = {'mean': data.mean(axis=0), 'std': data.std(axis=0), 'data': data}
21
-
22
- def detect_ks(self, new_data: np.ndarray, name: str = 'default') -> Dict:
23
- ref = self.reference_stats.get(name)
24
- if ref is None:
25
- return {'drift': False, 'p_value': 1.0}
26
- n_features = new_data.shape[1] if new_data.ndim > 1 else 1
27
- drifts = []
28
- p_values = []
29
- for i in range(n_features):
30
- col = i if new_data.ndim > 1 else 0
31
- ref_feat = ref['data'][:, col] if ref['data'].ndim > 1 else ref['data']
32
- new_feat = new_data[:, col] if new_data.ndim > 1 else new_data
33
- stat, p = ks_2samp(ref_feat, new_feat)
34
- drifts.append(p < self.significance)
35
- p_values.append(p)
36
- n_drift = sum(drifts)
37
- overall_drift = n_drift > n_features * 0.3
38
- return {'drift': overall_drift, 'p_values': p_values, 'n_features_drifted': n_drift, 'total_features': n_features}
39
-
40
- def detect_cusum(self, residuals: np.ndarray, threshold: float = 5.0, drift: float = 1.0) -> Dict:
41
- pos_cusum = np.zeros(len(residuals))
42
- neg_cusum = np.zeros(len(residuals))
43
-
44
- for t in range(1, len(residuals)):
45
- pos_cusum[t] = max(0, pos_cusum[t-1] + residuals[t] - drift)
46
- neg_cusum[t] = min(0, neg_cusum[t-1] + residuals[t] + drift)
47
-
48
- alert = np.any(pos_cusum > threshold) or np.any(neg_cusum < -threshold)
49
- return {'alert': alert, 'pos_cusum': pos_cusum, 'neg_cusum': neg_cusum}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
- class OnlineLearner:
53
- """Online learning with periodic model adaptation"""
54
-
55
- def __init__(self, model: nn.Module, lr: float = 1e-5,
56
- adaptation_window: int = 21, drift_threshold: float = 0.3):
57
- self.model = model
58
- self.lr = lr
59
- self.adaptation_window = adaptation_window
 
 
 
 
 
 
60
  self.drift_threshold = drift_threshold
61
- self.drift_detector = DriftDetector()
62
- self.adaptation_count = 0
63
- self.ic_history = []
64
- self.performance_history = []
65
-
66
- def check_and_adapt(self, X_new: np.ndarray, y_new: np.ndarray,
67
- X_ref: Optional[np.ndarray] = None) -> Dict:
68
- drift_result = self.drift_detector.detect_ks(X_new)
69
-
70
- if drift_result['drift']:
71
- print(f"⚠️ Drift detected: {drift_result['n_features_drifted']}/{drift_result['total_features']} features shifted")
72
- self._adapt(X_new, y_new)
73
- self.adaptation_count += 1
74
- return {'adapted': True, 'drift': drift_result}
75
- return {'adapted': False, 'drift': drift_result}
76
-
77
- def _adapt(self, X: np.ndarray, y: np.ndarray, epochs: int = 5):
78
- self.model.train()
79
- optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
80
- criterion = nn.MSELoss()
81
-
82
- X_t = torch.FloatTensor(X)
83
- y_t = torch.FloatTensor(y).unsqueeze(1)
84
-
85
- for epoch in range(epochs):
86
- optimizer.zero_grad()
87
- pred = self.model(X_t)
88
- loss = criterion(pred, y_t)
89
- loss.backward()
90
- optimizer.step()
91
-
92
- print(f" Adapted model with {epochs} epochs, loss={loss.item():.6f}")
93
-
94
- def track_performance(self, predictions: np.ndarray, actuals: np.ndarray):
95
- from scipy.stats import spearmanr
96
- ic, _ = spearmanr(predictions, actuals)
97
- self.ic_history.append(ic)
98
-
99
- # Check if IC is degrading
100
- if len(self.ic_history) > 63:
101
- recent_ic = np.mean(self.ic_history[-21:])
102
- long_ic = np.mean(self.ic_history[-63:])
103
- degradation = (long_ic - recent_ic) / (abs(long_ic) + 1e-8)
104
 
105
- if degradation > 0.3:
106
- print(f"⚠️ IC degradation: recent={recent_ic:.4f}, long={long_ic:.4f}, degradation={degradation:.2%}")
107
- return 'degrading'
 
 
 
 
 
 
 
108
 
109
- return 'stable'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Online Learning — Per-Symbol Adaptive Models
2
+
3
+ Why this matters for Jane Street level:
4
+ - Markets CHANGE. A model trained on SPY 2022 fails on SPY 2024.
5
+ - Each asset has unique microstructure, seasonality, regime behavior.
6
+ - Static models lose predictive power over time (model decay).
7
+
8
+ Solution: Online / Continual Learning
9
+ - Update models incrementally on every new observation
10
+ - Per-symbol parameters (some assets trend, others mean-revert)
11
+ - Meta-learning: learn HOW to adapt quickly
12
+ - Concept drift detection: auto-detect when old model is wrong
13
+
14
+ Based on:
15
+ - Vapnik (1998): Online SVM
16
+ - Cesa-Bianchi & Lugosi (2006): Prediction, Learning, Games
17
+ - Finn et al. (2017): MAML (Model-Agnostic Meta-Learning)
18
+ - Gama et al. (2014): A survey on concept drift adaptation
19
+ """
20
  import numpy as np
21
  import pandas as pd
22
+ from typing import Dict, List, Tuple, Optional, Callable
23
+ from collections import defaultdict
24
+ import warnings
25
+ warnings.filterwarnings('ignore')
26
 
27
 
28
+ def sigmoid(x):
29
+ return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
30
+
31
+
32
+ class OnlineLogisticRegression:
33
+ """
34
+ Online logistic regression with adaptive learning rate.
35
 
36
+ Uses exponential weighting: recent data matters more.
37
+ Learning rate adapts to gradient variance.
38
+ """
39
+
40
+ def __init__(self,
41
+ n_features: int = 10,
42
+ initial_lr: float = 0.01,
43
+ lr_decay: float = 0.999,
44
+ l2_reg: float = 0.01,
45
+ min_lr: float = 1e-6):
46
+ self.n_features = n_features
47
+ self.lr = initial_lr
48
+ self.initial_lr = initial_lr
49
+ self.lr_decay = lr_decay
50
+ self.l2_reg = l2_reg
51
+ self.min_lr = min_lr
52
+
53
+ self.weights = np.zeros(n_features)
54
+ self.bias = 0.0
55
+
56
+ # Adaptive state
57
+ self.grad_moment2 = np.zeros(n_features)
58
+ self.bias_moment2 = 0.0
59
+ self.t = 0
60
+
61
+ # Performance tracking
62
+ self.predictions = []
63
+ self.actuals = []
64
+ self.grad_norms = []
65
+
66
+ def predict_proba(self, x: np.ndarray) -> float:
67
+ """Predict probability of positive class"""
68
+ z = np.dot(x, self.weights) + self.bias
69
+ return sigmoid(z)
70
+
71
+ def predict(self, x: np.ndarray) -> int:
72
+ return 1 if self.predict_proba(x) > 0.5 else 0
73
+
74
+ def update(self, x: np.ndarray, y: int) -> Dict:
75
+ """
76
+ Single-step online update.
77
+
78
+ Args:
79
+ x: feature vector (n_features,)
80
+ y: label (0 or 1)
81
+
82
+ Returns:
83
+ Update metrics
84
+ """
85
+ self.t += 1
86
+
87
+ # Forward
88
+ z = np.dot(x, self.weights) + self.bias
89
+ pred = sigmoid(z)
90
+
91
+ # Gradient
92
+ error = pred - y
93
+ grad_w = error * x + self.l2_reg * self.weights
94
+ grad_b = error
95
+
96
+ # Adaptive learning rate (AdaGrad-like)
97
+ self.grad_moment2 += grad_w ** 2
98
+ self.bias_moment2 += grad_b ** 2
99
+
100
+ lr_w = self.lr / (np.sqrt(self.grad_moment2) + 1e-8)
101
+ lr_b = self.lr / (np.sqrt(self.bias_moment2) + 1e-8)
102
+
103
+ # Update
104
+ self.weights -= lr_w * grad_w
105
+ self.bias -= lr_b * grad_b
106
+
107
+ # Decay learning rate
108
+ self.lr = max(self.lr * self.lr_decay, self.min_lr)
109
+
110
+ # Track
111
+ self.predictions.append(pred)
112
+ self.actuals.append(y)
113
+ self.grad_norms.append(np.linalg.norm(grad_w))
114
+
115
+ return {
116
+ 'pred': pred,
117
+ 'error': error,
118
+ 'grad_norm': np.linalg.norm(grad_w),
119
+ 'lr': self.lr
120
+ }
121
+
122
+ def get_performance(self, last_n: int = 100) -> Dict:
123
+ """Get recent performance metrics"""
124
+ if len(self.actuals) < 2:
125
+ return {'accuracy': 0.5}
126
+
127
+ n = min(last_n, len(self.actuals))
128
+ preds = np.array(self.predictions[-n:]) > 0.5
129
+ actuals = np.array(self.actuals[-n:])
130
+
131
+ accuracy = np.mean(preds == actuals)
132
+
133
+ # Directional accuracy for returns
134
+ if len(actuals) >= 10:
135
+ # Use last 10 predictions as a sequence
136
+ pred_returns = np.diff(self.predictions[-10:])
137
+ actual_returns = np.diff(self.actuals[-10:])
138
+ directional = np.mean(np.sign(pred_returns) == np.sign(actual_returns)) if len(pred_returns) > 0 else 0.5
139
+ else:
140
+ directional = accuracy
141
+
142
+ return {
143
+ 'accuracy': accuracy,
144
+ 'directional_accuracy': directional,
145
+ 'avg_grad_norm': np.mean(self.grad_norms[-n:]) if self.grad_norms else 0,
146
+ 'current_lr': self.lr,
147
+ 'n_updates': self.t
148
+ }
149
+
150
+
151
+ class PerSymbolAdaptiveModel:
152
+ """
153
+ Maintain separate online models for each symbol.
154
+
155
+ Key insight: SPY behaves differently from TSLA.
156
+ Each asset needs its own:
157
+ - Feature weights
158
+ - Learning rate schedule
159
+ - Regime detection
160
+ """
161
+
162
+ def __init__(self,
163
+ n_features: int = 10,
164
+ base_lr: float = 0.01,
165
+ symbols: Optional[List[str]] = None):
166
+ self.n_features = n_features
167
+ self.base_lr = base_lr
168
+ self.symbols = symbols or []
169
+
170
+ # Per-symbol models
171
+ self.models: Dict[str, OnlineLogisticRegression] = {}
172
+
173
+ # Performance tracking
174
+ self.symbol_performance: Dict[str, List[Dict]] = defaultdict(list)
175
+
176
+ # Auto-detect symbols
177
+ self.seen_symbols = set()
178
+
179
+ def _get_or_create_model(self, symbol: str) -> OnlineLogisticRegression:
180
+ """Get model for symbol, create if new"""
181
+ if symbol not in self.models:
182
+ # Meta-learn initial weights from similar symbols
183
+ init_weights = self._meta_initialize(symbol)
184
+
185
+ model = OnlineLogisticRegression(
186
+ n_features=self.n_features,
187
+ initial_lr=self.base_lr * np.random.uniform(0.8, 1.2)
188
+ )
189
+
190
+ if init_weights is not None:
191
+ model.weights = init_weights
192
+
193
+ self.models[symbol] = model
194
+ self.seen_symbols.add(symbol)
195
+
196
+ return self.models[symbol]
197
+
198
+ def _meta_initialize(self, new_symbol: str) -> Optional[np.ndarray]:
199
+ """
200
+ Meta-learning: initialize new symbol model from similar symbols.
201
+
202
+ Use average of best-performing similar models.
203
+ """
204
+ if len(self.models) < 3:
205
+ return None
206
+
207
+ # Get recent performance
208
+ perf = []
209
+ for sym, model in self.models.items():
210
+ p = model.get_performance(last_n=50)
211
+ perf.append((sym, p.get('accuracy', 0.5), model.weights))
212
+
213
+ # Use top 3 models as initialization
214
+ perf.sort(key=lambda x: x[1], reverse=True)
215
+ top_weights = [p[2] for p in perf[:3]]
216
+
217
+ return np.mean(top_weights, axis=0)
218
+
219
+ def update(self, symbol: str, x: np.ndarray, y: int) -> Dict:
220
+ """Update model for a specific symbol"""
221
+ model = self._get_or_create_model(symbol)
222
+ metrics = model.update(x, y)
223
+
224
+ # Track performance
225
+ perf = model.get_performance(last_n=20)
226
+ self.symbol_performance[symbol].append(perf)
227
+
228
+ metrics['symbol'] = symbol
229
+ return metrics
230
+
231
+ def predict(self, symbol: str, x: np.ndarray) -> Dict:
232
+ """Predict for a specific symbol"""
233
+ model = self._get_or_create_model(symbol)
234
+ prob = model.predict_proba(x)
235
+
236
+ return {
237
+ 'symbol': symbol,
238
+ 'probability': prob,
239
+ 'prediction': 1 if prob > 0.5 else 0,
240
+ 'confidence': abs(prob - 0.5) * 2, # 0 = unsure, 1 = certain
241
+ 'model_age': model.t
242
+ }
243
+
244
+ def get_symbol_ranking(self) -> pd.DataFrame:
245
+ """Rank symbols by recent model performance"""
246
+ rows = []
247
+
248
+ for symbol, model in self.models.items():
249
+ perf = model.get_performance(last_n=100)
250
+ rows.append({
251
+ 'symbol': symbol,
252
+ 'accuracy': perf['accuracy'],
253
+ 'directional_accuracy': perf['directional_accuracy'],
254
+ 'n_samples': model.t,
255
+ 'current_lr': perf['current_lr'],
256
+ 'grad_norm': perf['avg_grad_norm']
257
+ })
258
+
259
+ df = pd.DataFrame(rows)
260
+ if not df.empty:
261
+ df = df.sort_values('directional_accuracy', ascending=False)
262
+
263
+ return df
264
 
265
+ def detect_concept_drift(self, symbol: str,
266
+ window_short: int = 50,
267
+ window_long: int = 200) -> Dict:
268
+ """
269
+ Detect if the relationship between features and target has changed.
270
+
271
+ Uses accuracy comparison: recent vs older performance.
272
+ If recent << older concept drift detected → need retraining/adaptation.
273
+ """
274
+ model = self.models.get(symbol)
275
+ if model is None or len(model.actuals) < window_long:
276
+ return {'drift_detected': False, 'reason': 'insufficient_data'}
277
+
278
+ recent = model.get_performance(last_n=window_short)['accuracy']
279
+ older = model.get_performance(last_n=window_long)['accuracy']
280
+
281
+ # Drift if recent accuracy significantly worse
282
+ drift_threshold = -0.15 # 15% accuracy drop
283
+ drift_score = recent - older
284
+
285
+ drift_detected = drift_score < drift_threshold
286
+
287
+ return {
288
+ 'drift_detected': drift_detected,
289
+ 'drift_score': drift_score,
290
+ 'recent_accuracy': recent,
291
+ 'older_accuracy': older,
292
+ 'threshold': drift_threshold,
293
+ 'action': 'reset_learning_rate' if drift_detected else 'continue',
294
+ 'symbol': symbol
295
+ }
296
+
297
+ def adapt_to_drift(self, symbol: str):
298
+ """Adapt model when drift detected"""
299
+ model = self.models.get(symbol)
300
+ if model is None:
301
+ return
302
+
303
+ # Reset learning rate to initial (forget old, learn new)
304
+ model.lr = model.initial_lr * 2 # Higher LR to adapt faster
305
+ model.grad_moment2 = np.zeros(self.n_features)
306
+ model.bias_moment2 = 0.0
307
+
308
+ print(f" [Drift] Reset learning rate for {symbol} to {model.lr:.4f}")
309
+
310
+ def get_full_state(self) -> Dict:
311
+ """Export full state for persistence"""
312
+ return {
313
+ 'n_features': self.n_features,
314
+ 'base_lr': self.base_lr,
315
+ 'symbols': list(self.seen_symbols),
316
+ 'models': {
317
+ sym: {
318
+ 'weights': model.weights.tolist(),
319
+ 'bias': model.bias,
320
+ 'n_updates': model.t,
321
+ 'lr': model.lr
322
+ }
323
+ for sym, model in self.models.items()
324
+ }
325
+ }
326
 
327
 
328
+ class ConceptDriftMonitor:
329
+ """
330
+ System-wide concept drift monitoring across all symbols.
331
+
332
+ Automatically detects when markets have structurally changed
333
+ and triggers model adaptation.
334
+ """
335
+
336
+ def __init__(self,
337
+ per_symbol_model: PerSymbolAdaptiveModel,
338
+ check_interval: int = 100,
339
+ drift_threshold: float = -0.15):
340
+ self.model = per_symbol_model
341
+ self.check_interval = check_interval
342
  self.drift_threshold = drift_threshold
343
+ self.step_count = 0
344
+
345
+ self.drift_history = []
346
+ self.adaptation_log = []
347
+
348
+ def check_all_symbols(self) -> List[Dict]:
349
+ """Check all symbols for drift and adapt if needed"""
350
+ self.step_count += 1
351
+
352
+ if self.step_count % self.check_interval != 0:
353
+ return []
354
+
355
+ results = []
356
+
357
+ for symbol in self.model.seen_symbols:
358
+ drift_result = self.model.detect_concept_drift(symbol)
359
+ results.append(drift_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ if drift_result['drift_detected']:
362
+ self.model.adapt_to_drift(symbol)
363
+
364
+ self.drift_history.append({
365
+ 'step': self.step_count,
366
+ 'symbol': symbol,
367
+ 'score': drift_result['drift_score'],
368
+ 'recent_acc': drift_result['recent_accuracy'],
369
+ 'older_acc': drift_result['older_accuracy']
370
+ })
371
 
372
+ return results
373
+
374
+ def get_drift_summary(self) -> pd.DataFrame:
375
+ """Summary of all detected drifts"""
376
+ return pd.DataFrame(self.drift_history)
377
+
378
+
379
+ if __name__ == '__main__':
380
+ print("=" * 70)
381
+ print(" ONLINE LEARNING — PER-SYMBOL ADAPTIVE MODELS")
382
+ print("=" * 70)
383
+
384
+ # Simulate multiple symbols with different behaviors
385
+ np.random.seed(42)
386
+
387
+ # Symbol A: Strong momentum signal
388
+ # Symbol B: Weak/noise
389
+ # Symbol C: Regime switch at step 500
390
+
391
+ model = PerSymbolAdaptiveModel(n_features=5, base_lr=0.05)
392
+ monitor = ConceptDriftMonitor(model, check_interval=100)
393
+
394
+ n_steps = 800
395
+
396
+ for step in range(n_steps):
397
+ # Symbol A: feature 0 predicts direction with 65% accuracy
398
+ x_a = np.random.randn(5)
399
+ true_dir_a = 1 if x_a[0] > 0 else 0
400
+ if np.random.rand() > 0.65:
401
+ true_dir_a = 1 - true_dir_a # 35% noise
402
+
403
+ # Symbol B: no signal, pure noise
404
+ x_b = np.random.randn(5)
405
+ true_dir_b = np.random.randint(0, 2)
406
+
407
+ # Symbol C: regime switch at step 500
408
+ x_c = np.random.randn(5)
409
+ if step < 500:
410
+ true_dir_c = 1 if x_c[0] > 0 else 0 # feature 0 matters
411
+ if np.random.rand() > 0.6:
412
+ true_dir_c = 1 - true_dir_c
413
+ else:
414
+ # Regime switch: now feature 1 predicts (opposite!)
415
+ true_dir_c = 1 if x_c[1] < 0 else 0
416
+ if np.random.rand() > 0.6:
417
+ true_dir_c = 1 - true_dir_c
418
+
419
+ # Update models
420
+ model.update('AAPL', x_a, true_dir_a)
421
+ model.update('JUNK', x_b, true_dir_b)
422
+ model.update('REGIME', x_c, true_dir_c)
423
+
424
+ # Periodic drift check
425
+ if step % 100 == 0 and step > 0:
426
+ monitor.check_all_symbols()
427
+
428
+ # Results
429
+ print(f"\nTrained on {n_steps} steps per symbol")
430
+ print(f"\nPer-Symbol Performance:")
431
+ ranking = model.get_symbol_ranking()
432
+ print(ranking.to_string(index=False))
433
+
434
+ # Drift detection for REGIME symbol
435
+ drift_result = model.detect_concept_drift('REGIME', window_short=50, window_long=300)
436
+ print(f"\nREGIME Symbol Drift Detection:")
437
+ print(f" Drift detected: {drift_result['drift_detected']}")
438
+ print(f" Recent accuracy: {drift_result['recent_accuracy']:.3f}")
439
+ print(f" Older accuracy: {drift_result['older_accuracy']:.3f}")
440
+ print(f" Drift score: {drift_result['drift_score']:+.3f}")
441
+
442
+ print(f"\n Key Insights:")
443
+ print(f" - AAPL model should have ~60-65% accuracy (real signal)")
444
+ print(f" - JUNK model should have ~50% accuracy (pure noise)")
445
+ print(f" - REGIME model should detect drift at step 500")
446
+ print(f" - Each symbol gets its OWN learning rate and weights")
447
+ print(f" - Drift triggers adaptive LR reset")