Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from datetime import datetime | |
| import math | |
| from functools import lru_cache | |
| from typing import Dict, Any, List | |
| import numpy as np | |
| from fastapi import FastAPI, Query | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| import gradio as gr | |
| # Local import: vendored from working project | |
| from backend.grib_wave_puller import GRIBWavePuller | |
| app = FastAPI(title="Wave Visualizer API") | |
| def _compute_uv_from_wave(height: np.ndarray, direction_deg: np.ndarray, scale: float = 0.1): | |
| """Compute U/V components from wave height and meteorological 'from' direction. | |
| - height: significant wave height array (m) | |
| - direction_deg: wave direction (deg, meteorological, coming from) | |
| - scale: visualization scaling factor | |
| """ | |
| dir_rad = np.deg2rad(direction_deg) | |
| mag = np.clip(height, 0, np.nanmax(height)) * scale | |
| # Eastward (u) and northward (v) components; negative on v because 'from' | |
| u = mag * np.sin(dir_rad) | |
| v = -mag * np.cos(dir_rad) | |
| return u, v | |
| def _build_velocity_grib_json(lats: np.ndarray, lons: np.ndarray, u: np.ndarray, v: np.ndarray, ref_time: str) -> List[Dict[str, Any]]: | |
| """Build leaflet-velocity compatible JSON (Wind/Earth GRIB-like format). | |
| Data must be provided on a regular lat-lon grid. Arrays are 2D with shape (ny, nx) | |
| where ny=len(lats), nx=len(lons). Latitude should be provided in descending order | |
| (north to south) to match common GRIB conventions; reorder if needed. | |
| """ | |
| # Ensure 1D coordinate arrays | |
| lats_1d = lats if lats.ndim == 1 else lats[:, 0] | |
| lons_1d = lons if lons.ndim == 1 else lons[0, :] | |
| ny = int(len(lats_1d)) | |
| nx = int(len(lons_1d)) | |
| # If latitude increases northward, reverse to north->south | |
| if ny > 1 and lats_1d[0] < lats_1d[-1]: | |
| lats_1d = lats_1d[::-1] | |
| u = np.flipud(u) | |
| v = np.flipud(v) | |
| # Normalize longitudes to [-180, 180) to avoid 0..360 grids causing clipping | |
| # Then ensure they ascend west->east and reorder u/v columns accordingly. | |
| if nx > 1: | |
| lons_wrapped = ((np.asarray(lons_1d, dtype=float) + 180.0) % 360.0) - 180.0 | |
| order = np.argsort(lons_wrapped) | |
| lons_1d = lons_wrapped[order] | |
| if u.ndim == 2 and v.ndim == 2 and u.shape[1] == nx and v.shape[1] == nx: | |
| u = u[:, order] | |
| v = v[:, order] | |
| la1 = float(lats_1d[0]) | |
| la2 = float(lats_1d[-1]) | |
| lo1 = float(lons_1d[0]) | |
| lo2 = float(lons_1d[-1]) | |
| # Grid spacing (approx) | |
| dy = float(abs(lats_1d[1] - lats_1d[0])) if ny > 1 else 0.0 | |
| dx = float(abs(lons_1d[1] - lons_1d[0])) if nx > 1 else 0.0 | |
| # Sanitize arrays: replace NaN/Inf with zeros for JSON compliance | |
| u = np.nan_to_num(np.asarray(u, dtype=float), nan=0.0, posinf=0.0, neginf=0.0) | |
| v = np.nan_to_num(np.asarray(v, dtype=float), nan=0.0, posinf=0.0, neginf=0.0) | |
| # Optional clamp to reasonable range (avoid absurd values) | |
| # Here, clamp to [-20, 20] m/s just for safety in visualization | |
| u = np.clip(u, -20.0, 20.0) | |
| v = np.clip(v, -20.0, 20.0) | |
| # Flatten row-major (lat-major first, then lon) matching header | |
| u_data = u.flatten().tolist() | |
| v_data = v.flatten().tolist() | |
| header_common = { | |
| "lo1": lo1, | |
| "la1": la1, | |
| "lo2": lo2, | |
| "la2": la2, | |
| "nx": nx, | |
| "ny": ny, | |
| "dx": dx, | |
| "dy": dy, | |
| "refTime": ref_time, | |
| } | |
| u_record = { | |
| "header": { | |
| **header_common, | |
| "parameterCategory": 2, | |
| "parameterNumber": 2, # U component | |
| "parameterUnit": "m/s", | |
| }, | |
| "data": u_data, | |
| } | |
| v_record = { | |
| "header": { | |
| **header_common, | |
| "parameterCategory": 2, | |
| "parameterNumber": 3, # V component | |
| "parameterUnit": "m/s", | |
| }, | |
| "data": v_data, | |
| } | |
| return [u_record, v_record] | |
| def get_puller() -> GRIBWavePuller: | |
| return GRIBWavePuller() | |
| def data_points(hour: int = Query(0, ge=0, le=240)): | |
| puller = get_puller() | |
| result = puller.fetch_global_wave_data(hour) | |
| if not result: | |
| return JSONResponse(status_code=503, content={"error": "No data available"}) | |
| def _sanitize(obj): | |
| if isinstance(obj, dict): | |
| return {k: _sanitize(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [_sanitize(v) for v in obj] | |
| if isinstance(obj, (np.floating,)): | |
| v = float(obj) | |
| return None if not math.isfinite(v) else v | |
| if isinstance(obj, (np.integer,)): | |
| return int(obj) | |
| if isinstance(obj, float): | |
| return None if not math.isfinite(obj) else obj | |
| return obj | |
| payload = { | |
| "type": "points", | |
| "refTime": result.get("timestamp"), | |
| "points": result.get("sample_points", []), | |
| } | |
| return JSONResponse(content=_sanitize(payload)) | |
| def data_velocity(hour: int = Query(0, ge=0, le=240), scale: float = Query(0.1)): | |
| puller = get_puller() | |
| result = puller.fetch_global_wave_data(hour) | |
| if not result: | |
| return JSONResponse(status_code=503, content={"error": "No data available"}) | |
| # If we have a downsampled UV grid, return leaflet-velocity JSON | |
| grid_uv = result.get("grid_uv") | |
| if grid_uv: | |
| lats = np.array(grid_uv['lats']) | |
| lons = np.array(grid_uv['lons']) | |
| u = np.array(grid_uv['u']) | |
| v = np.array(grid_uv['v']) | |
| # Validate grid content; if empty or trivial, fall back to points | |
| if ( | |
| u.size < 16 or v.size < 16 or | |
| not np.isfinite(u).any() or not np.isfinite(v).any() or | |
| (np.nanmax(np.abs(u)) < 1e-6 and np.nanmax(np.abs(v)) < 1e-6) | |
| ): | |
| sample_points = result.get("sample_points", []) | |
| return JSONResponse(content={"type": "points", "refTime": result.get("timestamp"), "points": sample_points}) | |
| payload = _build_velocity_grib_json(lats, lons, u, v, ref_time=result.get("timestamp", datetime.utcnow().isoformat())) | |
| return JSONResponse(content=payload) | |
| # Fallback to points if no grid is present | |
| sample_points = result.get("sample_points", []) | |
| payload = {"type": "points", "refTime": result.get("timestamp"), "points": sample_points} | |
| # Sanitize for JSON compliance | |
| def _san(obj): | |
| if isinstance(obj, dict): | |
| return {k: _san(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [_san(v) for v in obj] | |
| if isinstance(obj, (np.floating,)): | |
| v = float(obj) | |
| return None if not math.isfinite(v) else v | |
| if isinstance(obj, (np.integer,)): | |
| return int(obj) | |
| if isinstance(obj, float): | |
| return None if not math.isfinite(obj) else obj | |
| return obj | |
| return JSONResponse(content=_san(payload)) | |
| def leaflet_html() -> str: | |
| return """ | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset=\"utf-8\" /> | |
| <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\" /> | |
| <link rel=\"stylesheet\" href=\"https://unpkg.com/leaflet@1.9.4/dist/leaflet.css\" /> | |
| <style> | |
| html, body, #map { height: 100%; margin: 0; } | |
| .leaflet-control-container .leaflet-top.leaflet-left { z-index: 1000; } | |
| .control { position:absolute; top:10px; left:10px; z-index:1000; background:#fff; padding:8px; border-radius:4px; box-shadow:0 1px 3px rgba(0,0,0,0.3); pointer-events:auto; } | |
| /* Animated flow for arrow polylines */ | |
| @keyframes arrow-dash { | |
| 0% { stroke-dashoffset: 0; } | |
| 100% { stroke-dashoffset: -20; } | |
| } | |
| /* Leaflet renders polylines as SVG paths */ | |
| .leaflet-overlay-pane path.arrow-line { | |
| vector-effect: non-scaling-stroke; | |
| fill: none; | |
| stroke-linecap: butt; | |
| stroke-dasharray: 6 10; | |
| animation: arrow-dash 1.2s linear infinite; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div id=\"map\"></div> | |
| <div class=\"control\"> | |
| <label>Forecast hour: <input type=\"number\" id=\"hour\" min=\"0\" max=\"240\" step=\"6\" value=\"0\" /></label> | |
| <button id=\"load\">Load Waves</button> | |
| <span id=\"status\" style=\"margin-left:8px; font-size:12px; color:#333\"></span> | |
| </div> | |
| <script src=\"https://unpkg.com/leaflet@1.9.4/dist/leaflet.js\"></script> | |
| <!-- Use a known-good Leaflet-Velocity build --> | |
| <script src=\"https://cdn.jsdelivr.net/npm/leaflet-velocity@1.8.0/dist/leaflet-velocity.min.js\"></script> | |
| <script> | |
| const map = L.map('map').setView([20, 0], 2); | |
| L.tileLayer('https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', { | |
| maxZoom: 6, | |
| attribution: '© OpenStreetMap contributors' | |
| }).addTo(map); | |
| let velocityLayer = null; | |
| let pointLayer = null; | |
| // Create a short dash polyline with attached metadata (wind-like particle) | |
| function mkArrow(p, opts = {}) { | |
| const dir = p.wave_direction ?? 0; | |
| const height = p.wave_height ?? 0.5; | |
| const period = p.wave_period ?? null; | |
| // Keep segments short to resemble particles rather than long arrows | |
| const len = Math.max(15000, Math.min(60000, height * 35000)); // meters | |
| const rad = dir * Math.PI/180.0; | |
| const dx = Math.sin(rad) * len; | |
| const dy = -Math.cos(rad) * len; | |
| const poly = L.polyline( | |
| [[p.lat, p.lon], [p.lat + dy/1e6, p.lon + dx/1e6]], | |
| { color: '#e5242a', weight: 1.0, opacity: 0.9, className: 'arrow-line', ...opts } | |
| ); | |
| // Attach metadata so we can tune styles after added to map | |
| poly._waveMeta = { height, period, direction: dir }; | |
| return poly; | |
| } | |
| // Tune per-feature animation speed/appearance using period and height | |
| function tuneArrowStyles(group) { | |
| if (!group) return; | |
| const tune = (layer) => { | |
| const el = (layer.getElement && layer.getElement()) || layer._path; | |
| if (!el || !layer._waveMeta) return; | |
| const { height, period } = layer._waveMeta; | |
| // Stroke weight by height (clamped) | |
| const w = Math.max(0.6, Math.min(1.6, 0.8 + (height || 0) * 0.25)); | |
| layer.setStyle && layer.setStyle({ weight: w }); | |
| // Particle-like short dashes | |
| const dashLen = Math.max(2, Math.min(10, 3 + (height || 0) * 1.2)); | |
| const gapLen = Math.round(dashLen * 1.4); | |
| el.style.strokeDasharray = `${dashLen} ${gapLen}`; | |
| // Animation speed by period: longer period => faster flow (shorter duration) | |
| let dur; | |
| if (period && isFinite(period)) { | |
| // Map 2s..20s -> 1.2s..0.5s duration for livelier particles | |
| const p = Math.max(2, Math.min(20, period)); | |
| dur = 1.2 - (p - 2) * ((1.2 - 0.5) / (20 - 2)); | |
| } else { | |
| dur = 0.9; // default | |
| } | |
| el.style.animationDuration = `${dur.toFixed(2)}s`; | |
| }; | |
| // Defer slightly to ensure SVG paths exist | |
| setTimeout(() => { | |
| group.eachLayer(tune); | |
| }, 0); | |
| } | |
| async function load(hour) { | |
| if (velocityLayer) { map.removeLayer(velocityLayer); velocityLayer = null; } | |
| if (pointLayer) { map.removeLayer(pointLayer); pointLayer = null; } | |
| const status = document.getElementById('status'); | |
| status.textContent = 'Loading...'; | |
| const res = await fetch(`/data/velocity?hour=${hour}`); | |
| if (!res.ok) { alert('Failed to fetch data'); return; } | |
| const payload = await res.json(); | |
| console.log('velocity payload', payload); | |
| if (payload && payload.type === 'points') { | |
| // Fallback: draw particle-like markers with direction | |
| const features = payload.points.map(p => mkArrow(p)); | |
| pointLayer = L.layerGroup(features).addTo(map); | |
| tuneArrowStyles(pointLayer); | |
| status.textContent = `Rendered ${features.length} wave arrows`; | |
| } else { | |
| try { | |
| // Expected: array of two GRIB-like records (u and v) | |
| if (Array.isArray(payload) && payload.length >= 2 && payload[0].data && payload[0].data.length) { | |
| // Quick sanity check: some non-zero magnitudes | |
| const sample = payload[0].data.slice(0, 200); | |
| const nz = sample.reduce((acc, v) => acc + Math.abs(v), 0); | |
| if (nz < 1e-3) { | |
| throw new Error('Velocity grid near-zero; fallback to points'); | |
| } | |
| velocityLayer = L.velocityLayer({ | |
| data: payload, | |
| displayValues: true, | |
| displayOptions: { | |
| velocityType: 'Wave', | |
| position: 'bottomleft', | |
| emptyString: 'No wave data', | |
| speedUnit: 'm/s', | |
| angleConvention: 'bearingCW', | |
| showCardinal: true | |
| }, | |
| // Settings aligned with the working wind demo | |
| velocityScale: 0.01, | |
| opacity: 0.9, | |
| maxVelocity: 20, | |
| particleMultiplier: 0.002, | |
| lineWidth: 1.2, | |
| frameRate: 15, | |
| particleAge: 40, | |
| fadeOpacity: 0, | |
| animationDuration: 0, | |
| // Remove strict bounds/wrap to support 0..360 or -180..180 grids | |
| // Red gradient color scale | |
| colorScale: [ | |
| "#4c0000", "#660000", "#800000", "#990000", "#b30000", | |
| "#cc0000", "#e60000", "#ff0000", "#ff3333", "#ff6666", "#ff9999" | |
| ], | |
| }); | |
| velocityLayer.addTo(map); | |
| status.textContent = 'Velocity layer active'; | |
| // Also overlay a sparse set of arrows for immediate visual feedback | |
| try { | |
| const resPts = await fetch(`/data/points?hour=${hour}`); | |
| if (resPts.ok) { | |
| const pld = await resPts.json(); | |
| const pts = (pld.points || []).slice(0, 300); | |
| const arrs = pts.map(p => mkArrow(p, { color: '#e5242a', opacity: 0.85 })); | |
| pointLayer = L.layerGroup(arrs).addTo(map); | |
| tuneArrowStyles(pointLayer); | |
| } | |
| } catch (e2) { console.warn('arrow overlay failed', e2); } | |
| } else { | |
| // Final fallback: fetch points explicitly | |
| const res2 = await fetch(`/data/points?hour=${hour}`); | |
| if (res2.ok) { | |
| const payload2 = await res2.json(); | |
| console.log('points payload', payload2); | |
| const features = (payload2.points || []).map(p => mkArrow(p)); | |
| if (features.length) { | |
| pointLayer = L.layerGroup(features).addTo(map); | |
| tuneArrowStyles(pointLayer); | |
| status.textContent = `Rendered ${features.length} wave arrows`; | |
| } else { | |
| status.textContent = 'No wave data available'; | |
| } | |
| } else { | |
| status.textContent = 'Failed to fetch data'; | |
| } | |
| } | |
| } catch (e) { | |
| console.warn('Velocity layer failed, falling back to points:', e); | |
| const res2 = await fetch(`/data/points?hour=${hour}`); | |
| const payload2 = await res2.json(); | |
| const features = payload2.points.map(p => mkArrow(p)); | |
| pointLayer = L.layerGroup(features).addTo(map); | |
| tuneArrowStyles(pointLayer); | |
| status.textContent = `Rendered ${features.length} wave arrows`; | |
| } | |
| } | |
| } | |
| document.getElementById('load').onclick = () => { | |
| const h = parseInt(document.getElementById('hour').value || '0', 10); | |
| load(h); | |
| }; | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def map_page(): | |
| return leaflet_html() | |
| def root_page(): | |
| return leaflet_html() | |
| # Optional Gradio UI under /ui | |
| with gr.Blocks(title="Wave Visualizer UI") as demo: | |
| gr.Markdown("# Wave Visualizer\nUse the link below to open the map page.") | |
| gr.HTML('<p><a href="/map" target="_blank">Open Map</a></p>') | |
| from gradio.routes import mount_gradio_app | |
| app = mount_gradio_app(app, demo, path="/ui") | |