farmerbot / create_retriever.py
Nelly-43's picture
Update create_retriever.py
21defd4 verified
import os
import glob
from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader, CSVLoader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.retrievers import EnsembleRetriever
# from ragatouille import RAGPretrainedModel
# Function to load and process documents
def docs_return(flag):
directory_path = 'rag_data/'
docx_file_pattern = '*.docx'
pdf_file_pattern = '*.pdf'
txt_file_pattern = '*.txt'
csv_file_pattern = '*.csv'
docx_file_paths = glob.glob(directory_path + docx_file_pattern)
pdf_file_paths = glob.glob(directory_path + pdf_file_pattern)
txt_file_paths = glob.glob(directory_path + txt_file_pattern)
csv_file_paths = glob.glob(directory_path + csv_file_pattern)
all_doc, all_doc2 = [], []
for x in docx_file_paths:
loader = Docx2txtLoader(x)
documents = loader.load()
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
for x in pdf_file_paths:
loader = PyPDFLoader(x, extract_images=True)
docs_lazy = loader.lazy_load()
documents = []
for doc in docs_lazy:
documents.append(doc)
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
for x in txt_file_paths:
loader = TextLoader(x)
documents = loader.load()
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
for x in csv_file_paths:
loader = CSVLoader(file_path=x, source_column="translation")
documents = loader.load()
all_doc.extend(documents)
all_doc2.append(str(documents[0].page_content))
docs = '\n\n'.join(all_doc2)
return all_doc if flag == 0 else docs
# Function to get or download the embedding model
def get_embedding_model(model_name):
local_model_path = f"embedding_model/{model_name.replace('/', '_')}"
if os.path.exists(local_model_path):
print(f"Loading local model from {local_model_path}")
return HuggingFaceEmbeddings(model_name=local_model_path)
else:
print(f"Downloading model {model_name}")
return HuggingFaceEmbeddings(model_name=model_name)
# Function to return different types of text splitters
def get_text_splitter(splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
if splitter_type == 'character':
return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
elif splitter_type == 'recursive':
return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
elif splitter_type == 'token':
return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap)
else:
raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.")
# Retriever using Chroma and HuggingFace embeddings
def retriever_chroma(flag, model_name="sentence-transformers/all-mpnet-base-v2", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000):
# Load or download the embedding model
embeddings = get_embedding_model(model_name)
if not flag:
# Load the documents
all_doc = docs_return(0)
# Use the splitter parameters
text_splitter = get_text_splitter(splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens)
# Split the documents using the text splitter
docs = text_splitter.split_documents(documents=all_doc)
# Create a Chroma vector database
vectordb = Chroma.from_documents(all_doc, embeddings, persist_directory="./chroma_db")
# print(all_doc)
# Create the retriever
chroma_retriever = vectordb.as_retriever()
return chroma_retriever
else:
# Load a local Chroma vectorstore
vectordb = Chroma.load_local("vectorstore", embeddings)
chroma_retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10}
)
return chroma_retriever