EgoisticCoderX commited on
Commit
7a8b57c
Β·
verified Β·
1 Parent(s): c92b888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -90
app.py CHANGED
@@ -1,19 +1,17 @@
1
- """"
2
  =================================================================
3
- Disaster AI_HuggingFace Spaces API
4
- Final version_all fixes applied
5
  =================================================================
6
  """
7
 
8
  import os
9
  import io
10
- import json
11
  import time
12
  import base64
13
  import threading
14
  import traceback
15
  import numpy as np
16
- from pathlib import Path
17
  from PIL import Image
18
  import cv2
19
  import torch
@@ -21,7 +19,6 @@ import requests
21
 
22
  from fastapi import FastAPI, File, UploadFile, HTTPException
23
  from fastapi.middleware.cors import CORSMiddleware
24
- from fastapi.responses import JSONResponse
25
  from huggingface_hub import hf_hub_download
26
 
27
  # ════════════════════════════════
@@ -30,7 +27,7 @@ from huggingface_hub import hf_hub_download
30
  app = FastAPI(
31
  title="Disaster AI Inference API",
32
  description="Multi-model disaster scene analysis for Dokai / RoboXavier",
33
- version="2.0.0",
34
  )
35
 
36
  app.add_middleware(
@@ -41,13 +38,17 @@ app.add_middleware(
41
  )
42
 
43
  # ════════════════════════════════
44
- # Configuration
45
  # ════════════════════════════════
46
- HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "")
 
 
47
  ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "rltTa8UANpettqj6aHJG")
48
  MODEL_CACHE_DIR = "/tmp/model_cache"
 
49
  os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
50
 
 
51
  TARGET_CLASSES = {
52
  0: "injured_civilian",
53
  1: "trapped_civilian",
@@ -62,14 +63,22 @@ CLASS_PRIORITY = {
62
  "rescue_personnel": 0.0,
63
  }
64
 
 
 
 
 
 
 
 
 
65
  # ════════════════════════════════
66
  # Model Registry
67
  # ════════════════════════════════
68
  class ModelRegistry:
69
  def __init__(self):
70
- self._models = {}
71
- self._errors = {}
72
- self._lock = threading.Lock()
73
 
74
  def get(self, name):
75
  return self._models.get(name)
@@ -103,20 +112,17 @@ registry = ModelRegistry()
103
  # ════════════════════════════════
104
 
105
  def load_ladi_model():
106
- """Load LADI-v2 classifier from HuggingFace Hub."""
107
  if registry.is_loaded("ladi"):
108
  return registry.get("ladi")
109
-
110
  try:
111
  from transformers import AutoImageProcessor, AutoModelForImageClassification
112
 
113
  print("⬇️ Loading MITLL/LADI-v2-classifier-small ...")
114
-
115
  processor = AutoImageProcessor.from_pretrained(
116
  "MITLL/LADI-v2-classifier-small",
117
  cache_dir=MODEL_CACHE_DIR,
118
  )
119
-
120
  model = AutoModelForImageClassification.from_pretrained(
121
  "MITLL/LADI-v2-classifier-small",
122
  cache_dir=MODEL_CACHE_DIR,
@@ -124,8 +130,6 @@ def load_ladi_model():
124
  ignore_mismatched_sizes=True,
125
  )
126
  model.eval()
127
-
128
- # CPU only on HF free tier
129
  registry.register("ladi", {"model": model, "processor": processor})
130
  print("βœ… LADI-v2 ready")
131
  return registry.get("ladi")
@@ -142,7 +146,7 @@ def load_victim_model():
142
  return registry.get("victim")
143
 
144
  if not HF_VICTIM_MODEL_REPO:
145
- registry.set_error("victim", "HF_VICTIM_MODEL_REPO secret not set β€” train the model first")
146
  return None
147
 
148
  try:
@@ -153,6 +157,7 @@ def load_victim_model():
153
  repo_id=HF_VICTIM_MODEL_REPO,
154
  filename="best.pt",
155
  cache_dir=MODEL_CACHE_DIR,
 
156
  )
