farmerbot / model_ret.py.old
Nelly43's picture
Update app
0c6d13f
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import HuggingFacePipeline
from transformers import pipeline
import transformers
import torch
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
def zepyhr_model(model_info,quanization):
path=f"models/{model_info}"
tokenizer = AutoTokenizer.from_pretrained(path,
use_auth_token=True,)
if quanization=="8":
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_8bit=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_4bit=True
)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
return llm
def llama_model(model_info,quanization):
path=f"models/{model_info}"
tokenizer = AutoTokenizer.from_pretrained(path,
use_auth_token=True,)
if quanization=="8":
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_8bit=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_4bit=True
)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
return llm
def mistral_model(model_info,quanization):
path=f"models/{model_info}"
tokenizer = AutoTokenizer.from_pretrained(path,
use_auth_token=True,)
if quanization=="8":
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_8bit=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_4bit=True
)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
return llm
def phi_model(model_info,quanization):
path=f"models/{model_info}"
tokenizer = AutoTokenizer.from_pretrained(path,
use_auth_token=True,)
if quanization=="8":
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_8bit=True,
trust_remote_code=True
)
else:
model = AutoModelForCausalLM.from_pretrained(path,
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_4bit=True,
trust_remote_code=True
)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
return llm
def flant5_model(model_info):
path=f"models/{model_info}"
model = T5ForConditionalGeneration.from_pretrained(path)
MODEL_NAME = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 512,
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs = {'temperature':0})
return tokenizer,model,llm
# model = AutoModelForCausalLM.from_pretrained(
# path,
# # quantization_config=bnb_config,
# device_map="auto",
# trust_remote_code=True,
# attn_implementation="flash_attention_2",
# torch_dtype=torch.bfloat16,
# )