|
|
import os |
|
|
import glob |
|
|
from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader, CSVLoader |
|
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.vectorstores import Chroma |
|
|
from langchain.retrievers import EnsembleRetriever |
|
|
|
|
|
|
|
|
|
|
|
def docs_return(flag): |
|
|
directory_path = 'rag_data/' |
|
|
docx_file_pattern = '*.docx' |
|
|
pdf_file_pattern = '*.pdf' |
|
|
txt_file_pattern = '*.txt' |
|
|
csv_file_pattern = '*.csv' |
|
|
|
|
|
docx_file_paths = glob.glob(directory_path + docx_file_pattern) |
|
|
pdf_file_paths = glob.glob(directory_path + pdf_file_pattern) |
|
|
txt_file_paths = glob.glob(directory_path + txt_file_pattern) |
|
|
csv_file_paths = glob.glob(directory_path + csv_file_pattern) |
|
|
|
|
|
all_doc, all_doc2 = [], [] |
|
|
|
|
|
for x in docx_file_paths: |
|
|
loader = Docx2txtLoader(x) |
|
|
documents = loader.load() |
|
|
all_doc.extend(documents) |
|
|
all_doc2.append(str(documents[0].page_content)) |
|
|
|
|
|
for x in pdf_file_paths: |
|
|
loader = PyPDFLoader(x, extract_images=True) |
|
|
docs_lazy = loader.lazy_load() |
|
|
documents = [] |
|
|
for doc in docs_lazy: |
|
|
documents.append(doc) |
|
|
all_doc.extend(documents) |
|
|
all_doc2.append(str(documents[0].page_content)) |
|
|
|
|
|
for x in txt_file_paths: |
|
|
loader = TextLoader(x) |
|
|
documents = loader.load() |
|
|
all_doc.extend(documents) |
|
|
all_doc2.append(str(documents[0].page_content)) |
|
|
|
|
|
for x in csv_file_paths: |
|
|
loader = CSVLoader(file_path=x, source_column="translation") |
|
|
documents = loader.load() |
|
|
all_doc.extend(documents) |
|
|
all_doc2.append(str(documents[0].page_content)) |
|
|
|
|
|
docs = '\n\n'.join(all_doc2) |
|
|
|
|
|
return all_doc if flag == 0 else docs |
|
|
|
|
|
|
|
|
def get_embedding_model(model_name): |
|
|
local_model_path = f"embedding_model/{model_name.replace('/', '_')}" |
|
|
if os.path.exists(local_model_path): |
|
|
print(f"Loading local model from {local_model_path}") |
|
|
return HuggingFaceEmbeddings(model_name=local_model_path) |
|
|
else: |
|
|
print(f"Downloading model {model_name}") |
|
|
return HuggingFaceEmbeddings(model_name=model_name) |
|
|
|
|
|
|
|
|
def get_text_splitter(splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000): |
|
|
if splitter_type == 'character': |
|
|
return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator) |
|
|
elif splitter_type == 'recursive': |
|
|
return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
|
elif splitter_type == 'token': |
|
|
return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap) |
|
|
else: |
|
|
raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.") |
|
|
|
|
|
|
|
|
def retriever_chroma(flag, model_name="sentence-transformers/all-mpnet-base-v2", splitter_type='character', chunk_size=500, chunk_overlap=30, separator="\n", max_tokens=1000): |
|
|
|
|
|
embeddings = get_embedding_model(model_name) |
|
|
|
|
|
if not flag: |
|
|
|
|
|
all_doc = docs_return(0) |
|
|
|
|
|
|
|
|
text_splitter = get_text_splitter(splitter_type=splitter_type, chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator, max_tokens=max_tokens) |
|
|
|
|
|
|
|
|
docs = text_splitter.split_documents(documents=all_doc) |
|
|
|
|
|
|
|
|
vectordb = Chroma.from_documents(all_doc, embeddings, persist_directory="./chroma_db") |
|
|
|
|
|
|
|
|
chroma_retriever = vectordb.as_retriever() |
|
|
return chroma_retriever |
|
|
else: |
|
|
|
|
|
vectordb = Chroma.load_local("vectorstore", embeddings) |
|
|
chroma_retriever = vectordb.as_retriever( |
|
|
search_type="mmr", search_kwargs={"k": 4, "fetch_k": 10} |
|
|
) |
|
|
return chroma_retriever |