waveVisualizer-Codex / backend /grib_wave_puller.py
nakas's picture
Animate wave arrows; use height+period for per-arrow styling; add gribplayground Arctic fetcher (no pygrib)
1a437b8
"""
Vendored GRIBWavePuller from NWPS_SWAN project (trimmed only for runtime import here).
If Arctic-specific helpers are missing, the puller will log and continue with fallbacks.
"""
import os
import sys
import tempfile
import logging
import subprocess
from datetime import datetime, timedelta
import numpy as np
import xarray as xr
from ecmwf.opendata import Client
import requests
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Optional Arctic handler
try:
from arctic_grib_handler import ArcticGRIBHandler # type: ignore
ARCTIC_HANDLER_AVAILABLE = True
logger.info("✅ Arctic GRIB Handler loaded (Docker production version)")
except Exception as e:
logger.warning(f"Arctic GRIB handler not available: {e}")
ARCTIC_HANDLER_AVAILABLE = False
class GRIBWavePuller:
def __init__(self):
self.client = Client("ecmwf")
self.output_dir = os.getenv('OUTPUT_DIR', '/tmp/wave_data')
os.makedirs(self.output_dir, exist_ok=True)
self._setup_eccodes_environment()
def _setup_eccodes_environment(self):
try:
os.environ['ECCODES_GRIB_STRICT_PARSING'] = '0'
os.environ['ECCODES_GRIB_IGNORE_GRID_DEFINITION'] = '1'
except Exception:
pass
def fetch_ecmwf_wave_grib(self, forecast_time=0):
try:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.grib2')
try:
self.client.retrieve(
type="fc",
param=["swh"],
time=0,
step=forecast_time,
target=temp_file.name,
)
return temp_file.name
except Exception:
if os.path.exists(temp_file.name):
os.unlink(temp_file.name)
return None
except Exception:
return None
def fetch_noaa_wave_grib(self, forecast_hour=0):
"""Fetch NOAA WW3 files (regional + global attempt). Returns list of (path, region, run, fh)."""
try:
base_url = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"
now = datetime.utcnow()
dates_to_try = [
now.strftime("%Y%m%d"),
(now - timedelta(days=1)).strftime("%Y%m%d"),
(now - timedelta(days=2)).strftime("%Y%m%d"),
]
if now.hour >= 18:
preferred_runs = ["18", "12", "06", "00"]
elif now.hour >= 12:
preferred_runs = ["12", "06", "00", "18"]
elif now.hour >= 6:
preferred_runs = ["06", "00", "18", "12"]
else:
preferred_runs = ["00", "18", "12", "06"]
for date_str in dates_to_try:
for hour in preferred_runs:
successful = []
regional_files = [
(f"gfswave.t{hour}z.atlocn.0p16.f{forecast_hour:03d}.grib2", "Atlantic"),
(f"gfswave.t{hour}z.epacif.0p16.f{forecast_hour:03d}.grib2", "East_Pacific"),
(f"gfswave.t{hour}z.arctic.9km.f{forecast_hour:03d}.grib2", "Arctic"),
(f"gfswave.t{hour}z.global.0p16.f{forecast_hour:03d}.grib2", "Global"),
]
for filename, region_name in regional_files:
url = f"{base_url}/gfs.{date_str}/{hour}/wave/gridded/{filename}"
try:
tf = tempfile.NamedTemporaryFile(delete=False, suffix='.grib2')
r = requests.get(url, timeout=300)
if r.status_code == 200:
tf.write(r.content)
tf.close()
successful.append((tf.name, region_name, hour, forecast_hour))
else:
tf.close(); os.unlink(tf.name)
except Exception:
try:
tf.close(); os.unlink(tf.name)
except Exception:
pass
continue
if successful:
return successful
return None
except Exception:
return None
def process_grib_file(self, grib_file_path, region_name=None):
try:
ds = xr.open_dataset(grib_file_path, engine='cfgrib', decode_timedelta=True)
except Exception:
return None, None
# Identify variables
vars_map = {name: ds[name] for name in ds.variables}
wave_var = None
for cand in ['swh', 'HTSGW', 'htsgw']:
if cand in vars_map:
wave_var = cand; break
if wave_var is None:
# fallback heuristic
for n in vars_map:
if 'wave' in n.lower() and 'height' in n.lower():
wave_var = n; break
if wave_var is None:
ds.close()
return None, None
wave_heights = vars_map[wave_var].values
wave_dir = None
for cand in ['dirpw', 'DIRPW', 'dp', 'wvdir', 'WVDIR', 'dir', 'mwd', 'MWD', 'MWDIR']:
if cand in vars_map:
wave_dir = vars_map[cand].values; break
wave_per = None
for cand in ['perpw', 'PERPW', 'tp', 'wvper', 'WVPER', 'per', 'pp1d', 'PP1D', 'mwp', 'MWP']:
if cand in vars_map:
wave_per = vars_map[cand].values; break
lats = ds.latitude.values if 'latitude' in ds else ds.lat.values
lons = ds.longitude.values if 'longitude' in ds else ds.lon.values
# Sample points (downsample for visualization)
lon_grid, lat_grid = np.meshgrid(lons, lats)
flat_lats = lat_grid.flatten()
flat_lons = lon_grid.flatten()
flat_waves = wave_heights.flatten()
mask = ~np.isnan(flat_waves)
if wave_dir is not None:
mask &= ~np.isnan(wave_dir.flatten())
idx = np.random.choice(np.where(mask)[0], size=min(1000, mask.sum()), replace=False) if mask.any() else np.array([])
points = []
for i in idx:
point = {
'lat': float(flat_lats[i]),
'lon': float(flat_lons[i]),
'wave_height': float(flat_waves[i]),
}
if wave_dir is not None:
d = float(wave_dir.flatten()[i])
point['wave_direction'] = d
mag = point['wave_height'] * 0.1
rad = np.deg2rad(d)
point['u_component'] = float(mag * np.sin(rad))
point['v_component'] = float(-mag * np.cos(rad))
if wave_per is not None:
point['wave_period'] = float(wave_per.flatten()[i])
points.append(point)
data = {
'timestamp': datetime.utcnow().isoformat(),
'data_source': 'NOAA_GRIB' if 'gfswave' in os.path.basename(grib_file_path) else 'GRIB',
'parameters_found': {
'wave_height': wave_var,
'wave_direction': 'present' if wave_dir is not None else None,
'wave_period': 'present' if wave_per is not None else None,
'has_velocity_components': wave_dir is not None,
},
'grid_info': {
'lat_min': float(np.nanmin(lats)),
'lat_max': float(np.nanmax(lats)),
'lon_min': float(np.nanmin(lons)),
'lon_max': float(np.nanmax(lons)),
'grid_shape': list(wave_heights.shape),
},
'sample_points': points,
}
# Optional: include a downsampled U/V grid for velocity layers
try:
if wave_dir is not None:
# Compute U/V on the native grid
# Determine reasonable downsample strides to keep <= ~360x180
ny, nx = wave_heights.shape
sy = max(1, ny // 180)
sx = max(1, nx // 360)
lats_ds = lats[::sy]
lons_ds = lons[::sx]
# Align 2D arrays for downsample
wh_ds = wave_heights[::sy, ::sx]
wd_ds = wave_dir[::sy, ::sx]
wp_ds = None
if wave_per is not None:
try:
wp_ds = wave_per[::sy, ::sx]
except Exception:
wp_ds = None
# Compute U/V
dir_rad = np.deg2rad(wd_ds)
# Base speed from period if present (deep water group velocity)
if wp_ds is not None:
base = 0.78 * np.clip(wp_ds, 0, 20)
else:
base = 1.0 + 0.2 * np.clip(wh_ds, 0, 10)
# Add spatial variation via normalized wave height
try:
p50 = float(np.nanpercentile(wh_ds, 50))
p90 = float(np.nanpercentile(wh_ds, 90))
denom = (p90 - p50) if (p90 - p50) > 1e-6 else 1.0
hnorm = np.clip((wh_ds - p50) / denom, -1.0, 2.0)
except Exception:
hnorm = 0.0
mag = base * (1.0 + 0.5 * hnorm)
# Clamp to a reasonable range for visualization
mag = np.clip(mag, 0.0, 15.0)
u_ds = mag * np.sin(dir_rad)
v_ds = -mag * np.cos(dir_rad)
data['grid_uv'] = {
'lats': lats_ds.tolist() if hasattr(lats_ds, 'tolist') else list(map(float, lats_ds)),
'lons': lons_ds.tolist() if hasattr(lons_ds, 'tolist') else list(map(float, lons_ds)),
'u': np.nan_to_num(u_ds, nan=0.0, posinf=0.0, neginf=0.0).tolist(),
'v': np.nan_to_num(v_ds, nan=0.0, posinf=0.0, neginf=0.0).tolist(),
}
try:
sp = np.sqrt(u_ds*u_ds + v_ds*v_ds)
data['grid_uv_info'] = {
'speed_min': float(np.nanmin(sp)),
'speed_max': float(np.nanmax(sp)),
'speed_mean': float(np.nanmean(sp)),
}
except Exception:
pass
except Exception:
# If any step fails, just skip embedding grid_uv
pass
ds.close()
return data, grib_file_path
def process_multiple_regional_files(self, regional_files):
combined = []
for path, region_name, *_ in regional_files:
try:
res, _ = self.process_grib_file(path, region_name=region_name)
if res and 'sample_points' in res:
combined.extend(res['sample_points'])
finally:
try:
if os.path.exists(path):
os.unlink(path)
except Exception:
pass
if not combined:
return None
return {
'timestamp': datetime.utcnow().isoformat(),
'data_source': 'NOAA_MULTI_REGIONAL_GRIB',
'parameters_found': {'has_velocity_components': True},
'grid_info': {},
'sample_points': combined,
}
def fetch_global_wave_data(self, forecast_hour=0):
result = self.fetch_noaa_wave_grib(forecast_hour)
if isinstance(result, list) and result:
if any(region == 'Global' for _, region, *_ in result):
# Prefer the global grid if present
global_entry = next((t for t in result if t[1] == 'Global'), None)
if global_entry:
data, _ = self.process_grib_file(global_entry[0], region_name='Global')
return data
# Otherwise combine sample points from regions
return self.process_multiple_regional_files(result)
# Fallback ECMWF (may not include waves)
grib_file = self.fetch_ecmwf_wave_grib(forecast_hour)
if grib_file:
data, _ = self.process_grib_file(grib_file)
try:
os.unlink(grib_file)
except Exception:
pass
return data
return None