import os import tempfile import traceback from io import StringIO from typing import Generator, Optional import gradio as gr import pandas as pd from loguru import logger from utils import pipeline from utils.models import QWEN3_EMBEDDING_MODEL, get_embedding_model, list_models LEGACY_POOLING_CHOICES = ['mean', 'cls'] QWEN3_POOLING_CHOICES = ['last-token'] DEFAULT_POOLING = 'mean' QWEN3_POOLING = 'last-token' def resolve_file_path(file) -> str: if isinstance(file, dict) and file.get('path'): return os.fspath(file['path']) if isinstance(file, (str, os.PathLike)): return os.fspath(file) if hasattr(file, 'name'): return os.fspath(file.name) raise TypeError(f'Unsupported file input: {type(file)!r}') def read_data(filepath: str) -> Optional[pd.DataFrame]: filepath = os.fspath(filepath) if filepath.endswith('.xlsx'): df = pd.read_excel(filepath) elif filepath.endswith('.csv'): df = pd.read_csv(filepath) else: raise Exception('File type not supported') return df def effective_pooling(model_name: str, pooling: str) -> str: if model_name == QWEN3_EMBEDDING_MODEL: return QWEN3_POOLING return pooling def update_pooling_for_model(model_name: str): if model_name == QWEN3_EMBEDDING_MODEL: return gr.update( choices=QWEN3_POOLING_CHOICES, value=QWEN3_POOLING, interactive=False, ) return gr.update( choices=LEGACY_POOLING_CHOICES, value=DEFAULT_POOLING, interactive=True, ) def process( task_name: str, model_name: str, pooling: str, text: str, file=None, ) -> Generator[tuple[str, Optional[pd.DataFrame], Optional[str]], None, None]: try: pooling = effective_pooling(model_name, pooling) logger.info(f'Processing {task_name} with {model_name} and {pooling}') # load file if file: df = read_data(resolve_file_path(file)) elif text: string_io = StringIO(text) df = pd.read_csv(string_io) assert len(df) >= 1, 'No input data' else: raise Exception('No input data') # check if len(df) > 10000: raise Exception('Data exceeds 10,000 rows') yield f'模型加载中:{model_name}', None, None get_embedding_model(model_name) yield '计算中...', None, None # process if task_name == 'Originality': df = pipeline.p0_originality(df, model_name, pooling) elif task_name == 'Flexibility': df = pipeline.p1_flexibility(df, model_name, pooling) else: raise Exception('Task not supported') # save fd, path = tempfile.mkstemp(prefix='transdis_', suffix='.csv') os.close(fd) df.to_csv(path, index=False, encoding='utf-8-sig') yield '完成', df.iloc[:10], path except Exception: error = traceback.format_exc() logger.warning({ 'error': error, 'task_name': task_name, 'model_name': model_name, 'pooling': pooling, 'text': text, 'file': file, }) yield f'Something wrong\n\n{error}', None, None # input task_name_dropdown = gr.components.Dropdown( label='Task Name', value='Originality', choices=['Originality', 'Flexibility'] ) model_name_dropdown = gr.components.Dropdown( label='Model Name', value=list_models[0], choices=list_models ) pooling_dropdown = gr.components.Dropdown( label='Pooling', value=DEFAULT_POOLING, choices=LEGACY_POOLING_CHOICES ) text_input = gr.components.Textbox( value=open('data/example_xlm.csv', 'r').read(), lines=10, ) file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) # output text_output = gr.components.Textbox(label='Output') dataframe_output = gr.components.Dataframe(label='DataFrame') file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) with gr.Blocks(title='TransDis-CreativityAutoAssessment') as app: gr.Markdown('# TransDis-CreativityAutoAssessment') gr.Markdown(open('data/description.txt', 'r').read()) with gr.Row(): with gr.Column(): task_name_dropdown.render() model_name_dropdown.render() pooling_dropdown.render() text_input.render() file_input.render() submit_button = gr.Button('Submit', variant='primary') with gr.Column(): text_output.render() dataframe_output.render() file_output.render() model_name_dropdown.change( fn=update_pooling_for_model, inputs=model_name_dropdown, outputs=pooling_dropdown, ) submit_button.click( fn=process, inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], outputs=[text_output, dataframe_output, file_output], api_name='predict', concurrency_limit=1, ) if __name__ == '__main__': app.launch(max_threads=1)