File size: 4,100 Bytes
1fc9774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import numpy as np
import gradio as gr
from supabase import create_client, Client
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient


# ------------------ Cached singletons (efficient) ------------------

_supabase: Client | None = None
_embedder: SentenceTransformer | None = None
_llm_client: InferenceClient | None = None


def get_supabase_client() -> Client:
    global _supabase
    if _supabase is None:
        url = os.environ["SUPABASE_URL"]
        key = os.environ["SUPABASE_KEY"]
        _supabase = create_client(url, key)
    return _supabase


def get_embedder() -> SentenceTransformer:
    global _embedder
    if _embedder is None:
        _embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")
    return _embedder


def get_llm_client() -> InferenceClient:
    global _llm_client
    if _llm_client is None:
        hf_token = os.environ["HF_TOKEN"]
        _llm_client = InferenceClient(
            model="meta-llama/Llama-3.1-8B-Instruct",
            token=hf_token,
        )
    return _llm_client


# ------------------ Core RAG functions ------------------

def retrieve_context_from_supabase(question: str, k: int = 5):
    supabase = get_supabase_client()
    embedder = get_embedder()

    q_vec = embedder.encode([question], normalize_embeddings=True).astype(np.float32)[0]

    resp = supabase.rpc(
        "match_pregnancy_chunks",
        {
            "query_embedding": q_vec.tolist(),
            "match_count": int(k),
        },
    ).execute()

    rows = resp.data or []
    context = "\n\n".join([r.get("content", "") for r in rows])
    return context, rows


def answer_with_llama(question: str, k: int = 5, show_context: bool = False):
    question = (question or "").strip()
    if not question:
        return "Please enter a question.", ""

    context, rows = retrieve_context_from_supabase(question, k=k)

    if not context.strip():
        return "I couldn't find relevant context in the knowledge base.", ""

    llm_client = get_llm_client()

    prompt = (
        "Answer the question based only on the context below.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {question}\nAnswer:"
    )

    response = llm_client.chat_completion(
        model="meta-llama/Llama-3.1-8B-Instruct",
        messages=[
            {"role": "system", "content": "You are a helpful assistant who answers based only on the given context."},
            {"role": "user", "content": prompt},
        ],
        max_tokens=300,
        temperature=0.3,
    )

    answer = response.choices[0].message["content"]

    # Optional: display retrieved chunks for transparency/debugging
    if show_context:
        formatted = []
        for i, r in enumerate(rows, start=1):
            meta = f"Chunk {i} | source={r.get('source')} | pages={r.get('page_range')} | idx={r.get('chunk_index')} | sim={r.get('similarity')}"
            formatted.append(meta + "\n" + (r.get("content") or ""))
        debug_text = "\n\n" + ("-" * 60) + "\n\n".join(formatted)
    else:
        debug_text = ""

    return answer, debug_text


# ------------------ Gradio UI ------------------

with gr.Blocks(title="Pregnancy Book Q&A (RAG POC)") as demo:
    gr.Markdown("## 👶 Pregnancy Book Q&A (RAG POC)\nAsk a question. The app retrieves the most relevant chunks from Supabase (pgvector) and answers using Llama 3.1 8B on Hugging Face.")

    with gr.Row():
        question_in = gr.Textbox(label="Your question", value="What are pain interventions?", lines=2)
    with gr.Row():
        k_in = gr.Slider(minimum=3, maximum=10, value=5, step=1, label="Number of chunks to retrieve (top-k)")
        show_ctx = gr.Checkbox(label="Show retrieved chunks (debug)", value=False)

    btn = gr.Button("Get answer")

    answer_out = gr.Textbox(label="Answer", lines=8)
    ctx_out = gr.Textbox(label="Retrieved chunks", lines=16, visible=True)

    btn.click(
        fn=answer_with_llama,
        inputs=[question_in, k_in, show_ctx],
        outputs=[answer_out, ctx_out],
    )

demo.launch()