| import re |
|
|
| import torch |
|
|
| import gradio as gr |
|
|
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def load_model_tokenizer(): |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560) |
| model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560) |
|
|
| return model, tokenizer |
|
|
|
|
| model, tokenizer = load_model_tokenizer() |
|
|
|
|
| def create_prompt(schemas, question): |
| prompt = [ |
| { |
| 'role': 'system', |
| 'content': """\ |
| You are an expert SQL Query Writer. |
| Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer. |
| Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas. |
| |
| Remember that you should place all your reasoning between <reason> and </reason> tags. |
| Also, you should provide your solution between <answer> and </answer> tags. |
| |
| An example generation is as follows: |
| <reason> |
| This is a sample reasoning that solves the question based on the schema. |
| </reason> |
| <answer> |
| SELECT |
| COLUMN |
| FROM TABLE_NAME |
| WHERE |
| CONDITION |
| </answer>""" |
| }, |
| { |
| 'role': 'user', |
| 'content': f"""\ |
| SCHEMAS: |
| --------------- |
| |
| {schemas} |
| |
| --------------- |
| |
| QUESTION: "{question}"\ |
| """ |
| } |
| ] |
|
|
| return prompt |
|
|
|
|
| def extract_answer(gen_output): |
| answer_start_token = "<answer>" |
| answer_end_token = "</answer>" |
| answer_match_format = re.compile(rf"{answer_start_token}(.+?){answer_end_token}", flags = re.MULTILINE | re.DOTALL | re.IGNORECASE) |
|
|
| answer_match = answer_match_format.search(gen_output) |
|
|
| final_answer = None |
|
|
| if answer_match is not None: |
| final_answer = answer_match.group(1) |
|
|
| return final_answer |
|
|
|
|
| def response(user_schemas, user_question): |
| user_prompt = create_prompt(user_schemas, user_question) |
|
|
| inputs = tokenizer.apply_chat_template(user_prompt, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt") |
| |
| with torch.inference_mode(): |
| outputs = model.generate(**inputs, max_new_tokens=1024) |
| |
| outputs = tokenizer.batch_decode(outputs) |
| output = outputs[0].split("<|im_start|>assistant")[-1].strip() |
|
|
| final_answer = extract_answer(output) |
|
|
| final_output = output + "\n\n" + "="*20 + "\n\nFinal Answer: \n" + final_answer |
|
|
| return final_output |
|
|
|
|
| desc=""" |
| **NOTE: This HF Space is running on Free Version so the generation process will be very slow.**<br> |
| |
| Please use the "Table Schemas" field to provide the required schemas to to generate the SQL Query for - separated by new lines.<br> |
| **Example:** |
| ```python |
| CREATE TABLE demographic ( |
| subject_id text, |
| admission_type text, |
| hadm_id text) |
| |
| CREATE TABLE diagnoses ( |
| subject_id text, |
| hadm_id text) |
| ``` |
| |
| Finally, use the "Question" field to provide the relevant question to be answered based on the provided schemas.<br> |
| **Example:** How many patients whose admission type is emergency. |
| """ |
|
|
| demo = gr.Interface( |
| fn=response, |
| inputs=[gr.Textbox(label="Table Schemas", |
| placeholder="Expected to have CREATE TABLE statements with datatypes separated by new lines"), |
| gr.Textbox(label="Question", |
| placeholder="Eg. How many patients whose admission type is emergency") |
| ], |
| outputs=gr.Textbox(label="Generated SQL Query with reasoning"), |
| title="SQL Query Generator trained with GRPO to elicit reasoning", |
| description=desc |
| ) |
|
|
| demo.launch() |