157
  model = YOLO(model_path)
158
  registry.register("victim", model)
@@ -165,33 +170,69 @@ def load_victim_model():
165
  return None
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # ════════════════════════════════
169
- # Startup β€” preload everything
170
  # ════════════════════════════════
171
  @app.on_event("startup")
172
  async def startup_event():
173
- print("\n" + "="*50)
174
- print("πŸš€ Disaster AI API starting up...")
175
- print("="*50)
176
 
177
- # Always load LADI β€” it's a public HF model
178
  load_ladi_model()
179
 
180
- # Only load victim model if repo is configured
181
  if HF_VICTIM_MODEL_REPO:
182
  load_victim_model()
183
  else:
184
  print("⚠️ Victim model skipped β€” HF_VICTIM_MODEL_REPO not set")
185
- print(" Train the model first, then add the secret to this Space")
186
 
187
- print("="*50)
188
- print(f"πŸ“Š Registry status: {registry.status()}")
189
- print("="*50 + "\n")
 
 
 
 
 
 
190
 
191
 
192
  # ════════════════════════════════
193
- # Utility
194
  # ════════════════════════════════
 
195
  def read_image(file_bytes: bytes) -> np.ndarray:
196
  nparr = np.frombuffer(file_bytes, np.uint8)
197
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
@@ -239,7 +280,7 @@ def compute_triage(detections: list) -> dict:
239
  "total": 0, "critical": 0, "high": 0,
240
  "moderate": 0, "low": 0,
241
  "highest_score": 0.0,
242
- "action": "βœ… No victims detected",
243
  "ranked_victims": [],
244
  }
245
 
@@ -265,10 +306,10 @@ def compute_triage(detections: list) -> dict:
265
  low = sum(1 for d in scored if d["priority_rank"] == "LOW")
266
 
267
  action = (
268
- "⚠️ IMMEDIATE RESCUE β€” Critical victims present" if critical else
269
- "πŸ”΄ Deploy rescue team β€” High priority victims" if high else
270
- "🟑 Assess and triage β€” Moderate victims present" if moderate else
271
- "🟒 Low priority β€” Monitor the area"
272
  )
273
 
274
  return {
@@ -283,6 +324,31 @@ def compute_triage(detections: list) -> dict:
283
  }
284
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  # ════════════════════════════════
287
  # Routes
288
  # ════════════════════════════════
@@ -290,16 +356,17 @@ def compute_triage(detections: list) -> dict:
290
  @app.get("/")
291
  def root():
292
  return {
293
- "service": "Disaster AI Inference API",
294
- "version": "2.0.0",
295
- "status": registry.status(),
296
  "endpoints": {
297
  "GET /health": "Health check + model status",
298
- "POST /classify": "LADI-v2 scene classification",
299
  "POST /detect/victims": "Victim detection + triage priority",
300
- "POST /detect/vehicles": "Emergency vehicle detection",
301
- "POST /analyze/full": "All models in one call",
302
- }
 
303
  }
304
 
305
 
@@ -309,12 +376,18 @@ def health():
309
  "status": "ok",
310
  "registry": registry.status(),
311
  "gpu_available": torch.cuda.is_available(),
312
- "timestamp": time.time(),
 
 
 
 
 
 
313
  }
314
 
315
 
316
  # ─────────────────────────────────────────────
317
- # LADI-v2 Classification
318
  # ─────────────────────────────────────────────
319
  @app.post("/classify")
