space / app.py
wayne0603's picture
Update app.py
16873bc verified
import os
import torch
import numpy as np
import faiss
import gradio as gr
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModel, pipeline
from ebooklib import epub
from bs4 import BeautifulSoup
# ===== 嵌入模型 =====
embed_model = AutoModel.from_pretrained(
"BAAI/bge-small-zh", trust_remote_code=True
)
embed_tokenizer = AutoTokenizer.from_pretrained(
"BAAI/bge-small-zh", trust_remote_code=True
)
def embed_text(text):
inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
embeddings = embed_model(**inputs).last_hidden_state[:, 0, :]
return embeddings[0].numpy()
# ===== 生成模型(Qwen 1.8B) =====
generator = pipeline(
"text-generation",
model="Qwen/Qwen1.5-1.8B-Chat",
device=-1
)
# ===== 全局变量 =====
index = None
docs = []
# ===== 文件解析 =====
def load_file(file_obj):
global index, docs
docs = []
text_data = ""
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
reader = PdfReader(file_path)
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text_data += page_text + "\n"
elif ext == ".txt":
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
text_data = f.read()
elif ext == ".epub":
book = epub.read_epub(file_path)
for item in book.get_items():
if item.get_type() == 9: # ITEM_DOCUMENT
soup = BeautifulSoup(item.get_content(), "html.parser")
text_data += soup.get_text() + "\n"
else:
return "仅支持 PDF / TXT / EPUB 文件", None
except Exception as e:
return f"文件解析失败: {str(e)}", None
if not text_data.strip():
return "未能从文件中提取到文本", None
# 分块(350字 + 100字重叠)
chunk_size = 350
overlap = 100
start = 0
chunks = []
while start < len(text_data):
end = min(start + chunk_size, len(text_data))
chunks.append(text_data[start:end])
start += chunk_size - overlap
docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)]
# 向量化 & 建索引
doc_embeddings = np.array([embed_text(d["text"]) for d in docs])
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(doc_embeddings)
return f"已加载 {len(docs)} 个文本块", None
# ===== RAG 查询 =====
def rag_query(query):
if index is None or not docs:
return "请先上传文件并构建知识库"
q_emb = embed_text(query).reshape(1, -1)
D, I = index.search(q_emb, k=8)
retrieved = [docs[i]["text"] for i in I[0]]
context = "\n".join([f"[{idx+1}] {txt}" for idx, txt in enumerate(retrieved)])
prompt = f"""已知信息:
{context}
问题:{query}
请严格按照以下格式输出:
【结论】
用 2-3 句话总结所有引用片段的关键信息,形成一个完整结论。
【详细说明】
整合所有引用片段的细节,分段描述,并在每个关键信息后标注引用编号。
无法回答时直接说“我不知道”。
【引用片段】
逐条列出引用编号及对应的原文。
"""
result = generator(prompt, max_length=800, do_sample=False)
answer = result[0]["generated_text"]
return answer
# ===== Gradio 界面 =====
with gr.Blocks() as demo:
gr.Markdown("## 📚 完整性增强版 RAG(PDF/TXT/EPUB 支持 + 结论 + 引用)")
with gr.Row():
file_input = gr.File(label="上传 PDF / TXT / EPUB 文件")
load_btn = gr.Button("构建知识库")
status = gr.Textbox(label="状态")
query_input = gr.Textbox(label="输入你的问题")
answer_output = gr.Textbox(label="回答", lines=15)
load_btn.click(load_file, inputs=file_input, outputs=status)
query_input.submit(rag_query, inputs=query_input, outputs=answer_output)
if __name__ == "__main__":
demo.launch()