fokan commited on
Commit
5721a78
·
verified ·
1 Parent(s): 01df01b

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/cache_manager.py +49 -0
  2. utils/modality_router.py +63 -0
utils/cache_manager.py CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Iterable, List, Tuple
3
+
4
+
5
+ BATCH_SIZE = 50
6
+
7
+
8
+ def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
9
+ if isinstance(labels, tuple):
10
+ return labels
11
+ return tuple(labels)
12
+
13
+
14
+ @lru_cache(maxsize=5)
15
+ def cached_inference(image_path, labels, model, processor):
16
+ import torch
17
+ from PIL import Image
18
+
19
+ label_tuple: Tuple[str, ...] = _ensure_tuple(labels)
20
+
21
+ with Image.open(image_path).convert("RGB") as img:
22
+ tensor_image = img.copy()
23
+
24
+ device = next(model.parameters()).device
25
+ dtype = next(model.parameters()).dtype
26
+ logits: List[float] = []
27
+
28
+ with torch.no_grad():
29
+ for start in range(0, len(label_tuple), BATCH_SIZE):
30
+ batch = label_tuple[start : start + BATCH_SIZE]
31
+ inputs = processor(images=tensor_image, text=list(batch), return_tensors="pt", padding=True)
32
+
33
+ prepared = {}
34
+ for key, value in inputs.items():
35
+ if torch.is_tensor(value):
36
+ moved = value.to(device)
37
+ if torch.is_floating_point(moved):
38
+ moved = moved.to(dtype=dtype)
39
+ prepared[key] = moved
40
+ else:
41
+ prepared[key] = value
42
+
43
+ outputs = model(**prepared)
44
+ batch_logits = outputs.logits_per_image[0].detach().cpu().tolist()
45
+ logits.extend(batch_logits)
46
+
47
+ tensor_image.close()
48
+ scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
49
+ return scores
utils/modality_router.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from transformers import pipeline
4
+
5
+ _router = None
6
+
7
+
8
+ def init_router():
9
+ global _router
10
+ if _router is None:
11
+ _router = pipeline("image-classification", model="Matthijs/mobilevit-small")
12
+
13
+
14
+ def detect_modality(image_path: str) -> str:
15
+ """Hybrid modality detection: filename hints + visual + lightweight model"""
16
+ name = image_path.lower()
17
+ # 1️⃣ تحليل الاسم
18
+ if any(k in name for k in ["xray", "chest", "lung"]):
19
+ return "xray"
20
+ if any(k in name for k in ["ct", "abdomen", "liver"]):
21
+ return "ct"
22
+ if any(k in name for k in ["ultrasound", "us", "sonogram"]):
23
+ return "ultrasound"
24
+ if any(k in name for k in ["mri", "brain", "spine"]):
25
+ return "mri"
26
+ if any(k in name for k in ["histopath", "slide", "micro"]):
27
+ return "pathology"
28
+ if any(k in name for k in ["skin", "derma"]):
29
+ return "skin"
30
+ if any(k in name for k in ["eye", "retina", "fundus"]):
31
+ return "eye"
32
+ if any(k in name for k in ["cardio", "echo", "heart"]):
33
+ return "cardio"
34
+ if any(k in name for k in ["msk", "musculoskeletal", "orthopedic"]):
35
+ return "musculoskeletal"
36
+
37
+ # 2️⃣ تحليل بصري بسيط
38
+ with Image.open(image_path).convert("RGB") as img:
39
+ arr = np.array(img)
40
+ mean_sat = np.std(arr, axis=(0, 1)).mean()
41
+ if mean_sat < 15:
42
+ return "xray"
43
+ elif mean_sat < 25:
44
+ return "mri"
45
+ elif mean_sat > 45:
46
+ return "skin"
47
+
48
+ # 3️⃣ نموذج صغير fallback
49
+ if _router is None:
50
+ init_router()
51
+ pred = _router(image_path)[0]["label"].lower()
52
+ valid = {
53
+ "xray",
54
+ "ct",
55
+ "ultrasound",
56
+ "mri",
57
+ "pathology",
58
+ "skin",
59
+ "eye",
60
+ "cardio",
61
+ "musculoskeletal",
62
+ }
63
+ return "general" if pred not in valid else pred