|
|
import os
|
|
|
import json
|
|
|
import signal
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
from typing import List, Dict, Tuple, Optional, Sequence, Set, Any
|
|
|
from multiprocessing import Pool, cpu_count
|
|
|
from functools import partial
|
|
|
|
|
|
import fitz
|
|
|
import pypdfium2 as pdfium
|
|
|
import torch
|
|
|
from doclayout_yolo import YOLOv10
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
from loguru import logger
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
|
|
|
try:
|
|
|
import pymupdf4llm
|
|
|
except ImportError:
|
|
|
pymupdf4llm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
MODEL_SIZE = 1024
|
|
|
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
|
|
|
WEIGHTS_FILE = f"doclayout_yolo_docstructbench_imgsz{MODEL_SIZE}.pt"
|
|
|
|
|
|
|
|
|
CONF_THRESHOLD = 0.25
|
|
|
|
|
|
|
|
|
NUM_WORKERS = None
|
|
|
USE_MULTIPROCESSING = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLASS_COLORS = {
|
|
|
"text": (0, 128, 0),
|
|
|
"title": (192, 0, 0),
|
|
|
"figure": (0, 0, 192),
|
|
|
"table": (218, 165, 32),
|
|
|
"list": (128, 0, 128),
|
|
|
"header": (0, 128, 128),
|
|
|
"footer": (100, 100, 100),
|
|
|
"figure_caption": (0, 0, 128),
|
|
|
"table_caption": (139, 69, 19),
|
|
|
"table_footnote": (128, 0, 128),
|
|
|
}
|
|
|
|
|
|
|
|
|
_model = None
|
|
|
_shutdown_requested = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signal_handler(signum, frame):
|
|
|
"""Handle interrupt signals gracefully."""
|
|
|
global _shutdown_requested
|
|
|
if not _shutdown_requested:
|
|
|
_shutdown_requested = True
|
|
|
logger.warning("\nβ οΈ Interrupt received! Finishing current page and shutting down gracefully...")
|
|
|
logger.warning("Press Ctrl+C again to force quit (may leave incomplete files)")
|
|
|
else:
|
|
|
logger.error("\nβ Force quit requested. Exiting immediately.")
|
|
|
sys.exit(1)
|
|
|
|
|
|
def setup_signal_handlers():
|
|
|
"""Setup signal handlers for graceful shutdown."""
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model():
|
|
|
"""Lazy load the model (only once per process)."""
|
|
|
global _model
|
|
|
if _model is None:
|
|
|
weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
|
|
|
_model = YOLOv10(weights_path)
|
|
|
logger.info(f"β Model loaded in worker process (PID: {os.getpid()})")
|
|
|
return _model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_worker():
|
|
|
"""Initialize worker process - loads model once at startup."""
|
|
|
try:
|
|
|
get_model()
|
|
|
logger.success(f"Worker {os.getpid()} ready")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_page(pil_img: Image.Image) -> List[dict]:
|
|
|
"""Detect layout elements using YOLO model."""
|
|
|
model = get_model()
|
|
|
img_cv = np.array(pil_img)
|
|
|
results = model.predict(
|
|
|
img_cv,
|
|
|
imgsz=MODEL_SIZE,
|
|
|
conf=CONF_THRESHOLD,
|
|
|
device=DEVICE,
|
|
|
verbose=False
|
|
|
)
|
|
|
dets = []
|
|
|
for i, box in enumerate(results[0].boxes):
|
|
|
cls_id = int(box.cls.item())
|
|
|
name = results[0].names[cls_id]
|
|
|
conf = float(box.conf.item())
|
|
|
x0, y0, x1, y1 = box.xyxy[0].cpu().numpy().tolist()
|
|
|
dets.append({
|
|
|
"name": name,
|
|
|
"bbox": [x0, y0, x1, y1],
|
|
|
"conf": conf,
|
|
|
"source": "yolo",
|
|
|
"index": i
|
|
|
})
|
|
|
return dets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_union_box(box1: List[float], box2: List[float]) -> List[float]:
|
|
|
"""Get the bounding box enclosing two boxes."""
|
|
|
x0 = min(box1[0], box2[0])
|
|
|
y0 = min(box1[1], box2[1])
|
|
|
x1 = max(box1[2], box2[2])
|
|
|
y1 = max(box1[3], box2[3])
|
|
|
return [x0, y0, x1, y1]
|
|
|
|
|
|
def collect_caption_elements(
|
|
|
element: Dict,
|
|
|
all_dets: List[Dict],
|
|
|
target_name: str,
|
|
|
max_vertical_gap: float = 60.0,
|
|
|
min_overlap: float = 0.25,
|
|
|
) -> List[Dict]:
|
|
|
"""
|
|
|
Collect contiguous caption detections directly below a figure/table.
|
|
|
"""
|
|
|
base_box = element["bbox"]
|
|
|
base_bottom = base_box[3]
|
|
|
selected: List[Dict] = []
|
|
|
last_bottom = base_bottom
|
|
|
|
|
|
relevant = [
|
|
|
d for d in all_dets
|
|
|
if d["name"] == target_name and d["bbox"][1] >= base_bottom - 5
|
|
|
]
|
|
|
|
|
|
relevant.sort(key=lambda d: d["bbox"][1])
|
|
|
|
|
|
for cand in relevant:
|
|
|
cand_box = cand["bbox"]
|
|
|
top = cand_box[1]
|
|
|
if selected and top - last_bottom > max_vertical_gap:
|
|
|
break
|
|
|
|
|
|
if selected:
|
|
|
overlap = _horizontal_overlap_ratio(selected[-1]["bbox"], cand_box)
|
|
|
else:
|
|
|
overlap = _horizontal_overlap_ratio(base_box, cand_box)
|
|
|
|
|
|
if overlap < min_overlap:
|
|
|
continue
|
|
|
|
|
|
selected.append(cand)
|
|
|
last_bottom = cand_box[3]
|
|
|
|
|
|
return selected
|
|
|
|
|
|
|
|
|
def collect_title_and_text_segments(
|
|
|
element: Dict,
|
|
|
all_dets: List[Dict],
|
|
|
processed_indices: Set[int],
|
|
|
settings: Optional[Dict[str, float]] = None,
|
|
|
) -> Tuple[List[Dict], List[Dict]]:
|
|
|
"""
|
|
|
Locate a title below the element and any contiguous text blocks directly beneath it.
|
|
|
"""
|
|
|
if settings is None:
|
|
|
settings = TITLE_TEXT_ASSOCIATION
|
|
|
|
|
|
if not element.get("bbox"):
|
|
|
return [], []
|
|
|
|
|
|
figure_box = element["bbox"]
|
|
|
figure_bottom = figure_box[3]
|
|
|
|
|
|
candidates = [
|
|
|
d for d in all_dets
|
|
|
if d.get("bbox") and d["index"] not in processed_indices
|
|
|
]
|
|
|
candidates.sort(key=lambda d: d["bbox"][1])
|
|
|
|
|
|
titles: List[Dict] = []
|
|
|
texts: List[Dict] = []
|
|
|
|
|
|
for idx, det in enumerate(candidates):
|
|
|
if det["name"] != "title":
|
|
|
continue
|
|
|
|
|
|
title_box = det["bbox"]
|
|
|
if title_box[1] < figure_bottom - 5:
|
|
|
continue
|
|
|
|
|
|
vertical_gap = title_box[1] - figure_bottom
|
|
|
if vertical_gap > settings["max_title_gap"]:
|
|
|
break
|
|
|
|
|
|
overlap = _horizontal_overlap_ratio(figure_box, title_box)
|
|
|
if overlap < settings["min_overlap"]:
|
|
|
continue
|
|
|
|
|
|
titles.append(det)
|
|
|
last_bottom = title_box[3]
|
|
|
|
|
|
for follower in candidates[idx + 1 :]:
|
|
|
if follower["name"] == "title":
|
|
|
break
|
|
|
if follower["name"] != "text":
|
|
|
continue
|
|
|
text_box = follower["bbox"]
|
|
|
if text_box[1] < title_box[1]:
|
|
|
continue
|
|
|
|
|
|
gap = text_box[1] - last_bottom
|
|
|
if gap > settings["max_text_gap"]:
|
|
|
break
|
|
|
|
|
|
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
|
|
|
continue
|
|
|
|
|
|
texts.append(follower)
|
|
|
last_bottom = text_box[3]
|
|
|
|
|
|
break
|
|
|
|
|
|
return titles, texts
|
|
|
|
|
|
|
|
|
def save_layout_elements(pil_img: Image.Image, page_num: int,
|
|
|
dets: List[dict], out_dir: Path) -> List[dict]:
|
|
|
"""Save figure and table crops, merging captions."""
|
|
|
fig_dir = out_dir / "figures"
|
|
|
tab_dir = out_dir / "tables"
|
|
|
os.makedirs(fig_dir, exist_ok=True)
|
|
|
os.makedirs(tab_dir, exist_ok=True)
|
|
|
|
|
|
infos = []
|
|
|
fig_count = 0
|
|
|
tab_count = 0
|
|
|
|
|
|
processed_indices = set()
|
|
|
|
|
|
for i, d in enumerate(dets):
|
|
|
if d["index"] in processed_indices:
|
|
|
continue
|
|
|
|
|
|
name = d["name"].lower()
|
|
|
final_box = d["bbox"]
|
|
|
caption_segments: List[Dict] = []
|
|
|
title_segments: List[Dict] = []
|
|
|
text_segments: List[Dict] = []
|
|
|
|
|
|
if name == "figure":
|
|
|
elem_type = "figure"
|
|
|
path_template = fig_dir / f"page_{page_num + 1}_fig_{fig_count}.png"
|
|
|
fig_count += 1
|
|
|
caption_segments = collect_caption_elements(d, dets, "figure_caption")
|
|
|
for cap in caption_segments:
|
|
|
final_box = get_union_box(final_box, cap["bbox"])
|
|
|
processed_indices.add(cap["index"])
|
|
|
title_segments, text_segments = collect_title_and_text_segments(
|
|
|
d, dets, processed_indices
|
|
|
)
|
|
|
for seg in title_segments + text_segments:
|
|
|
final_box = get_union_box(final_box, seg["bbox"])
|
|
|
processed_indices.add(seg["index"])
|
|
|
|
|
|
elif name == "table":
|
|
|
elem_type = "table"
|
|
|
path_template = tab_dir / f"page_{page_num + 1}_tab_{tab_count}.png"
|
|
|
tab_count += 1
|
|
|
caption_segments = collect_caption_elements(d, dets, "table_caption")
|
|
|
for cap in caption_segments:
|
|
|
final_box = get_union_box(final_box, cap["bbox"])
|
|
|
processed_indices.add(cap["index"])
|
|
|
else:
|
|
|
continue
|
|
|
|
|
|
x0, y0, x1, y1 = map(int, final_box)
|
|
|
crop = pil_img.crop((x0, y0, x1, y1))
|
|
|
|
|
|
if crop.mode == "CMYK":
|
|
|
crop = crop.convert("RGB")
|
|
|
|
|
|
crop.save(path_template)
|
|
|
|
|
|
info_data = {
|
|
|
"type": elem_type,
|
|
|
"page": page_num + 1,
|
|
|
"bbox_pixels": final_box,
|
|
|
"conf": d["conf"],
|
|
|
"source": d.get("source", "yolo"),
|
|
|
"image_path": str(path_template.relative_to(out_dir)),
|
|
|
"width": int(x1 - x0),
|
|
|
"height": int(y1 - y0),
|
|
|
"page_width": pil_img.width,
|
|
|
"page_height": pil_img.height,
|
|
|
}
|
|
|
if caption_segments:
|
|
|
info_data["captions"] = [
|
|
|
{
|
|
|
"bbox": cap["bbox"],
|
|
|
"conf": cap.get("conf"),
|
|
|
"index": cap["index"],
|
|
|
"source": cap.get("source"),
|
|
|
"page": page_num + 1,
|
|
|
}
|
|
|
for cap in caption_segments
|
|
|
]
|
|
|
if title_segments:
|
|
|
info_data["titles"] = [
|
|
|
{
|
|
|
"bbox": seg["bbox"],
|
|
|
"conf": seg.get("conf"),
|
|
|
"index": seg["index"],
|
|
|
"source": seg.get("source"),
|
|
|
"page": page_num + 1,
|
|
|
}
|
|
|
for seg in title_segments
|
|
|
]
|
|
|
if text_segments:
|
|
|
info_data["texts"] = [
|
|
|
{
|
|
|
"bbox": seg["bbox"],
|
|
|
"conf": seg.get("conf"),
|
|
|
"index": seg["index"],
|
|
|
"source": seg.get("source"),
|
|
|
"page": page_num + 1,
|
|
|
}
|
|
|
for seg in text_segments
|
|
|
]
|
|
|
|
|
|
infos.append(info_data)
|
|
|
|
|
|
return infos
|
|
|
|
|
|
|
|
|
TABLE_STITCH_TOLERANCES = {
|
|
|
"x_tol": 60,
|
|
|
"y_tol": 60,
|
|
|
"width_tol": 120,
|
|
|
"height_tol": 120,
|
|
|
}
|
|
|
|
|
|
CROSS_PAGE_CAPTION_THRESHOLDS = {
|
|
|
"max_top_ratio": 0.35,
|
|
|
"max_top_pixels": 220,
|
|
|
"x_tol": 120,
|
|
|
"width_tol": 200,
|
|
|
"min_overlap": 0.05,
|
|
|
}
|
|
|
|
|
|
TITLE_TEXT_ASSOCIATION = {
|
|
|
"max_title_gap": 220,
|
|
|
"max_text_gap": 160,
|
|
|
"min_overlap": 0.2,
|
|
|
}
|
|
|
|
|
|
|
|
|
def _horizontal_overlap_ratio(box1: List[float], box2: List[float]) -> float:
|
|
|
"""Compute horizontal overlap ratio between two bounding boxes."""
|
|
|
x_left = max(box1[0], box2[0])
|
|
|
x_right = min(box1[2], box2[2])
|
|
|
overlap = max(0.0, x_right - x_left)
|
|
|
if overlap <= 0:
|
|
|
return 0.0
|
|
|
width_union = max(box1[2], box2[2]) - min(box1[0], box2[0])
|
|
|
if width_union <= 0:
|
|
|
return 0.0
|
|
|
return overlap / width_union
|
|
|
|
|
|
|
|
|
def _bbox_to_rect(bbox: List[float]) -> Tuple[int, int, int, int]:
|
|
|
"""Convert [x0, y0, x1, y1] into (x, y, w, h)."""
|
|
|
x0, y0, x1, y1 = bbox
|
|
|
return int(x0), int(y0), int(x1 - x0), int(y1 - y0)
|
|
|
|
|
|
|
|
|
def _open_table_image(elem: Dict, out_dir: Path) -> Optional[Image.Image]:
|
|
|
"""Open a table image relative to the output directory."""
|
|
|
image_path = out_dir / elem["image_path"]
|
|
|
if not image_path.exists():
|
|
|
logger.warning(f"Missing table crop for stitching: {image_path}")
|
|
|
return None
|
|
|
img = Image.open(image_path)
|
|
|
if img.mode != "RGB":
|
|
|
img = img.convert("RGB")
|
|
|
return img
|
|
|
|
|
|
|
|
|
def _pad_width(img: Image.Image, target_width: int) -> Image.Image:
|
|
|
if img.width >= target_width:
|
|
|
return img
|
|
|
canvas = Image.new("RGB", (target_width, img.height), color=(255, 255, 255))
|
|
|
canvas.paste(img, (0, 0))
|
|
|
return canvas
|
|
|
|
|
|
|
|
|
def _pad_height(img: Image.Image, target_height: int) -> Image.Image:
|
|
|
if img.height >= target_height:
|
|
|
return img
|
|
|
canvas = Image.new("RGB", (img.width, target_height), color=(255, 255, 255))
|
|
|
canvas.paste(img, (0, 0))
|
|
|
return canvas
|
|
|
|
|
|
|
|
|
def _append_segment_image(
|
|
|
base_img: Image.Image,
|
|
|
segment_img: Image.Image,
|
|
|
resize_to_base: bool = False,
|
|
|
) -> Image.Image:
|
|
|
"""Append segment image below base image with optional width alignment."""
|
|
|
if base_img.mode != "RGB":
|
|
|
base_img = base_img.convert("RGB")
|
|
|
if segment_img.mode != "RGB":
|
|
|
segment_img = segment_img.convert("RGB")
|
|
|
|
|
|
if resize_to_base and segment_img.width > 0 and base_img.width > 0:
|
|
|
segment_img = segment_img.resize(
|
|
|
(
|
|
|
base_img.width,
|
|
|
max(1, int(segment_img.height * (base_img.width / segment_img.width))),
|
|
|
),
|
|
|
Image.Resampling.LANCZOS,
|
|
|
)
|
|
|
|
|
|
target_width = max(base_img.width, segment_img.width)
|
|
|
base_img = _pad_width(base_img, target_width)
|
|
|
segment_img = _pad_width(segment_img, target_width)
|
|
|
|
|
|
stitched = Image.new(
|
|
|
"RGB",
|
|
|
(target_width, base_img.height + segment_img.height),
|
|
|
color=(255, 255, 255),
|
|
|
)
|
|
|
stitched.paste(base_img, (0, 0))
|
|
|
stitched.paste(segment_img, (0, base_img.height))
|
|
|
return stitched
|
|
|
|
|
|
|
|
|
def _render_pdf_page(
|
|
|
pdf_doc: pdfium.PdfDocument,
|
|
|
page_index: int,
|
|
|
scale: float,
|
|
|
cache: Dict[int, Image.Image],
|
|
|
) -> Optional[Image.Image]:
|
|
|
"""Render a PDF page to a PIL image with caching."""
|
|
|
if page_index in cache:
|
|
|
return cache[page_index]
|
|
|
|
|
|
try:
|
|
|
page = pdf_doc[page_index]
|
|
|
bitmap = page.render(scale=scale)
|
|
|
pil_img = bitmap.to_pil()
|
|
|
page.close()
|
|
|
except Exception as exc:
|
|
|
logger.error(f"Failed to render page {page_index + 1} for caption stitching: {exc}")
|
|
|
return None
|
|
|
|
|
|
cache[page_index] = pil_img
|
|
|
return pil_img
|
|
|
|
|
|
|
|
|
def _crop_pdf_region(
|
|
|
page_img: Optional[Image.Image], bbox: List[float]
|
|
|
) -> Optional[Image.Image]:
|
|
|
"""Crop a region from a rendered PDF page."""
|
|
|
if page_img is None:
|
|
|
return None
|
|
|
|
|
|
x0, y0, x1, y1 = map(int, bbox)
|
|
|
x0 = max(0, x0)
|
|
|
y0 = max(0, y0)
|
|
|
x1 = min(page_img.width, max(x0 + 1, x1))
|
|
|
y1 = min(page_img.height, max(y0 + 1, y1))
|
|
|
|
|
|
if x0 >= x1 or y0 >= y1:
|
|
|
return None
|
|
|
|
|
|
crop = page_img.crop((x0, y0, x1, y1))
|
|
|
if crop.mode == "CMYK":
|
|
|
crop = crop.convert("RGB")
|
|
|
return crop
|
|
|
|
|
|
|
|
|
def write_markdown_document(pdf_path: Path, out_dir: Path) -> Optional[Path]:
|
|
|
"""
|
|
|
Extract markdown text from a PDF using PyMuPDF4LLM and write it to disk.
|
|
|
"""
|
|
|
if pymupdf4llm is None:
|
|
|
logger.warning(
|
|
|
"Skipping markdown extraction for %s because pymupdf4llm is not installed.",
|
|
|
pdf_path.name,
|
|
|
)
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
markdown_content = pymupdf4llm.to_markdown(str(pdf_path))
|
|
|
except Exception as exc:
|
|
|
logger.error(f" Failed to create markdown for {pdf_path.name}: {exc}")
|
|
|
return None
|
|
|
|
|
|
if isinstance(markdown_content, list):
|
|
|
markdown_content = "\n\n".join(
|
|
|
part for part in markdown_content if isinstance(part, str)
|
|
|
)
|
|
|
|
|
|
if not isinstance(markdown_content, str):
|
|
|
logger.error(
|
|
|
f" Unexpected markdown output type {type(markdown_content)} for {pdf_path.name}"
|
|
|
)
|
|
|
return None
|
|
|
|
|
|
markdown_content = markdown_content.strip()
|
|
|
if not markdown_content:
|
|
|
logger.warning(f" No textual content extracted from {pdf_path.name}")
|
|
|
return None
|
|
|
|
|
|
if not markdown_content.endswith("\n"):
|
|
|
markdown_content += "\n"
|
|
|
|
|
|
md_path = out_dir / f"{pdf_path.stem}.md"
|
|
|
md_path.write_text(markdown_content, encoding="utf-8")
|
|
|
logger.info(f" Saved markdown to {md_path.name}")
|
|
|
return md_path
|
|
|
|
|
|
|
|
|
def _collect_text_under_title_cross_page(
|
|
|
title_det: Dict,
|
|
|
sorted_dets: List[Dict],
|
|
|
start_idx: int,
|
|
|
page_idx: int,
|
|
|
used_indices: Set[Tuple[int, int]],
|
|
|
settings: Optional[Dict[str, float]] = None,
|
|
|
) -> List[Dict]:
|
|
|
"""Collect text elements directly below a title on the next page."""
|
|
|
if settings is None:
|
|
|
settings = TITLE_TEXT_ASSOCIATION
|
|
|
texts: List[Dict] = []
|
|
|
title_box = title_det["bbox"]
|
|
|
last_bottom = title_box[3]
|
|
|
|
|
|
for follower in sorted_dets[start_idx + 1 :]:
|
|
|
det_index = follower.get("index")
|
|
|
if det_index is None or (page_idx, det_index) in used_indices:
|
|
|
continue
|
|
|
|
|
|
if follower["name"] == "title":
|
|
|
break
|
|
|
|
|
|
if follower["name"] != "text":
|
|
|
continue
|
|
|
|
|
|
text_box = follower["bbox"]
|
|
|
if text_box[1] < title_box[1]:
|
|
|
continue
|
|
|
|
|
|
gap = text_box[1] - last_bottom
|
|
|
if gap > settings["max_text_gap"]:
|
|
|
break
|
|
|
|
|
|
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
|
|
|
continue
|
|
|
|
|
|
texts.append(follower)
|
|
|
last_bottom = text_box[3]
|
|
|
|
|
|
return texts
|
|
|
|
|
|
|
|
|
def attach_cross_page_figure_captions(
|
|
|
elements: List[Dict],
|
|
|
all_dets: Sequence[Optional[List[Dict[str, Any]]]],
|
|
|
pdf_bytes: bytes,
|
|
|
out_dir: Path,
|
|
|
scale: float,
|
|
|
) -> List[Dict]:
|
|
|
"""
|
|
|
If a figure caption appears on the next page, stitch it to the prior figure.
|
|
|
"""
|
|
|
figures = [elem for elem in elements if elem.get("type") == "figure"]
|
|
|
if not figures or not all_dets:
|
|
|
return elements
|
|
|
|
|
|
try:
|
|
|
pdf_doc = pdfium.PdfDocument(pdf_bytes)
|
|
|
except Exception as exc:
|
|
|
logger.error(f"Unable to reopen PDF for figure caption stitching: {exc}")
|
|
|
return elements
|
|
|
|
|
|
page_cache: Dict[int, Image.Image] = {}
|
|
|
used_following_ids: Set[Tuple[int, int]] = set()
|
|
|
|
|
|
|
|
|
for elem in figures:
|
|
|
for key in ("captions", "titles", "texts"):
|
|
|
for seg in elem.get(key, []) or []:
|
|
|
idx = seg.get("index")
|
|
|
page_no = seg.get("page")
|
|
|
if idx is None or page_no is None:
|
|
|
continue
|
|
|
used_following_ids.add((page_no - 1, idx))
|
|
|
|
|
|
for elem in figures:
|
|
|
page_no = elem.get("page")
|
|
|
bbox = elem.get("bbox_pixels")
|
|
|
if page_no is None or bbox is None:
|
|
|
continue
|
|
|
|
|
|
current_idx = page_no - 1
|
|
|
next_idx = current_idx + 1
|
|
|
if next_idx >= len(all_dets):
|
|
|
continue
|
|
|
|
|
|
next_dets = all_dets[next_idx]
|
|
|
if not next_dets:
|
|
|
continue
|
|
|
|
|
|
fig_width = bbox[2] - bbox[0]
|
|
|
page_img = _render_pdf_page(pdf_doc, next_idx, scale, page_cache)
|
|
|
if page_img is None:
|
|
|
continue
|
|
|
|
|
|
next_page_height = page_img.height
|
|
|
max_top_allowed = min(
|
|
|
CROSS_PAGE_CAPTION_THRESHOLDS["max_top_pixels"],
|
|
|
int(next_page_height * CROSS_PAGE_CAPTION_THRESHOLDS["max_top_ratio"]),
|
|
|
)
|
|
|
|
|
|
sorted_next = sorted(
|
|
|
[det for det in next_dets if det.get("bbox")],
|
|
|
key=lambda det: det["bbox"][1],
|
|
|
)
|
|
|
|
|
|
caption_candidate: Optional[Tuple[Dict, int]] = None
|
|
|
caption_candidates = []
|
|
|
for det in sorted_next:
|
|
|
if det.get("name") != "figure_caption":
|
|
|
continue
|
|
|
det_index = det.get("index")
|
|
|
if det_index is None or (next_idx, det_index) in used_following_ids:
|
|
|
continue
|
|
|
|
|
|
det_bbox = det.get("bbox")
|
|
|
if not det_bbox or det_bbox[1] > max_top_allowed:
|
|
|
continue
|
|
|
|
|
|
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
|
|
|
x_diff = abs(bbox[0] - det_bbox[0])
|
|
|
width_diff = abs((bbox[2] - bbox[0]) - (det_bbox[2] - det_bbox[0]))
|
|
|
|
|
|
if overlap < CROSS_PAGE_CAPTION_THRESHOLDS["min_overlap"]:
|
|
|
if (
|
|
|
x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
|
|
|
or width_diff > CROSS_PAGE_CAPTION_THRESHOLDS["width_tol"]
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
score = width_diff + 0.5 * x_diff
|
|
|
caption_candidates.append((score, det, det_index))
|
|
|
|
|
|
if caption_candidates:
|
|
|
caption_candidates.sort(key=lambda item: item[0])
|
|
|
_, best_det, best_index = caption_candidates[0]
|
|
|
caption_candidate = (best_det, best_index)
|
|
|
|
|
|
title_candidate: Optional[Tuple[Dict, int]] = None
|
|
|
title_texts: List[Dict] = []
|
|
|
for idx_sorted, det in enumerate(sorted_next):
|
|
|
if det.get("name") != "title":
|
|
|
continue
|
|
|
det_index = det.get("index")
|
|
|
if det_index is None or (next_idx, det_index) in used_following_ids:
|
|
|
continue
|
|
|
|
|
|
det_bbox = det.get("bbox")
|
|
|
if not det_bbox or det_bbox[1] > max_top_allowed:
|
|
|
continue
|
|
|
|
|
|
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
|
|
|
x_diff = abs(bbox[0] - det_bbox[0])
|
|
|
if (
|
|
|
overlap < TITLE_TEXT_ASSOCIATION["min_overlap"]
|
|
|
and x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
title_candidate = (det, det_index)
|
|
|
title_texts = _collect_text_under_title_cross_page(
|
|
|
det, sorted_next, idx_sorted, next_idx, used_following_ids
|
|
|
)
|
|
|
break
|
|
|
|
|
|
if not caption_candidate and not title_candidate and not title_texts:
|
|
|
continue
|
|
|
|
|
|
figure_path = out_dir / elem["image_path"]
|
|
|
if not figure_path.exists():
|
|
|
continue
|
|
|
|
|
|
figure_img = Image.open(figure_path)
|
|
|
if figure_img.mode == "CMYK":
|
|
|
figure_img = figure_img.convert("RGB")
|
|
|
|
|
|
segments_added = False
|
|
|
|
|
|
if caption_candidate:
|
|
|
cap_det, cap_index = caption_candidate
|
|
|
caption_crop = _crop_pdf_region(page_img, cap_det["bbox"])
|
|
|
if caption_crop is not None:
|
|
|
figure_img = _append_segment_image(
|
|
|
figure_img, caption_crop, resize_to_base=True
|
|
|
)
|
|
|
elem.setdefault("captions", [])
|
|
|
elem["captions"].append(
|
|
|
{
|
|
|
"bbox": cap_det["bbox"],
|
|
|
"conf": cap_det.get("conf"),
|
|
|
"index": cap_index,
|
|
|
"source": cap_det.get("source"),
|
|
|
"page": next_idx + 1,
|
|
|
}
|
|
|
)
|
|
|
used_following_ids.add((next_idx, cap_index))
|
|
|
segments_added = True
|
|
|
|
|
|
if title_candidate:
|
|
|
title_det, title_index = title_candidate
|
|
|
title_crop = _crop_pdf_region(page_img, title_det["bbox"])
|
|
|
if title_crop is not None:
|
|
|
figure_img = _append_segment_image(figure_img, title_crop)
|
|
|
elem.setdefault("titles", [])
|
|
|
elem["titles"].append(
|
|
|
{
|
|
|
"bbox": title_det["bbox"],
|
|
|
"conf": title_det.get("conf"),
|
|
|
"index": title_index,
|
|
|
"source": title_det.get("source"),
|
|
|
"page": next_idx + 1,
|
|
|
}
|
|
|
)
|
|
|
used_following_ids.add((next_idx, title_index))
|
|
|
segments_added = True
|
|
|
|
|
|
for text_det in title_texts:
|
|
|
text_index = text_det.get("index")
|
|
|
text_crop = _crop_pdf_region(page_img, text_det["bbox"])
|
|
|
if text_crop is None:
|
|
|
continue
|
|
|
figure_img = _append_segment_image(figure_img, text_crop)
|
|
|
elem.setdefault("texts", [])
|
|
|
elem["texts"].append(
|
|
|
{
|
|
|
"bbox": text_det["bbox"],
|
|
|
"conf": text_det.get("conf"),
|
|
|
"index": text_index,
|
|
|
"source": text_det.get("source"),
|
|
|
"page": next_idx + 1,
|
|
|
}
|
|
|
)
|
|
|
if text_index is not None:
|
|
|
used_following_ids.add((next_idx, text_index))
|
|
|
segments_added = True
|
|
|
|
|
|
if not segments_added:
|
|
|
continue
|
|
|
|
|
|
figure_img.save(figure_path)
|
|
|
elem["width"] = figure_img.width
|
|
|
elem["height"] = figure_img.height
|
|
|
|
|
|
span = elem.get("page_span")
|
|
|
if span:
|
|
|
if next_idx + 1 not in span:
|
|
|
span.append(next_idx + 1)
|
|
|
else:
|
|
|
base_page = elem.get("page")
|
|
|
new_span = [page for page in (base_page, next_idx + 1) if page is not None]
|
|
|
elem["page_span"] = new_span
|
|
|
|
|
|
pdf_doc.close()
|
|
|
return elements
|
|
|
|
|
|
|
|
|
def _stitch_table_pair(
|
|
|
base_elem: Dict,
|
|
|
candidate_elem: Dict,
|
|
|
out_dir: Path,
|
|
|
merge_index: int,
|
|
|
stitch_type: str,
|
|
|
) -> Optional[Dict]:
|
|
|
"""Stitch two table crops either vertically or horizontally."""
|
|
|
base_img = _open_table_image(base_elem, out_dir)
|
|
|
candidate_img = _open_table_image(candidate_elem, out_dir)
|
|
|
if base_img is None or candidate_img is None:
|
|
|
return None
|
|
|
|
|
|
tables_dir = out_dir / "tables"
|
|
|
tables_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
if stitch_type == "vertical":
|
|
|
target_width = max(base_img.width, candidate_img.width)
|
|
|
base_img = _pad_width(base_img, target_width)
|
|
|
candidate_img = _pad_width(candidate_img, target_width)
|
|
|
merged_height = base_img.height + candidate_img.height
|
|
|
stitched = Image.new("RGB", (target_width, merged_height), color=(255, 255, 255))
|
|
|
stitched.paste(base_img, (0, 0))
|
|
|
stitched.paste(candidate_img, (0, base_img.height))
|
|
|
else:
|
|
|
target_height = max(base_img.height, candidate_img.height)
|
|
|
base_img = _pad_height(base_img, target_height)
|
|
|
candidate_img = _pad_height(candidate_img, target_height)
|
|
|
merged_width = base_img.width + candidate_img.width
|
|
|
stitched = Image.new("RGB", (merged_width, target_height), color=(255, 255, 255))
|
|
|
stitched.paste(base_img, (0, 0))
|
|
|
stitched.paste(candidate_img, (base_img.width, 0))
|
|
|
|
|
|
merged_name = (
|
|
|
f"page_{base_elem['page']}_to_{candidate_elem['page']}_"
|
|
|
f"table_merged_{merge_index}.png"
|
|
|
)
|
|
|
merged_path = tables_dir / merged_name
|
|
|
stitched.save(merged_path)
|
|
|
|
|
|
|
|
|
(out_dir / base_elem["image_path"]).unlink(missing_ok=True)
|
|
|
(out_dir / candidate_elem["image_path"]).unlink(missing_ok=True)
|
|
|
|
|
|
new_bbox = [
|
|
|
min(base_elem["bbox_pixels"][0], candidate_elem["bbox_pixels"][0]),
|
|
|
min(base_elem["bbox_pixels"][1], candidate_elem["bbox_pixels"][1]),
|
|
|
max(base_elem["bbox_pixels"][2], candidate_elem["bbox_pixels"][2]),
|
|
|
max(base_elem["bbox_pixels"][3], candidate_elem["bbox_pixels"][3]),
|
|
|
]
|
|
|
|
|
|
merged_elem = base_elem.copy()
|
|
|
merged_elem["page_span"] = [base_elem["page"], candidate_elem["page"]]
|
|
|
merged_elem["box_refs"] = [
|
|
|
{"page": base_elem["page"], "image_path": base_elem["image_path"]},
|
|
|
{"page": candidate_elem["page"], "image_path": candidate_elem["image_path"]},
|
|
|
]
|
|
|
merged_elem["bbox_pixels"] = new_bbox
|
|
|
merged_elem["image_path"] = str(merged_path.relative_to(out_dir))
|
|
|
merged_elem["width"] = stitched.width
|
|
|
merged_elem["height"] = stitched.height
|
|
|
merged_elem["page_height"] = stitched.height
|
|
|
merged_elem["conf"] = min(
|
|
|
base_elem.get("conf", 1.0), candidate_elem.get("conf", 1.0)
|
|
|
)
|
|
|
return merged_elem
|
|
|
|
|
|
|
|
|
def merge_spanning_tables(elements: List[Dict], out_dir: Path) -> List[Dict]:
|
|
|
"""
|
|
|
Stitch table crops that continue across adjacent pages using the heuristic
|
|
|
from the legacy OpenCV-based extractor.
|
|
|
"""
|
|
|
if not elements:
|
|
|
return elements
|
|
|
|
|
|
tables_by_page: Dict[int, List[Dict]] = {}
|
|
|
non_tables: List[Dict] = []
|
|
|
|
|
|
for elem in elements:
|
|
|
if elem.get("type") != "table":
|
|
|
non_tables.append(elem)
|
|
|
continue
|
|
|
page = elem.get("page")
|
|
|
if not isinstance(page, int):
|
|
|
non_tables.append(elem)
|
|
|
continue
|
|
|
tables_by_page.setdefault(page, []).append(elem)
|
|
|
|
|
|
merged_results: List[Dict] = []
|
|
|
used_next: Dict[int, set[int]] = {}
|
|
|
merge_counter = 0
|
|
|
|
|
|
for page in sorted(tables_by_page.keys()):
|
|
|
current_tables = tables_by_page.get(page, [])
|
|
|
next_page_tables = tables_by_page.get(page + 1, [])
|
|
|
next_used_indices = used_next.get(page + 1, set())
|
|
|
current_used_indices = used_next.get(page, set())
|
|
|
|
|
|
for idx_current, table_elem in enumerate(current_tables):
|
|
|
if idx_current in current_used_indices:
|
|
|
continue
|
|
|
|
|
|
if not next_page_tables:
|
|
|
merged_results.append(table_elem)
|
|
|
continue
|
|
|
|
|
|
x, y, w, h = _bbox_to_rect(table_elem["bbox_pixels"])
|
|
|
matched = False
|
|
|
|
|
|
for idx, candidate in enumerate(next_page_tables):
|
|
|
if idx in next_used_indices:
|
|
|
continue
|
|
|
if candidate.get("type") != "table":
|
|
|
continue
|
|
|
|
|
|
cx, cy, cw, ch = _bbox_to_rect(candidate["bbox_pixels"])
|
|
|
|
|
|
vertical_match = (
|
|
|
abs(x - cx) <= TABLE_STITCH_TOLERANCES["x_tol"]
|
|
|
and abs((x + w) - (cx + cw)) <= TABLE_STITCH_TOLERANCES["width_tol"]
|
|
|
)
|
|
|
horizontal_match = (
|
|
|
abs(y - cy) <= TABLE_STITCH_TOLERANCES["y_tol"]
|
|
|
and abs((y + h) - (cy + ch))
|
|
|
<= TABLE_STITCH_TOLERANCES["height_tol"]
|
|
|
)
|
|
|
|
|
|
stitch_type = "vertical" if vertical_match else None
|
|
|
if not stitch_type and horizontal_match:
|
|
|
stitch_type = "horizontal"
|
|
|
|
|
|
if not stitch_type:
|
|
|
continue
|
|
|
|
|
|
merge_counter += 1
|
|
|
merged_elem = _stitch_table_pair(
|
|
|
table_elem, candidate, out_dir, merge_counter, stitch_type
|
|
|
)
|
|
|
if merged_elem is None:
|
|
|
continue
|
|
|
|
|
|
merged_results.append(merged_elem)
|
|
|
next_used_indices.add(idx)
|
|
|
matched = True
|
|
|
break
|
|
|
|
|
|
if not matched:
|
|
|
merged_results.append(table_elem)
|
|
|
|
|
|
used_next[page + 1] = next_used_indices
|
|
|
|
|
|
merged_results.extend(non_tables)
|
|
|
return merged_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_layout_pdf(pdf_bytes: bytes, all_dets: List[List[dict]],
|
|
|
scale: float, out_path: Path):
|
|
|
"""Annotate PDF with semi-transparent bounding boxes and labels."""
|
|
|
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
|
|
|
|
|
for page_no, dets in enumerate(all_dets):
|
|
|
page = doc[page_no]
|
|
|
|
|
|
for d in dets:
|
|
|
rgb = CLASS_COLORS.get(d["name"], (0, 0, 0))
|
|
|
rect = fitz.Rect([c / scale for c in d["bbox"]])
|
|
|
|
|
|
border_color = [c / 255 for c in rgb]
|
|
|
fill_color = [c / 255 for c in rgb]
|
|
|
fill_opacity = 0.15
|
|
|
border_width = 1.5
|
|
|
|
|
|
page.draw_rect(
|
|
|
rect,
|
|
|
color=border_color,
|
|
|
fill=fill_color,
|
|
|
width=border_width,
|
|
|
overlay=True,
|
|
|
fill_opacity=fill_opacity
|
|
|
)
|
|
|
|
|
|
label = f"{d['name']} {d['conf']:.2f}"
|
|
|
if d.get("source"):
|
|
|
label += f" [{d['source'][0].upper()}]"
|
|
|
|
|
|
text_bg = fitz.Rect(rect.x0, rect.y0 - 10, rect.x0 + 60, rect.y0)
|
|
|
page.draw_rect(text_bg, color=None, fill=(1, 1, 1, 0.6), overlay=True)
|
|
|
|
|
|
page.insert_text(
|
|
|
(rect.x0 + 2, rect.y0 - 8),
|
|
|
label,
|
|
|
fontsize=6.5,
|
|
|
color=border_color,
|
|
|
overlay=True
|
|
|
)
|
|
|
|
|
|
doc.save(str(out_path))
|
|
|
doc.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_page(task_data: Tuple[int, bytes, float, Path, str]) -> Optional[Tuple[int, List[dict], List[dict]]]:
|
|
|
"""
|
|
|
Process a single page of a PDF in a worker process.
|
|
|
Returns: (page_number, detections, elements) or None on failure
|
|
|
"""
|
|
|
pno, pdf_bytes, scale, out_dir, pdf_name = task_data
|
|
|
|
|
|
if _shutdown_requested:
|
|
|
return None
|
|
|
|
|
|
pdf_pdfium = None
|
|
|
try:
|
|
|
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
|
|
|
|
|
|
page = pdf_pdfium[pno]
|
|
|
bitmap = page.render(scale=scale)
|
|
|
pil = bitmap.to_pil()
|
|
|
|
|
|
dets = detect_page(pil)
|
|
|
elements = save_layout_elements(pil, pno, dets, out_dir)
|
|
|
|
|
|
page_figures = len([d for d in dets if d['name'] == 'figure'])
|
|
|
page_tables = len([d for d in dets if d['name'] == 'table'])
|
|
|
logger.info(f" [{pdf_name}] Page {pno + 1}: {page_figures} figs, {page_tables} tables")
|
|
|
|
|
|
page.close()
|
|
|
pdf_pdfium.close()
|
|
|
|
|
|
return (pno, dets, elements)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to process page {pno + 1} of {pdf_name}: {e}")
|
|
|
if pdf_pdfium:
|
|
|
pdf_pdfium.close()
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_pdf_with_pool(
|
|
|
pdf_path: Path,
|
|
|
out_dir: Path,
|
|
|
pool: Optional[Pool] = None,
|
|
|
*,
|
|
|
extract_images: bool = True,
|
|
|
extract_markdown: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Main processing pipeline for a PDF file.
|
|
|
If pool is provided, uses it. Otherwise processes serially.
|
|
|
"""
|
|
|
|
|
|
if _shutdown_requested:
|
|
|
logger.warning(f"Skipping {pdf_path.name} due to shutdown request")
|
|
|
return
|
|
|
|
|
|
stem = pdf_path.stem
|
|
|
logger.info(f"Processing {pdf_path.name}")
|
|
|
|
|
|
pdf_bytes = pdf_path.read_bytes()
|
|
|
|
|
|
doc = None
|
|
|
try:
|
|
|
doc = pdfium.PdfDocument(pdf_bytes)
|
|
|
page_count = len(doc)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to open PDF {pdf_path.name}: {e}. Skipping.")
|
|
|
return
|
|
|
finally:
|
|
|
if doc is not None:
|
|
|
doc.close()
|
|
|
|
|
|
scale = 2.0
|
|
|
all_elements: List[Dict] = []
|
|
|
filtered_dets: List[List[dict]] = []
|
|
|
|
|
|
if extract_images:
|
|
|
all_dets: List[Optional[List[dict]]] = [None] * page_count
|
|
|
|
|
|
if pool is not None and USE_MULTIPROCESSING:
|
|
|
logger.info(f" Using worker pool for {page_count} pages...")
|
|
|
|
|
|
tasks = [
|
|
|
(pno, pdf_bytes, scale, out_dir, pdf_path.name)
|
|
|
for pno in range(page_count)
|
|
|
]
|
|
|
|
|
|
try:
|
|
|
results = pool.map(process_page, tasks)
|
|
|
|
|
|
for res in results:
|
|
|
if res:
|
|
|
pno, dets, elements = res
|
|
|
all_dets[pno] = dets
|
|
|
all_elements.extend(elements)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
logger.warning("Processing interrupted during parallel execution")
|
|
|
raise
|
|
|
|
|
|
else:
|
|
|
logger.info("Using serial processing...")
|
|
|
|
|
|
try:
|
|
|
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
|
|
|
|
|
|
for pno in range(page_count):
|
|
|
if _shutdown_requested:
|
|
|
logger.warning(
|
|
|
f"Stopping at page {pno + 1}/{page_count} due to shutdown request"
|
|
|
)
|
|
|
break
|
|
|
|
|
|
try:
|
|
|
logger.info(f" Processing page {pno + 1}/{page_count}")
|
|
|
|
|
|
page = pdf_pdfium[pno]
|
|
|
bitmap = page.render(scale=scale)
|
|
|
pil = bitmap.to_pil()
|
|
|
|
|
|
dets = detect_page(pil)
|
|
|
all_dets[pno] = dets
|
|
|
|
|
|
elements = save_layout_elements(pil, pno, dets, out_dir)
|
|
|
all_elements.extend(elements)
|
|
|
|
|
|
page_figures = len([d for d in dets if d["name"] == "figure"])
|
|
|
page_tables = len([d for d in dets if d["name"] == "table"])
|
|
|
logger.info(
|
|
|
f" Found {page_figures} figures and {page_tables} tables"
|
|
|
)
|
|
|
|
|
|
page.close()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to process page {pno + 1}: {e}. Skipping page.")
|
|
|
|
|
|
pdf_pdfium.close()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Fatal error processing {pdf_path.name}: {e}")
|
|
|
if "pdf_pdfium" in locals() and pdf_pdfium:
|
|
|
pdf_pdfium.close()
|
|
|
return
|
|
|
|
|
|
dets_per_page: List[Optional[List[Dict[str, Any]]]] = [
|
|
|
det if det is not None else None for det in all_dets
|
|
|
]
|
|
|
|
|
|
filtered_dets = [d for d in all_dets if d is not None]
|
|
|
|
|
|
if all_elements:
|
|
|
all_elements = merge_spanning_tables(all_elements, out_dir)
|
|
|
all_elements = attach_cross_page_figure_captions(
|
|
|
all_elements, dets_per_page, pdf_bytes, out_dir, scale
|
|
|
)
|
|
|
|
|
|
if all_elements:
|
|
|
content_list_path = out_dir / f"{stem}_content_list.json"
|
|
|
with open(content_list_path, "w", encoding="utf-8") as f:
|
|
|
json.dump(all_elements, f, ensure_ascii=False, indent=4)
|
|
|
logger.info(f" Saved {len(all_elements)} elements to JSON")
|
|
|
|
|
|
if filtered_dets:
|
|
|
draw_layout_pdf(
|
|
|
pdf_bytes, filtered_dets, scale, out_dir / f"{stem}_layout.pdf"
|
|
|
)
|
|
|
logger.info(" Generated annotated PDF")
|
|
|
else:
|
|
|
logger.warning(f"No detections found for {stem}. Skipping layout PDF.")
|
|
|
|
|
|
else:
|
|
|
logger.info(" Image extraction skipped per configuration.")
|
|
|
|
|
|
markdown_path = None
|
|
|
if extract_markdown:
|
|
|
markdown_path = write_markdown_document(pdf_path, out_dir)
|
|
|
if markdown_path is None:
|
|
|
logger.warning(f" Markdown extraction yielded no content for {stem}.")
|
|
|
|
|
|
if _shutdown_requested:
|
|
|
logger.warning(f"β οΈ Partial results saved for {stem} β {out_dir}")
|
|
|
else:
|
|
|
if extract_images:
|
|
|
logger.success(
|
|
|
f"β {stem} β {out_dir} ({len(all_elements)} elements extracted)"
|
|
|
)
|
|
|
else:
|
|
|
logger.success(f"β {stem} β {out_dir} (image extraction skipped)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
torch.multiprocessing.set_start_method('spawn', force=True)
|
|
|
|
|
|
|
|
|
setup_signal_handlers()
|
|
|
|
|
|
INPUT_DIR = Path("./pdfs")
|
|
|
OUTPUT_DIR = Path("./output")
|
|
|
|
|
|
os.makedirs(INPUT_DIR, exist_ok=True)
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
pdf_files = list(INPUT_DIR.glob("*.pdf"))
|
|
|
if not pdf_files:
|
|
|
logger.warning("No PDF files found in ./pdfs")
|
|
|
logger.info("Please add PDF files to the ./pdfs directory")
|
|
|
logger.info("The script will exit gracefully. No errors occurred.")
|
|
|
sys.exit(0)
|
|
|
|
|
|
logger.info(f"Found {len(pdf_files)} PDF file(s) to process")
|
|
|
logger.info(f"Settings: MODEL_SIZE={MODEL_SIZE}, CONF={CONF_THRESHOLD}")
|
|
|
|
|
|
|
|
|
total_cpus = cpu_count()
|
|
|
if NUM_WORKERS is None:
|
|
|
num_workers = max(1, total_cpus - 1)
|
|
|
else:
|
|
|
num_workers = max(1, min(NUM_WORKERS, total_cpus))
|
|
|
|
|
|
|
|
|
use_pool = USE_MULTIPROCESSING and DEVICE == "cpu" and total_cpus >= 4
|
|
|
|
|
|
if use_pool:
|
|
|
logger.info(f"π Creating persistent worker pool with {num_workers} workers...")
|
|
|
else:
|
|
|
if not USE_MULTIPROCESSING:
|
|
|
logger.info("Multiprocessing disabled by configuration")
|
|
|
elif DEVICE != "cpu":
|
|
|
logger.info(f"Using serial GPU processing (device: {DEVICE})")
|
|
|
else:
|
|
|
logger.info(f"Using serial CPU processing (CPU count {total_cpus} too low)")
|
|
|
|
|
|
pool = None
|
|
|
try:
|
|
|
|
|
|
if use_pool:
|
|
|
pool = Pool(processes=num_workers, initializer=init_worker)
|
|
|
logger.success(f"β Worker pool ready with {num_workers} workers\n")
|
|
|
else:
|
|
|
|
|
|
logger.info("Initializing model in main process...")
|
|
|
get_model()
|
|
|
logger.success(f"β Model loaded (device: {DEVICE})\n")
|
|
|
|
|
|
|
|
|
for i, pdf_path in enumerate(pdf_files, 1):
|
|
|
if _shutdown_requested:
|
|
|
logger.warning(f"\nShutdown requested. Processed {i-1}/{len(pdf_files)} files.")
|
|
|
break
|
|
|
|
|
|
logger.info(f"\n{'='*60}")
|
|
|
logger.info(f"π File {i}/{len(pdf_files)}: {pdf_path.name}")
|
|
|
logger.info(f"{'='*60}")
|
|
|
|
|
|
sub_out = OUTPUT_DIR / pdf_path.stem
|
|
|
os.makedirs(sub_out, exist_ok=True)
|
|
|
|
|
|
try:
|
|
|
process_pdf_with_pool(pdf_path, sub_out, pool)
|
|
|
except KeyboardInterrupt:
|
|
|
logger.warning(f"\nInterrupted while processing {pdf_path.name}")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error processing {pdf_path.name}: {e}")
|
|
|
if _shutdown_requested:
|
|
|
break
|
|
|
logger.info("Continuing with next file...")
|
|
|
continue
|
|
|
|
|
|
if _shutdown_requested:
|
|
|
logger.warning(f"\nβ οΈ Processing interrupted. Partial results saved in {OUTPUT_DIR}")
|
|
|
else:
|
|
|
logger.success(f"\n⨠All done! Results are in {OUTPUT_DIR}")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
logger.error("\nβ Processing interrupted by user")
|
|
|
sys.exit(1)
|
|
|
except Exception as e:
|
|
|
logger.error(f"\nβ Fatal error: {e}")
|
|
|
sys.exit(1)
|
|
|
finally:
|
|
|
|
|
|
if pool is not None:
|
|
|
logger.info("\nπ§Ή Shutting down worker pool...")
|
|
|
pool.close()
|
|
|
pool.join()
|
|
|
logger.success("β Worker pool closed cleanly") |