Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pickle | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| class ArxivClassifierModelsPipeline(): | |
| def __init__(self): | |
| self.model_topic_clf = self.__load_topic_clf() | |
| self.model_maintopic_clf = self.__load_maintopic_clf() | |
| topic_clf_default_model = "allenai/scibert_scivocab_uncased" | |
| self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_clf_default_model) | |
| maintopic_clf_default_model = "Wi/arxiv-topics-distilbert-base-cased" | |
| self.maintopic_tokenizer = AutoTokenizer.from_pretrained(maintopic_clf_default_model) | |
| with open('models/scibert/decode_dict_topic.pkl', 'rb') as f: | |
| self.decode_dict_topic = pickle.load(f) | |
| with open('models/maintopic_clf/decode_dict_maintopic.pkl', 'rb') as f: | |
| self.decode_dict_maintopic = pickle.load(f) | |
| with open('models/maintopic_clf/main_topic_dict.pkl', 'rb') as f: | |
| self.main_topic_dict = pickle.load(f) | |
| with open('models/scibert/topic_dict.pkl', 'rb') as f: | |
| self.topic_dict = pickle.load(f) | |
| def make_predict(self, text): | |
| tokens_topic = self.topic_tokenizer(text, return_tensors="pt") | |
| topic_outs = self.model_topic_clf(tokens_topic.input_ids) | |
| probs_topic = topic_outs["logits"].softmax(dim=-1).tolist()[0] | |
| topic_probs = {} | |
| for i, p in enumerate(probs_topic): | |
| if p > 0.1: | |
| if self.decode_dict_topic[i] in self.topic_dict: | |
| topic_probs[self.topic_dict[self.decode_dict_topic[i]]] = p | |
| else: | |
| topic_probs[self.decode_dict_topic[i]] = p | |
| tokens_maintopic = self.maintopic_tokenizer(text, return_tensors="pt") | |
| maintopic_outs = self.model_maintopic_clf(tokens_maintopic.input_ids) | |
| probs_maintopic = maintopic_outs["logits"].softmax(dim=-1).tolist()[0] | |
| maintopic_probs = self.decode_dict_maintopic[0] | |
| return topic_probs, self.main_topic_dict[maintopic_probs] | |
| def __load_topic_clf(self): | |
| st.write("Loading model") | |
| return AutoModelForSequenceClassification.from_pretrained("models/scibert/") | |
| def __load_maintopic_clf(self): | |
| st.write("Loading second model") | |
| return AutoModelForSequenceClassification.from_pretrained("models/maintopic_clf/") |