| import os |
| import tempfile |
| import traceback |
| from io import StringIO |
| from typing import Generator, Optional |
|
|
| import gradio as gr |
| import pandas as pd |
| from loguru import logger |
|
|
| from utils import pipeline |
| from utils.models import QWEN3_EMBEDDING_MODEL, get_embedding_model, list_models |
|
|
| LEGACY_POOLING_CHOICES = ['mean', 'cls'] |
| QWEN3_POOLING_CHOICES = ['last-token'] |
| DEFAULT_POOLING = 'mean' |
| QWEN3_POOLING = 'last-token' |
|
|
|
|
| def resolve_file_path(file) -> str: |
| if isinstance(file, dict) and file.get('path'): |
| return os.fspath(file['path']) |
| if isinstance(file, (str, os.PathLike)): |
| return os.fspath(file) |
| if hasattr(file, 'name'): |
| return os.fspath(file.name) |
| raise TypeError(f'Unsupported file input: {type(file)!r}') |
|
|
|
|
| def read_data(filepath: str) -> Optional[pd.DataFrame]: |
| filepath = os.fspath(filepath) |
| if filepath.endswith('.xlsx'): |
| df = pd.read_excel(filepath) |
| elif filepath.endswith('.csv'): |
| df = pd.read_csv(filepath) |
| else: |
| raise Exception('File type not supported') |
| return df |
|
|
|
|
| def effective_pooling(model_name: str, pooling: str) -> str: |
| if model_name == QWEN3_EMBEDDING_MODEL: |
| return QWEN3_POOLING |
| return pooling |
|
|
|
|
| def update_pooling_for_model(model_name: str): |
| if model_name == QWEN3_EMBEDDING_MODEL: |
| return gr.update( |
| choices=QWEN3_POOLING_CHOICES, |
| value=QWEN3_POOLING, |
| interactive=False, |
| ) |
| return gr.update( |
| choices=LEGACY_POOLING_CHOICES, |
| value=DEFAULT_POOLING, |
| interactive=True, |
| ) |
|
|
|
|
| def process( |
| task_name: str, |
| model_name: str, |
| pooling: str, |
| text: str, |
| file=None, |
| ) -> Generator[tuple[str, Optional[pd.DataFrame], Optional[str]], None, None]: |
| try: |
| pooling = effective_pooling(model_name, pooling) |
| logger.info(f'Processing {task_name} with {model_name} and {pooling}') |
| |
| if file: |
| df = read_data(resolve_file_path(file)) |
| elif text: |
| string_io = StringIO(text) |
| df = pd.read_csv(string_io) |
| assert len(df) >= 1, 'No input data' |
| else: |
| raise Exception('No input data') |
|
|
| |
| if len(df) > 10000: |
| raise Exception('Data exceeds 10,000 rows') |
|
|
| yield f'模型加载中:{model_name}', None, None |
| get_embedding_model(model_name) |
|
|
| yield '计算中...', None, None |
|
|
| |
| if task_name == 'Originality': |
| df = pipeline.p0_originality(df, model_name, pooling) |
| elif task_name == 'Flexibility': |
| df = pipeline.p1_flexibility(df, model_name, pooling) |
| else: |
| raise Exception('Task not supported') |
|
|
| |
| fd, path = tempfile.mkstemp(prefix='transdis_', suffix='.csv') |
| os.close(fd) |
| df.to_csv(path, index=False, encoding='utf-8-sig') |
| yield '完成', df.iloc[:10], path |
|
|
| except Exception: |
| error = traceback.format_exc() |
| logger.warning({ |
| 'error': error, |
| 'task_name': task_name, |
| 'model_name': model_name, |
| 'pooling': pooling, |
| 'text': text, |
| 'file': file, |
| }) |
| yield f'Something wrong\n\n{error}', None, None |
|
|
|
|
| |
| task_name_dropdown = gr.components.Dropdown( |
| label='Task Name', |
| value='Originality', |
| choices=['Originality', 'Flexibility'] |
| ) |
| model_name_dropdown = gr.components.Dropdown( |
| label='Model Name', |
| value=list_models[0], |
| choices=list_models |
| ) |
| pooling_dropdown = gr.components.Dropdown( |
| label='Pooling', |
| value=DEFAULT_POOLING, |
| choices=LEGACY_POOLING_CHOICES |
| ) |
| text_input = gr.components.Textbox( |
| value=open('data/example_xlm.csv', 'r').read(), |
| lines=10, |
| ) |
| file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) |
|
|
| |
| text_output = gr.components.Textbox(label='Output') |
| dataframe_output = gr.components.Dataframe(label='DataFrame') |
| file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) |
|
|
| with gr.Blocks(title='TransDis-CreativityAutoAssessment') as app: |
| gr.Markdown('# TransDis-CreativityAutoAssessment') |
| gr.Markdown(open('data/description.txt', 'r').read()) |
| with gr.Row(): |
| with gr.Column(): |
| task_name_dropdown.render() |
| model_name_dropdown.render() |
| pooling_dropdown.render() |
| text_input.render() |
| file_input.render() |
| submit_button = gr.Button('Submit', variant='primary') |
| with gr.Column(): |
| text_output.render() |
| dataframe_output.render() |
| file_output.render() |
|
|
| model_name_dropdown.change( |
| fn=update_pooling_for_model, |
| inputs=model_name_dropdown, |
| outputs=pooling_dropdown, |
| ) |
| submit_button.click( |
| fn=process, |
| inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], |
| outputs=[text_output, dataframe_output, file_output], |
| api_name='predict', |
| concurrency_limit=1, |
| ) |
|
|
| if __name__ == '__main__': |
| app.launch(max_threads=1) |
|
|