File size: 1,688 Bytes
7dc8b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

class BioBERTAnswerExtractor:
    def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1-squad'):
        print("⏳ Loading BioBERT model...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        print("BioBERT model loaded.")

    def extract_answer(self, question, context):
        inputs = self.tokenizer.encode_plus(
            question, context,
            return_tensors='pt',
            truncation=True,
            max_length=512
        )

        with torch.no_grad():
            outputs = self.model(**inputs)
            start_scores = outputs.start_logits
            end_scores = outputs.end_logits

        start_idx = torch.argmax(start_scores)
        end_idx = torch.argmax(end_scores)

        if start_idx > end_idx:
            return ""  # invalid span

        all_tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        answer_tokens = all_tokens[start_idx:end_idx + 1]
        answer = self.tokenizer.convert_tokens_to_string(answer_tokens).strip()

        # Filter out junk answers
        if not answer or answer.lower() in ["[cls]", "[sep]"] or len(answer) < 3:
            return ""  # signal to use fallback

        return answer


# Example usage
if __name__ == "__main__":
    qa = BioBERTAnswerExtractor()
    question = "What are the symptoms of flu?"
    context = "The flu can cause fever, cough, sore throat, muscle aches, fatigue, and chills."
    answer = qa.extract_answer(question, context)
    print(f"Answer: {answer}")