farmerbot / inference.py
Nelly-43's picture
Update inference.py
fca1e2c verified
import os
import torch
import pandas as pd
import transformers
from pynvml import *
import torch
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from model_ret import load_model_and_pipeline
from create_retriever import retriever_chroma
# Model chain class
class model_chain:
model_name = ""
def __init__(self,
model_name_local,
model_name_online="Llama",
use_online=True,
embedding_name="sentence-transformers/all-mpnet-base-v2",
splitter_type_dropdown="character",
chunk_size_slider=512,
chunk_overlap_slider=30,
separator_textbox="\n",
max_tokens_slider=2048) -> None:
if os.path.exists(f"models//{model_name_local}") and len(os.listdir(f"models//{model_name_local}")):
import gradio as gr
gr.Info("Model *()* from online!!")
self.model_name = model_name_local
else:
self.model_name = model_name_online
self.tokenizer, self.model, self.llm = load_model_and_pipeline(self.model_name)
# Creating the retriever
# self.retriever = ensemble_retriever(embedding_name,
# splitter_type=splitter_type_dropdown,
# chunk_size=chunk_size_slider,
# chunk_overlap=chunk_overlap_slider,
# separator=separator_textbox,
# max_tokens=max_tokens_slider)
self.retriever = retriever_chroma(False, embedding_name, splitter_type_dropdown,
chunk_size_slider, chunk_size_slider,
separator_textbox, max_tokens_slider)
# Defining the RAG chain
prompt = hub.pull("rlm/rag-prompt")
self.rag_chain = (
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
)
# Helper function to format documents
def format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
# Retrieve RAG chain
def rag_chain_ret(self):
return self.rag_chain
# Answer retrieval function
def ans_ret(self, inp):
if self.model_name == 'Flant5':
my_question = "What is KUET?"
data = self.retriever.invoke(inp)
context = ""
for x in data[:2]:
context += (x.page_content) + "\n"
inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}"""
inputs = self.tokenizer(inputs, return_tensors="pt")
outputs = self.model.generate(**inputs)
answer = self.tokenizer.decode(outputs[0])
from textwrap import fill
ans = fill(answer, width=100)
return ans
ans = self.rag_chain.invoke(inp)
ans = ans.split("Answer:")[1]
return ans