Spaces:
Sleeping
Sleeping
File size: 2,619 Bytes
e4accbb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | """Small FastAPI compatibility layer used when openenv-core is unavailable."""
from __future__ import annotations
from typing import Any, Generic, TypeVar
from fastapi import Body, FastAPI
from pydantic import BaseModel
ActT = TypeVar("ActT", bound=BaseModel)
ObsT = TypeVar("ObsT", bound=BaseModel)
StateT = TypeVar("StateT", bound=BaseModel)
class Environment(Generic[ActT, ObsT, StateT]):
SUPPORTS_CONCURRENT_SESSIONS = False
def reset(self, **kwargs: Any) -> ObsT:
raise NotImplementedError
def step(self, action: ActT) -> ObsT:
raise NotImplementedError
def state(self) -> StateT:
raise NotImplementedError
def create_app(
environment_cls: type[Environment[ActT, ObsT, StateT]],
action_model: type[ActT],
observation_model: type[ObsT],
env_name: str,
**_: Any,
) -> FastAPI:
app = FastAPI(title=env_name)
app.state.environment = environment_cls()
@app.get("/")
def root() -> dict[str, Any]:
return {
"name": env_name,
"status": "ok",
"endpoints": ["/health", "/reset", "/step", "/state", "/tasks", "/metadata", "/schema"],
}
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/metadata")
def metadata() -> dict[str, Any]:
return {
"name": env_name,
"supports_state": True,
"supports_tasks": True,
"transport": "http",
}
@app.get("/schema")
def schema() -> dict[str, Any]:
return {
"action": action_model.model_json_schema(),
"observation": observation_model.model_json_schema(),
}
@app.post("/reset")
def reset(payload: dict[str, Any] | None = Body(default=None)) -> dict[str, Any]:
observation = app.state.environment.reset(**(payload or {}))
data = observation.model_dump()
return {
"observation": data,
"reward": float(data.get("reward") or 0.0),
"done": bool(data.get("done", False)),
}
@app.post("/step")
def step(payload: dict[str, Any]) -> dict[str, Any]:
action = action_model.model_validate(payload)
observation = app.state.environment.step(action)
data = observation.model_dump()
return {
"observation": data,
"reward": float(data.get("reward") or 0.0),
"done": bool(data.get("done", False)),
}
@app.get("/state")
def state() -> dict[str, Any]:
return app.state.environment.state().model_dump()
return app
|