|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
|
|
import torch |
|
|
|
|
|
class BioBERTAnswerExtractor: |
|
|
def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1-squad'): |
|
|
print("⏳ Loading BioBERT model...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name) |
|
|
print("BioBERT model loaded.") |
|
|
|
|
|
def extract_answer(self, question, context): |
|
|
inputs = self.tokenizer.encode_plus( |
|
|
question, context, |
|
|
return_tensors='pt', |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
start_scores = outputs.start_logits |
|
|
end_scores = outputs.end_logits |
|
|
|
|
|
start_idx = torch.argmax(start_scores) |
|
|
end_idx = torch.argmax(end_scores) |
|
|
|
|
|
if start_idx > end_idx: |
|
|
return "" |
|
|
|
|
|
all_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) |
|
|
answer_tokens = all_tokens[start_idx:end_idx + 1] |
|
|
answer = self.tokenizer.convert_tokens_to_string(answer_tokens).strip() |
|
|
|
|
|
|
|
|
if not answer or answer.lower() in ["[cls]", "[sep]"] or len(answer) < 3: |
|
|
return "" |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
qa = BioBERTAnswerExtractor() |
|
|
question = "What are the symptoms of flu?" |
|
|
context = "The flu can cause fever, cough, sore throat, muscle aches, fatigue, and chills." |
|
|
answer = qa.extract_answer(question, context) |
|
|
print(f"Answer: {answer}") |
|
|
|