Aarush commited on
Commit
25762a1
·
1 Parent(s): b27731b

feat: Groq API support, nuanced fractional rewards, .env.example

Browse files
Files changed (4) hide show
  1. .env.example +10 -0
  2. .gitignore +1 -0
  3. hybrid_agent.py +11 -4
  4. multi_step_env.py +25 -3
.env.example ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # SQL Debug Environment — Configuration
3
+ # ============================================
4
+ # Copy this file to .env and fill in your keys.
5
+ # The .env file is gitignored and will NOT be committed.
6
+
7
+ # Optional: Only needed for the Live Dashboard demo agent
8
+ # Supports Groq (free) or OpenAI — set ONE of these:
9
+ GROQ_API_KEY=your_groq_key_here
10
+ # OPENAI_API_KEY=your_openai_key_here
.gitignore CHANGED
@@ -21,6 +21,7 @@ Thumbs.db
21
  *.db-wal
22
 
23
  databases/*.db
 
24
 
25
  baseline_scores.py
26
 
 
21
  *.db-wal
22
 
23
  databases/*.db
24
+ outputs/trajectories/
25
 
26
  baseline_scores.py
27
 
hybrid_agent.py CHANGED
@@ -52,9 +52,16 @@ class LLMPolicy:
52
  """
53
  def __init__(self, model_name="gpt-4o-mini", api_key=None):
54
  self.model_name = model_name
55
- if "gpt" in model_name.lower():
56
- if not OpenAI:
57
- raise ImportError("OpenAI SDK not installed. Run: pip install openai")
 
 
 
 
 
 
 
58
  # Prevent Streamlit UI crash if key is missing locally by injecting a placeholder
59
  resolved_key = api_key or os.getenv("OPENAI_API_KEY") or "mock_key_to_prevent_ui_crash"
60
  self.client = OpenAI(api_key=resolved_key)
@@ -87,7 +94,7 @@ class LLMPolicy:
87
  )
88
  return response.choices[0].message.content.strip()
89
  except Exception as e:
90
- return "GIVE_UP\n-- Missing valid OPENAI_API_KEY. Add key to environment to enable LLM generation."
91
  else:
92
  # Fallback for unconnected local testing
93
  return "SHOW_TABLES"
 
52
  """
53
  def __init__(self, model_name="gpt-4o-mini", api_key=None):
54
  self.model_name = model_name
55
+ if not OpenAI:
56
+ raise ImportError("OpenAI SDK not installed. Run: pip install openai")
57
+
58
+ if os.getenv("GROQ_API_KEY"):
59
+ # Use Groq via OpenAI client compatibility
60
+ if "gpt" in self.model_name.lower():
61
+ self.model_name = "llama-3.3-70b-versatile" # Map to a strong Groq model
62
+ resolved_key = api_key or os.getenv("GROQ_API_KEY")
63
+ self.client = OpenAI(api_key=resolved_key, base_url="https://api.groq.com/openai/v1")
64
+ elif "gpt" in model_name.lower():
65
  # Prevent Streamlit UI crash if key is missing locally by injecting a placeholder
66
  resolved_key = api_key or os.getenv("OPENAI_API_KEY") or "mock_key_to_prevent_ui_crash"
67
  self.client = OpenAI(api_key=resolved_key)
 
94
  )
95
  return response.choices[0].message.content.strip()
96
  except Exception as e:
97
+ return f"GIVE_UP\n-- API Error or Missing valid API KEY. Add GROQ_API_KEY or OPENAI_API_KEY to environment. Error: {e}"
98
  else:
99
  # Fallback for unconnected local testing
100
  return "SHOW_TABLES"
multi_step_env.py CHANGED
@@ -180,16 +180,38 @@ class MultiStepSQLEnv(gym.Env):
180
  correctness = metadata.get("correctness", 0.0) if metadata else 0.0
181
 
182
  if correctness >= 1.0:
183
- reward += 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  done = True
185
- feedback = "Success! The query produces the correct result set."
 
 
 
186
  else:
187
  if err:
188
  feedback = f"SQL Error: {err}"
189
  reward -= 0.05
190
  else:
 
 
 
191
  feedback = ("Query executed successfully but results are incorrect. "
192
- f"Correctness score: {correctness}")
193
 
194
  elif command == "GIVE_UP":
195
  feedback = "Session aborted by agent."
 
180
  correctness = metadata.get("correctness", 0.0) if metadata else 0.0
181
 
182
  if correctness >= 1.0:
183
+ # --- Nuanced reward: never exactly 0.0 or 1.0 ---
184
+ # Base correctness component (max 0.60)
185
+ correctness_reward = 0.60
186
+
187
+ # Exploration bonus (max 0.20): reward agents that investigated first
188
+ exploration_actions = sum(
189
+ 1 for act, _, _ in self.history
190
+ if act.strip().upper().startswith(("EXPLAIN", "DESCRIBE", "SHOW_TABLES"))
191
+ )
192
+ exploration_bonus = min(0.20, exploration_actions * 0.05)
193
+
194
+ # Efficiency bonus (max 0.15): reward solving in fewer steps
195
+ steps_used = self.current_step
196
+ efficiency_bonus = max(0.0, (self.max_steps - steps_used) / self.max_steps) * 0.15
197
+
198
+ # Final reward: fractional, clamped to [0.05, 0.95]
199
+ reward += round(min(0.95, max(0.05, correctness_reward + exploration_bonus + efficiency_bonus)), 4)
200
  done = True
201
+ feedback = (f"Success! The query produces the correct result set. "
202
+ f"(Reward breakdown: correctness={correctness_reward}, "
203
+ f"exploration={round(exploration_bonus, 4)}, "
204
+ f"efficiency={round(efficiency_bonus, 4)})")
205
  else:
206
  if err:
207
  feedback = f"SQL Error: {err}"
208
  reward -= 0.05
209
  else:
210
+ # Partial credit for close attempts
211
+ partial = round(correctness * 0.4, 4)
212
+ reward += partial
213
  feedback = ("Query executed successfully but results are incorrect. "
214
+ f"Correctness score: {correctness}, partial reward: +{partial}")
215
 
216
  elif command == "GIVE_UP":
217
  feedback = "Session aborted by agent."