|
|
import os |
|
|
import torch |
|
|
import pandas as pd |
|
|
import transformers |
|
|
from pynvml import * |
|
|
import torch |
|
|
from langchain import hub |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from model_ret import load_model_and_pipeline |
|
|
from create_retriever import retriever_chroma |
|
|
|
|
|
|
|
|
class model_chain: |
|
|
model_name = "" |
|
|
|
|
|
def __init__(self, |
|
|
model_name_local, |
|
|
model_name_online="Llama", |
|
|
use_online=True, |
|
|
embedding_name="sentence-transformers/all-mpnet-base-v2", |
|
|
splitter_type_dropdown="character", |
|
|
chunk_size_slider=512, |
|
|
chunk_overlap_slider=30, |
|
|
separator_textbox="\n", |
|
|
max_tokens_slider=2048) -> None: |
|
|
if os.path.exists(f"models//{model_name_local}") and len(os.listdir(f"models//{model_name_local}")): |
|
|
import gradio as gr |
|
|
gr.Info("Model *()* from online!!") |
|
|
self.model_name = model_name_local |
|
|
else: |
|
|
self.model_name = model_name_online |
|
|
|
|
|
self.tokenizer, self.model, self.llm = load_model_and_pipeline(self.model_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.retriever = retriever_chroma(False, embedding_name, splitter_type_dropdown, |
|
|
chunk_size_slider, chunk_size_slider, |
|
|
separator_textbox, max_tokens_slider) |
|
|
|
|
|
|
|
|
prompt = hub.pull("rlm/rag-prompt") |
|
|
self.rag_chain = ( |
|
|
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()} |
|
|
| prompt |
|
|
| self.llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
|
|
|
|
|
|
def format_docs(self, docs): |
|
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
def rag_chain_ret(self): |
|
|
return self.rag_chain |
|
|
|
|
|
|
|
|
def ans_ret(self, inp): |
|
|
if self.model_name == 'Flant5': |
|
|
my_question = "What is KUET?" |
|
|
data = self.retriever.invoke(inp) |
|
|
context = "" |
|
|
for x in data[:2]: |
|
|
context += (x.page_content) + "\n" |
|
|
inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}""" |
|
|
inputs = self.tokenizer(inputs, return_tensors="pt") |
|
|
outputs = self.model.generate(**inputs) |
|
|
answer = self.tokenizer.decode(outputs[0]) |
|
|
from textwrap import fill |
|
|
ans = fill(answer, width=100) |
|
|
return ans |
|
|
|
|
|
ans = self.rag_chain.invoke(inp) |
|
|
ans = ans.split("Answer:")[1] |
|
|
return ans |