process_images.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. #!/usr/bin/env python3
  2. """
  3. Image processing script for OCR and entity extraction using OpenAI-compatible API.
  4. Processes images from Downloads folder and extracts structured data.
  5. """
  6. import os
  7. import json
  8. import re
  9. import base64
  10. from pathlib import Path
  11. from typing import Dict, List, Optional
  12. import concurrent.futures
  13. from dataclasses import dataclass, asdict
  14. from openai import OpenAI
  15. from tqdm import tqdm
  16. import argparse
  17. from dotenv import load_dotenv
  18. @dataclass
  19. class ProcessingResult:
  20. """Structure for processing results"""
  21. filename: str
  22. success: bool
  23. data: Optional[Dict] = None
  24. error: Optional[str] = None
  25. class ImageProcessor:
  26. """Process images using OpenAI-compatible vision API"""
  27. def __init__(self, api_url: str, api_key: str, model: str = "gpt-4o", index_file: str = "processing_index.json", downloads_dir: Optional[str] = None):
  28. self.client = OpenAI(api_key=api_key, base_url=api_url)
  29. self.model = model
  30. self.downloads_dir = Path(downloads_dir) if downloads_dir else Path.home() / "Downloads"
  31. self.index_file = index_file
  32. self.processed_files = self.load_index()
  33. def load_index(self) -> set:
  34. """Load the index of already processed files"""
  35. if os.path.exists(self.index_file):
  36. try:
  37. with open(self.index_file, 'r') as f:
  38. data = json.load(f)
  39. return set(data.get('processed_files', []))
  40. except Exception as e:
  41. print(f"⚠️ Warning: Could not load index file: {e}")
  42. return set()
  43. return set()
  44. def save_index(self, failed_files=None):
  45. """Save the current index of processed files"""
  46. data = {
  47. 'processed_files': sorted(list(self.processed_files)),
  48. 'last_updated': str(Path.cwd())
  49. }
  50. if failed_files:
  51. data['failed_files'] = failed_files
  52. with open(self.index_file, 'w') as f:
  53. json.dump(data, f, indent=2)
  54. def mark_processed(self, filename: str):
  55. """Mark a file as processed and update index"""
  56. self.processed_files.add(filename)
  57. self.save_index()
  58. def get_image_files(self) -> List[Path]:
  59. """Get all image files from Downloads folder (recursively)"""
  60. image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
  61. image_files = []
  62. for ext in image_extensions:
  63. image_files.extend(self.downloads_dir.glob(f'**/*{ext}'))
  64. image_files.extend(self.downloads_dir.glob(f'**/*{ext.upper()}'))
  65. return sorted(image_files)
  66. def get_relative_path(self, file_path: Path) -> str:
  67. """Get relative path from downloads directory for unique indexing"""
  68. try:
  69. return str(file_path.relative_to(self.downloads_dir))
  70. except ValueError:
  71. # If file is not relative to downloads_dir, use full path
  72. return str(file_path)
  73. def get_unprocessed_files(self) -> List[Path]:
  74. """Get only files that haven't been processed yet"""
  75. all_files = self.get_image_files()
  76. unprocessed = [f for f in all_files if self.get_relative_path(f) not in self.processed_files]
  77. return unprocessed
  78. def encode_image(self, image_path: Path) -> str:
  79. """Encode image to base64"""
  80. with open(image_path, 'rb') as f:
  81. return base64.b64encode(f.read()).decode('utf-8')
  82. def get_system_prompt(self) -> str:
  83. """Get the system prompt for structured extraction"""
  84. return """You are an expert OCR and document analysis system.
  85. Extract ALL text from the image in READING ORDER to create a digital twin of the document.
  86. IMPORTANT: Transcribe text exactly as it appears on the page, from top to bottom, left to right, including:
  87. - All printed text
  88. - All handwritten text (inline where it appears)
  89. - Stamps and annotations (inline where they appear)
  90. - Signatures (note location)
  91. Preserve the natural reading flow. Mix printed and handwritten text together in the order they appear.
  92. Return ONLY valid JSON in this exact structure:
  93. {
  94. "document_metadata": {
  95. "page_number": "string or null",
  96. "document_number": "string or null",
  97. "date": "string or null",
  98. "document_type": "string or null",
  99. "has_handwriting": true/false,
  100. "has_stamps": true/false
  101. },
  102. "full_text": "Complete text transcription in reading order. Include ALL text - printed, handwritten, stamps, etc. - exactly as it appears from top to bottom.",
  103. "text_blocks": [
  104. {
  105. "type": "printed|handwritten|stamp|signature|other",
  106. "content": "text content",
  107. "position": "top|middle|bottom|header|footer|margin"
  108. }
  109. ],
  110. "entities": {
  111. "people": ["list of person names"],
  112. "organizations": ["list of organizations"],
  113. "locations": ["list of locations"],
  114. "dates": ["list of dates found"],
  115. "reference_numbers": ["list of any reference/ID numbers"]
  116. },
  117. "additional_notes": "Any observations about document quality, redactions, damage, etc."
  118. }"""
  119. def fix_json_with_llm(self, base64_image: str, broken_json: str, error_msg: str) -> dict:
  120. """Ask the LLM to fix its own broken JSON"""
  121. response = self.client.chat.completions.create(
  122. model=self.model,
  123. messages=[
  124. {
  125. "role": "system",
  126. "content": self.get_system_prompt()
  127. },
  128. {
  129. "role": "user",
  130. "content": [
  131. {
  132. "type": "text",
  133. "text": "Extract all text and entities from this image. Return only valid JSON."
  134. },
  135. {
  136. "type": "image_url",
  137. "image_url": {
  138. "url": f"data:image/jpeg;base64,{base64_image}"
  139. }
  140. }
  141. ]
  142. },
  143. {
  144. "role": "assistant",
  145. "content": broken_json
  146. },
  147. {
  148. "role": "user",
  149. "content": f"Your JSON response has an error: {error_msg}\n\nPlease fix the JSON and return ONLY the corrected valid JSON. Do not explain, just return the fixed JSON."
  150. }
  151. ],
  152. max_tokens=4096,
  153. temperature=0.1
  154. )
  155. content = response.choices[0].message.content.strip()
  156. # Extract JSON using same logic
  157. json_match = re.search(r'```(?:json)?\s*\n(.*?)\n```', content, re.DOTALL)
  158. if json_match:
  159. content = json_match.group(1).strip()
  160. else:
  161. json_match = re.search(r'\{.*\}', content, re.DOTALL)
  162. if json_match:
  163. content = json_match.group(0).strip()
  164. return json.loads(content)
  165. def process_image(self, image_path: Path) -> ProcessingResult:
  166. """Process a single image through the API"""
  167. try:
  168. # Encode image
  169. base64_image = self.encode_image(image_path)
  170. # Make API call using OpenAI client
  171. response = self.client.chat.completions.create(
  172. model=self.model,
  173. messages=[
  174. {
  175. "role": "system",
  176. "content": self.get_system_prompt()
  177. },
  178. {
  179. "role": "user",
  180. "content": [
  181. {
  182. "type": "text",
  183. "text": "Extract all text and entities from this image. Return only valid JSON."
  184. },
  185. {
  186. "type": "image_url",
  187. "image_url": {
  188. "url": f"data:image/jpeg;base64,{base64_image}"
  189. }
  190. }
  191. ]
  192. }
  193. ],
  194. max_tokens=4096,
  195. temperature=0.1
  196. )
  197. # Parse response
  198. content = response.choices[0].message.content
  199. original_content = content # Keep original for retry
  200. # Robust JSON extraction
  201. content = content.strip()
  202. # 1. Try to find JSON between markdown code fences
  203. json_match = re.search(r'```(?:json)?\s*\n(.*?)\n```', content, re.DOTALL)
  204. if json_match:
  205. content = json_match.group(1).strip()
  206. else:
  207. # 2. Try to find JSON between curly braces
  208. json_match = re.search(r'\{.*\}', content, re.DOTALL)
  209. if json_match:
  210. content = json_match.group(0).strip()
  211. else:
  212. # 3. Strip markdown manually
  213. if content.startswith('```json'):
  214. content = content[7:]
  215. elif content.startswith('```'):
  216. content = content[3:]
  217. if content.endswith('```'):
  218. content = content[:-3]
  219. content = content.strip()
  220. # Try to parse JSON
  221. try:
  222. extracted_data = json.loads(content)
  223. except json.JSONDecodeError as e:
  224. # Try to salvage by finding the first complete JSON object
  225. try:
  226. # Find first { and matching }
  227. start = content.find('{')
  228. if start == -1:
  229. raise ValueError("No JSON object found")
  230. brace_count = 0
  231. end = start
  232. for i in range(start, len(content)):
  233. if content[i] == '{':
  234. brace_count += 1
  235. elif content[i] == '}':
  236. brace_count -= 1
  237. if brace_count == 0:
  238. end = i + 1
  239. break
  240. if end > start:
  241. content = content[start:end]
  242. extracted_data = json.loads(content)
  243. else:
  244. raise ValueError("Could not find complete JSON object")
  245. except Exception:
  246. # Last resort: Ask LLM to fix its JSON
  247. try:
  248. extracted_data = self.fix_json_with_llm(base64_image, original_content, str(e))
  249. except Exception:
  250. # Save ORIGINAL LLM response to errors directory (not our extracted version)
  251. self.save_broken_json(self.get_relative_path(image_path), original_content)
  252. # If even that fails, raise the original error
  253. raise e
  254. return ProcessingResult(
  255. filename=self.get_relative_path(image_path),
  256. success=True,
  257. data=extracted_data
  258. )
  259. except Exception as e:
  260. return ProcessingResult(
  261. filename=self.get_relative_path(image_path),
  262. success=False,
  263. error=str(e)
  264. )
  265. def process_all(self, max_workers: int = 5, limit: Optional[int] = None, resume: bool = True) -> List[ProcessingResult]:
  266. """Process all images with parallel processing"""
  267. if resume:
  268. image_files = self.get_unprocessed_files()
  269. total_files = len(self.get_image_files())
  270. already_processed = len(self.processed_files)
  271. print(f"Found {total_files} total image files")
  272. print(f"Already processed: {already_processed}")
  273. print(f"Remaining to process: {len(image_files)}")
  274. else:
  275. image_files = self.get_image_files()
  276. print(f"Found {len(image_files)} image files to process")
  277. if limit:
  278. image_files = image_files[:limit]
  279. print(f"Limited to {limit} files for this run")
  280. if not image_files:
  281. print("No files to process!")
  282. return []
  283. results = []
  284. failed_files = []
  285. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  286. futures = {executor.submit(self.process_image, img): img for img in image_files}
  287. with tqdm(total=len(image_files), desc="Processing images") as pbar:
  288. for future in concurrent.futures.as_completed(futures):
  289. result = future.result()
  290. results.append(result)
  291. # Save individual result to file
  292. if result.success:
  293. self.save_individual_result(result)
  294. tqdm.write(f"✅ Processed: {result.filename}")
  295. else:
  296. # Track failed files
  297. failed_files.append({
  298. 'filename': result.filename,
  299. 'error': result.error
  300. })
  301. tqdm.write(f"❌ Failed: {result.filename} - {result.error}")
  302. # Mark as processed regardless of success/failure
  303. self.mark_processed(result.filename)
  304. pbar.update(1)
  305. # Save failed files to index for reference
  306. if failed_files:
  307. self.save_index(failed_files=failed_files)
  308. print(f"\n⚠️ {len(failed_files)} files failed - logged in {self.index_file}")
  309. return results
  310. def save_individual_result(self, result: ProcessingResult):
  311. """Save individual result to ./results/folder/imagename.json"""
  312. # Create output path mirroring the source structure
  313. result_path = Path("./results") / result.filename
  314. result_path = result_path.with_suffix('.json')
  315. # Create parent directories
  316. result_path.parent.mkdir(parents=True, exist_ok=True)
  317. # Save the extracted data
  318. with open(result_path, 'w', encoding='utf-8') as f:
  319. json.dump(result.data, f, indent=2, ensure_ascii=False)
  320. def save_broken_json(self, filename: str, broken_content: str):
  321. """Save broken JSON to errors directory"""
  322. error_path = Path("./errors") / filename
  323. error_path = error_path.with_suffix('.json')
  324. # Create parent directories
  325. error_path.parent.mkdir(parents=True, exist_ok=True)
  326. # Save the broken content as-is
  327. with open(error_path, 'w', encoding='utf-8') as f:
  328. f.write(broken_content)
  329. def save_results(self, results: List[ProcessingResult], output_file: str = "processed_results.json"):
  330. """Save summary results to JSON file"""
  331. output_data = {
  332. "total_processed": len(results),
  333. "successful": sum(1 for r in results if r.success),
  334. "failed": sum(1 for r in results if not r.success),
  335. "results": [asdict(r) for r in results]
  336. }
  337. with open(output_file, 'w', encoding='utf-8') as f:
  338. json.dump(output_data, f, indent=2, ensure_ascii=False)
  339. print(f"\n✅ Summary saved to {output_file}")
  340. print(f" Individual results saved to ./results/")
  341. print(f" Successful: {output_data['successful']}")
  342. print(f" Failed: {output_data['failed']}")
  343. def main():
  344. # Load environment variables
  345. load_dotenv()
  346. parser = argparse.ArgumentParser(description="Process images with OCR and entity extraction")
  347. parser.add_argument("--api-url", help="OpenAI-compatible API base URL (default: from .env or OPENAI_API_URL)")
  348. parser.add_argument("--api-key", help="API key (default: from .env or OPENAI_API_KEY)")
  349. parser.add_argument("--model", help="Model name (default: from .env, OPENAI_MODEL, or meta-llama/Llama-4-Maverick-17B-128E-Instruct)")
  350. parser.add_argument("--workers", type=int, default=5, help="Number of parallel workers (default: 5)")
  351. parser.add_argument("--limit", type=int, help="Limit number of images to process (for testing)")
  352. parser.add_argument("--output", default="processed_results.json", help="Output JSON file")
  353. parser.add_argument("--index", default="processing_index.json", help="Index file to track processed files")
  354. parser.add_argument("--downloads-dir", default="./downloads", help="Directory containing images (default: ./downloads)")
  355. parser.add_argument("--no-resume", action="store_true", help="Process all files, ignoring index")
  356. args = parser.parse_args()
  357. # Get values from args or environment variables
  358. api_url = args.api_url or os.getenv("OPENAI_API_URL", "http://...")
  359. api_key = args.api_key or os.getenv("OPENAI_API_KEY", "abcd1234")
  360. model = args.model or os.getenv("OPENAI_MODEL", "meta-llama/Llama-4-Maverick-17B-128E-Instruct")
  361. processor = ImageProcessor(
  362. api_url=api_url,
  363. api_key=api_key,
  364. model=model,
  365. index_file=args.index,
  366. downloads_dir=args.downloads_dir
  367. )
  368. results = processor.process_all(
  369. max_workers=args.workers,
  370. limit=args.limit,
  371. resume=not args.no_resume
  372. )
  373. processor.save_results(results, args.output)
  374. if __name__ == "__main__":
  375. main()