| | import torch |
| | from typing import Any, Optional |
| | from transformers import LayoutLMv2ForQuestionAnswering |
| | from transformers import LayoutLMv2Processor |
| | from transformers import LayoutLMv2FeatureExtractor |
| | from transformers import LayoutLMv2ImageProcessor |
| | from transformers import LayoutLMv2TokenizerFast |
| | from transformers.tokenization_utils_base import BatchEncoding |
| | from transformers.tokenization_utils_base import TruncationStrategy |
| | from transformers.utils import TensorType |
| | |
| | |
| | |
| | import numpy as np |
| | |
| | |
| | import pdf2image |
| | |
| | import logging |
| | from os import environ |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | feature_extractor = LayoutLMv2FeatureExtractor() |
| |
|
| | |
| | |
| | |
| |
|
| | class NoOCRReaderFound(Exception): |
| | def __init__(self, e): |
| | self.e = e |
| |
|
| | def __str__(self): |
| | return f"Could not load OCR Reader: {self.e}" |
| |
|
| | def pdf_to_image(b: bytes): |
| | |
| | |
| | |
| | |
| | images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)] |
| | encoded_inputs = feature_extractor(images) |
| | print('feature_extractor: ', encoded_inputs.keys()) |
| | data = {} |
| | data['image'] = encoded_inputs.pixel_values |
| | data['words'] = encoded_inputs.words |
| | data['boxes'] = encoded_inputs.boxes |
| | return data |
| |
|
| |
|
| | def setup_logger(which_logger: Optional[str] = None): |
| | lib_level = logging.DEBUG |
| | root_level = logging.INFO |
| | log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s' |
| | logging.basicConfig( |
| | filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'), |
| | format=log_format, |
| | datefmt='%d-%b-%y %H:%M:%S', |
| | level=root_level, |
| | force=True |
| | ) |
| | log = logging.getLogger(which_logger) |
| | log.setLevel(lib_level) |
| | return log |
| |
|
| | logger = setup_logger(__name__) |
| |
|
| |
|
| | class Funcs: |
| | |
| | @staticmethod |
| | def unnormalize_box(bbox, width, height): |
| | return [ |
| | width * (bbox[0] / 1000), |
| | height * (bbox[1] / 1000), |
| | width * (bbox[2] / 1000), |
| | height * (bbox[3] / 1000), |
| | ] |
| |
|
| | @staticmethod |
| | def num_spans(encoding: BatchEncoding) -> int: |
| | return len(encoding["input_ids"]) |
| |
|
| | @staticmethod |
| | def p_mask(num_spans: int, encoding: BatchEncoding) -> list: |
| | try: |
| | return [ |
| | [tok != 1 for tok in encoding.sequence_ids(span_id)] \ |
| | for span_id in range(num_spans) |
| | ] |
| | except Exception as e: |
| | raise |
| |
|
| | @staticmethod |
| | def token_start_end(encoding, tokenizer): |
| | sequence_ids = encoding.sequence_ids() |
| |
|
| | |
| | token_start_index = 0 |
| | while sequence_ids[token_start_index] != 1: |
| | token_start_index += 1 |
| |
|
| | |
| | token_end_index = len(encoding.input_ids) - 1 |
| | while sequence_ids[token_end_index] != 1: |
| | token_end_index -= 1 |
| |
|
| | print("Token start index:", token_start_index) |
| | print("Token end index:", token_end_index) |
| | print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1])) |
| | return token_start_index, token_end_index |
| |
|
| | @staticmethod |
| | def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer): |
| | word_ids = encoding.word_ids()[token_start_index:token_end_index+1] |
| | print("Word ids:", word_ids) |
| | for id in word_ids: |
| | if id == word_idx_start: |
| | start_position = token_start_index |
| | else: |
| | token_start_index += 1 |
| |
|
| | for id in word_ids[::-1]: |
| | if id == word_idx_end: |
| | end_position = token_end_index |
| | else: |
| | token_end_index -= 1 |
| |
|
| | print("Reconstructed answer:", |
| | tokenizer.decode(encoding.input_ids[start_position:end_position+1]) |
| | ) |
| | return start_position, end_position |
| |
|
| | @staticmethod |
| | def sigmoid(_outputs): |
| | return 1.0 / (1.0 + np.exp(-_outputs)) |
| |
|
| | @staticmethod |
| | def softmax(_outputs): |
| | maxes = np.max(_outputs, axis=-1, keepdims=True) |
| | shifted_exp = np.exp(_outputs - maxes) |
| | return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path="./"): |
| | |
| | self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path) |
| | self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path) |
| | |
| | self.processor = LayoutLMv2Processor.from_pretrained( |
| | path, |
| | |
| | tokenizer=self.tokenizer) |
| |
|
| | def __call__(self, data: dict[str, bytes]): |
| | """ |
| | Args: |
| | data (:obj:): |
| | includes the deserialized image file as PIL.Image |
| | """ |
| | image = data.pop("inputs", data) |
| | images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)] |
| |
|
| | question = "what is the bill date" |
| | with torch.no_grad(): |
| | for image in images: |
| | |
| | |
| | encoding = self.processor( |
| | image, |
| | question, |
| | |
| | |
| | truncation=True, |
| | |
| | |
| | |
| | |
| | return_tensors=TensorType.PYTORCH |
| | ) |
| | print('encoding: ', encoding.keys()) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | outputs = self.model(**encoding) |
| | |
| | start_logits = outputs.start_logits |
| | end_logits = outputs.end_logits |
| |
|
| | predicted_start_idx = start_logits.argmax(-1).item() |
| | predicted_end_idx = end_logits.argmax(-1).item() |
| |
|
| | predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] |
| | predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens) |
| | |
| | target_start_index = torch.tensor([7]) |
| | target_end_index = torch.tensor([14]) |
| | outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index) |
| | |
| | |
| |
|
| | logger.info(f''' |
| | START |
| | predicted_start_idx: {predicted_start_idx} |
| | predicted_end_idx: {predicted_end_idx} |
| | --- |
| | answer: {predicted_answer} |
| | |
| | END''') |
| | return {'data': 'success'} |
| |
|