| | |
| | """ |
| | Model Preloader for Multilingual Audio Intelligence System - Enhanced Version |
| | |
| | Key improvements: |
| | 1. Smart local cache detection with corruption checking |
| | 2. Fallback to download if local files don't exist or are corrupted |
| | 3. Better error handling and retry mechanisms |
| | 4. Consistent approach across all model types |
| | """ |
| |
|
| | import os |
| | import sys |
| | import logging |
| | import time |
| | from pathlib import Path |
| | from typing import Dict, Any, Optional |
| | import json |
| | from datetime import datetime |
| |
|
| | |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | import whisper |
| | from pyannote.audio import Pipeline |
| | from rich.console import Console |
| | from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn |
| | from rich.panel import Panel |
| | from rich.text import Text |
| | import psutil |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | console = Console() |
| |
|
| | class ModelPreloader: |
| | """Comprehensive model preloader with enhanced local cache detection.""" |
| | |
| | def __init__(self, cache_dir: str = "./model_cache", device: str = "auto"): |
| | self.cache_dir = Path(cache_dir) |
| | self.cache_dir.mkdir(exist_ok=True) |
| | |
| | |
| | if device == "auto": |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | else: |
| | self.device = device |
| | |
| | self.models = {} |
| | self.model_info = {} |
| | |
| | |
| | self.model_configs = { |
| | "speaker_diarization": { |
| | "name": "pyannote/speaker-diarization-3.1", |
| | "type": "pyannote", |
| | "description": "Speaker Diarization Pipeline", |
| | "size_mb": 32 |
| | }, |
| | "whisper_small": { |
| | "name": "openai/whisper-small", |
| | "type": "whisper", |
| | "description": "Whisper Speech Recognition (Small)", |
| | "size_mb": 484 |
| | }, |
| | "mbart_translation": { |
| | "name": "facebook/mbart-large-50-many-to-many-mmt", |
| | "type": "mbart", |
| | "description": "mBART Neural Machine Translation", |
| | "size_mb": 2440 |
| | }, |
| | |
| | "opus_mt_ja_en": { |
| | "name": "Helsinki-NLP/opus-mt-ja-en", |
| | "type": "opus_mt", |
| | "description": "Japanese to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_es_en": { |
| | "name": "Helsinki-NLP/opus-mt-es-en", |
| | "type": "opus_mt", |
| | "description": "Spanish to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_fr_en": { |
| | "name": "Helsinki-NLP/opus-mt-fr-en", |
| | "type": "opus_mt", |
| | "description": "French to English Translation", |
| | "size_mb": 303 |
| | }, |
| | |
| | "opus_mt_hi_en": { |
| | "name": "Helsinki-NLP/opus-mt-hi-en", |
| | "type": "opus_mt", |
| | "description": "Hindi to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_ta_en": { |
| | "name": "Helsinki-NLP/opus-mt-ta-en", |
| | "type": "opus_mt", |
| | "description": "Tamil to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_bn_en": { |
| | "name": "Helsinki-NLP/opus-mt-bn-en", |
| | "type": "opus_mt", |
| | "description": "Bengali to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_te_en": { |
| | "name": "Helsinki-NLP/opus-mt-te-en", |
| | "type": "opus_mt", |
| | "description": "Telugu to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_mr_en": { |
| | "name": "Helsinki-NLP/opus-mt-mr-en", |
| | "type": "opus_mt", |
| | "description": "Marathi to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_gu_en": { |
| | "name": "Helsinki-NLP/opus-mt-gu-en", |
| | "type": "opus_mt", |
| | "description": "Gujarati to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_kn_en": { |
| | "name": "Helsinki-NLP/opus-mt-kn-en", |
| | "type": "opus_mt", |
| | "description": "Kannada to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_pa_en": { |
| | "name": "Helsinki-NLP/opus-mt-pa-en", |
| | "type": "opus_mt", |
| | "description": "Punjabi to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_ml_en": { |
| | "name": "Helsinki-NLP/opus-mt-ml-en", |
| | "type": "opus_mt", |
| | "description": "Malayalam to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_ne_en": { |
| | "name": "Helsinki-NLP/opus-mt-ne-en", |
| | "type": "opus_mt", |
| | "description": "Nepali to English Translation", |
| | "size_mb": 303 |
| | }, |
| | "opus_mt_ur_en": { |
| | "name": "Helsinki-NLP/opus-mt-ur-en", |
| | "type": "opus_mt", |
| | "description": "Urdu to English Translation", |
| | "size_mb": 303 |
| | } |
| | } |
| | |
| | def check_local_model_files(self, model_name: str, model_type: str) -> bool: |
| | """ |
| | Check if model files exist locally and are not corrupted. |
| | Returns True if valid local files exist, False otherwise. |
| | """ |
| | try: |
| | if model_type == "whisper": |
| | |
| | whisper_cache = self.cache_dir / "whisper" / "models--Systran--faster-whisper-small" |
| | required_files = ["config.json", "model.bin", "tokenizer.json", "vocabulary.txt"] |
| | |
| | |
| | snapshots_dir = whisper_cache / "snapshots" |
| | if not snapshots_dir.exists(): |
| | return False |
| | |
| | |
| | snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()] |
| | if not snapshot_dirs: |
| | return False |
| | |
| | |
| | snapshot_path = snapshot_dirs[0] |
| | for file in required_files: |
| | file_path = snapshot_path / file |
| | if not file_path.exists() or file_path.stat().st_size == 0: |
| | return False |
| | |
| | return True |
| | |
| | elif model_type in ["mbart", "opus_mt"]: |
| | |
| | if model_type == "mbart": |
| | model_cache_path = self.cache_dir / "mbart" / f"models--{model_name.replace('/', '--')}" |
| | else: |
| | model_cache_path = self.cache_dir / "opus_mt" / f"{model_name.replace('/', '--')}" / f"models--{model_name.replace('/', '--')}" |
| | |
| | required_files = ["config.json", "tokenizer_config.json"] |
| | |
| | model_files = ["pytorch_model.bin", "model.safetensors"] |
| | |
| | |
| | snapshots_dir = model_cache_path / "snapshots" |
| | if not snapshots_dir.exists(): |
| | return False |
| | |
| | |
| | snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()] |
| | if not snapshot_dirs: |
| | return False |
| | |
| | |
| | snapshot_path = max(snapshot_dirs, key=lambda x: x.stat().st_mtime) |
| | |
| | |
| | for file in required_files: |
| | file_path = snapshot_path / file |
| | if not file_path.exists() or file_path.stat().st_size == 0: |
| | return False |
| | |
| | |
| | model_file_exists = any( |
| | (snapshot_path / model_file).exists() and (snapshot_path / model_file).stat().st_size > 0 |
| | for model_file in model_files |
| | ) |
| | |
| | return model_file_exists |
| | |
| | elif model_type == "pyannote": |
| | |
| | |
| | return False |
| | |
| | except Exception as e: |
| | logger.warning(f"Error checking local files for {model_name}: {e}") |
| | return False |
| | |
| | return False |
| |
|
| | def load_transformers_model_with_cache_check(self, model_name: str, cache_path: Path, model_type: str = "seq2seq") -> Optional[Dict[str, Any]]: |
| | """ |
| | Load transformers model with intelligent cache checking and fallback. |
| | """ |
| | try: |
| | |
| | has_local_files = self.check_local_model_files(model_name, "mbart" if "mbart" in model_name else "opus_mt") |
| | |
| | if has_local_files: |
| | console.print(f"[green]Found valid local cache for {model_name}, loading from cache...[/green]") |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | cache_dir=str(cache_path), |
| | local_files_only=True |
| | ) |
| | |
| | model = AutoModelForSeq2SeqLM.from_pretrained( |
| | model_name, |
| | cache_dir=str(cache_path), |
| | local_files_only=True, |
| | torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
| | ) |
| | |
| | console.print(f"[green]SUCCESS: Successfully loaded {model_name} from local cache[/green]") |
| | |
| | except Exception as e: |
| | console.print(f"[yellow]Local cache load failed for {model_name}, will download: {e}[/yellow]") |
| | has_local_files = False |
| | |
| | if not has_local_files: |
| | console.print(f"[yellow]No valid local cache for {model_name}, downloading...[/yellow]") |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | cache_dir=str(cache_path) |
| | ) |
| | |
| | model = AutoModelForSeq2SeqLM.from_pretrained( |
| | model_name, |
| | cache_dir=str(cache_path), |
| | torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
| | ) |
| | |
| | console.print(f"[green]SUCCESS: Successfully downloaded and loaded {model_name}[/green]") |
| | |
| | |
| | if self.device != "cpu": |
| | model = model.to(self.device) |
| | |
| | |
| | test_input = tokenizer("Hello world", return_tensors="pt") |
| | if self.device != "cpu": |
| | test_input = {k: v.to(self.device) for k, v in test_input.items()} |
| | |
| | with torch.no_grad(): |
| | output = model.generate(**test_input, max_length=10) |
| | |
| | return { |
| | "model": model, |
| | "tokenizer": tokenizer |
| | } |
| | |
| | except Exception as e: |
| | console.print(f"[red]✗ Failed to load {model_name}: {e}[/red]") |
| | logger.error(f"Model loading failed for {model_name}: {e}") |
| | return None |
| |
|
| | def get_system_info(self) -> Dict[str, Any]: |
| | """Get system information for optimal model loading.""" |
| | return { |
| | "cpu_count": psutil.cpu_count(), |
| | "memory_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
| | "available_memory_gb": round(psutil.virtual_memory().available / (1024**3), 2), |
| | "device": self.device, |
| | "torch_version": torch.__version__, |
| | "cuda_available": torch.cuda.is_available(), |
| | "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
| | } |
| | |
| | def check_model_cache(self, model_key: str) -> bool: |
| | """Check if model is already cached and working.""" |
| | cache_file = self.cache_dir / f"{model_key}_info.json" |
| | if not cache_file.exists(): |
| | return False |
| | |
| | try: |
| | with open(cache_file, 'r') as f: |
| | cache_info = json.load(f) |
| | |
| | |
| | cache_time = datetime.fromisoformat(cache_info['timestamp']) |
| | days_old = (datetime.now() - cache_time).days |
| | |
| | if days_old > 7: |
| | logger.info(f"Cache for {model_key} is {days_old} days old, will refresh") |
| | return False |
| | |
| | return cache_info.get('status') == 'success' |
| | except Exception as e: |
| | logger.warning(f"Error reading cache for {model_key}: {e}") |
| | return False |
| | |
| | def save_model_cache(self, model_key: str, status: str, info: Dict[str, Any]): |
| | """Save model loading information to cache.""" |
| | cache_file = self.cache_dir / f"{model_key}_info.json" |
| | cache_data = { |
| | "timestamp": datetime.now().isoformat(), |
| | "status": status, |
| | "device": self.device, |
| | "info": info |
| | } |
| | |
| | try: |
| | with open(cache_file, 'w') as f: |
| | json.dump(cache_data, f, indent=2) |
| | except Exception as e: |
| | logger.warning(f"Error saving cache for {model_key}: {e}") |
| | |
| | def load_pyannote_pipeline(self, task_id: str) -> Optional[Pipeline]: |
| | """Load pyannote speaker diarization pipeline with container-safe settings.""" |
| | try: |
| | console.print(f"[yellow]Loading pyannote.audio pipeline...[/yellow]") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | hf_token = os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HF_TOKEN') |
| | if not hf_token: |
| | console.print("[red]Warning: HUGGINGFACE_TOKEN not found. Some models may not be accessible.[/red]") |
| | |
| | |
| | import warnings |
| | import logging |
| | |
| | |
| | old_warning_filters = warnings.filters[:] |
| | warnings.filterwarnings("ignore") |
| | |
| | |
| | os.environ['ORT_LOGGING_LEVEL'] = '3' |
| | |
| | |
| | |
| | logging.getLogger('transformers').setLevel(logging.ERROR) |
| | |
| | try: |
| | pipeline = Pipeline.from_pretrained( |
| | "pyannote/speaker-diarization-3.1", |
| | use_auth_token=hf_token, |
| | cache_dir=str(self.cache_dir / "pyannote") |
| | ) |
| | |
| | |
| | if hasattr(pipeline, '_models'): |
| | for model_name, model in pipeline._models.items(): |
| | if hasattr(model, 'to'): |
| | model.to('cpu') |
| | |
| | console.print(f"[green]SUCCESS: pyannote.audio pipeline loaded successfully on CPU[/green]") |
| | return pipeline |
| | |
| | finally: |
| | |
| | warnings.filters[:] = old_warning_filters |
| | |
| | except Exception as e: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | logger.error(f"Pyannote loading failed: {e}") |
| | return None |
| |
|
| | def load_whisper_model(self, task_id: str) -> Optional[whisper.Whisper]: |
| | """Load Whisper speech recognition model with enhanced cache checking.""" |
| | try: |
| | console.print(f"[yellow]Loading Whisper model (small)...[/yellow]") |
| | |
| | whisper_cache_dir = self.cache_dir / "whisper" |
| | |
| | |
| | has_local_files = self.check_local_model_files("small", "whisper") |
| | |
| | if has_local_files: |
| | console.print(f"[green]Found valid local Whisper cache, loading from cache...[/green]") |
| | else: |
| | console.print(f"[yellow]No valid local Whisper cache found, will download...[/yellow]") |
| | |
| | |
| | model = whisper.load_model("small", device=self.device) |
| | |
| | |
| | import numpy as np |
| | dummy_audio = np.zeros(16000, dtype=np.float32) |
| | result = model.transcribe(dummy_audio, language="en") |
| | |
| | console.print(f"[green]SUCCESS: Whisper model loaded successfully on {self.device}[/green]") |
| | |
| | return model |
| | |
| | except Exception as e: |
| | console.print(f"[red]ERROR: Failed to load Whisper model: {e}[/red]") |
| | logger.error(f"Whisper loading failed: {e}") |
| | return None |
| | |
| | def load_mbart_model(self, task_id: str) -> Optional[Dict[str, Any]]: |
| | """Load mBART translation model with enhanced cache checking.""" |
| | console.print(f"[yellow]Loading mBART translation model...[/yellow]") |
| | |
| | model_name = "facebook/mbart-large-50-many-to-many-mmt" |
| | cache_path = self.cache_dir / "mbart" |
| | cache_path.mkdir(exist_ok=True) |
| | |
| | return self.load_transformers_model_with_cache_check(model_name, cache_path, "seq2seq") |
| | |
| | def load_opus_mt_model(self, task_id: str, model_name: str) -> Optional[Dict[str, Any]]: |
| | """Load Opus-MT translation model with enhanced cache checking.""" |
| | console.print(f"[yellow]Loading Opus-MT model: {model_name}...[/yellow]") |
| | |
| | cache_path = self.cache_dir / "opus_mt" / model_name.replace("/", "--") |
| | cache_path.mkdir(parents=True, exist_ok=True) |
| | |
| | return self.load_transformers_model_with_cache_check(model_name, cache_path, "seq2seq") |
| | |
| | def preload_all_models(self) -> Dict[str, Any]: |
| | """Preload all models with progress tracking.""" |
| | |
| | |
| | sys_info = self.get_system_info() |
| | |
| | info_panel = Panel.fit( |
| | f"""System Information |
| | |
| | • CPU Cores: {sys_info['cpu_count']} |
| | • Total Memory: {sys_info['memory_gb']} GB |
| | • Available Memory: {sys_info['available_memory_gb']} GB |
| | • Device: {sys_info['device'].upper()} |
| | • PyTorch: {sys_info['torch_version']} |
| | • CUDA Available: {sys_info['cuda_available']} |
| | {f"• GPU: {sys_info['gpu_name']}" if sys_info['gpu_name'] else ""}""", |
| | title="[bold blue]Audio Intelligence System[/bold blue]", |
| | border_style="blue" |
| | ) |
| | console.print(info_panel) |
| | console.print() |
| | |
| | results = { |
| | "system_info": sys_info, |
| | "models": {}, |
| | "total_time": 0, |
| | "success_count": 0, |
| | "total_count": len(self.model_configs) |
| | } |
| | |
| | start_time = time.time() |
| | |
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
| | TimeRemainingColumn(), |
| | console=console |
| | ) as progress: |
| | |
| | |
| | main_task = progress.add_task("[cyan]Loading AI Models...", total=len(self.model_configs)) |
| | |
| | |
| | for model_key, config in self.model_configs.items(): |
| | task_id = progress.add_task(f"[yellow]{config['description']}", total=100) |
| | |
| | |
| | if self.check_model_cache(model_key): |
| | console.print(f"[green]SUCCESS: {config['description']} found in cache[/green]") |
| | progress.update(task_id, completed=100) |
| | progress.update(main_task, advance=1) |
| | results["models"][model_key] = {"status": "cached", "time": 0} |
| | results["success_count"] += 1 |
| | continue |
| | |
| | model_start_time = time.time() |
| | progress.update(task_id, completed=10) |
| | |
| | |
| | if config["type"] == "pyannote": |
| | model = self.load_pyannote_pipeline(task_id) |
| | elif config["type"] == "whisper": |
| | model = self.load_whisper_model(task_id) |
| | elif config["type"] == "mbart": |
| | model = self.load_mbart_model(task_id) |
| | elif config["type"] == "opus_mt": |
| | model = self.load_opus_mt_model(task_id, config["name"]) |
| | else: |
| | model = None |
| | |
| | model_time = time.time() - model_start_time |
| | |
| | if model is not None: |
| | self.models[model_key] = model |
| | progress.update(task_id, completed=100) |
| | results["models"][model_key] = {"status": "success", "time": model_time} |
| | results["success_count"] += 1 |
| | |
| | |
| | self.save_model_cache(model_key, "success", { |
| | "load_time": model_time, |
| | "device": self.device, |
| | "model_name": config["name"] |
| | }) |
| | else: |
| | progress.update(task_id, completed=100) |
| | results["models"][model_key] = {"status": "failed", "time": model_time} |
| | |
| | |
| | self.save_model_cache(model_key, "failed", { |
| | "load_time": model_time, |
| | "device": self.device, |
| | "error": "Model loading failed" |
| | }) |
| | |
| | progress.update(main_task, advance=1) |
| | |
| | results["total_time"] = time.time() - start_time |
| | |
| | |
| | console.print() |
| | if results["success_count"] == results["total_count"]: |
| | status_text = "[bold green]SUCCESS: All models loaded successfully![/bold green]" |
| | status_color = "green" |
| | elif results["success_count"] > 0: |
| | status_text = f"[bold yellow]WARNING: {results['success_count']}/{results['total_count']} models loaded[/bold yellow]" |
| | status_color = "yellow" |
| | else: |
| | status_text = "[bold red]ERROR: No models loaded successfully[/bold red]" |
| | status_color = "red" |
| | |
| | summary_panel = Panel.fit( |
| | f"""{status_text} |
| | |
| | • Loading Time: {results['total_time']:.1f} seconds |
| | • Device: {self.device.upper()} |
| | • Memory Usage: {psutil.virtual_memory().percent:.1f}% |
| | • Models Ready: {results['success_count']}/{results['total_count']}""", |
| | title="[bold]Model Loading Summary[/bold]", |
| | border_style=status_color |
| | ) |
| | console.print(summary_panel) |
| | |
| | return results |
| | |
| | def get_models(self) -> Dict[str, Any]: |
| | """Get loaded models.""" |
| | return self.models |
| | |
| | def cleanup(self): |
| | """Cleanup resources.""" |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | def main(): |
| | """Main function to run model preloading.""" |
| | |
| | console.print(Panel.fit( |
| | "[bold blue]Multilingual Audio Intelligence System[/bold blue]\n[yellow]Model Preloader[/yellow]", |
| | border_style="blue" |
| | )) |
| | console.print() |
| | |
| | |
| | preloader = ModelPreloader() |
| | |
| | |
| | try: |
| | results = preloader.preload_all_models() |
| | |
| | if results["success_count"] > 0: |
| | console.print("\n[bold green]SUCCESS: Model preloading completed![/bold green]") |
| | console.print(f"[dim]Models cached in: {preloader.cache_dir}[/dim]") |
| | return True |
| | else: |
| | console.print("\n[bold red]ERROR: Model preloading failed![/bold red]") |
| | return False |
| | |
| | except KeyboardInterrupt: |
| | console.print("\n[yellow]Model preloading interrupted by user[/yellow]") |
| | return False |
| | except Exception as e: |
| | console.print(f"\n[bold red]✗ Model preloading failed: {e}[/bold red]") |
| | logger.error(f"Preloading failed: {e}") |
| | return False |
| | finally: |
| | preloader.cleanup() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | success = main() |
| | sys.exit(0 if success else 1) |