Upload 2 files
Browse files- utils/cache_manager.py +49 -0
- utils/modality_router.py +63 -0
utils/cache_manager.py
CHANGED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
from typing import Iterable, List, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
BATCH_SIZE = 50
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
|
| 9 |
+
if isinstance(labels, tuple):
|
| 10 |
+
return labels
|
| 11 |
+
return tuple(labels)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@lru_cache(maxsize=5)
|
| 15 |
+
def cached_inference(image_path, labels, model, processor):
|
| 16 |
+
import torch
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
label_tuple: Tuple[str, ...] = _ensure_tuple(labels)
|
| 20 |
+
|
| 21 |
+
with Image.open(image_path).convert("RGB") as img:
|
| 22 |
+
tensor_image = img.copy()
|
| 23 |
+
|
| 24 |
+
device = next(model.parameters()).device
|
| 25 |
+
dtype = next(model.parameters()).dtype
|
| 26 |
+
logits: List[float] = []
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
for start in range(0, len(label_tuple), BATCH_SIZE):
|
| 30 |
+
batch = label_tuple[start : start + BATCH_SIZE]
|
| 31 |
+
inputs = processor(images=tensor_image, text=list(batch), return_tensors="pt", padding=True)
|
| 32 |
+
|
| 33 |
+
prepared = {}
|
| 34 |
+
for key, value in inputs.items():
|
| 35 |
+
if torch.is_tensor(value):
|
| 36 |
+
moved = value.to(device)
|
| 37 |
+
if torch.is_floating_point(moved):
|
| 38 |
+
moved = moved.to(dtype=dtype)
|
| 39 |
+
prepared[key] = moved
|
| 40 |
+
else:
|
| 41 |
+
prepared[key] = value
|
| 42 |
+
|
| 43 |
+
outputs = model(**prepared)
|
| 44 |
+
batch_logits = outputs.logits_per_image[0].detach().cpu().tolist()
|
| 45 |
+
logits.extend(batch_logits)
|
| 46 |
+
|
| 47 |
+
tensor_image.close()
|
| 48 |
+
scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
|
| 49 |
+
return scores
|
utils/modality_router.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
|
| 5 |
+
_router = None
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def init_router():
|
| 9 |
+
global _router
|
| 10 |
+
if _router is None:
|
| 11 |
+
_router = pipeline("image-classification", model="Matthijs/mobilevit-small")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def detect_modality(image_path: str) -> str:
|
| 15 |
+
"""Hybrid modality detection: filename hints + visual + lightweight model"""
|
| 16 |
+
name = image_path.lower()
|
| 17 |
+
# 1️⃣ تحليل الاسم
|
| 18 |
+
if any(k in name for k in ["xray", "chest", "lung"]):
|
| 19 |
+
return "xray"
|
| 20 |
+
if any(k in name for k in ["ct", "abdomen", "liver"]):
|
| 21 |
+
return "ct"
|
| 22 |
+
if any(k in name for k in ["ultrasound", "us", "sonogram"]):
|
| 23 |
+
return "ultrasound"
|
| 24 |
+
if any(k in name for k in ["mri", "brain", "spine"]):
|
| 25 |
+
return "mri"
|
| 26 |
+
if any(k in name for k in ["histopath", "slide", "micro"]):
|
| 27 |
+
return "pathology"
|
| 28 |
+
if any(k in name for k in ["skin", "derma"]):
|
| 29 |
+
return "skin"
|
| 30 |
+
if any(k in name for k in ["eye", "retina", "fundus"]):
|
| 31 |
+
return "eye"
|
| 32 |
+
if any(k in name for k in ["cardio", "echo", "heart"]):
|
| 33 |
+
return "cardio"
|
| 34 |
+
if any(k in name for k in ["msk", "musculoskeletal", "orthopedic"]):
|
| 35 |
+
return "musculoskeletal"
|
| 36 |
+
|
| 37 |
+
# 2️⃣ تحليل بصري بسيط
|
| 38 |
+
with Image.open(image_path).convert("RGB") as img:
|
| 39 |
+
arr = np.array(img)
|
| 40 |
+
mean_sat = np.std(arr, axis=(0, 1)).mean()
|
| 41 |
+
if mean_sat < 15:
|
| 42 |
+
return "xray"
|
| 43 |
+
elif mean_sat < 25:
|
| 44 |
+
return "mri"
|
| 45 |
+
elif mean_sat > 45:
|
| 46 |
+
return "skin"
|
| 47 |
+
|
| 48 |
+
# 3️⃣ نموذج صغير fallback
|
| 49 |
+
if _router is None:
|
| 50 |
+
init_router()
|
| 51 |
+
pred = _router(image_path)[0]["label"].lower()
|
| 52 |
+
valid = {
|
| 53 |
+
"xray",
|
| 54 |
+
"ct",
|
| 55 |
+
"ultrasound",
|
| 56 |
+
"mri",
|
| 57 |
+
"pathology",
|
| 58 |
+
"skin",
|
| 59 |
+
"eye",
|
| 60 |
+
"cardio",
|
| 61 |
+
"musculoskeletal",
|
| 62 |
+
}
|
| 63 |
+
return "general" if pred not in valid else pred
|