from functools import lru_cache
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
import fitz
import gradio as gr
from rapidocr import RapidOCR
DPI = 150
JPEG_QUALITY = 85
GLM_MODEL_ID = "zai-org/GLM-OCR"
GLM_MAX_NEW_TOKENS = 4096
GLM_PROMPT = "Text Recognition:"
rapidocr = RapidOCR()
def render_pages(pdf_path, pages, image_dir):
zoom = DPI / 72
matrix = fitz.Matrix(zoom, zoom)
image_paths = []
with fitz.open(pdf_path) as pdf:
for page_number in pages:
page = pdf.load_page(page_number - 1)
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
image_path = image_dir / f"page-{page_number}.jpg"
image_path.write_bytes(
pixmap.tobytes("jpeg", jpg_quality=JPEG_QUALITY)
)
image_paths.append(image_path)
return image_paths
def preview_page(pdf_path, page_number):
if not pdf_path or page_number is None:
return None
preview_file = NamedTemporaryFile(delete=False, suffix=".png")
preview_file.close()
with fitz.open(pdf_path) as pdf:
page = pdf.load_page(int(page_number) - 1)
pixmap = page.get_pixmap(matrix=fitz.Matrix(DPI / 72, DPI / 72), alpha=False)
pixmap.save(preview_file.name)
return preview_file.name
def load_pdf(pdf_path, start_page):
if not pdf_path:
return None, gr.update(value=1)
with fitz.open(pdf_path) as pdf:
page_count = pdf.page_count
# Limit to 2 pages so the demo stays responsive.
page_count = min(page_count, 2)
return preview_page(pdf_path, start_page), gr.update(value=page_count)
def run_rapidocr(image_paths):
texts = []
for image_path in image_paths:
result = rapidocr(str(image_path))
texts.append("\n".join(result.txts or []).strip())
return texts
def pick_device_and_dtype(torch):
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return "cuda", dtype
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps", torch.float16
return "cpu", torch.float32
@lru_cache
def glm_ocr_model():
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
device, torch_dtype = pick_device_and_dtype(torch)
processor = AutoProcessor.from_pretrained(GLM_MODEL_ID)
model_kwargs = {"low_cpu_mem_usage": True}
try:
model = AutoModelForImageTextToText.from_pretrained(
GLM_MODEL_ID,
dtype=torch_dtype,
**model_kwargs,
)
except TypeError:
model = AutoModelForImageTextToText.from_pretrained(
GLM_MODEL_ID,
torch_dtype=torch_dtype,
**model_kwargs,
)
model = model.to(device)
model.eval()
return processor, model, torch, device
def run_glm_ocr(image_paths):
import time
processor, model, torch, _device = glm_ocr_model()
texts = []
for image_path in image_paths:
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": str(image_path)},
{"type": "text", "text": GLM_PROMPT},
],
}
]
started = time.perf_counter()
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
inputs.pop("token_type_ids", None)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=GLM_MAX_NEW_TOKENS,
do_sample=False,
)
input_token_count = inputs["input_ids"].shape[1]
text = processor.decode(
output_ids[0][input_token_count:],
skip_special_tokens=True,
)
elapsed = time.perf_counter() - started
texts.append(f"{text}\n\n_OCR time: {elapsed:.1f}s_")
return texts
def ocr_pdf(pdf_file, start_page, end_page, engine):
pages = list(range(int(start_page), int(end_page) + 1))
with TemporaryDirectory() as image_dir:
image_paths = render_pages(pdf_file, pages, Path(image_dir))
if engine == "rapidocr":
page_texts = run_rapidocr(image_paths)
else:
page_texts = run_glm_ocr(image_paths)
return "\n\n".join(
f"## Page {page_number}\n\n{text}"
for page_number, text in zip(pages, page_texts)
)
with gr.Blocks(title="PDF OCR demo") as demo:
gr.Markdown("# PDF OCR demo")
with gr.Row():
with gr.Column():
pdf = gr.File(label="PDF", file_types=[".pdf"], type="filepath")
with gr.Row():
start_page = gr.Number(label="Start page", value=1, precision=0)
end_page = gr.Number(label="End page", value=1, precision=0)
engine = gr.Radio(
label="OCR engine",
choices=["rapidocr", "glm_ocr"],
value="rapidocr",
)
button = gr.Button("OCR pages", variant="primary")
preview = gr.Image(
label="Page preview",
type="filepath",
height=520,
interactive=False,
)
with gr.Column():
output = gr.Textbox(label="OCR text", lines=24)
pdf.change(load_pdf, [pdf, start_page], [preview, end_page])
start_page.change(preview_page, [pdf, start_page], preview)
button.click(ocr_pdf, [pdf, start_page, end_page, engine], output)