import os import torch import numpy as np import faiss import gradio as gr from PyPDF2 import PdfReader from transformers import AutoTokenizer, AutoModel, pipeline from ebooklib import epub from bs4 import BeautifulSoup # ===== 嵌入模型 ===== embed_model = AutoModel.from_pretrained( "BAAI/bge-small-zh", trust_remote_code=True ) embed_tokenizer = AutoTokenizer.from_pretrained( "BAAI/bge-small-zh", trust_remote_code=True ) def embed_text(text): inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): embeddings = embed_model(**inputs).last_hidden_state[:, 0, :] return embeddings[0].numpy() # ===== 生成模型(Qwen 1.8B) ===== generator = pipeline( "text-generation", model="Qwen/Qwen1.5-1.8B-Chat", device=-1 ) # ===== 全局变量 ===== index = None docs = [] # ===== 文件解析 ===== def load_file(file_obj): global index, docs docs = [] text_data = "" file_path = file_obj.name if hasattr(file_obj, "name") else file_obj ext = os.path.splitext(file_path)[1].lower() try: if ext == ".pdf": reader = PdfReader(file_path) for page in reader.pages: page_text = page.extract_text() if page_text: text_data += page_text + "\n" elif ext == ".txt": with open(file_path, "r", encoding="utf-8", errors="ignore") as f: text_data = f.read() elif ext == ".epub": book = epub.read_epub(file_path) for item in book.get_items(): if item.get_type() == 9: # ITEM_DOCUMENT soup = BeautifulSoup(item.get_content(), "html.parser") text_data += soup.get_text() + "\n" else: return "仅支持 PDF / TXT / EPUB 文件", None except Exception as e: return f"文件解析失败: {str(e)}", None if not text_data.strip(): return "未能从文件中提取到文本", None # 分块(350字 + 100字重叠) chunk_size = 350 overlap = 100 start = 0 chunks = [] while start < len(text_data): end = min(start + chunk_size, len(text_data)) chunks.append(text_data[start:end]) start += chunk_size - overlap docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)] # 向量化 & 建索引 doc_embeddings = np.array([embed_text(d["text"]) for d in docs]) index = faiss.IndexFlatL2(doc_embeddings.shape[1]) index.add(doc_embeddings) return f"已加载 {len(docs)} 个文本块", None # ===== RAG 查询 ===== def rag_query(query): if index is None or not docs: return "请先上传文件并构建知识库" q_emb = embed_text(query).reshape(1, -1) D, I = index.search(q_emb, k=8) retrieved = [docs[i]["text"] for i in I[0]] context = "\n".join([f"[{idx+1}] {txt}" for idx, txt in enumerate(retrieved)]) prompt = f"""已知信息: {context} 问题:{query} 请严格按照以下格式输出: 【结论】 用 2-3 句话总结所有引用片段的关键信息,形成一个完整结论。 【详细说明】 整合所有引用片段的细节,分段描述,并在每个关键信息后标注引用编号。 无法回答时直接说“我不知道”。 【引用片段】 逐条列出引用编号及对应的原文。 """ result = generator(prompt, max_length=800, do_sample=False) answer = result[0]["generated_text"] return answer # ===== Gradio 界面 ===== with gr.Blocks() as demo: gr.Markdown("## 📚 完整性增强版 RAG(PDF/TXT/EPUB 支持 + 结论 + 引用)") with gr.Row(): file_input = gr.File(label="上传 PDF / TXT / EPUB 文件") load_btn = gr.Button("构建知识库") status = gr.Textbox(label="状态") query_input = gr.Textbox(label="输入你的问题") answer_output = gr.Textbox(label="回答", lines=15) load_btn.click(load_file, inputs=file_input, outputs=status) query_input.submit(rag_query, inputs=query_input, outputs=answer_output) if __name__ == "__main__": demo.launch()