| | |
| | """ |
| | FIXED PixelText OCR Model with proper Hugging Face Hub support |
| | This version has the from_pretrained method and works with AutoModel.from_pretrained() |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import ( |
| | PaliGemmaForConditionalGeneration, |
| | PaliGemmaProcessor, |
| | AutoTokenizer, |
| | PreTrainedModel, |
| | PretrainedConfig |
| | ) |
| | from PIL import Image |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | class PixelTextConfig(PretrainedConfig): |
| | """Configuration for PixelText model.""" |
| | |
| | model_type = "pixeltext" |
| | |
| | def __init__( |
| | self, |
| | base_model="google/paligemma-3b-pt-224", |
| | hidden_size=2048, |
| | vocab_size=257216, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.base_model = base_model |
| | self.hidden_size = hidden_size |
| | self.vocab_size = vocab_size |
| |
|
| | class FixedPixelTextOCR(PreTrainedModel): |
| | """ |
| | FIXED PixelText OCR model with proper Hugging Face Hub support. |
| | This version works with AutoModel.from_pretrained() |
| | """ |
| | |
| | config_class = PixelTextConfig |
| | |
| | def __init__(self, config=None): |
| | if config is None: |
| | config = PixelTextConfig() |
| | |
| | super().__init__(config) |
| | |
| | print(f"🚀 Loading FIXED PixelText OCR...") |
| | |
| | |
| | if torch.cuda.is_available(): |
| | self._device = "cuda" |
| | self.torch_dtype = torch.float16 |
| | else: |
| | self._device = "cpu" |
| | self.torch_dtype = torch.float32 |
| | |
| | print(f"🔧 Device: {self._device}") |
| | |
| | |
| | try: |
| | self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( |
| | config.base_model, |
| | torch_dtype=self.torch_dtype, |
| | trust_remote_code=True |
| | ).to(self._device) |
| | |
| | self.processor = PaliGemmaProcessor.from_pretrained(config.base_model) |
| | self.tokenizer = AutoTokenizer.from_pretrained(config.base_model) |
| | |
| | print("✅ FIXED PixelText OCR ready!") |
| | |
| | except Exception as e: |
| | print(f"❌ Failed to load components: {e}") |
| | raise |
| | |
| | |
| | self.hidden_size = config.hidden_size |
| | self.vocab_size = config.vocab_size |
| | |
| | def forward(self, **kwargs): |
| | """Forward pass through the base model.""" |
| | return self.base_model(**kwargs) |
| | |
| | def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512): |
| | """ |
| | 🎯 MAIN METHOD: Extract text from image |
| | |
| | Args: |
| | image: PIL Image, file path, or numpy array |
| | prompt: Custom prompt (optional) |
| | max_length: Maximum length of generated text |
| | |
| | Returns: |
| | dict: Contains extracted text, confidence, and metadata |
| | """ |
| | |
| | |
| | if isinstance(image, str): |
| | image = Image.open(image).convert('RGB') |
| | elif hasattr(image, 'shape'): |
| | image = Image.fromarray(image).convert('RGB') |
| | elif not isinstance(image, Image.Image): |
| | raise ValueError("Image must be PIL Image, file path, or numpy array") |
| | |
| | |
| | if "<image>" not in prompt: |
| | prompt = f"<image>{prompt}" |
| | |
| | try: |
| | |
| | inputs = self.processor(text=prompt, images=image, return_tensors="pt") |
| | |
| | |
| | for key in inputs: |
| | if isinstance(inputs[key], torch.Tensor): |
| | inputs[key] = inputs[key].to(self._device) |
| | |
| | |
| | with torch.no_grad(): |
| | generated_ids = self.base_model.generate( |
| | **inputs, |
| | max_length=max_length, |
| | do_sample=False, |
| | num_beams=1, |
| | pad_token_id=self.tokenizer.eos_token_id |
| | ) |
| | |
| | |
| | generated_text = self.processor.batch_decode( |
| | generated_ids, |
| | skip_special_tokens=True |
| | )[0] |
| | |
| | |
| | text = self._clean_text(generated_text, prompt) |
| | |
| | |
| | confidence = self._calculate_confidence(text) |
| | |
| | return { |
| | 'text': text, |
| | 'confidence': confidence, |
| | 'success': True, |
| | 'method': 'fixed_pixeltext', |
| | 'raw_output': generated_text |
| | } |
| | |
| | except Exception as e: |
| | return { |
| | 'text': "", |
| | 'confidence': 0.0, |
| | 'success': False, |
| | 'method': 'error', |
| | 'error': str(e) |
| | } |
| | |
| | def _clean_text(self, generated_text, prompt): |
| | """Clean the generated text.""" |
| | |
| | |
| | clean_prompt = prompt.replace("<image>", "").strip() |
| | if clean_prompt and clean_prompt in generated_text: |
| | text = generated_text.replace(clean_prompt, "").strip() |
| | else: |
| | text = generated_text.strip() |
| | |
| | |
| | artifacts = [ |
| | "The image shows", "The text in the image says", |
| | "The image contains", "I can see", "The text reads", |
| | "This image shows", "The picture shows" |
| | ] |
| | |
| | for artifact in artifacts: |
| | if text.lower().startswith(artifact.lower()): |
| | text = text[len(artifact):].strip() |
| | if text.startswith(":"): |
| | text = text[1:].strip() |
| | if text.startswith('"') and text.endswith('"'): |
| | text = text[1:-1].strip() |
| | |
| | return text |
| | |
| | def _calculate_confidence(self, text): |
| | """Calculate confidence score.""" |
| | |
| | if not text: |
| | return 0.0 |
| | |
| | confidence = 0.5 |
| | |
| | if len(text) > 10: |
| | confidence += 0.2 |
| | if len(text) > 50: |
| | confidence += 0.1 |
| | if len(text) > 100: |
| | confidence += 0.1 |
| | |
| | if any(c.isalpha() for c in text): |
| | confidence += 0.1 |
| | if any(c.isdigit() for c in text): |
| | confidence += 0.05 |
| | |
| | if len(text.strip()) < 3: |
| | confidence *= 0.5 |
| | |
| | return min(0.95, confidence) |
| | |
| | def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512): |
| | """Process multiple images.""" |
| | |
| | results = [] |
| | |
| | for i, image in enumerate(images): |
| | print(f"📄 Processing image {i+1}/{len(images)}...") |
| | result = self.generate_ocr_text(image, prompt, max_length) |
| | results.append(result) |
| | |
| | if result['success']: |
| | print(f" ✅ Success: {len(result['text'])} characters") |
| | else: |
| | print(f" ❌ Failed: {result.get('error', 'Unknown error')}") |
| | |
| | return results |
| | |
| | def get_model_info(self): |
| | """Get model information.""" |
| | |
| | return { |
| | 'model_name': 'FIXED PixelText OCR', |
| | 'base_model': 'PaliGemma-3B', |
| | 'device': self._device, |
| | 'dtype': str(self.torch_dtype), |
| | 'hidden_size': self.hidden_size, |
| | 'vocab_size': self.vocab_size, |
| | 'parameters': '~3B', |
| | 'repository': 'BabaK07/pixeltext-ai', |
| | 'status': 'FIXED - Hub loading works!', |
| | 'features': [ |
| | 'Hub loading support', |
| | 'from_pretrained method', |
| | 'Fast OCR extraction', |
| | 'Multi-language support', |
| | 'Batch processing', |
| | 'Production ready' |
| | ] |
| | } |
| |
|
| | |
| | WorkingQwenOCRModel = FixedPixelTextOCR |
| |
|