|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import faiss |
|
|
import gradio as gr |
|
|
from PyPDF2 import PdfReader |
|
|
from transformers import AutoTokenizer, AutoModel, pipeline |
|
|
from ebooklib import epub |
|
|
from bs4 import BeautifulSoup |
|
|
|
|
|
|
|
|
embed_model = AutoModel.from_pretrained( |
|
|
"BAAI/bge-small-zh", trust_remote_code=True |
|
|
) |
|
|
embed_tokenizer = AutoTokenizer.from_pretrained( |
|
|
"BAAI/bge-small-zh", trust_remote_code=True |
|
|
) |
|
|
|
|
|
def embed_text(text): |
|
|
inputs = embed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
embeddings = embed_model(**inputs).last_hidden_state[:, 0, :] |
|
|
return embeddings[0].numpy() |
|
|
|
|
|
|
|
|
generator = pipeline( |
|
|
"text-generation", |
|
|
model="Qwen/Qwen1.5-1.8B-Chat", |
|
|
device=-1 |
|
|
) |
|
|
|
|
|
|
|
|
index = None |
|
|
docs = [] |
|
|
|
|
|
|
|
|
def load_file(file_obj): |
|
|
global index, docs |
|
|
docs = [] |
|
|
text_data = "" |
|
|
|
|
|
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj |
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
try: |
|
|
if ext == ".pdf": |
|
|
reader = PdfReader(file_path) |
|
|
for page in reader.pages: |
|
|
page_text = page.extract_text() |
|
|
if page_text: |
|
|
text_data += page_text + "\n" |
|
|
|
|
|
elif ext == ".txt": |
|
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as f: |
|
|
text_data = f.read() |
|
|
|
|
|
elif ext == ".epub": |
|
|
book = epub.read_epub(file_path) |
|
|
for item in book.get_items(): |
|
|
if item.get_type() == 9: |
|
|
soup = BeautifulSoup(item.get_content(), "html.parser") |
|
|
text_data += soup.get_text() + "\n" |
|
|
|
|
|
else: |
|
|
return "仅支持 PDF / TXT / EPUB 文件", None |
|
|
|
|
|
except Exception as e: |
|
|
return f"文件解析失败: {str(e)}", None |
|
|
|
|
|
if not text_data.strip(): |
|
|
return "未能从文件中提取到文本", None |
|
|
|
|
|
|
|
|
chunk_size = 350 |
|
|
overlap = 100 |
|
|
start = 0 |
|
|
chunks = [] |
|
|
while start < len(text_data): |
|
|
end = min(start + chunk_size, len(text_data)) |
|
|
chunks.append(text_data[start:end]) |
|
|
start += chunk_size - overlap |
|
|
|
|
|
docs = [{"text": chunk, "source": f"chunk_{i}"} for i, chunk in enumerate(chunks)] |
|
|
|
|
|
|
|
|
doc_embeddings = np.array([embed_text(d["text"]) for d in docs]) |
|
|
index = faiss.IndexFlatL2(doc_embeddings.shape[1]) |
|
|
index.add(doc_embeddings) |
|
|
|
|
|
return f"已加载 {len(docs)} 个文本块", None |
|
|
|
|
|
|
|
|
def rag_query(query): |
|
|
if index is None or not docs: |
|
|
return "请先上传文件并构建知识库" |
|
|
q_emb = embed_text(query).reshape(1, -1) |
|
|
D, I = index.search(q_emb, k=8) |
|
|
retrieved = [docs[i]["text"] for i in I[0]] |
|
|
context = "\n".join([f"[{idx+1}] {txt}" for idx, txt in enumerate(retrieved)]) |
|
|
|
|
|
prompt = f"""已知信息: |
|
|
{context} |
|
|
|
|
|
问题:{query} |
|
|
|
|
|
请严格按照以下格式输出: |
|
|
【结论】 |
|
|
用 2-3 句话总结所有引用片段的关键信息,形成一个完整结论。 |
|
|
|
|
|
【详细说明】 |
|
|
整合所有引用片段的细节,分段描述,并在每个关键信息后标注引用编号。 |
|
|
无法回答时直接说“我不知道”。 |
|
|
|
|
|
【引用片段】 |
|
|
逐条列出引用编号及对应的原文。 |
|
|
""" |
|
|
|
|
|
result = generator(prompt, max_length=800, do_sample=False) |
|
|
answer = result[0]["generated_text"] |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 📚 完整性增强版 RAG(PDF/TXT/EPUB 支持 + 结论 + 引用)") |
|
|
with gr.Row(): |
|
|
file_input = gr.File(label="上传 PDF / TXT / EPUB 文件") |
|
|
load_btn = gr.Button("构建知识库") |
|
|
status = gr.Textbox(label="状态") |
|
|
query_input = gr.Textbox(label="输入你的问题") |
|
|
answer_output = gr.Textbox(label="回答", lines=15) |
|
|
load_btn.click(load_file, inputs=file_input, outputs=status) |
|
|
query_input.submit(rag_query, inputs=query_input, outputs=answer_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|