Jaswanth-0821 commited on
Commit
d521ae9
·
verified ·
1 Parent(s): 16adbe8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -35
app.py CHANGED
@@ -2,11 +2,11 @@ import gradio as gr
2
  from sentence_transformers import SentenceTransformer
3
  import torch
4
 
5
- # Cache loaded models to avoid reloading
6
  loaded_models = {}
7
 
8
  def load_model(model_name):
9
- """Load and cache a model."""
10
  if model_name in loaded_models:
11
  return loaded_models[model_name]
12
  model = SentenceTransformer(model_name)
@@ -20,16 +20,16 @@ def find_similar_documents(query, documents, model_name):
20
  return "⚠️ Please enter a query."
21
  if not documents.strip():
22
  return "⚠️ Please enter documents (one per line)."
23
-
24
  model = load_model(model_name)
25
  doc_list = [d.strip() for d in documents.split("\n") if d.strip()]
26
  if not doc_list:
27
  return "⚠️ Please enter at least one valid document."
28
-
29
  query_emb = model.encode_query(query)
30
  doc_emb = model.encode_document(doc_list)
31
  similarities = model.similarity(query_emb, doc_emb)
32
-
33
  sorted_idx = torch.argsort(similarities[0], descending=True)
34
  results = []
35
  for i, idx in enumerate(sorted_idx):
@@ -44,7 +44,7 @@ def compare_models(query, documents, tarka_model, open_model):
44
  return "⚠️ Please enter a query.", ""
45
  if not documents.strip():
46
  return "⚠️ Please enter documents (one per line).", ""
47
-
48
  tarka = load_model(tarka_model)
49
  openm = load_model(open_model)
50
 
@@ -52,7 +52,6 @@ def compare_models(query, documents, tarka_model, open_model):
52
  if not doc_list:
53
  return "⚠️ Please enter at least one valid document.", ""
54
 
55
- # Compute similarities for both models
56
  tq = tarka.encode_query(query)
57
  td = tarka.encode_document(doc_list)
58
  tsim = tarka.similarity(tq, td)
@@ -61,31 +60,40 @@ def compare_models(query, documents, tarka_model, open_model):
61
  od = openm.encode_document(doc_list)
62
  osim = openm.similarity(oq, od)
63
 
64
- # Sort for each model
65
  tsorted = torch.argsort(tsim[0], descending=True)
66
  osorted = torch.argsort(osim[0], descending=True)
67
 
68
- tarka_results, open_results = [], []
69
- for i, idx in enumerate(tsorted):
70
- tarka_results.append(f"**{i+1}. (Score: {tsim[0][idx]:.4f})**\n{doc_list[idx]}")
 
 
 
 
 
 
71
 
72
- for i, idx in enumerate(osorted):
73
- open_results.append(f"**{i+1}. (Score: {osim[0][idx]:.4f})**\n{doc_list[idx]}")
74
 
75
- return "\n\n".join(tarka_results), "\n\n".join(open_results)
76
 
77
 
78
- # --- UI Layout ---
79
  with gr.Blocks(
80
  title="Document Similarity Explorer",
81
- theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo", neutral_hue="zinc")
82
  ) as demo:
83
 
84
- gr.Markdown("# 🔍 Document Similarity Explorer")
85
- gr.Markdown("Compare document relevance across embedding models easily.")
86
-
 
 
 
 
87
  with gr.Tabs():
88
- # ----------------- SINGLE MODEL TAB -----------------
89
  with gr.Tab("Single Model Search"):
90
  with gr.Row():
91
  with gr.Column(scale=1):
@@ -100,21 +108,21 @@ with gr.Blocks(
100
  value="Tarka-AIR/Tarka-Embedding-150M-V1"
101
  )
102
  loading_msg = gr.Markdown(visible=False)
103
-
104
  query_input = gr.Textbox(
105
  label="Query",
106
  placeholder="Enter your search query...",
107
  lines=2
108
  )
109
-
110
  docs_input = gr.Textbox(
111
  label="Documents",
112
  placeholder="Enter one document per line...",
113
  lines=10
114
  )
115
-
116
  search_btn = gr.Button("🔎 Search", variant="primary")
117
-
118
  with gr.Column(scale=1):
119
  result_box = gr.Markdown(label="Results", elem_id="results-box")
120
 
