import os import numpy as np import gradio as gr from supabase import create_client, Client from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient # ------------------ Cached singletons (efficient) ------------------ _supabase: Client | None = None _embedder: SentenceTransformer | None = None _llm_client: InferenceClient | None = None def get_supabase_client() -> Client: global _supabase if _supabase is None: url = os.environ["SUPABASE_URL"] key = os.environ["SUPABASE_KEY"] _supabase = create_client(url, key) return _supabase def get_embedder() -> SentenceTransformer: global _embedder if _embedder is None: _embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2") return _embedder def get_llm_client() -> InferenceClient: global _llm_client if _llm_client is None: hf_token = os.environ["HF_TOKEN"] _llm_client = InferenceClient( model="meta-llama/Llama-3.1-8B-Instruct", token=hf_token, ) return _llm_client # ------------------ Core RAG functions ------------------ def retrieve_context_from_supabase(question: str, k: int = 5): supabase = get_supabase_client() embedder = get_embedder() q_vec = embedder.encode([question], normalize_embeddings=True).astype(np.float32)[0] resp = supabase.rpc( "match_pregnancy_chunks", { "query_embedding": q_vec.tolist(), "match_count": int(k), }, ).execute() rows = resp.data or [] context = "\n\n".join([r.get("content", "") for r in rows]) return context, rows def answer_with_llama(question: str, k: int = 5, show_context: bool = False): question = (question or "").strip() if not question: return "Please enter a question.", "" context, rows = retrieve_context_from_supabase(question, k=k) if not context.strip(): return "I couldn't find relevant context in the knowledge base.", "" llm_client = get_llm_client() prompt = ( "Answer the question based only on the context below.\n\n" f"Context:\n{context}\n\n" f"Question: {question}\nAnswer:" ) response = llm_client.chat_completion( model="meta-llama/Llama-3.1-8B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant who answers based only on the given context."}, {"role": "user", "content": prompt}, ], max_tokens=300, temperature=0.3, ) answer = response.choices[0].message["content"] # Optional: display retrieved chunks for transparency/debugging if show_context: formatted = [] for i, r in enumerate(rows, start=1): meta = f"Chunk {i} | source={r.get('source')} | pages={r.get('page_range')} | idx={r.get('chunk_index')} | sim={r.get('similarity')}" formatted.append(meta + "\n" + (r.get("content") or "")) debug_text = "\n\n" + ("-" * 60) + "\n\n".join(formatted) else: debug_text = "" return answer, debug_text # ------------------ Gradio UI ------------------ with gr.Blocks(title="Pregnancy Book Q&A (RAG POC)") as demo: gr.Markdown("## 👶 Pregnancy Book Q&A (RAG POC)\nAsk a question. The app retrieves the most relevant chunks from Supabase (pgvector) and answers using Llama 3.1 8B on Hugging Face.") with gr.Row(): question_in = gr.Textbox(label="Your question", value="What are pain interventions?", lines=2) with gr.Row(): k_in = gr.Slider(minimum=3, maximum=10, value=5, step=1, label="Number of chunks to retrieve (top-k)") show_ctx = gr.Checkbox(label="Show retrieved chunks (debug)", value=False) btn = gr.Button("Get answer") answer_out = gr.Textbox(label="Answer", lines=8) ctx_out = gr.Textbox(label="Retrieved chunks", lines=16, visible=True) btn.click( fn=answer_with_llama, inputs=[question_in, k_in, show_ctx], outputs=[answer_out, ctx_out], ) demo.launch()