Spaces:
Running
Running
Aarush commited on
Commit ·
25762a1
1
Parent(s): b27731b
feat: Groq API support, nuanced fractional rewards, .env.example
Browse files- .env.example +10 -0
- .gitignore +1 -0
- hybrid_agent.py +11 -4
- multi_step_env.py +25 -3
.env.example
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# SQL Debug Environment — Configuration
|
| 3 |
+
# ============================================
|
| 4 |
+
# Copy this file to .env and fill in your keys.
|
| 5 |
+
# The .env file is gitignored and will NOT be committed.
|
| 6 |
+
|
| 7 |
+
# Optional: Only needed for the Live Dashboard demo agent
|
| 8 |
+
# Supports Groq (free) or OpenAI — set ONE of these:
|
| 9 |
+
GROQ_API_KEY=your_groq_key_here
|
| 10 |
+
# OPENAI_API_KEY=your_openai_key_here
|
.gitignore
CHANGED
|
@@ -21,6 +21,7 @@ Thumbs.db
|
|
| 21 |
*.db-wal
|
| 22 |
|
| 23 |
databases/*.db
|
|
|
|
| 24 |
|
| 25 |
baseline_scores.py
|
| 26 |
|
|
|
|
| 21 |
*.db-wal
|
| 22 |
|
| 23 |
databases/*.db
|
| 24 |
+
outputs/trajectories/
|
| 25 |
|
| 26 |
baseline_scores.py
|
| 27 |
|
hybrid_agent.py
CHANGED
|
@@ -52,9 +52,16 @@ class LLMPolicy:
|
|
| 52 |
"""
|
| 53 |
def __init__(self, model_name="gpt-4o-mini", api_key=None):
|
| 54 |
self.model_name = model_name
|
| 55 |
-
if
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
# Prevent Streamlit UI crash if key is missing locally by injecting a placeholder
|
| 59 |
resolved_key = api_key or os.getenv("OPENAI_API_KEY") or "mock_key_to_prevent_ui_crash"
|
| 60 |
self.client = OpenAI(api_key=resolved_key)
|
|
@@ -87,7 +94,7 @@ class LLMPolicy:
|
|
| 87 |
)
|
| 88 |
return response.choices[0].message.content.strip()
|
| 89 |
except Exception as e:
|
| 90 |
-
return "GIVE_UP\n-- Missing valid
|
| 91 |
else:
|
| 92 |
# Fallback for unconnected local testing
|
| 93 |
return "SHOW_TABLES"
|
|
|
|
| 52 |
"""
|
| 53 |
def __init__(self, model_name="gpt-4o-mini", api_key=None):
|
| 54 |
self.model_name = model_name
|
| 55 |
+
if not OpenAI:
|
| 56 |
+
raise ImportError("OpenAI SDK not installed. Run: pip install openai")
|
| 57 |
+
|
| 58 |
+
if os.getenv("GROQ_API_KEY"):
|
| 59 |
+
# Use Groq via OpenAI client compatibility
|
| 60 |
+
if "gpt" in self.model_name.lower():
|
| 61 |
+
self.model_name = "llama-3.3-70b-versatile" # Map to a strong Groq model
|
| 62 |
+
resolved_key = api_key or os.getenv("GROQ_API_KEY")
|
| 63 |
+
self.client = OpenAI(api_key=resolved_key, base_url="https://api.groq.com/openai/v1")
|
| 64 |
+
elif "gpt" in model_name.lower():
|
| 65 |
# Prevent Streamlit UI crash if key is missing locally by injecting a placeholder
|
| 66 |
resolved_key = api_key or os.getenv("OPENAI_API_KEY") or "mock_key_to_prevent_ui_crash"
|
| 67 |
self.client = OpenAI(api_key=resolved_key)
|
|
|
|
| 94 |
)
|
| 95 |
return response.choices[0].message.content.strip()
|
| 96 |
except Exception as e:
|
| 97 |
+
return f"GIVE_UP\n-- API Error or Missing valid API KEY. Add GROQ_API_KEY or OPENAI_API_KEY to environment. Error: {e}"
|
| 98 |
else:
|
| 99 |
# Fallback for unconnected local testing
|
| 100 |
return "SHOW_TABLES"
|
multi_step_env.py
CHANGED
|
@@ -180,16 +180,38 @@ class MultiStepSQLEnv(gym.Env):
|
|
| 180 |
correctness = metadata.get("correctness", 0.0) if metadata else 0.0
|
| 181 |
|
| 182 |
if correctness >= 1.0:
|
| 183 |
-
reward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
done = True
|
| 185 |
-
feedback = "Success! The query produces the correct result set."
|
|
|
|
|
|
|
|
|
|
| 186 |
else:
|
| 187 |
if err:
|
| 188 |
feedback = f"SQL Error: {err}"
|
| 189 |
reward -= 0.05
|
| 190 |
else:
|
|
|
|
|
|
|
|
|
|
| 191 |
feedback = ("Query executed successfully but results are incorrect. "
|
| 192 |
-
f"Correctness score: {correctness}")
|
| 193 |
|
| 194 |
elif command == "GIVE_UP":
|
| 195 |
feedback = "Session aborted by agent."
|
|
|
|
| 180 |
correctness = metadata.get("correctness", 0.0) if metadata else 0.0
|
| 181 |
|
| 182 |
if correctness >= 1.0:
|
| 183 |
+
# --- Nuanced reward: never exactly 0.0 or 1.0 ---
|
| 184 |
+
# Base correctness component (max 0.60)
|
| 185 |
+
correctness_reward = 0.60
|
| 186 |
+
|
| 187 |
+
# Exploration bonus (max 0.20): reward agents that investigated first
|
| 188 |
+
exploration_actions = sum(
|
| 189 |
+
1 for act, _, _ in self.history
|
| 190 |
+
if act.strip().upper().startswith(("EXPLAIN", "DESCRIBE", "SHOW_TABLES"))
|
| 191 |
+
)
|
| 192 |
+
exploration_bonus = min(0.20, exploration_actions * 0.05)
|
| 193 |
+
|
| 194 |
+
# Efficiency bonus (max 0.15): reward solving in fewer steps
|
| 195 |
+
steps_used = self.current_step
|
| 196 |
+
efficiency_bonus = max(0.0, (self.max_steps - steps_used) / self.max_steps) * 0.15
|
| 197 |
+
|
| 198 |
+
# Final reward: fractional, clamped to [0.05, 0.95]
|
| 199 |
+
reward += round(min(0.95, max(0.05, correctness_reward + exploration_bonus + efficiency_bonus)), 4)
|
| 200 |
done = True
|
| 201 |
+
feedback = (f"Success! The query produces the correct result set. "
|
| 202 |
+
f"(Reward breakdown: correctness={correctness_reward}, "
|
| 203 |
+
f"exploration={round(exploration_bonus, 4)}, "
|
| 204 |
+
f"efficiency={round(efficiency_bonus, 4)})")
|
| 205 |
else:
|
| 206 |
if err:
|
| 207 |
feedback = f"SQL Error: {err}"
|
| 208 |
reward -= 0.05
|
| 209 |
else:
|
| 210 |
+
# Partial credit for close attempts
|
| 211 |
+
partial = round(correctness * 0.4, 4)
|
| 212 |
+
reward += partial
|
| 213 |
feedback = ("Query executed successfully but results are incorrect. "
|
| 214 |
+
f"Correctness score: {correctness}, partial reward: +{partial}")
|
| 215 |
|
| 216 |
elif command == "GIVE_UP":
|
| 217 |
feedback = "Session aborted by agent."
|