Mahmoud Sayed
commited on
Commit
·
7dc8b43
1
Parent(s):
64b4ab7
First
Browse files- .env +2 -0
- Dockerfile +5 -0
- __pycache__/pinecone.cpython-311.pyc +0 -0
- api.py +180 -0
- data/clean_medquad.py +37 -0
- data/coaching_millionaer_dataset.json +0 -0
- main.py +32 -0
- model/1_Pooling/config.json +10 -0
- model/README.md +173 -0
- model/config.json +25 -0
- model/config_sentence_transformers.json +14 -0
- model/model.safetensors +3 -0
- model/modules.json +20 -0
- model/sentence_bert_config.json +4 -0
- model/special_tokens_map.json +37 -0
- model/tokenizer.json +0 -0
- model/tokenizer_config.json +65 -0
- model/vocab.txt +0 -0
- pinecone_index.py +45 -0
- qa/__pycache__/biobert_qa.cpython-311.pyc +0 -0
- qa/biobert_qa.py +48 -0
- requirements.txt +8 -0
- retriever/__pycache__/bm25_retriever.cpython-311.pyc +0 -0
- retriever/__pycache__/faiss_retriever.cpython-311.pyc +0 -0
- retriever/bm25_retriever.py +48 -0
- retriever/faiss_retriever.py +89 -0
- retriever/pinecone_retriever.py +22 -0
.env
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_API_KEY = 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
|
| 2 |
+
PINECONE_API_KEY = pcsk_6FCjSE_FFtwDN4PEY5Q7pqKGqGsNgBQrH2Ut9xWcpr3oe1FA28VDPFqei4XtpXMCwb7zdX
|
Dockerfile
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
WORKDIR /app
|
| 3 |
+
COPY . /app
|
| 4 |
+
RUN pip install -r requirements.txt
|
| 5 |
+
CMD ["python", "api.py"]
|
__pycache__/pinecone.cpython-311.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
api.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
from flask import Flask, request, jsonify
|
| 4 |
+
from flask_cors import CORS
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from langdetect import detect
|
| 8 |
+
from googletrans import Translator
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from pinecone import Pinecone
|
| 11 |
+
|
| 12 |
+
# ---------- Config ----------
|
| 13 |
+
DATASET_PATH = "data/coaching_millionaer_dataset.json"
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 17 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") # add this to your .env
|
| 18 |
+
PINECONE_INDEX_NAME = "ebook"
|
| 19 |
+
|
| 20 |
+
# ---------- App ----------
|
| 21 |
+
app = Flask(__name__)
|
| 22 |
+
CORS(app, resources={r"/ask": {"origins": "*"}})
|
| 23 |
+
|
| 24 |
+
# ---------- OpenAI Client ----------
|
| 25 |
+
client = None
|
| 26 |
+
if OPENAI_API_KEY:
|
| 27 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 28 |
+
else:
|
| 29 |
+
print("⚠️ OPENAI_API_KEY is missing in .env")
|
| 30 |
+
|
| 31 |
+
# ---------- Retriever ----------
|
| 32 |
+
retriever = None
|
| 33 |
+
try:
|
| 34 |
+
if not PINECONE_API_KEY:
|
| 35 |
+
raise ValueError("PINECONE_API_KEY missing in .env")
|
| 36 |
+
|
| 37 |
+
pc = Pinecone(api_key=PINECONE_API_KEY)
|
| 38 |
+
index = pc.Index(PINECONE_INDEX_NAME)
|
| 39 |
+
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 40 |
+
|
| 41 |
+
class PineconeRetriever:
|
| 42 |
+
def __init__(self, index, embedder):
|
| 43 |
+
self.index = index
|
| 44 |
+
self.embedder = embedder
|
| 45 |
+
|
| 46 |
+
def retrieve(self, query, top_k=10):
|
| 47 |
+
emb = self.embedder.encode(query).tolist()
|
| 48 |
+
res = self.index.query(vector=emb, top_k=top_k, include_metadata=True)
|
| 49 |
+
matches = res.get("matches", [])
|
| 50 |
+
results = []
|
| 51 |
+
for match in matches:
|
| 52 |
+
meta = match.get("metadata", {})
|
| 53 |
+
results.append({
|
| 54 |
+
"context": meta.get("context", ""),
|
| 55 |
+
"page": meta.get("page"),
|
| 56 |
+
"score": match.get("score", 0)
|
| 57 |
+
})
|
| 58 |
+
return results
|
| 59 |
+
|
| 60 |
+
retriever = PineconeRetriever(index, embedder)
|
| 61 |
+
print("✅ Pinecone retriever initialized successfully.")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print("❌ Retriever initialization failed:", e)
|
| 64 |
+
traceback.print_exc()
|
| 65 |
+
|
| 66 |
+
translator = Translator()
|
| 67 |
+
|
| 68 |
+
# ---------- Helpers ----------
|
| 69 |
+
def detect_language(question: str) -> str:
|
| 70 |
+
"""Detect the user's language without translation."""
|
| 71 |
+
try:
|
| 72 |
+
return detect(question)
|
| 73 |
+
except Exception:
|
| 74 |
+
return "unknown"
|
| 75 |
+
|
| 76 |
+
def normalize_language(lang: str, text: str) -> str:
|
| 77 |
+
"""Fix incorrect language detection like 'wer is' → German."""
|
| 78 |
+
if lang == "nl" and any(word in text.lower() for word in ["wer", "was", "wie", "javid", "coaching"]):
|
| 79 |
+
return "de"
|
| 80 |
+
return lang
|
| 81 |
+
|
| 82 |
+
def system_prompt_book_only() -> str:
|
| 83 |
+
return (
|
| 84 |
+
"You are CoachingBot, a professional mentor trained on the book 'Coaching Millionär' by Javid Niazi-Hoffmann. "
|
| 85 |
+
"Use only the provided book context to answer the question. "
|
| 86 |
+
"If the user asks about people like Javid Niazi-Hoffmann, describe them factually using the book content. "
|
| 87 |
+
"Mention page numbers where possible. "
|
| 88 |
+
"If the context is not relevant, say you don’t have that information in the book and provide a general, helpful answer. "
|
| 89 |
+
"Always respond in the same language as the user's question, even if the book content is in another language."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def system_prompt_fallback() -> str:
|
| 93 |
+
return (
|
| 94 |
+
"You are CoachingBot, a helpful business and life mentor. "
|
| 95 |
+
"The question cannot be answered from the book, so answer using your general coaching knowledge. "
|
| 96 |
+
"Always respond in the same language as the user's question, even if the book content is in another language. "
|
| 97 |
+
"Do not invent book citations."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def format_answers(question: str, answer: str, results):
|
| 101 |
+
pages = [f"Seite {r.get('page', '')}" for r in results if r.get("page")]
|
| 102 |
+
source = ", ".join(pages) if pages else "No source"
|
| 103 |
+
top_score = max([r.get("score", 0.0) for r in results], default=0.0)
|
| 104 |
+
return {"answers": [{"question": question, "answer": answer, "source": source, "bm25_score": top_score}]}
|
| 105 |
+
|
| 106 |
+
# ---------- Routes ----------
|
| 107 |
+
@app.route("/", methods=["GET"])
|
| 108 |
+
def health():
|
| 109 |
+
return jsonify({
|
| 110 |
+
"status": "running",
|
| 111 |
+
"retriever_ready": bool(retriever),
|
| 112 |
+
"openai_key_loaded": bool(OPENAI_API_KEY),
|
| 113 |
+
"pinecone_key_loaded": bool(PINECONE_API_KEY),
|
| 114 |
+
"index_name": PINECONE_INDEX_NAME
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
@app.route("/ask", methods=["POST", "OPTIONS"])
|
| 118 |
+
def ask():
|
| 119 |
+
if request.method == "OPTIONS":
|
| 120 |
+
return ("", 204)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
data = request.get_json(force=True) or {}
|
| 124 |
+
question = (data.get("question") or "").strip()
|
| 125 |
+
except Exception:
|
| 126 |
+
return jsonify(format_answers("", "Invalid JSON request", [])), 200
|
| 127 |
+
|
| 128 |
+
if not question:
|
| 129 |
+
return jsonify(format_answers("", "Please enter a question.", [])), 200
|
| 130 |
+
|
| 131 |
+
print(f"\n--- User Question ---\n{question}")
|
| 132 |
+
|
| 133 |
+
# Detect and normalize language
|
| 134 |
+
user_lang = normalize_language(detect_language(question), question)
|
| 135 |
+
print(f"Detected language: {user_lang}")
|
| 136 |
+
|
| 137 |
+
# Retrieve context
|
| 138 |
+
context, results = "", []
|
| 139 |
+
try:
|
| 140 |
+
raw_results = retriever.retrieve(question)
|
| 141 |
+
MIN_SCORE = 0.10 # Pinecone similarity scores are normalized (0–1)
|
| 142 |
+
results = [r for r in raw_results if r.get("score", 0) >= MIN_SCORE]
|
| 143 |
+
if results:
|
| 144 |
+
context = "\n\n---\n\n".join(
|
| 145 |
+
[f"(Seite {r['page']}) {r['context']}" for r in results]
|
| 146 |
+
)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
traceback.print_exc()
|
| 149 |
+
return jsonify(format_answers(question, f"Retriever error: {e}", [])), 200
|
| 150 |
+
|
| 151 |
+
# Build prompts
|
| 152 |
+
if context:
|
| 153 |
+
sys_prompt = system_prompt_book_only()
|
| 154 |
+
user_content = f"Question: {question}\n\nBook context:\n{context}"
|
| 155 |
+
else:
|
| 156 |
+
sys_prompt = system_prompt_fallback()
|
| 157 |
+
user_content = question
|
| 158 |
+
|
| 159 |
+
# Query GPT
|
| 160 |
+
try:
|
| 161 |
+
response = client.chat.completions.create(
|
| 162 |
+
model="gpt-4o-mini",
|
| 163 |
+
messages=[
|
| 164 |
+
{"role": "system", "content": sys_prompt},
|
| 165 |
+
{"role": "user", "content": user_content}
|
| 166 |
+
],
|
| 167 |
+
max_tokens=700,
|
| 168 |
+
)
|
| 169 |
+
answer = response.choices[0].message.content.strip()
|
| 170 |
+
except Exception as e:
|
| 171 |
+
traceback.print_exc()
|
| 172 |
+
return jsonify(format_answers(question, f"⚠️ OpenAI call failed: {e}", [])), 200
|
| 173 |
+
|
| 174 |
+
return jsonify(format_answers(question, answer, results))
|
| 175 |
+
|
| 176 |
+
# ---------- Run ----------
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
port = int(os.environ.get("PORT", 5000))
|
| 179 |
+
print(f"🚀 Server started on port {port}")
|
| 180 |
+
app.run(host="0.0.0.0", port=port)
|
data/clean_medquad.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Input and output paths
|
| 7 |
+
input_csv_path = "data/medquad.csv"
|
| 8 |
+
output_json_path = "data/medquad_cleaned.json"
|
| 9 |
+
|
| 10 |
+
# Make sure output directory exists
|
| 11 |
+
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# Load CSV
|
| 14 |
+
df = pd.read_csv(input_csv_path)
|
| 15 |
+
|
| 16 |
+
# Basic cleaning
|
| 17 |
+
df.dropna(subset=["question", "answer"], inplace=True)
|
| 18 |
+
df["question"] = df["question"].str.strip()
|
| 19 |
+
df["answer"] = df["answer"].str.strip()
|
| 20 |
+
df["source"] = df["source"].fillna("").str.strip()
|
| 21 |
+
df.drop_duplicates(subset=["question", "answer"], inplace=True)
|
| 22 |
+
|
| 23 |
+
# Convert to list of dicts
|
| 24 |
+
cleaned_data = [
|
| 25 |
+
{
|
| 26 |
+
"title": row["question"],
|
| 27 |
+
"context": row["answer"],
|
| 28 |
+
"source": row["source"]
|
| 29 |
+
}
|
| 30 |
+
for _, row in df.iterrows()
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
# Save as JSON
|
| 34 |
+
with open(output_json_path, "w", encoding="utf-8") as f:
|
| 35 |
+
json.dump(cleaned_data, f, indent=2)
|
| 36 |
+
|
| 37 |
+
print(f"✅ Cleaned data saved to: {output_json_path}")
|
data/coaching_millionaer_dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
main.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from retriever.bm25_retriever import BM25Retriever
|
| 3 |
+
from qa.biobert_qa import BioBERTAnswerExtractor
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
# Initialize retriever and QA model
|
| 7 |
+
retriever = BM25Retriever("data/medquad_cleaned.json")
|
| 8 |
+
qa = BioBERTAnswerExtractor()
|
| 9 |
+
|
| 10 |
+
print("\n🩺 MedBot is ready! Type your question or 'exit' to quit.")
|
| 11 |
+
|
| 12 |
+
while True:
|
| 13 |
+
question = input("\nAsk a medical question: ").strip()
|
| 14 |
+
if question.lower() in {"exit", "quit"}:
|
| 15 |
+
print("👋 Goodbye!")
|
| 16 |
+
break
|
| 17 |
+
|
| 18 |
+
# Step 1: Retrieve top 3 passages
|
| 19 |
+
results = retriever.retrieve(question, top_k=3)
|
| 20 |
+
|
| 21 |
+
# Step 2: Run BioBERT on each passage
|
| 22 |
+
print("\n🔍 Best answers:")
|
| 23 |
+
for idx, item in enumerate(results, 1):
|
| 24 |
+
context = item["context"]
|
| 25 |
+
answer = qa.extract_answer(question, context)
|
| 26 |
+
print(f"\nResult {idx}")
|
| 27 |
+
print(f"Q: {item['title']}")
|
| 28 |
+
print(f"A: {answer}")
|
| 29 |
+
print(f"Source: {item['source']} (BM25 Score: {item['score']:.2f})")
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
model/1_Pooling/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"word_embedding_dimension": 384,
|
| 3 |
+
"pooling_mode_cls_token": false,
|
| 4 |
+
"pooling_mode_mean_tokens": true,
|
| 5 |
+
"pooling_mode_max_tokens": false,
|
| 6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
| 7 |
+
"pooling_mode_weightedmean_tokens": false,
|
| 8 |
+
"pooling_mode_lasttoken": false,
|
| 9 |
+
"include_prompt": true
|
| 10 |
+
}
|
model/README.md
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
library_name: sentence-transformers
|
| 5 |
+
tags:
|
| 6 |
+
- sentence-transformers
|
| 7 |
+
- feature-extraction
|
| 8 |
+
- sentence-similarity
|
| 9 |
+
- transformers
|
| 10 |
+
datasets:
|
| 11 |
+
- s2orc
|
| 12 |
+
- flax-sentence-embeddings/stackexchange_xml
|
| 13 |
+
- ms_marco
|
| 14 |
+
- gooaq
|
| 15 |
+
- yahoo_answers_topics
|
| 16 |
+
- code_search_net
|
| 17 |
+
- search_qa
|
| 18 |
+
- eli5
|
| 19 |
+
- snli
|
| 20 |
+
- multi_nli
|
| 21 |
+
- wikihow
|
| 22 |
+
- natural_questions
|
| 23 |
+
- trivia_qa
|
| 24 |
+
- embedding-data/sentence-compression
|
| 25 |
+
- embedding-data/flickr30k-captions
|
| 26 |
+
- embedding-data/altlex
|
| 27 |
+
- embedding-data/simple-wiki
|
| 28 |
+
- embedding-data/QQP
|
| 29 |
+
- embedding-data/SPECTER
|
| 30 |
+
- embedding-data/PAQ_pairs
|
| 31 |
+
- embedding-data/WikiAnswers
|
| 32 |
+
pipeline_tag: sentence-similarity
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# all-MiniLM-L6-v2
|
| 37 |
+
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
| 38 |
+
|
| 39 |
+
## Usage (Sentence-Transformers)
|
| 40 |
+
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
pip install -U sentence-transformers
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Then you can use the model like this:
|
| 47 |
+
```python
|
| 48 |
+
from sentence_transformers import SentenceTransformer
|
| 49 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
| 50 |
+
|
| 51 |
+
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 52 |
+
embeddings = model.encode(sentences)
|
| 53 |
+
print(embeddings)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Usage (HuggingFace Transformers)
|
| 57 |
+
Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from transformers import AutoTokenizer, AutoModel
|
| 61 |
+
import torch
|
| 62 |
+
import torch.nn.functional as F
|
| 63 |
+
|
| 64 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
| 65 |
+
def mean_pooling(model_output, attention_mask):
|
| 66 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
| 67 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 68 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Sentences we want sentence embeddings for
|
| 72 |
+
sentences = ['This is an example sentence', 'Each sentence is converted']
|
| 73 |
+
|
| 74 |
+
# Load model from HuggingFace Hub
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
| 76 |
+
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
| 77 |
+
|
| 78 |
+
# Tokenize sentences
|
| 79 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
| 80 |
+
|
| 81 |
+
# Compute token embeddings
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
model_output = model(**encoded_input)
|
| 84 |
+
|
| 85 |
+
# Perform pooling
|
| 86 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 87 |
+
|
| 88 |
+
# Normalize embeddings
|
| 89 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
| 90 |
+
|
| 91 |
+
print("Sentence embeddings:")
|
| 92 |
+
print(sentence_embeddings)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
------
|
| 96 |
+
|
| 97 |
+
## Background
|
| 98 |
+
|
| 99 |
+
The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
|
| 100 |
+
contrastive learning objective. We used the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model and fine-tuned in on a
|
| 101 |
+
1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
|
| 102 |
+
|
| 103 |
+
We developed this model during the
|
| 104 |
+
[Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
|
| 105 |
+
organized by Hugging Face. We developed this model as part of the project:
|
| 106 |
+
[Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
|
| 107 |
+
|
| 108 |
+
## Intended uses
|
| 109 |
+
|
| 110 |
+
Our model is intended to be used as a sentence and short paragraph encoder. Given an input text, it outputs a vector which captures
|
| 111 |
+
the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
|
| 112 |
+
|
| 113 |
+
By default, input text longer than 256 word pieces is truncated.
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
## Training procedure
|
| 117 |
+
|
| 118 |
+
### Pre-training
|
| 119 |
+
|
| 120 |
+
We use the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model. Please refer to the model card for more detailed information about the pre-training procedure.
|
| 121 |
+
|
| 122 |
+
### Fine-tuning
|
| 123 |
+
|
| 124 |
+
We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
|
| 125 |
+
We then apply the cross entropy loss by comparing with true pairs.
|
| 126 |
+
|
| 127 |
+
#### Hyper parameters
|
| 128 |
+
|
| 129 |
+
We trained our model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
|
| 130 |
+
We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
|
| 131 |
+
a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
|
| 132 |
+
|
| 133 |
+
#### Training data
|
| 134 |
+
|
| 135 |
+
We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
|
| 136 |
+
We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
| Dataset | Paper | Number of training tuples |
|
| 140 |
+
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
|
| 141 |
+
| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
|
| 142 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
|
| 143 |
+
| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
|
| 144 |
+
| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
|
| 145 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
|
| 146 |
+
| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
|
| 147 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
|
| 148 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
|
| 149 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
|
| 150 |
+
| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
|
| 151 |
+
| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
|
| 152 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
|
| 153 |
+
| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
|
| 154 |
+
| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
|
| 155 |
+
| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
|
| 156 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
|
| 157 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
|
| 158 |
+
| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
|
| 159 |
+
| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
|
| 160 |
+
| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
|
| 161 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
|
| 162 |
+
| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
|
| 163 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
|
| 164 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
|
| 165 |
+
| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
|
| 166 |
+
| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
|
| 167 |
+
| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
|
| 168 |
+
| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
|
| 169 |
+
| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
|
| 170 |
+
| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
|
| 171 |
+
| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
|
| 172 |
+
| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
|
| 173 |
+
| **Total** | | **1,170,060,424** |
|
model/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertModel"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"classifier_dropout": null,
|
| 7 |
+
"dtype": "float32",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 384,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 1536,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 6,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"position_embedding_type": "absolute",
|
| 21 |
+
"transformers_version": "4.56.2",
|
| 22 |
+
"type_vocab_size": 2,
|
| 23 |
+
"use_cache": true,
|
| 24 |
+
"vocab_size": 30522
|
| 25 |
+
}
|
model/config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"__version__": {
|
| 3 |
+
"sentence_transformers": "5.1.1",
|
| 4 |
+
"transformers": "4.56.2",
|
| 5 |
+
"pytorch": "2.8.0+cpu"
|
| 6 |
+
},
|
| 7 |
+
"model_type": "SentenceTransformer",
|
| 8 |
+
"prompts": {
|
| 9 |
+
"query": "",
|
| 10 |
+
"document": ""
|
| 11 |
+
},
|
| 12 |
+
"default_prompt_name": null,
|
| 13 |
+
"similarity_fn_name": "cosine"
|
| 14 |
+
}
|
model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1377e9af0ca0b016a9f2aa584d6fc71ab3ea6804fae21ef9fb1416e2944057ac
|
| 3 |
+
size 90864192
|
model/modules.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.models.Transformer"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"idx": 1,
|
| 10 |
+
"name": "1",
|
| 11 |
+
"path": "1_Pooling",
|
| 12 |
+
"type": "sentence_transformers.models.Pooling"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"idx": 2,
|
| 16 |
+
"name": "2",
|
| 17 |
+
"path": "2_Normalize",
|
| 18 |
+
"type": "sentence_transformers.models.Normalize"
|
| 19 |
+
}
|
| 20 |
+
]
|
model/sentence_bert_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_seq_length": 256,
|
| 3 |
+
"do_lower_case": false
|
| 4 |
+
}
|
model/special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": false,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_basic_tokenize": true,
|
| 47 |
+
"do_lower_case": true,
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "[MASK]",
|
| 50 |
+
"max_length": 128,
|
| 51 |
+
"model_max_length": 256,
|
| 52 |
+
"never_split": null,
|
| 53 |
+
"pad_to_multiple_of": null,
|
| 54 |
+
"pad_token": "[PAD]",
|
| 55 |
+
"pad_token_type_id": 0,
|
| 56 |
+
"padding_side": "right",
|
| 57 |
+
"sep_token": "[SEP]",
|
| 58 |
+
"stride": 0,
|
| 59 |
+
"strip_accents": null,
|
| 60 |
+
"tokenize_chinese_chars": true,
|
| 61 |
+
"tokenizer_class": "BertTokenizer",
|
| 62 |
+
"truncation_side": "right",
|
| 63 |
+
"truncation_strategy": "longest_first",
|
| 64 |
+
"unk_token": "[UNK]"
|
| 65 |
+
}
|
model/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pinecone_index.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
from pinecone import Pinecone
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
# === Load dataset ===
|
| 6 |
+
with open("data/coaching_millionaer_dataset.json", "r", encoding="utf-8") as f:
|
| 7 |
+
docs = json.load(f)
|
| 8 |
+
|
| 9 |
+
# === Init embedding model ===
|
| 10 |
+
model = SentenceTransformer("./model")
|
| 11 |
+
|
| 12 |
+
# === Init Pinecone ===
|
| 13 |
+
pc = Pinecone(api_key="pcsk_6FCjSE_FFtwDN4PEY5Q7pqKGqGsNgBQrH2Ut9xWcpr3oe1FA28VDPFqei4XtpXMCwb7zdX")
|
| 14 |
+
index = pc.Index("ebook")
|
| 15 |
+
|
| 16 |
+
# === Upload data ===
|
| 17 |
+
vectors = []
|
| 18 |
+
|
| 19 |
+
for i, doc in enumerate(docs):
|
| 20 |
+
# Handle multiple possible content keys safely
|
| 21 |
+
content = (
|
| 22 |
+
doc.get("content")
|
| 23 |
+
or doc.get("text")
|
| 24 |
+
or doc.get("context")
|
| 25 |
+
or doc.get("paragraph")
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if not content:
|
| 29 |
+
print(f"⚠️ Skipping item {i} (no text field found)")
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
emb = model.encode(content).tolist()
|
| 33 |
+
vectors.append((str(i), emb, {"page": doc.get("page"), "context": content}))
|
| 34 |
+
|
| 35 |
+
# Upload in batches
|
| 36 |
+
if len(vectors) >= 100:
|
| 37 |
+
index.upsert(vectors=vectors)
|
| 38 |
+
print(f"✅ Uploaded {i + 1} documents...")
|
| 39 |
+
vectors = []
|
| 40 |
+
|
| 41 |
+
# Upload remaining
|
| 42 |
+
if vectors:
|
| 43 |
+
index.upsert(vectors=vectors)
|
| 44 |
+
|
| 45 |
+
print("🎉 Upload complete! All documents added to Pinecone.")
|
qa/__pycache__/biobert_qa.cpython-311.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
qa/biobert_qa.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class BioBERTAnswerExtractor:
|
| 6 |
+
def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1-squad'):
|
| 7 |
+
print("⏳ Loading BioBERT model...")
|
| 8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 9 |
+
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
| 10 |
+
print("BioBERT model loaded.")
|
| 11 |
+
|
| 12 |
+
def extract_answer(self, question, context):
|
| 13 |
+
inputs = self.tokenizer.encode_plus(
|
| 14 |
+
question, context,
|
| 15 |
+
return_tensors='pt',
|
| 16 |
+
truncation=True,
|
| 17 |
+
max_length=512
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
outputs = self.model(**inputs)
|
| 22 |
+
start_scores = outputs.start_logits
|
| 23 |
+
end_scores = outputs.end_logits
|
| 24 |
+
|
| 25 |
+
start_idx = torch.argmax(start_scores)
|
| 26 |
+
end_idx = torch.argmax(end_scores)
|
| 27 |
+
|
| 28 |
+
if start_idx > end_idx:
|
| 29 |
+
return "" # invalid span
|
| 30 |
+
|
| 31 |
+
all_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
| 32 |
+
answer_tokens = all_tokens[start_idx:end_idx + 1]
|
| 33 |
+
answer = self.tokenizer.convert_tokens_to_string(answer_tokens).strip()
|
| 34 |
+
|
| 35 |
+
# Filter out junk answers
|
| 36 |
+
if not answer or answer.lower() in ["[cls]", "[sep]"] or len(answer) < 3:
|
| 37 |
+
return "" # signal to use fallback
|
| 38 |
+
|
| 39 |
+
return answer
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Example usage
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
qa = BioBERTAnswerExtractor()
|
| 45 |
+
question = "What are the symptoms of flu?"
|
| 46 |
+
context = "The flu can cause fever, cough, sore throat, muscle aches, fatigue, and chills."
|
| 47 |
+
answer = qa.extract_answer(question, context)
|
| 48 |
+
print(f"Answer: {answer}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask
|
| 2 |
+
flask-cors
|
| 3 |
+
sentence-transformers
|
| 4 |
+
pinecone-client
|
| 5 |
+
langdetect
|
| 6 |
+
googletrans==4.0.0-rc1
|
| 7 |
+
openai
|
| 8 |
+
python-dotenv
|
retriever/__pycache__/bm25_retriever.cpython-311.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
retriever/__pycache__/faiss_retriever.cpython-311.pyc
ADDED
|
Binary file (6.84 kB). View file
|
|
|
retriever/bm25_retriever.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
from rank_bm25 import BM25Okapi
|
| 4 |
+
import nltk
|
| 5 |
+
from nltk.tokenize import word_tokenize
|
| 6 |
+
|
| 7 |
+
nltk.download('punkt')
|
| 8 |
+
|
| 9 |
+
class BM25Retriever:
|
| 10 |
+
def __init__(self, json_path):
|
| 11 |
+
self.data = self.load_data(json_path)
|
| 12 |
+
self.contexts = [item["context"] for item in self.data]
|
| 13 |
+
self.tokenized_corpus = [word_tokenize(doc.lower()) for doc in self.contexts]
|
| 14 |
+
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 15 |
+
|
| 16 |
+
def load_data(self, path):
|
| 17 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 18 |
+
return json.load(f)
|
| 19 |
+
|
| 20 |
+
def retrieve(self, query, top_k=5):
|
| 21 |
+
tokenized_query = word_tokenize(query.lower())
|
| 22 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 23 |
+
top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
|
| 24 |
+
results = []
|
| 25 |
+
|
| 26 |
+
for i in top_k_indices:
|
| 27 |
+
item = self.data[i]
|
| 28 |
+
results.append({
|
| 29 |
+
"score": scores[i],
|
| 30 |
+
"title": item["title"],
|
| 31 |
+
"context": item["context"],
|
| 32 |
+
"source": item.get("source", "")
|
| 33 |
+
})
|
| 34 |
+
|
| 35 |
+
return results
|
| 36 |
+
|
| 37 |
+
# Example usage:
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
retriever = BM25Retriever("data/medquad_cleaned.json")
|
| 40 |
+
question = input("Ask a medical question: ")
|
| 41 |
+
results = retriever.retrieve(question)
|
| 42 |
+
|
| 43 |
+
for idx, result in enumerate(results, 1):
|
| 44 |
+
print(f"\nResult {idx}")
|
| 45 |
+
print(f"Score: {result['score']:.2f}")
|
| 46 |
+
print(f"Question: {result['title']}")
|
| 47 |
+
print(f"Answer: {result['context']}")
|
| 48 |
+
print(f"Source: {result['source']}")
|
retriever/faiss_retriever.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import faiss
|
| 5 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 6 |
+
from sklearn.preprocessing import normalize
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FAISSRetriever:
|
| 10 |
+
def __init__(self, data_path="data/coaching_millionaer_dataset.json"):
|
| 11 |
+
"""
|
| 12 |
+
Multilingual FAISS retriever for the 'Coaching Millionär' dataset.
|
| 13 |
+
Supports English and German queries.
|
| 14 |
+
"""
|
| 15 |
+
self.data_path = data_path
|
| 16 |
+
self.index_path = "data/faiss_index.bin"
|
| 17 |
+
self.meta_path = "data/faiss_metadata.json"
|
| 18 |
+
|
| 19 |
+
# ✅ multilingual model (English + German + 50+ languages)
|
| 20 |
+
self.model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
| 21 |
+
|
| 22 |
+
# optional reranker for better precision
|
| 23 |
+
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
| 24 |
+
|
| 25 |
+
# Load existing FAISS index or build new one
|
| 26 |
+
if os.path.exists(self.index_path) and os.path.exists(self.meta_path):
|
| 27 |
+
self.index = faiss.read_index(self.index_path)
|
| 28 |
+
with open(self.meta_path, "r", encoding="utf-8") as f:
|
| 29 |
+
self.metadata = json.load(f)
|
| 30 |
+
print("✅ Loaded existing FAISS index.")
|
| 31 |
+
else:
|
| 32 |
+
self._build_index()
|
| 33 |
+
|
| 34 |
+
def _build_index(self):
|
| 35 |
+
"""Build and save FAISS index from dataset."""
|
| 36 |
+
with open(self.data_path, "r", encoding="utf-8") as f:
|
| 37 |
+
dataset = json.load(f)
|
| 38 |
+
|
| 39 |
+
texts = [item["text"] for item in dataset]
|
| 40 |
+
embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 41 |
+
embeddings = normalize(embeddings)
|
| 42 |
+
|
| 43 |
+
self.index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 44 |
+
self.index.add(embeddings)
|
| 45 |
+
|
| 46 |
+
self.metadata = dataset
|
| 47 |
+
os.makedirs("data", exist_ok=True)
|
| 48 |
+
faiss.write_index(self.index, self.index_path)
|
| 49 |
+
with open(self.meta_path, "w", encoding="utf-8") as f:
|
| 50 |
+
json.dump(self.metadata, f, ensure_ascii=False)
|
| 51 |
+
|
| 52 |
+
print(f"✅ Built new FAISS index from {len(texts)} passages.")
|
| 53 |
+
|
| 54 |
+
def retrieve(self, question, top_k=10):
|
| 55 |
+
"""
|
| 56 |
+
Retrieve relevant passages from the FAISS index.
|
| 57 |
+
Automatically boosts results mentioning key entities like 'Javid Niazi-Hoffmann'.
|
| 58 |
+
"""
|
| 59 |
+
query_vec = self.model.encode([question], convert_to_numpy=True)
|
| 60 |
+
query_vec = normalize(query_vec)
|
| 61 |
+
|
| 62 |
+
scores, indices = self.index.search(query_vec, top_k)
|
| 63 |
+
results = []
|
| 64 |
+
|
| 65 |
+
# small keyword boost for known entities
|
| 66 |
+
boost_keywords = ["Javid", "Niazi", "Hoffmann", "Coaching", "Millionär"]
|
| 67 |
+
for idx, score in zip(indices[0], scores[0]):
|
| 68 |
+
if idx < len(self.metadata):
|
| 69 |
+
item = self.metadata[idx]
|
| 70 |
+
text = item["text"]
|
| 71 |
+
boost = any(k.lower() in text.lower() for k in boost_keywords)
|
| 72 |
+
final_score = float(score * 100 + (5 if boost else 0))
|
| 73 |
+
results.append({
|
| 74 |
+
"page": item.get("page", ""),
|
| 75 |
+
"context": text,
|
| 76 |
+
"score": final_score
|
| 77 |
+
})
|
| 78 |
+
|
| 79 |
+
# ✅ Rerank using cross-encoder for higher accuracy
|
| 80 |
+
if results:
|
| 81 |
+
pairs = [(question, r["context"]) for r in results]
|
| 82 |
+
rerank_scores = self.reranker.predict(pairs)
|
| 83 |
+
results = [
|
| 84 |
+
{**r, "rerank_score": float(s)}
|
| 85 |
+
for r, s in zip(results, rerank_scores)
|
| 86 |
+
]
|
| 87 |
+
results = sorted(results, key=lambda x: x["rerank_score"], reverse=True)[:top_k]
|
| 88 |
+
|
| 89 |
+
return results
|
retriever/pinecone_retriever.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pinecone import Pinecone
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
|
| 4 |
+
class PineconeRetriever:
|
| 5 |
+
def __init__(self, api_key: str, index_name: str):
|
| 6 |
+
self.model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L3-v2")
|
| 7 |
+
self.pinecone = Pinecone(api_key=api_key)
|
| 8 |
+
self.index = self.pinecone.Index(index_name)
|
| 9 |
+
|
| 10 |
+
def retrieve(self, query: str, top_k: int = 5):
|
| 11 |
+
query_emb = self.model.encode(query).tolist()
|
| 12 |
+
results = self.index.query(vector=query_emb, top_k=top_k, include_metadata=True)
|
| 13 |
+
matches = results.get("matches", [])
|
| 14 |
+
docs = []
|
| 15 |
+
for match in matches:
|
| 16 |
+
meta = match["metadata"]
|
| 17 |
+
docs.append({
|
| 18 |
+
"content": meta.get("context", ""),
|
| 19 |
+
"page": meta.get("page"),
|
| 20 |
+
"score": match.get("score")
|
| 21 |
+
})
|
| 22 |
+
return docs
|