Qifan Zhang
Add Qwen3 embedding support with last-token pooling
b91a6bd
import os
import tempfile
import traceback
from io import StringIO
from typing import Generator, Optional
import gradio as gr
import pandas as pd
from loguru import logger
from utils import pipeline
from utils.models import QWEN3_EMBEDDING_MODEL, get_embedding_model, list_models
LEGACY_POOLING_CHOICES = ['mean', 'cls']
QWEN3_POOLING_CHOICES = ['last-token']
DEFAULT_POOLING = 'mean'
QWEN3_POOLING = 'last-token'
def resolve_file_path(file) -> str:
if isinstance(file, dict) and file.get('path'):
return os.fspath(file['path'])
if isinstance(file, (str, os.PathLike)):
return os.fspath(file)
if hasattr(file, 'name'):
return os.fspath(file.name)
raise TypeError(f'Unsupported file input: {type(file)!r}')
def read_data(filepath: str) -> Optional[pd.DataFrame]:
filepath = os.fspath(filepath)
if filepath.endswith('.xlsx'):
df = pd.read_excel(filepath)
elif filepath.endswith('.csv'):
df = pd.read_csv(filepath)
else:
raise Exception('File type not supported')
return df
def effective_pooling(model_name: str, pooling: str) -> str:
if model_name == QWEN3_EMBEDDING_MODEL:
return QWEN3_POOLING
return pooling
def update_pooling_for_model(model_name: str):
if model_name == QWEN3_EMBEDDING_MODEL:
return gr.update(
choices=QWEN3_POOLING_CHOICES,
value=QWEN3_POOLING,
interactive=False,
)
return gr.update(
choices=LEGACY_POOLING_CHOICES,
value=DEFAULT_POOLING,
interactive=True,
)
def process(
task_name: str,
model_name: str,
pooling: str,
text: str,
file=None,
) -> Generator[tuple[str, Optional[pd.DataFrame], Optional[str]], None, None]:
try:
pooling = effective_pooling(model_name, pooling)
logger.info(f'Processing {task_name} with {model_name} and {pooling}')
# load file
if file:
df = read_data(resolve_file_path(file))
elif text:
string_io = StringIO(text)
df = pd.read_csv(string_io)
assert len(df) >= 1, 'No input data'
else:
raise Exception('No input data')
# check
if len(df) > 10000:
raise Exception('Data exceeds 10,000 rows')
yield f'模型加载中:{model_name}', None, None
get_embedding_model(model_name)
yield '计算中...', None, None
# process
if task_name == 'Originality':
df = pipeline.p0_originality(df, model_name, pooling)
elif task_name == 'Flexibility':
df = pipeline.p1_flexibility(df, model_name, pooling)
else:
raise Exception('Task not supported')
# save
fd, path = tempfile.mkstemp(prefix='transdis_', suffix='.csv')
os.close(fd)
df.to_csv(path, index=False, encoding='utf-8-sig')
yield '完成', df.iloc[:10], path
except Exception:
error = traceback.format_exc()
logger.warning({
'error': error,
'task_name': task_name,
'model_name': model_name,
'pooling': pooling,
'text': text,
'file': file,
})
yield f'Something wrong\n\n{error}', None, None
# input
task_name_dropdown = gr.components.Dropdown(
label='Task Name',
value='Originality',
choices=['Originality', 'Flexibility']
)
model_name_dropdown = gr.components.Dropdown(
label='Model Name',
value=list_models[0],
choices=list_models
)
pooling_dropdown = gr.components.Dropdown(
label='Pooling',
value=DEFAULT_POOLING,
choices=LEGACY_POOLING_CHOICES
)
text_input = gr.components.Textbox(
value=open('data/example_xlm.csv', 'r').read(),
lines=10,
)
file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx'])
# output
text_output = gr.components.Textbox(label='Output')
dataframe_output = gr.components.Dataframe(label='DataFrame')
file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])
with gr.Blocks(title='TransDis-CreativityAutoAssessment') as app:
gr.Markdown('# TransDis-CreativityAutoAssessment')
gr.Markdown(open('data/description.txt', 'r').read())
with gr.Row():
with gr.Column():
task_name_dropdown.render()
model_name_dropdown.render()
pooling_dropdown.render()
text_input.render()
file_input.render()
submit_button = gr.Button('Submit', variant='primary')
with gr.Column():
text_output.render()
dataframe_output.render()
file_output.render()
model_name_dropdown.change(
fn=update_pooling_for_model,
inputs=model_name_dropdown,
outputs=pooling_dropdown,
)
submit_button.click(
fn=process,
inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
outputs=[text_output, dataframe_output, file_output],
api_name='predict',
concurrency_limit=1,
)
if __name__ == '__main__':
app.launch(max_threads=1)