Spaces:
Sleeping
Sleeping
| """ | |
| client.py | |
| --------- | |
| DataCleaningEnv β the typed WebSocket client for the data cleaning pipeline. | |
| This module contains exactly one public class: ``DataCleaningEnv``. | |
| It extends ``EnvClient`` from OpenEnv core and implements the three abstract | |
| translation methods that bridge Python objects and the server's JSON wire format: | |
| _step_payload(action) CleanAction β dict (outbound) | |
| _parse_result(payload) dict β StepResult[CleanObservation] (inbound) | |
| _parse_state(payload) dict β CleanState (inbound) | |
| Everything else β WebSocket lifecycle, connect/disconnect, async context | |
| manager, the `.sync()` wrapper β is handled by the base class. | |
| Usage (async) | |
| ------------- | |
| import asyncio | |
| from data_cleaning_env.client import DataCleaningEnv | |
| from data_cleaning_env.models import CleanAction | |
| async def main(): | |
| async with DataCleaningEnv(base_url="http://localhost:7860") as env: | |
| result = await env.reset(task_id="easy") | |
| print(result.observation.schema_hint) | |
| result = await env.set_value(row_index=3, column="price", value="29.99") | |
| print(result.reward, result.observation.current_score) | |
| result = await env.done() | |
| asyncio.run(main()) | |
| Usage (sync wrapper) | |
| -------------------- | |
| env = DataCleaningEnv(base_url="http://localhost:7860").sync() | |
| with env: | |
| result = env.reset(task_id="medium") | |
| result = env.fill_missing(column="amount", fill_strategy="median") | |
| result = env.done() | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Optional | |
| # ββ OpenEnv core imports ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| except ImportError: | |
| from openenv.core.client_types import StepResult # type: ignore[no-redef] | |
| from openenv.core.env_client import EnvClient # type: ignore[no-redef] | |
| #7860 | |
| # ββ Local model imports (try relative then absolute) ββββββββββββββββββββββββββ | |
| try: | |
| from .models import ( | |
| CleanAction, | |
| CleanObservation, | |
| CleanState, | |
| MAX_STEPS, | |
| DONE_THRESHOLD, | |
| ) | |
| except ImportError: | |
| from models import ( # type: ignore[no-redef] | |
| CleanAction, | |
| CleanObservation, | |
| CleanState, | |
| MAX_STEPS, | |
| DONE_THRESHOLD, | |
| ) | |
| class DataCleaningEnv(EnvClient[CleanAction, CleanObservation, CleanState]): | |
| """ | |
| Async WebSocket client for the Data Cleaning Pipeline environment. | |
| Connects to a running ``DataCleaningEnvironment`` server and exposes the | |
| standard OpenEnv interface (``reset``, ``step``, ``state``) plus typed | |
| convenience helpers for each command. | |
| All methods are async. For synchronous use, call ``.sync()`` to get a | |
| ``SyncEnvClient`` wrapper: | |
| with DataCleaningEnv(base_url="http://localhost:7860").sync() as env: | |
| result = env.reset(task_id="easy") | |
| result = env.set_value(row_index=0, column="price", value="9.99") | |
| Connecting to different backends | |
| --------------------------------- | |
| Local dev server (after ``openenv serve``): | |
| env = DataCleaningEnv(base_url="http://localhost:7860") | |
| Local Docker image (after ``openenv build``): | |
| env = await DataCleaningEnv.from_docker_image("data-cleaning-env:latest") | |
| Hugging Face Space (after ``openenv push``): | |
| env = await DataCleaningEnv.from_env("your-org/data-cleaning-env") | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Abstract method implementations β the three translation methods | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _step_payload(self, action: CleanAction) -> dict[str, Any]: | |
| """ | |
| Serialise a CleanAction to the JSON dict the server expects. | |
| The server's ``step()`` endpoint receives this dict, validates it | |
| against ``CleanAction``, and dispatches to the correct handler. | |
| We use ``model_dump(exclude_none=True)`` to omit fields the agent | |
| left as ``None`` β this keeps the wire message minimal and avoids | |
| triggering Pydantic's ``extra="forbid"`` validator on the server side | |
| for fields that weren't set. | |
| """ | |
| return action.model_dump(exclude_none=True) | |
| def _parse_result(self, payload: dict[str, Any]) -> StepResult[CleanObservation]: | |
| """ | |
| Parse the server's step/reset response into a ``StepResult``. | |
| Wire format (what the server sends back): | |
| :: | |
| { | |
| "observation": { | |
| "done": false, | |
| "reward": -0.005, | |
| "metadata": {}, | |
| "task_id": "easy", | |
| "schema_hint": "Sales orders...", | |
| "initial_dirty_cells": 29, | |
| "dirty_csv": "row_index,order_id,...\\n0,1001,...", | |
| "current_score": 0.9550, | |
| "issues_remaining": 18, | |
| "step_number": 1, | |
| "max_steps": 40, | |
| "last_action_success": true, | |
| "last_action_error": null | |
| }, | |
| "reward": -0.005, | |
| "done": false | |
| } | |
| Note: ``reward`` and ``done`` appear both at the top level (for | |
| convenience) and inside ``observation`` (because ``Observation`` base | |
| carries them). We use the top-level copies for ``StepResult`` so the | |
| caller doesn't have to dig into the observation. | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| observation = CleanObservation( | |
| # ββ inherited from Observation base ββββββββββββββββββββββββββββββ | |
| done=payload.get("done", obs_data.get("done", False)), | |
| reward=payload.get("reward", obs_data.get("reward")), | |
| metadata=obs_data.get("metadata", {}), | |
| # ββ task context (constant for the episode) βββββββββββββββββββββββ | |
| task_id=obs_data["task_id"], | |
| schema_hint=obs_data["schema_hint"], | |
| initial_dirty_cells=obs_data["initial_dirty_cells"], | |
| # ββ per-step state ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| dirty_csv=obs_data["dirty_csv"], | |
| current_score=obs_data.get("current_score", 0.0), | |
| issues_remaining=obs_data.get("issues_remaining", 0), | |
| step_number=obs_data.get("step_number", 0), | |
| max_steps=obs_data["max_steps"], | |
| # ββ last-action feedback ββββββββββββββββββββββββββββββββββββββββββ | |
| last_action_success=obs_data.get("last_action_success", True), | |
| last_action_error=obs_data.get("last_action_error"), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: dict[str, Any]) -> CleanState: | |
| """ | |
| Parse the server's state response into a ``CleanState``. | |
| The server serialises ``CleanState`` via Pydantic's ``model_dump()``, | |
| so the wire keys match our field names exactly. We use ``.get()`` | |
| with sensible defaults everywhere so a partially-initialised state | |
| (e.g. before the first reset) doesn't crash the client. | |
| """ | |
| return CleanState( | |
| # ββ inherited from State base βββββββββββββββββββββββββββββββββββββ | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| # ββ task identity βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| task_id=payload.get("task_id", "easy"), | |
| # ββ DataFrame snapshots βββββββββββββββββββββββββββββββββββββββββββ | |
| dirty_csv_snapshot=payload.get("dirty_csv_snapshot", ""), | |
| clean_csv_snapshot=payload.get("clean_csv_snapshot", ""), | |
| # ββ scoring βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| initial_dirty_cells=payload.get("initial_dirty_cells", 0), | |
| current_score=payload.get("current_score", 0.0), | |
| previous_score=payload.get("previous_score", 0.0), | |
| # ββ grader metadata βββββββββββββββββββββββββββββββββββββββββββββββ | |
| task_metadata=payload.get("task_metadata", {}), | |
| # ββ schema ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| schema_hint=payload.get("schema_hint", ""), | |
| # ββ step budget βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| max_steps=payload.get("max_steps", 40), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Typed convenience helpers β one per CleanAction command | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # These methods exist purely for ergonomics: they let callers write | |
| # | |
| # await env.set_value(row_index=3, column="price", value="29.99") | |
| # | |
| # instead of the more verbose: | |
| # | |
| # await env.step(CleanAction( | |
| # command="SET_VALUE", row_index=3, column="price", value="29.99" | |
| # )) | |
| # | |
| # The baseline inference script can use either form. | |
| async def set_value( | |
| self, | |
| row_index: int, | |
| column: str, | |
| value: str, | |
| ) -> StepResult[CleanObservation]: | |
| """Fix a single cell. ``value`` is always passed as a string; the | |
| server casts it to the column's target dtype automatically.""" | |
| return await self.step( | |
| CleanAction( | |
| command="SET_VALUE", | |
| row_index=row_index, | |
| column=column, | |
| value=value, | |
| ) | |
| ) | |
| async def drop_row(self, row_index: int) -> StepResult[CleanObservation]: | |
| """Remove an entire row (e.g. a true outlier in the medium task).""" | |
| return await self.step( | |
| CleanAction(command="DROP_ROW", row_index=row_index) | |
| ) | |
| async def standardize_col(self, column: str) -> StepResult[CleanObservation]: | |
| """Normalise a whole column's format. | |
| The server auto-detects what to do: | |
| - Date columns β parse any format, reformat as ``YYYY-MM-DD`` | |
| - Numeric columns β coerce to float/int, drop unit strings | |
| - String columns β strip leading/trailing whitespace | |
| """ | |
| return await self.step( | |
| CleanAction(command="STANDARDIZE_COL", column=column) | |
| ) | |
| async def fill_missing( | |
| self, | |
| column: str, | |
| fill_strategy: str, | |
| ) -> StepResult[CleanObservation]: | |
| """Fill ``NaN`` values in ``column``. | |
| Args: | |
| column: Column name to fill. | |
| fill_strategy: One of ``"mean"``, ``"median"``, ``"mode"``, ``"drop"``. | |
| ``"drop"`` removes rows where the column is ``NaN``. | |
| """ | |
| return await self.step( | |
| CleanAction( | |
| command="FILL_MISSING", | |
| column=column, | |
| fill_strategy=fill_strategy, | |
| ) | |
| ) | |
| async def done(self) -> StepResult[CleanObservation]: | |
| """Signal that the agent believes the CSV is clean. | |
| This ends the episode immediately. If the current score is below | |
| ``EARLY_DONE_THRESHOLD`` (0.60) a penalty of -0.20 is applied. | |
| """ | |
| return await self.step(CleanAction(command="DONE")) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Introspection helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def current_score(self) -> float: | |
| """Return the grader score from the last step (0.0β1.0).""" | |
| st = await self.state() | |
| return st.current_score | |
| async def task_id(self) -> str: | |
| """Return the active task ID (``"easy"``, ``"medium"``, or ``"hard"``).""" | |
| st = await self.state() | |
| return st.task_id | |
| async def steps_remaining(self) -> int: | |
| """Return the number of steps left before forced termination.""" | |
| st = await self.state() | |
| return max(0, st.max_steps - st.step_count) | |
| async def is_solved(self) -> bool: | |
| """Return ``True`` if the current score meets the task's done threshold.""" | |
| st = await self.state() | |
| threshold = DONE_THRESHOLD.get(st.task_id, 0.95) | |
| return st.current_score >= threshold |