import time import re import requests import streamlit as st import xml.etree.ElementTree as ET from openai import OpenAI # ========================= # OpenAI Client # ========================= def get_openai_client(): api_key = st.session_state.get("OPENAI_API_KEY", "") if not api_key: raise ValueError("OpenAI API Key が未設定です。") return OpenAI(api_key=api_key) def ask_llm(prompt, model="gpt-5-nano"): client = get_openai_client() res = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], ) return (res.choices[0].message.content or "").strip() # ========================= # Utility # ========================= def normalize_title(title: str) -> str: return " ".join((title or "").lower().strip().split()) def normalize_text(text: str) -> str: return " ".join((text or "").strip().split()) def deduplicate_papers(papers): seen = set() unique = [] for p in papers: title = normalize_title(p.get("title", "")) if not title: continue authors = p.get("authors", []) or [] first_author = authors[0].lower().strip() if authors else "" key = (title, first_author) if key not in seen: seen.add(key) unique.append(p) return unique # ========================= # arXiv Search # ========================= import re import xml.etree.ElementTree as ET def normalize_space(text: str) -> str: return re.sub(r"\s+", " ", text or "").strip() def extract_venue_from_arxiv(journal_ref: str, comment: str) -> str: text = f"{journal_ref} {comment}".strip() if not text: return "" # よくある国際会議・ジャーナル略称 venue_patterns = [ r"\bNeurIPS\s*\d{4}\b", r"\bNIPS\s*\d{4}\b", r"\bICML\s*\d{4}\b", r"\bICLR\s*\d{4}\b", r"\bACL\s*\d{4}\b", r"\bEMNLP\s*\d{4}\b", r"\bNAACL\s*\d{4}\b", r"\bCOLING\s*\d{4}\b", r"\bCVPR\s*\d{4}\b", r"\bICCV\s*\d{4}\b", r"\bECCV\s*\d{4}\b", r"\bAAAI\s*\d{4}\b", r"\bIJCAI\s*\d{4}\b", r"\bKDD\s*\d{4}\b", r"\bSIGIR\s*\d{4}\b", r"\bWWW\s*\d{4}\b", r"\bTheWebConf\s*\d{4}\b", r"\bCHI\s*\d{4}\b", r"\bUAI\s*\d{4}\b", r"\bAISTATS\s*\d{4}\b", r"\bICRA\s*\d{4}\b", r"\bIROS\s*\d{4}\b", ] for pattern in venue_patterns: m = re.search(pattern, text, flags=re.IGNORECASE) if m: return m.group(0) # journal_refがあるなら、まずそれをvenueとして使う if journal_ref: return journal_ref # commentに Accepted / Published / To appear などがあれば、それをvenue候補にする accepted_patterns = [ r"(?:Accepted|Accepted at|Accepted to|To appear in|Published in)\s+(.+?)(?:\.|$)", r"(?:Proceedings of)\s+(.+?)(?:\.|$)", ] for pattern in accepted_patterns: m = re.search(pattern, comment, flags=re.IGNORECASE) if m: return normalize_space(m.group(1)) return "" def parse_arxiv_response(xml_text): root = ET.fromstring(xml_text) papers = [] ATOM = "{http://www.w3.org/2005/Atom}" ARXIV = "{http://arxiv.org/schemas/atom}" for entry in root.findall(f"{ATOM}entry"): title_el = entry.find(f"{ATOM}title") abstract_el = entry.find(f"{ATOM}summary") date_el = entry.find(f"{ATOM}published") id_el = entry.find(f"{ATOM}id") journal_ref_el = entry.find(f"{ARXIV}journal_ref") comment_el = entry.find(f"{ARXIV}comment") authors = [] for a in entry.findall(f"{ATOM}author"): name_el = a.find(f"{ATOM}name") if name_el is not None and name_el.text: authors.append(normalize_space(name_el.text)) title = normalize_space(title_el.text) if title_el is not None and title_el.text else "" abstract = normalize_space(abstract_el.text) if abstract_el is not None and abstract_el.text else "" date = normalize_space(date_el.text) if date_el is not None and date_el.text else "" url = normalize_space(id_el.text) if id_el is not None and id_el.text else "" journal_ref = ( normalize_space(journal_ref_el.text) if journal_ref_el is not None and journal_ref_el.text else "" ) comment = ( normalize_space(comment_el.text) if comment_el is not None and comment_el.text else "" ) venue = extract_venue_from_arxiv(journal_ref, comment) pdf_url = "" for link in entry.findall(f"{ATOM}link"): if link.attrib.get("title") == "pdf": pdf_url = link.attrib.get("href", "") break if title: papers.append( { "title": title, "authors": authors, "abstract": abstract, "date": date, "source": "arXiv", "venue": venue, "journal_ref": journal_ref, "comment": comment, "url": url, "pdf_url": pdf_url, } ) return papers ARXIV_API_URL = "https://export.arxiv.org/api/query" _last_arxiv_request_time = 0 def escape_arxiv_phrase(text: str) -> str: """ arXivのフレーズ検索用に最低限エスケープする。 """ text = text.strip() text = text.replace('"', " ") text = re.sub(r"\s+", " ", text) return text def wait_for_arxiv_rate_limit(min_interval=3.2): """ arXiv APIは連続アクセスに弱いので、最低3秒以上空ける。 """ global _last_arxiv_request_time elapsed = time.time() - _last_arxiv_request_time if elapsed < min_interval: time.sleep(min_interval - elapsed) def search_arxiv_once(search_query, max_results=3, retries=3): global _last_arxiv_request_time params = { "search_query": search_query, "start": 0, "max_results": max_results, "sortBy": "relevance", "sortOrder": "descending", } headers = { "User-Agent": "paper-finder/0.1 contact:your-email@example.com" } last_error = None for attempt in range(retries): wait_for_arxiv_rate_limit() try: res = requests.get( ARXIV_API_URL, params=params, timeout=30, headers=headers, ) _last_arxiv_request_time = time.time() if res.status_code == 429: wait = 5 * (attempt + 1) time.sleep(wait) last_error = RuntimeError("arXiv rate limited: 429") continue res.raise_for_status() return parse_arxiv_response(res.text) except requests.RequestException as e: last_error = e time.sleep(2 * (attempt + 1)) raise last_error def search_arxiv(query, max_results=3, debug=False): query = normalize_text(query) if not query: return [] query = escape_arxiv_phrase(query) terms = [t for t in re.split(r"\s+", query) if t] strategies = [] # まずフレーズ検索 strategies.append(f'all:"{query}"') # タイトル検索 strategies.append(f'ti:"{query}"') # abstract検索も追加 strategies.append(f'abs:"{query}"') # 単語AND検索 if terms: safe_terms = [escape_arxiv_phrase(t) for t in terms] strategies.append(" AND ".join([f'all:{t}' for t in safe_terms])) # 最後に緩めの単語OR検索 if len(terms) >= 2: safe_terms = [escape_arxiv_phrase(t) for t in terms] strategies.append(" OR ".join([f'all:{t}' for t in safe_terms])) seen = set() all_papers = [] for s in strategies: try: if debug: st.write("arXiv API query:", s) papers = search_arxiv_once(s, max_results=max_results) for p in papers: key = normalize_title(p["title"]) if key not in seen: seen.add(key) all_papers.append(p) if len(all_papers) >= max_results: return all_papers[:max_results] except Exception as e: if debug: st.warning(f"arXiv query failed: {s} / {e}") return all_papers[:max_results] # ========================= # OpenAlex Search # ========================= def reconstruct_abstract(inv_index): if not inv_index: return "" words = [] for word, pos_list in inv_index.items(): for pos in pos_list: words.append((pos, word)) words.sort(key=lambda x: x[0]) return " ".join(w for _, w in words) def extract_openalex_venue(item): primary_location = item.get("primary_location") or {} source = primary_location.get("source") or {} venue = source.get("display_name", "") or "" if not venue: locations = item.get("locations") or [] for loc in locations: src = (loc or {}).get("source") or {} venue = src.get("display_name", "") or "" if venue: break if not venue: host_venue = item.get("host_venue") or {} venue = host_venue.get("display_name", "") or "" return venue def search_openalex(query, venues, max_results=3, debug=False): query = normalize_text(query) if not query or not venues: return [] url = "https://api.openalex.org/works" params = { "search": query, "per-page": 50, } try: res = requests.get( url, params=params, timeout=30, headers={"User-Agent": "paper-finder/0.1"}, ) res.raise_for_status() data = res.json() papers = [] for item in data.get("results", []): venue = extract_openalex_venue(item) if not any(v.lower() in venue.lower() for v in venues): continue authors = [] for a in item.get("authorships", []): author = a.get("author") or {} name = author.get("display_name") if name: authors.append(name) abstract = item.get("abstract_inverted_index") if isinstance(abstract, dict): abstract = reconstruct_abstract(abstract) elif not isinstance(abstract, str): abstract = "" papers.append( { "title": item.get("title", "") or "", "authors": authors, "abstract": abstract, "date": item.get("publication_date", "") or "", "source": "OpenAlex", "venue": venue, "url": item.get("id", "") or "", } ) if len(papers) >= max_results: break if debug: st.write("OpenAlex matched papers:", len(papers)) return papers except Exception as e: if debug: st.warning(f"OpenAlex search failed: {e}") return [] # ========================= # LLM Utilities # ========================= def normalize_keyword_for_search(keyword, model): prompt = f""" あなたは学術論文検索アシスタントです。 以下のユーザー入力を、arXivやOpenAlexで検索しやすい英語の短い検索クエリに変換してください。 ルール: - 出力は英語の検索クエリ1つだけ - 余計な説明は不要 - 日本語入力なら自然な英語の研究キーワードへ変換 - 英語入力なら意味を保って簡潔に整形 - 2語から8語程度が望ましい - 不要な記号は入れない input: {keyword} """ return normalize_text(ask_llm(prompt, model)) def paraphrase_query(keyword, model): prompt = f""" 次の研究トピックを、英語の論文検索クエリとして言い換えてください。 出力は短い英語クエリを1つだけにしてください。 説明は不要です。 topic: {keyword} """ return normalize_text(ask_llm(prompt, model)) def classify_field(keyword, model): prompt = f""" 次の研究トピックが主に属する分野を、以下から1つだけ選んでください。 候補: ML NLP CV OTHER 研究トピック: {keyword} 判定ルール: - 機械学習全般、最適化、表現学習、強化学習、生成モデルなどは ML - 自然言語処理、対話、翻訳、要約、LLM、RAG などは NLP - 画像、動画、物体検出、セグメンテーション、3D vision などは CV - 上記に明確に当てはまらなければ OTHER 出力はラベル1つだけにしてください。 """ return ask_llm(prompt, model).strip().upper() def summarize_paper(title, abstract, model, venue=""): prompt = f""" 次の論文を簡潔に日本語で解説してください。 Title: {title} Venue: {venue} Abstract: {abstract} 出力形式: - 要約 - 何が新しいか - どんな人におすすめか """ return ask_llm(prompt, model) def select_best_papers(papers, keyword, model, top_k=3): if not papers: return [] if len(papers) <= top_k: return papers[:top_k] text = "" for i, p in enumerate(papers): text += f""" Paper {i} Title: {p.get("title", "")} Venue: {p.get("venue", "")} Abstract: {p.get("abstract", "")} """ prompt = f""" 次の論文リストから、研究トピック「{keyword}」に最も関連があり重要度が高い論文を {top_k} 本選んでください。 必ず異なる論文を選んでください。 {text} 出力形式: 0,2,5 """ try: res = ask_llm(prompt, model) ids = [] for x in res.split(","): x = x.strip() if x.isdigit(): ids.append(int(x)) ids = list(dict.fromkeys(ids)) results = [] seen_titles = set() for i in ids: if 0 <= i < len(papers): title_key = normalize_title(papers[i].get("title", "")) if title_key and title_key not in seen_titles: results.append(papers[i]) seen_titles.add(title_key) if len(results) >= top_k: break if results: return results[:top_k] except Exception: pass return papers[:top_k] # ========================= # Streamlit UI # ========================= st.set_page_config(page_title="Paper Finder", layout="wide") st.title("📚 Paper Finder") st.sidebar.header("Settings") import os openai_api_key = os.getenv("OPENAI_API_KEY") st.session_state["OPENAI_API_KEY"] = openai_api_key # model = st.sidebar.selectbox( # "Model", # ["gpt-5-nano"], # index=0, # ) model = "gpt-5-nano" debug_mode = st.sidebar.checkbox("Debug mode", value=True) keyword = st.text_input("Research Keyword") if st.button("Search Papers"): if not st.session_state.get("OPENAI_API_KEY"): st.error("OpenAI API Key を入力してください。") st.stop() if not keyword.strip(): st.warning("Research Keyword を入力してください。") st.stop() paper_list = [] st.write("### Step0 Query Normalization") try: normalized_keyword = normalize_keyword_for_search(keyword, model) except Exception as e: st.error(f"検索クエリ正規化に失敗しました: {e}") st.stop() st.write("**Input keyword:**", keyword) st.write("**Normalized English query:**", normalized_keyword) st.write("### Step1 arXiv search") papers_step1 = search_arxiv(normalized_keyword, max_results=10, debug=debug_mode) paper_list.extend(papers_step1) st.write(f"found {len(papers_step1)} papers") st.write("### Step2 Query Paraphrase") try: paraphrased = paraphrase_query(normalized_keyword, model) except Exception as e: paraphrased = normalized_keyword if debug_mode: st.warning(f"Query paraphrase failed: {e}") st.write("**Paraphrased query:**", paraphrased) papers_step2 = search_arxiv(paraphrased, max_results=10, debug=debug_mode) paper_list.extend(papers_step2) st.write(f"found {len(papers_step2)} papers") st.write("### Step3 Field Classification") try: field = classify_field(keyword, model) except Exception as e: field = "OTHER" if debug_mode: st.warning(f"Field classification failed: {e}") st.write("**field:**", field) if field == "ML": venues = ["ICML", "ICLR", "NeurIPS"] elif field == "NLP": venues = ["ACL", "EMNLP", "NAACL", "AACL"] elif field == "CV": venues = ["CVPR", "ICCV", "ECCV", "SIGGRAPH"] else: venues = [] papers_step3 = [] if venues: st.write("### Step4 Top-conference Search") papers_step3 = search_openalex(normalized_keyword, venues, max_results=10, debug=debug_mode) paper_list.extend(papers_step3) st.write(f"found {len(papers_step3)} papers") paper_list = deduplicate_papers(paper_list) st.write("### Total candidate papers:", len(paper_list)) if debug_mode and paper_list: with st.expander("Candidate Papers"): for i, p in enumerate(paper_list): st.write( f"{i}. {p.get('title', '')} | venue={p.get('venue', '') or '-'} | source={p.get('source', '')}" ) if not paper_list: st.error("論文が見つかりませんでした。より一般的な表現や別のキーワードで試してください。") st.stop() st.write("### Selecting best papers") best = select_best_papers(paper_list, keyword, model, top_k=3) if not best: st.warning("推薦論文の選定に失敗したため、候補論文をそのまま表示します。") best = paper_list[:3] st.write("## Recommended Papers") for p in best: abstract = p.get("abstract", "") or "" venue = p.get("venue", "") or "-" try: summary = summarize_paper( title=p.get("title", ""), abstract=abstract, model=model, venue=venue, ) if abstract else "アブストラクトが取得できなかったため、要約を生成できませんでした。" except Exception as e: summary = f"要約生成に失敗しました: {e}" st.markdown("---") title = p.get("title", "No title") url = p.get("url") if url: st.markdown(f"### [{title}]({url})") else: st.markdown(f"### {title}") st.write("**Explanation:**") st.write(summary) st.write("**Authors:**", ", ".join(p.get("authors", [])) if p.get("authors") else "-") st.write("**Date:**", p.get("date", "") or "-") st.write("**Source:**", p.get("source", "") or "-") st.write("**Venue:**", venue) st.write("**Abstract:**") st.write(abstract if abstract else "アブストラクトなし")