320
  async def classify_scene(
@@ -322,8 +395,8 @@ async def classify_scene(
322
  top_k: int = 5,
323
  ):
324
  """
325
- Classify disaster scene using LADI-v2.
326
- Returns top-k predicted damage categories with confidence scores.
327
  """
328
  ladi = load_ladi_model()
329
  if ladi is None:
@@ -350,9 +423,9 @@ async def classify_scene(
350
  except Exception as e:
351
  raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
352
 
353
- elapsed = round((time.time() - t0) * 1000, 2)
 
354
 
355
- id2label = model.config.id2label
356
  all_scores = sorted(
357
  [
358
  {
@@ -365,7 +438,6 @@ async def classify_scene(
365
  reverse=True,
366
  )
367
 
368
- # Exclude water/flood from top predictions (not relevant for rover)
369
  relevant = [
370
  s for s in all_scores
371
  if not any(w in s["class"] for w in ["water", "flood"])
@@ -380,7 +452,7 @@ async def classify_scene(
380
 
381
 
382
  # ─────────────────────────────────────────────
383
- # Victim Detection
384
  # ─────────────────────────────────────────────
385
  @app.post("/detect/victims")
386
  async def detect_victims(
@@ -388,8 +460,9 @@ async def detect_victims(
388
  confidence: float = 0.35,
389
  ):
390
  """
391
- Detect victims and classify by triage priority.
392
- Returns CRITICAL / HIGH / MODERATE / LOW ranked detections.
 
393
  """
394
  model = load_victim_model()
395
  if model is None:
@@ -418,11 +491,11 @@ async def detect_victims(
418
  "class": TARGET_CLASSES.get(cls_id, "unknown"),
419
  "class_id": cls_id,
420
  "confidence": round(conf_val, 4),
421
- "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
422
  })
423
 
424
- triage = compute_triage(raw)
425
- victims = triage.pop("ranked_victims", raw)
426
 
427
  return {
428
  "detections": victims,
@@ -432,26 +505,25 @@ async def detect_victims(
432
 
433
 
434
  # ─────────────────────────────────────────────
435
- # Emergency Vehicle Detection
436
  # ─────────────────────────────────────────────
437
  @app.post("/detect/vehicles")
438
  async def detect_vehicles(file: UploadFile = File(...)):
439
  """
440
- Detect emergency vehicles using Roboflow.
441
- Returns ambulance / fire truck / rescue vehicle detections.
442
  """
443
  if not ROBOFLOW_API_KEY:
444
  raise HTTPException(status_code=503, detail="ROBOFLOW_API_KEY secret not set")
445
 
446
- contents = await file.read()
447
- img = read_image(contents)
448
-
449
  t0 = time.time()
450
  detections = call_roboflow(img, "ambulance-4bova/1", confidence=40)
451
  elapsed = round((time.time() - t0) * 1000, 2)
452
 
453
- has_ambulance = any("ambulance" in d["class"].lower() for d in detections)
454
- has_fire_truck = any("fire" in d["class"].lower() for d in detections)
455
 
456
  return {
457
  "detections": detections,
@@ -465,30 +537,106 @@ async def detect_vehicles(file: UploadFile = File(...)):
465
 
466
 
467
  # ─────────────────────────────────────────────
468
- # Full Analysis β€” all models in one call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  # ─────────────────────────────────────────────
470
  @app.post("/analyze/full")
471
  async def full_analysis(
472
  file: UploadFile = File(...),
 
473
  run_victims: bool = True,
474
  run_vehicles: bool = True,
475
- run_classify: bool = True,
476
  ):
477
  """
478
  Run all available models on one image.
479
- This is the main endpoint your rover Flask app should call.
480
-
481
- Returns unified JSON with zone color, all detections, triage summary.
482
  """
483
  contents = await file.read()
484
  t_total = time.time()
485
  output = {}
486
 
487
- # ── LADI classification ──
488
  if run_classify:
489
  try:
490
- fake_file = UploadFile(filename="f.jpg", file=io.BytesIO(contents))
491
- output["classification"] = await classify_scene(fake_file)
 
492
  except HTTPException as e:
493
  output["classification"] = {"error": e.detail}
494
  except Exception as e:
@@ -497,39 +645,43 @@ async def full_analysis(
497
  # ── Victim detection ──
498
  if run_victims:
499
  try:
500
- fake_file = UploadFile(filename="f.jpg", file=io.BytesIO(contents))
501
- output["victims"] = await detect_victims(fake_file)
 
502
  except HTTPException as e:
503
  output["victims"] = {"error": e.detail}
504
  except Exception as e:
505
  output["victims"] = {"error": str(e)}
506
 
507
- # ── Vehicle detection ──
508
  if run_vehicles:
509
  try:
510
- fake_file = UploadFile(filename="f.jpg", file=io.BytesIO(contents))
511
- output["vehicles"] = await detect_vehicles(fake_file)
 
512
  except HTTPException as e:
513
  output["vehicles"] = {"error": e.detail}
514
  except Exception as e:
515
  output["vehicles"] = {"error": str(e)}
516
 
517
- # ── Zone color ──
518
- triage_data = output.get("victims", {}).get("triage_summary", {})
519
- classify_top = output.get("classification", {}).get("top_predictions", [{}])
520
- top_class = classify_top[0].get("class", "") if classify_top else ""
 
 
 
 
 
 
521
 
522
- critical = triage_data.get("critical", 0)
523
- high = triage_data.get("high", 0)
 
 
 
524
 
525
- if critical > 0 or any(w in top_class for w in ["destroy", "collapse", "major"]):
526
- zone_color = "red"
527
- elif high > 0 or "minor_damage" in top_class:
528
- zone_color = "orange"
529
- elif triage_data.get("total", 0) > 0:
530
- zone_color = "yellow"
531
- else:
532
- zone_color = "green"
533
 
534
  return {
535
  "zone_color": zone_color,
 
1
+ """
2
  =================================================================
3
+ Disaster AI - HuggingFace Spaces API
4
+ Final version - all models integrated
5
  =================================================================
6
  """
7
 
8
  import os
9
  import io
 
10
  import time
11
  import base64
12
  import threading
13
  import traceback
14
  import numpy as np
 
15
  from PIL import Image
16
  import cv2
17
  import torch
 
19
 
20
  from fastapi import FastAPI, File, UploadFile, HTTPException
21
  from fastapi.middleware.cors import CORSMiddleware
 
22
  from huggingface_hub import hf_hub_download
23
 
24
  # ════════════════════════════════
 
27
  app = FastAPI(
28
  title="Disaster AI Inference API",
29
  description="Multi-model disaster scene analysis for Dokai / RoboXavier",
30
+ version="3.0.0",
31
  )
32
 
33
  app.add_middleware(
 
38
  )
39
 
40
  # ════════════════════════════════
41
+ # Configuration β€” all from secrets
42
  # ════════════════════════════════
43
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
44
+ HF_VICTIM_MODEL_REPO = os.getenv("HF_VICTIM_MODEL_REPO", "EgoisticCoderX/dokai-victim-detection")
45
+ HF_XVIEW2_MODEL_REPO = os.getenv("HF_XVIEW2_MODEL_REPO", "EgoisticCoderX/dokai-xview2-damage")
46
  ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", "rltTa8UANpettqj6aHJG")
47
  MODEL_CACHE_DIR = "/tmp/model_cache"
48
+
49
  os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
50
 
51
+ # ── Victim detection class map ──
52
  TARGET_CLASSES = {
53
  0: "injured_civilian",
54
  1: "trapped_civilian",
 
63
  "rescue_personnel": 0.0,
64
  }
65
 
66
+ # ── xView2 damage severity map ──
67
+ DAMAGE_SEVERITY_ORDER = {
68
+ "destroyed": 0,
69
+ "major_damage": 1,
70
+ "minor_damage": 2,
71
+ "no_damage": 3,
72
+ }
73
+
74
  # ════════════════════════════════
75
  # Model Registry
76
  # ════════════════════════════════
77
  class ModelRegistry:
78
  def __init__(self):
79
+ self._models = {}
80
+ self._errors = {}
81
+ self._lock = threading.Lock()
82
 
83
  def get(self, name):
84
  return self._models.get(name)
 
112
  # ════════════════════════════════
113
 
114
  def load_ladi_model():
115
+ """Load LADI-v2 scene classifier from HuggingFace Hub."""
116
  if registry.is_loaded("ladi"):
117
  return registry.get("ladi")
 
118
  try:
119
  from transformers import AutoImageProcessor, AutoModelForImageClassification
120
 
121
  print("⬇️ Loading MITLL/LADI-v2-classifier-small ...")
 
122
  processor = AutoImageProcessor.from_pretrained(
123
  "MITLL/LADI-v2-classifier-small",
124
  cache_dir=MODEL_CACHE_DIR,
125
  )
 
126
  model = AutoModelForImageClassification.from_pretrained(
127
  "MITLL/LADI-v2-classifier-small",
128
  cache_dir=MODEL_CACHE_DIR,
 
130
  ignore_mismatched_sizes=True,
131
  )
132
  model.eval()
 
 
133
  registry.register("ladi", {"model": model, "processor": processor})
134
  print("βœ… LADI-v2 ready")
135
  return registry.get("ladi")
 
146
  return registry.get("victim")
147
 
148
  if not HF_VICTIM_MODEL_REPO:
149
+ registry.set_error("victim", "HF_VICTIM_MODEL_REPO secret not set")
150
  return None
151
 
152
  try:
 
157
  repo_id=HF_VICTIM_MODEL_REPO,
158
  filename="best.pt",
159
  cache_dir=MODEL_CACHE_DIR,
160
+ token=HF_TOKEN,
161
  )
162
  model = YOLO(model_path)
163
  registry.register("victim", model)
 
170
  return None
171
 
172
 
173
+ def load_xview2_model():
174
+ """Load xView2 building damage YOLOv8 model from HuggingFace Hub."""
175
+ if registry.is_loaded("xview2"):
176
+ return registry.get("xview2")
177
+
178
+ if not HF_XVIEW2_MODEL_REPO:
179
+ registry.set_error("xview2", "HF_XVIEW2_MODEL_REPO secret not set")
180
+ return None
181
+
182
+ try:
183
+ from ultralytics import YOLO
184
+
185
+ print(f"⬇️ Loading xView2 model from {HF_XVIEW2_MODEL_REPO} ...")
186
+ model_path = hf_hub_download(
187
+ repo_id=HF_XVIEW2_MODEL_REPO,
188
+ filename="best.pt",
189
+ cache_dir=MODEL_CACHE_DIR,
190
+ token=HF_TOKEN,
191
+ )
192
+ model = YOLO(model_path)
193
+ registry.register("xview2", model)
194
+ print("βœ… xView2 damage model ready")
195
+ return model
196
+
197
+ except Exception as e:
198
+ print(f"❌ xView2 model load failed:\n{traceback.format_exc()}")
199
+ registry.set_error("xview2", e)
200
+ return None
201
+
202
+
203
  # ════════════════════════════════
204
+ # Startup
205
  # ════════════════════════════════
206
  @app.on_event("startup")
207
  async def startup_event():
208
+ print("\n" + "=" * 55)
209
+ print("πŸš€ Disaster AI API v3.0 starting up...")
210
+ print("=" * 55)
211
 
212
+ # LADI always loads β€” public model
213
  load_ladi_model()
214
 
215
+ # Victim model β€” needs secret
216
  if HF_VICTIM_MODEL_REPO:
217
  load_victim_model()
218
  else:
219
  print("⚠️ Victim model skipped β€” HF_VICTIM_MODEL_REPO not set")
 
220
 
221
+ # xView2 model β€” needs secret
222
+ if HF_XVIEW2_MODEL_REPO:
223
+ load_xview2_model()
224
+ else:
225
+ print("⚠️ xView2 model skipped β€” HF_XVIEW2_MODEL_REPO not set")
226
+
227
+ print("=" * 55)
228
+ print(f"πŸ“Š Registry: {registry.status()}")
229
+ print("=" * 55 + "\n")
230
 
231
 
232
  # ════════════════════════════════
233
+ # Utilities
234
  # ════════════════════════════════
235
+
236
  def read_image(file_bytes: bytes) -> np.ndarray:
237
  nparr = np.frombuffer(file_bytes, np.uint8)
238
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 
280
  "total": 0, "critical": 0, "high": 0,
281
  "moderate": 0, "low": 0,
282
  "highest_score": 0.0,
283
+ "action": "No victims detected",
284
  "ranked_victims": [],
285
  }
286
 
 
306
  low = sum(1 for d in scored if d["priority_rank"] == "LOW")
307
 
308
  action = (
309
+ "IMMEDIATE RESCUE - Critical victims present" if critical else
310
+ "Deploy rescue team - High priority victims" if high else
311
+ "Assess and triage - Moderate victims present" if moderate else
312
+ "Low priority - Monitor the area"
313
  )
314
 
315
  return {
 
324
  }
325
 
326
 
327
+ def compute_zone_color(triage_data: dict, damage_counts: dict, top_class: str) -> str:
328
+ """
329
+ Unified zone color logic combining victim triage + building damage + scene class.
330
+ red > orange > yellow > green
331
+ """
332
+ critical = triage_data.get("critical", 0)
333
+ high = triage_data.get("high", 0)
334
+ destroyed = damage_counts.get("destroyed", 0)
335
+ major_damage = damage_counts.get("major_damage", 0)
336
+ minor_damage = damage_counts.get("minor_damage", 0)
337
+ victim_total = triage_data.get("total", 0)
338
+
339
+ scene_critical = any(w in top_class for w in ["destroy", "collapse", "major"])
340
+ scene_moderate = "minor_damage" in top_class
341
+
342
+ if critical > 0 or destroyed > 0 or scene_critical:
343
+ return "red"
344
+ elif high > 0 or major_damage > 0 or scene_moderate:
345
+ return "orange"
346
+ elif victim_total > 0 or minor_damage > 0:
347
+ return "yellow"
348
+ else:
349
+ return "green"
350
+
351
+
352
  # ════════════════════════════════
353
  # Routes
354
  # ════════════════════════════════
 
356
  @app.get("/")
357
  def root():
358
  return {
359
+ "service": "Disaster AI Inference API",
360
+ "version": "3.0.0",
361
+ "status": registry.status(),
362
  "endpoints": {
363
  "GET /health": "Health check + model status",
364
+ "POST /classify": "LADI-v2 disaster scene classification",
365
  "POST /detect/victims": "Victim detection + triage priority",
366
+ "POST /detect/vehicles": "Emergency vehicle detection (Roboflow)",
367
+ "POST /detect/damage": "xView2 building damage assessment",
368
+ "POST /analyze/full": "All models in one call (main endpoint)",
369
+ },
370
  }
371
 
372
 
 
376
  "status": "ok",
377
  "registry": registry.status(),
378
  "gpu_available": torch.cuda.is_available(),
379
+ "secrets_set": {
380
+ "HF_TOKEN": HF_TOKEN is not None,
381
+ "HF_VICTIM_MODEL_REPO": bool(HF_VICTIM_MODEL_REPO),
382
+ "HF_XVIEW2_MODEL_REPO": bool(HF_XVIEW2_MODEL_REPO),
383
+ "ROBOFLOW_API_KEY": bool(ROBOFLOW_API_KEY),
384
+ },
385
+ "timestamp": time.time(),
386
  }
387
 
388
 
389
  # ─────────────────────────────────────────────
390
+ # 1. LADI-v2 Scene Classification
391
  # ─────────────────────────────────────────────
392
  @app.post("/classify")
393
  async def classify_scene(
 
395
  top_k: int = 5,
396
  ):
397
  """
398
+ Classify disaster scene using LADI-v2 (aerial damage categories).
399
+ Returns top-k predicted labels with confidence scores.
400
  """
401
  ladi = load_ladi_model()
402
  if ladi is None:
 
423
  except Exception as e:
424
  raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
425
 
426
+ elapsed = round((time.time() - t0) * 1000, 2)
427
+ id2label = model.config.id2label
428
 
 
429
  all_scores = sorted(
430
  [
431
  {
 
438
  reverse=True,
439
  )
440
 
 
441
  relevant = [
442
  s for s in all_scores
443
  if not any(w in s["class"] for w in ["water", "flood"])
 
452
 
453
 
454
  # ─────────────────────────────────────────────
455
+ # 2. Victim Detection + Triage
456
  # ─────────────────────────────────────────────
457
  @app.post("/detect/victims")
458
  async def detect_victims(
 
460
  confidence: float = 0.35,
461
  ):
462
  """
463
+ Detect victims and rank by triage priority.
464
+ Classes: injured_civilian, trapped_civilian, safe_civilian, rescue_personnel.
465
+ Priority ranks: CRITICAL / HIGH / MODERATE / LOW
466
  """
467
  model = load_victim_model()
468
  if model is None:
 
491
  "class": TARGET_CLASSES.get(cls_id, "unknown"),
492
  "class_id": cls_id,
493
  "confidence": round(conf_val, 4),
494
+ "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
495
  })
496
 
497
+ triage = compute_triage(raw)
498
+ victims = triage.pop("ranked_victims", raw)
499
 
500
  return {
501
  "detections": victims,
 
505
 
506
 
507
  # ─────────────────────────────────────────────
508
+ # 3. Emergency Vehicle Detection (Roboflow)
509
  # ─────────────────────────────────────────────
510
  @app.post("/detect/vehicles")
511
  async def detect_vehicles(file: UploadFile = File(...)):
512
  """
513
+ Detect emergency vehicles via Roboflow hosted model.
514
+ Returns ambulance / fire truck detections and rescue_arrived flag.
515
  """
516
  if not ROBOFLOW_API_KEY:
517
  raise HTTPException(status_code=503, detail="ROBOFLOW_API_KEY secret not set")
518
 
519
+ contents = await file.read()
520
+ img = read_image(contents)
 
521
  t0 = time.time()
522
  detections = call_roboflow(img, "ambulance-4bova/1", confidence=40)
523
  elapsed = round((time.time() - t0) * 1000, 2)
524
 
525
+ has_ambulance = any("ambulance" in d["class"].lower() for d in detections)
526
+ has_fire_truck = any("fire" in d["class"].lower() for d in detections)
527
 
528
  return {
529
  "detections": detections,
 
537
 
538
 
539
  # ─────────────────────────────────────────────
540
+ # 4. xView2 Building Damage Assessment
541
+ # ─────────────────────────────────────────────
542
+ @app.post("/detect/damage")
543
+ async def detect_building_damage(
544
+ file: UploadFile = File(...),
545
+ confidence: float = 0.30,
546
+ ):
547
+ """
548
+ Assess building damage using xView2-trained YOLOv8.
549
+ Classes: destroyed, major_damage, minor_damage, no_damage.
550
+ Returns per-building detections, counts, and zone color.
551
+ """
552
+ model = load_xview2_model()
553
+ if model is None:
554
+ raise HTTPException(
555
+ status_code=503,
556
+ detail=f"xView2 model unavailable: {registry.get_error('xview2')}"
557
+ )
558
+
559
+ contents = await file.read()
560
+ img = read_image(contents)
561
+
562
+ t0 = time.time()
563
+ try:
564
+ results = model.predict(source=img, conf=confidence, verbose=False)
565
+ except Exception as e:
566
+ raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
567
+ elapsed = round((time.time() - t0) * 1000, 2)
568
+
569
+ detections = []
570
+ counts = {"destroyed": 0, "major_damage": 0, "minor_damage": 0, "no_damage": 0}
571
+
572
+ for r in results:
573
+ for box in r.boxes:
574
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
575
+ conf_val = float(box.conf[0])
576
+ cls_id = int(box.cls[0])
577
+ class_name = model.names[cls_id].lower().replace(" ", "_")
578
+
579
+ # Map raw class name to standard severity key
580
+ matched_key = next(
581
+ (k for k in counts if k in class_name),
582
+ "no_damage"
583
+ )
584
+ counts[matched_key] += 1
585
+
586
+ detections.append({
587
+ "class": class_name,
588
+ "confidence": round(conf_val, 4),
589
+ "severity": matched_key,
590
+ "box": {"xmin": x1, "ymin": y1, "xmax": x2, "ymax": y2},
591
+ })
592
+
593
+ # Sort: destroyed first, no_damage last
594
+ detections.sort(key=lambda d: DAMAGE_SEVERITY_ORDER.get(d["severity"], 9))
595
+
596
+ if counts["destroyed"] > 0:
597
+ zone_color = "red"
598
+ elif counts["major_damage"] > 0:
599
+ zone_color = "orange"
600
+ elif counts["minor_damage"] > 0:
601
+ zone_color = "yellow"
602
+ else:
603
+ zone_color = "green"
604
+
605
+ return {
606
+ "detections": detections,
607
+ "summary": counts,
608
+ "total_buildings": sum(counts.values()),
609
+ "zone_color": zone_color,
610
+ "inference_time_ms": elapsed,
611
+ }
612
+
613
+
614
+ # ─────────────────────────────────────────────
615
+ # 5. Full Analysis β€” all models in one call
616
  # ─────────────────────────────────────────────
617
  @app.post("/analyze/full")
618
  async def full_analysis(
619
  file: UploadFile = File(...),
620
+ run_classify: bool = True,
621
  run_victims: bool = True,
622
  run_vehicles: bool = True,
623
+ run_damage: bool = True,
624
  ):
625
  """
626
  Run all available models on one image.
627
+ Main endpoint for the RoboXavier rover Flask app.
628
+ Returns unified zone_color, all detections, and triage/damage summaries.
 
629
  """
630
  contents = await file.read()
631
  t_total = time.time()
632
  output = {}
633
 
634
+ # ── LADI scene classification ──
635
  if run_classify:
636
  try:
637
+ output["classification"] = await classify_scene(
638
+ UploadFile(filename="f.jpg", file=io.BytesIO(contents))
639
+ )
640
  except HTTPException as e:
641
  output["classification"] = {"error": e.detail}
642
  except Exception as e:
 
645
  # ── Victim detection ──
646
  if run_victims:
647
  try:
648
+ output["victims"] = await detect_victims(
649
+ UploadFile(filename="f.jpg", file=io.BytesIO(contents))
650
+ )
651
  except HTTPException as e:
652
  output["victims"] = {"error": e.detail}
653
  except Exception as e:
654
  output["victims"] = {"error": str(e)}
655
 
656
+ # ── Emergency vehicle detection ──
657
  if run_vehicles:
658
  try:
659
+ output["vehicles"] = await detect_vehicles(
660
+ UploadFile(filename="f.jpg", file=io.BytesIO(contents))
661
+ )
662
  except HTTPException as e:
663
  output["vehicles"] = {"error": e.detail}
664
  except Exception as e:
665
  output["vehicles"] = {"error": str(e)}
666
 
667
+ # ── xView2 building damage ──
668
+ if run_damage:
669
+ try:
670
+ output["building_damage"] = await detect_building_damage(
671
+ UploadFile(filename="f.jpg", file=io.BytesIO(contents))
672
+ )
673
+ except HTTPException as e:
674
+ output["building_damage"] = {"error": e.detail}
675
+ except Exception as e:
676
+ output["building_damage"] = {"error": str(e)}
677
 
678
+ # ── Unified zone color (all signals combined) ──
679
+ triage_data = output.get("victims", {}).get("triage_summary", {})
680
+ damage_counts = output.get("building_damage", {}).get("summary", {})
681
+ classify_top = output.get("classification", {}).get("top_predictions", [{}])
682
+ top_class = classify_top[0].get("class", "") if classify_top else ""
683
 
684
+ zone_color = compute_zone_color(triage_data, damage_counts, top_class)
 
 
 
 
 
 
 
685
 
686
  return {
687
  "zone_color": zone_color,