@@ -122,7 +130,7 @@ with gr.Blocks(
122
  loading_msg.update(value=f"⏳ Loading **{model_name}**...", visible=True)
123
  load_model(model_name)
124
  return gr.update(value=f"✅ {model_name} loaded and ready!", visible=True)
125
-
126
  model_selector.change(fn=on_model_change, inputs=[model_selector], outputs=[loading_msg])
127
 
128
  search_btn.click(fn=find_similar_documents,
@@ -133,16 +141,16 @@ with gr.Blocks(
133
  inputs=[query_input, docs_input, model_selector],
134
  outputs=result_box)
135
 
136
- # ----------------- COMPARISON TAB -----------------
137
  with gr.Tab("Compare Models"):
 
 
138
  with gr.Row():
139
  with gr.Column(scale=1):
140
  tarka_selector = gr.Dropdown(
141
  label="Tarka Model",
142
  choices=[
143
  "Tarka-AIR/Tarka-Embedding-150M-V1",
144
- "Tarka-AIR/Tarka-Embedding-200M-V1",
145
- "Tarka-AIR/Tarka-Embedding-300M-V1"
146
  ],
147
  value="Tarka-AIR/Tarka-Embedding-150M-V1"
148
  )
@@ -173,9 +181,10 @@ with gr.Blocks(
173
  compare_btn = gr.Button("⚖️ Compare Models", variant="primary")
174
 
175
  with gr.Column(scale=2):
176
- with gr.Row():
177
- tarka_output = gr.Markdown(label="Tarka Model Results")
178
- open_output = gr.Markdown(label="Open Source Model Results")
 
179
 
180
  def on_compare_models_load(tarka_model, open_model):
181
  compare_loading.update(value=f"⏳ Loading **{tarka_model}** and **{open_model}**...", visible=True)
@@ -198,16 +207,15 @@ with gr.Blocks(
198
  inputs=[query_compare, docs_compare, tarka_selector, open_selector],
199
  outputs=[tarka_output, open_output])
200
 
201
- # Example block for both tabs
202
  gr.Examples(
203
  examples=[
204
  [
205
  "Which planet is known as the Red Planet?",
206
- "Venus is Earth's twin.\nMars, the Red Planet.\nJupiter is the biggest.\nSaturn has rings.",
207
- "Tarka-AIR/Tarka-Embedding-150M-V1"
208
  ]
209
  ],
210
- inputs=[query_input, docs_input, model_selector],
211
  label="Try Example"
212
  )
213
 
 
2
  from sentence_transformers import SentenceTransformer
3
  import torch
4
 
5
+ # Cache loaded models
6
  loaded_models = {}
7
 
8
  def load_model(model_name):
9
+ """Load and cache a SentenceTransformer model."""
10
  if model_name in loaded_models:
11
  return loaded_models[model_name]
12
  model = SentenceTransformer(model_name)
 
20
  return "⚠️ Please enter a query."
21
  if not documents.strip():
22
  return "⚠️ Please enter documents (one per line)."
23
+
24
  model = load_model(model_name)
25
  doc_list = [d.strip() for d in documents.split("\n") if d.strip()]
26
  if not doc_list:
27
  return "⚠️ Please enter at least one valid document."
28
+
29
  query_emb = model.encode_query(query)
30
  doc_emb = model.encode_document(doc_list)
31
  similarities = model.similarity(query_emb, doc_emb)
32
+
33
  sorted_idx = torch.argsort(similarities[0], descending=True)
34
  results = []
35
  for i, idx in enumerate(sorted_idx):
 
44
  return "⚠️ Please enter a query.", ""
45
  if not documents.strip():
46
  return "⚠️ Please enter documents (one per line).", ""
47
+
48
  tarka = load_model(tarka_model)
49
  openm = load_model(open_model)
50
 
 
52
  if not doc_list:
53
  return "⚠️ Please enter at least one valid document.", ""
54
 
 
55
  tq = tarka.encode_query(query)
56
  td = tarka.encode_document(doc_list)
57
  tsim = tarka.similarity(tq, td)
 
60
  od = openm.encode_document(doc_list)
61
  osim = openm.similarity(oq, od)
62
 
 
63
  tsorted = torch.argsort(tsim[0], descending=True)
64
  osorted = torch.argsort(osim[0], descending=True)
65
 
66
+ # Make them look like cards
67
+ def format_result(sorted_indices, sims, model_label):
68
+ res = [
69
+ f"<div style='background-color:#f9fafb;border-radius:12px;padding:10px 14px;margin-bottom:10px;border:1px solid #e5e7eb;'>"
70
+ f"<b>{i+1}. (Score: {sims[0][idx]:.4f})</b><br>{doc_list[idx]}"
71
+ f"</div>"
72
+ for i, idx in enumerate(sorted_indices)
73
+ ]
74
+ return f"<div style='font-family:Inter,sans-serif;font-size:15px;line-height:1.5;'>{''.join(res)}</div>"
75
 
76
+ tarka_html = format_result(tsorted, tsim, "Tarka Model")
77
+ open_html = format_result(osorted, osim, "Open Model")
78
 
79
+ return tarka_html, open_html
80
 
81
 
82
+ # --------------------------- UI ---------------------------
83
  with gr.Blocks(
84
  title="Document Similarity Explorer",
85
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo", neutral_hue="zinc", font=[gr.themes.GoogleFont("Inter"), "Inter", "sans-serif"]),
86
  ) as demo:
87
 
88
+ gr.Markdown(
89
+ """
90
+ # 🧠 Tarka Embedding Model Playground
91
+ Experiment with Tarka-AIR’s embedding family for semantic search and compare performance with open-source baselines.
92
+ """,
93
+ )
94
+
95
  with gr.Tabs():
96
+ # ---------------- SINGLE MODEL SEARCH ----------------
97
  with gr.Tab("Single Model Search"):
98
  with gr.Row():
99
  with gr.Column(scale=1):
 
108
  value="Tarka-AIR/Tarka-Embedding-150M-V1"
109
  )
110
  loading_msg = gr.Markdown(visible=False)
111
+
112
  query_input = gr.Textbox(
113
  label="Query",
114
  placeholder="Enter your search query...",
115
  lines=2
116
  )
117
+
118
  docs_input = gr.Textbox(
119
  label="Documents",
120
  placeholder="Enter one document per line...",
121
  lines=10
122
  )
123
+
124
  search_btn = gr.Button("🔎 Search", variant="primary")
125
+
126
  with gr.Column(scale=1):
127
  result_box = gr.Markdown(label="Results", elem_id="results-box")
128
 
 
130
  loading_msg.update(value=f"⏳ Loading **{model_name}**...", visible=True)
131
  load_model(model_name)
132
  return gr.update(value=f"✅ {model_name} loaded and ready!", visible=True)
133
+
134
  model_selector.change(fn=on_model_change, inputs=[model_selector], outputs=[loading_msg])
135
 
136
  search_btn.click(fn=find_similar_documents,
 
141
  inputs=[query_input, docs_input, model_selector],
142
  outputs=result_box)
143
 
144
+ # ---------------- MODEL COMPARISON ----------------
145
  with gr.Tab("Compare Models"):
146
+ gr.Markdown("### ⚖️ Compare how different models rank the same documents")
147
+
148
  with gr.Row():
149
  with gr.Column(scale=1):
150
  tarka_selector = gr.Dropdown(
151
  label="Tarka Model",
152
  choices=[
153
  "Tarka-AIR/Tarka-Embedding-150M-V1",
 
 
154
  ],
155
  value="Tarka-AIR/Tarka-Embedding-150M-V1"
156
  )
 
181
  compare_btn = gr.Button("⚖️ Compare Models", variant="primary")
182
 
183
  with gr.Column(scale=2):
184
+ gr.Markdown("#### 📊 Comparison Results")
185
+ with gr.Row(equal_height=True):
186
+ tarka_output = gr.HTML(label="Tarka Model Results")
187
+ open_output = gr.HTML(label="Open Source Model Results")
188
 
189
  def on_compare_models_load(tarka_model, open_model):
190
  compare_loading.update(value=f"⏳ Loading **{tarka_model}** and **{open_model}**...", visible=True)
 
207
  inputs=[query_compare, docs_compare, tarka_selector, open_selector],
208
  outputs=[tarka_output, open_output])
209
 
210
+ # ---------------- Example Section ----------------
211
  gr.Examples(
212
  examples=[
213
  [
214
  "Which planet is known as the Red Planet?",
215
+ "Venus is Earth's twin.\nMars, the Red Planet.\nJupiter is the biggest.\nSaturn has rings."
 
216
  ]
217
  ],
218
+ inputs=[query_input, docs_input],
219
  label="Try Example"
220
  )
221