From 14db6bcd61761ae8abcdf834e71e36bd494832c2 Mon Sep 17 00:00:00 2001 From: MSVstudios <98731643+MSVstudios@users.noreply.github.com> Date: Mon, 16 Mar 2026 11:43:26 +0100 Subject: [PATCH] Initial Commit --- .gitignore | 223 ++++++++++++++++++++++++++++++++ .vscode/settings.json | 5 + README.md | 37 ++++++ config.toml.example | 36 ++++++ requirements.txt | 8 ++ src/__init__.py | 3 + src/config.py | 56 ++++++++ src/config/__init__.py | 1 + src/config/settings.py | 42 ++++++ src/db.py | 71 +++++++++++ src/download.py | 54 ++++++++ src/metrics.py | 27 ++++ src/model.py | 61 +++++++++ src/pipeline.py | 283 +++++++++++++++++++++++++++++++++++++++++ src/preprocess.py | 39 ++++++ tests/test_pipeline.py | 57 +++++++++ 16 files changed, 1003 insertions(+) create mode 100644 .gitignore create mode 100644 .vscode/settings.json create mode 100644 README.md create mode 100644 config.toml.example create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/config.py create mode 100644 src/config/__init__.py create mode 100644 src/config/settings.py create mode 100644 src/db.py create mode 100644 src/download.py create mode 100644 src/metrics.py create mode 100644 src/model.py create mode 100644 src/pipeline.py create mode 100644 src/preprocess.py create mode 100644 tests/test_pipeline.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d11086a --- /dev/null +++ b/.gitignore @@ -0,0 +1,223 @@ + +# project specifics + +config.toml +user-docs/ + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..29f2265 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "github.copilot.chat.reasoningEffort": "high", + "github.copilot.chat.responsesApiReasoningEffort": "high", + "github.copilot.selectedModel": "" +} \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..d64f8d1 --- /dev/null +++ b/README.md @@ -0,0 +1,37 @@ +# Florence-2 Captioning Pipeline + +High-throughput asynchronous captioning pipeline using **Florence-2 Base PromptGen**. + +## Goals + +- Download images from S3/HTTP concurrently +- Preprocess (resize/normalize) +- Run batched caption generation on GPU +- Persist captions back to a database (async) + +## Project structure + +- `src/`: implementation code +- `tests/`: unit/integration tests +- `todo.md`: tasks list +- `implementationPlanV2.md`: architecture + design notes + +## Quickstart + +1. Install dependencies: + +```bash +pip install -r requirements.txt +``` + +2. Configure environment variables (see `src/config.py` for expected vars). + +3. Run the pipeline (example): + +```bash +python -m src.pipeline --dry-run +``` + +## Notes + +This repo is intended as a foundation for building a fast, async dataset captioning tool. diff --git a/config.toml.example b/config.toml.example new file mode 100644 index 0000000..b107a14 --- /dev/null +++ b/config.toml.example @@ -0,0 +1,36 @@ +# Configuration for Florence-2 caption pipeline + +[model] +# HuggingFace model id (replace with your model) +id = "your-org/your-model" +# Device to run model on (cuda/cpu) +device = "cuda" +# Prompt token used for captioning (replace with your token) +prompt_token = "" + +[preprocessing] +# Max side for resizing (longest dimension) +image_max_side = 768 + +[pipeline] +# Batch sizes and queue sizing +gpu_batch_size = 8 +download_concurrency = 16 +image_queue_max_size = 64 +result_queue_max_size = 128 +db_write_batch_size = 64 + +[database] +# Async DB connection string +dsn = "" +# Table and column names +table = "images" +id_column = "id" +url_column = "url" +caption_column = "caption" +# Default WHERE clause used to fetch rows to caption +query_where = "caption IS NULL" + +[debug] +# Dry run disables writing to DB for quick local tests +dry_run = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..073a6f8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.0.0 +transformers>=4.40.0 +aiohttp>=3.9.0 +psycopg[binary]>=3.2 +Pillow>=10.0.0 +python-dotenv>=1.0.0 +pydantic-settings>=2.2.0 +pytest>=8.0.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8684eba --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,3 @@ +"""Florence-2 captioning pipeline package.""" + +__all__ = [] diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..773f4bc --- /dev/null +++ b/src/config.py @@ -0,0 +1,56 @@ +from .config.settings import settings + + +# Legacy compatibility helpers (so existing code can use config.) +MODEL_ID = settings.model_id +MODEL_DEVICE = settings.model_device +PROMPT_TOKEN = settings.prompt_token + +IMAGE_MAX_SIDE = settings.image_max_side + +GPU_BATCH_SIZE = settings.gpu_batch_size +DOWNLOAD_CONCURRENCY = settings.download_concurrency +IMAGE_QUEUE_MAX_SIZE = settings.image_queue_max_size +RESULT_QUEUE_MAX_SIZE = settings.result_queue_max_size +DB_WRITE_BATCH_SIZE = settings.db_write_batch_size + +DB_DSN = settings.db_dsn + +DRY_RUN = settings.dry_run + +DB_TABLE = settings.db_table +DB_ID_COLUMN = settings.db_id_column +DB_URL_COLUMN = settings.db_url_column +DB_CAPTION_COLUMN = settings.db_caption_column +DB_QUERY_WHERE = settings.db_query_where + + +def apply_overrides(**kwargs): + """Override config values at runtime (for CLI overrides).""" + + # Update pydantic settings object + for key, value in kwargs.items(): + if value is None: + continue + if hasattr(settings, key): + setattr(settings, key, value) + + # Sync legacy constants + globals().update({ + "MODEL_ID": settings.model_id, + "MODEL_DEVICE": settings.model_device, + "PROMPT_TOKEN": settings.prompt_token, + "IMAGE_MAX_SIDE": settings.image_max_side, + "GPU_BATCH_SIZE": settings.gpu_batch_size, + "DOWNLOAD_CONCURRENCY": settings.download_concurrency, + "IMAGE_QUEUE_MAX_SIZE": settings.image_queue_max_size, + "RESULT_QUEUE_MAX_SIZE": settings.result_queue_max_size, + "DB_WRITE_BATCH_SIZE": settings.db_write_batch_size, + "DB_DSN": settings.db_dsn, + "DRY_RUN": settings.dry_run, + "DB_TABLE": settings.db_table, + "DB_ID_COLUMN": settings.db_id_column, + "DB_URL_COLUMN": settings.db_url_column, + "DB_CAPTION_COLUMN": settings.db_caption_column, + "DB_QUERY_WHERE": settings.db_query_where, + }) diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..3c07262 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1 @@ +"""Package for configuration loaded from TOML via pydantic-settings.""" diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000..4cb6ff5 --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from config.toml (or environment variables).""" + + # Model + model_id: str = "MiaoshouAI/Florence-2-base-PromptGen-v2" + model_device: str = "cuda" + prompt_token: str = "" + + # Preprocessing + image_max_side: int = 768 + + # Pipeline sizing + gpu_batch_size: int = 8 + download_concurrency: int = 16 + image_queue_max_size: int = 64 + result_queue_max_size: int = 128 + db_write_batch_size: int = 64 + + # Database + db_dsn: str = "" + db_table: str = "images" + db_id_column: str = "id" + db_url_column: str = "url" + db_caption_column: str = "caption" + db_query_where: str = "caption IS NULL" + + # Other + dry_run: bool = True + + model_config = SettingsConfigDict( + env_file="config.toml", + env_file_encoding="utf-8", + case_sensitive=False, + ) + + +settings = Settings() diff --git a/src/db.py b/src/db.py new file mode 100644 index 0000000..318f7c6 --- /dev/null +++ b/src/db.py @@ -0,0 +1,71 @@ +import logging +from typing import AsyncGenerator, Iterable, List, Tuple + +import psycopg + +from .config import DB_DSN + +_LOGGER = logging.getLogger(__name__) + + +async def get_db_connection(): + if not DB_DSN: + raise ValueError("DB_DSN is not set. Set the DB_DSN environment variable.") + return await psycopg.AsyncConnection.connect(DB_DSN) + + +async def fetch_image_urls( + table: str = None, + id_column: str = None, + url_column: str = None, + where_clause: str = None, + batch_size: int = 100, +) -> AsyncGenerator[Tuple[int, str], None]: + """Yield (image_id, image_url) tuples from the database in batches.""" + # Prepare query components + table = table or "images" + id_column = id_column or "id" + url_column = url_column or "url" + + query = f"SELECT {id_column}, {url_column} FROM {table}" + if where_clause: + query += " WHERE " + where_clause + + async with await get_db_connection() as conn: + async with conn.cursor() as cur: + await cur.execute(query) + while True: + rows = await cur.fetchmany(batch_size) + if not rows: + break + for row in rows: + yield row[0], row[1] + + +async def write_captions( + updates: Iterable[Tuple[int, str]], + batch_size: int = 64, +): + """Write captions back to the DB in batches.""" + async with await get_db_connection() as conn: + async with conn.cursor() as cur: + batch: List[Tuple[str, int]] = [] + for image_id, caption in updates: + batch.append((caption, image_id)) + if len(batch) >= batch_size: + await _execute_update(cur, batch) + batch.clear() + if batch: + await _execute_update(cur, batch) + + +async def _execute_update(cur, batch: List[Tuple[str, int]]): + await cur.executemany( + """ + UPDATE images + SET caption = %s + WHERE id = %s + """, + batch, + ) + await cur.connection.commit() diff --git a/src/download.py b/src/download.py new file mode 100644 index 0000000..339cf8f --- /dev/null +++ b/src/download.py @@ -0,0 +1,54 @@ +import io +import asyncio +import logging +from typing import Optional, Tuple + +import aiohttp +from PIL import Image + +from . import config + +_LOGGER = logging.getLogger(__name__) + + +async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, timeout: int = 30) -> bytes: + async with session.get(url, timeout=timeout) as resp: + resp.raise_for_status() + return await resp.read() + + +async def download_image( + url: str, + session: aiohttp.ClientSession, + retries: int = 2, +) -> Optional[Image.Image]: + """Download an image and return a PIL Image (RGB).""" + for attempt in range(1, retries + 1): + try: + data = await fetch_image_bytes(session, url) + img = Image.open(io.BytesIO(data)).convert("RGB") + return img + except Exception as e: + _LOGGER.warning("Download failed (attempt %s/%s) for %s: %s", attempt, retries, url, e) + if attempt == retries: + return None + await asyncio.sleep(0.5 * attempt) + + +class DownloadWorker: + """Worker that downloads images and pushes them into a queue.""" + + def __init__(self, queue, session: aiohttp.ClientSession): + self.queue = queue + self.session = session + + async def run(self, image_id: int, url: str): + img = await download_image(url, self.session) + if img is None: + return + await self.queue.put((image_id, img)) + + +def create_aiohttp_session() -> aiohttp.ClientSession: + connector = aiohttp.TCPConnector(limit=config.DOWNLOAD_CONCURRENCY) + return aiohttp.ClientSession(connector=connector) diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000..cc5caff --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,27 @@ +import time + + +class Metrics: + def __init__(self): + self.start_ts = time.time() + self.images_processed = 0 + self.download_failures = 0 + self.caption_failures = 0 + + def mark_image_processed(self): + self.images_processed += 1 + + def mark_download_failure(self): + self.download_failures += 1 + + def mark_caption_failure(self): + self.caption_failures += 1 + + def summary(self): + elapsed = max(1.0, time.time() - self.start_ts) + return { + "images_per_second": self.images_processed / elapsed, + "download_failures": self.download_failures, + "caption_failures": self.caption_failures, + "elapsed_seconds": elapsed, + } diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..718b2b6 --- /dev/null +++ b/src/model.py @@ -0,0 +1,61 @@ +import logging +from typing import Dict, List, Optional + +import torch +from transformers import AutoModelForCausalLM, AutoProcessor + +from .config import MODEL_DEVICE, MODEL_ID + +_LOGGER = logging.getLogger(__name__) + + +class Florence2CaptionModel: + """Wrapper around Florence-2 PromptGen for batched caption generation.""" + + def __init__(self, model_id: str = MODEL_ID, device: str = MODEL_DEVICE): + self.model_id = model_id + self.device = device + self.model = None + self.processor = None + + def load(self, torch_dtype: torch.dtype = torch.float16): + """Load model and processor into memory (once).""" + _LOGGER.info("Loading model %s on %s", self.model_id, self.device) + self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True) + self.model = ( + AutoModelForCausalLM.from_pretrained( + self.model_id, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + .to(self.device) + .eval() + ) + + def generate_captions( + self, + images: List[torch.Tensor], + prompt: str, + max_new_tokens: int = 128, + do_sample: bool = False, + temperature: float = 0.0, + ) -> List[str]: + """Generate captions for a batch of preprocessed image tensors.""" + if self.model is None or self.processor is None: + raise RuntimeError("Model is not loaded. Call load() first.") + + inputs = self.processor( + text=[prompt] * len(images), + images=images, + return_tensors="pt", + ).to(self.device) + + generated = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + ) + + captions = self.processor.batch_decode(generated, skip_special_tokens=True) + return captions diff --git a/src/pipeline.py b/src/pipeline.py new file mode 100644 index 0000000..92b606d --- /dev/null +++ b/src/pipeline.py @@ -0,0 +1,283 @@ +import argparse +import asyncio +import logging +from collections.abc import AsyncIterable +from typing import Iterable, List, Optional, Tuple, Union + +import aiohttp +import torch + +from . import config +from .db import fetch_image_urls, write_captions +from .download import create_aiohttp_session, DownloadWorker +from .metrics import Metrics +from .model import Florence2CaptionModel +from .preprocess import preprocess_image + +_LOGGER = logging.getLogger(__name__) + + +def _configure_logging(level: int = logging.INFO) -> None: + """Configure root logger for the pipeline.""" + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +async def _producer(queue: asyncio.Queue, urls: Iterable[Tuple[int, str]]): + for image_id, url in urls: + await queue.put((image_id, url)) + await queue.put(None) # sentinel to indicate completion + + +async def _producer_async(queue: asyncio.Queue, urls: AsyncIterable[Tuple[int, str]]): + async for image_id, url in urls: + await queue.put((image_id, url)) + await queue.put(None) # sentinel to indicate completion + + +async def _download_consumer( + input_queue: asyncio.Queue, + image_queue: asyncio.Queue, + session: aiohttp.ClientSession, + metrics: Metrics, +): + worker = DownloadWorker(image_queue, session) + try: + while True: + item = await input_queue.get() + input_queue.task_done() + if item is None: + await image_queue.put(None) + break + + image_id, url = item + await worker.run(image_id, url) + except asyncio.CancelledError: + _LOGGER.info("Download consumer cancelled") + raise + except Exception as e: + _LOGGER.exception("Download consumer error: %s", e) + raise + + +async def _caption_worker( + image_queue: asyncio.Queue, + result_queue: asyncio.Queue, + model: Florence2CaptionModel, + metrics: Metrics, + prompt: str, +): + buffer: List[Tuple[int, torch.Tensor]] = [] + + try: + while True: + item = await image_queue.get() + image_queue.task_done() + + if item is None: + break + + image_id, img = item + try: + tensor = preprocess_image(img, device=config.MODEL_DEVICE) + buffer.append((image_id, tensor)) + if len(buffer) >= config.GPU_BATCH_SIZE: + await _flush_batch(buffer, model, result_queue, metrics, prompt) + except Exception as e: + metrics.mark_caption_failure() + _LOGGER.exception("Preprocess/inference error for %s: %s", image_id, e) + + if buffer: + await _flush_batch(buffer, model, result_queue, metrics, prompt) + + await result_queue.put(None) + except asyncio.CancelledError: + _LOGGER.info("Caption worker cancelled") + raise + except Exception as e: + _LOGGER.exception("Caption worker failure: %s", e) + raise + + +async def _flush_batch( + buffer: List[Tuple[int, torch.Tensor]], + model: Florence2CaptionModel, + result_queue: asyncio.Queue, + metrics: Metrics, + prompt: str, +): + ids, tensors = zip(*buffer) + captions = model.generate_captions(list(tensors), prompt=prompt) + for image_id, caption in zip(ids, captions): + await result_queue.put((image_id, caption)) + metrics.mark_image_processed() + buffer.clear() + + +async def _db_writer(result_queue: asyncio.Queue, metrics: Metrics): + batch: List[Tuple[int, str]] = [] + try: + while True: + item = await result_queue.get() + result_queue.task_done() + if item is None: + break + + batch.append(item) + if len(batch) >= config.DB_WRITE_BATCH_SIZE: + await write_captions(batch) + batch.clear() + + if batch: + await write_captions(batch) + except asyncio.CancelledError: + _LOGGER.info("DB writer cancelled") + raise + except Exception as e: + _LOGGER.exception("DB writer error: %s", e) + raise + + +async def run_pipeline( + urls: Union[Iterable[Tuple[int, str]], AsyncIterable[Tuple[int, str]]], + dry_run: bool = config.DRY_RUN, + prompt_token: str = config.PROMPT_TOKEN, +): + _LOGGER.info("Starting pipeline (dry_run=%s)", dry_run) + metrics = Metrics() + model = Florence2CaptionModel() + model.load() + + download_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE) + image_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE) + result_queue: asyncio.Queue = asyncio.Queue(maxsize=config.RESULT_QUEUE_MAX_SIZE) + + async with create_aiohttp_session() as session: + if isinstance(urls, AsyncIterable): + producer_task = asyncio.create_task(_producer_async(download_queue, urls)) + else: + producer_task = asyncio.create_task(_producer(download_queue, urls)) + + downloader_task = asyncio.create_task( + _download_consumer(download_queue, image_queue, session, metrics) + ) + caption_task = asyncio.create_task( + _caption_worker(image_queue, result_queue, model, metrics, prompt_token) + ) + writer_task = asyncio.create_task(_db_writer(result_queue, metrics)) + + tasks = [producer_task, downloader_task, caption_task, writer_task] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + # Cancel any remaining tasks if something failed + for p in pending: + p.cancel() + + # Wait for cancellation to complete + await asyncio.gather(*pending, return_exceptions=True) + + # If any task raised an exception, propagate it + for t in done: + if t.exception(): + raise t.exception() + + _LOGGER.info("Pipeline completed: %s", metrics.summary()) + + +async def _dry_run_urls() -> List[Tuple[int, str]]: + # Example placeholder list for dry run. + return [ + (1, "https://example.com/image1.jpg"), + (2, "https://example.com/image2.jpg"), + ] + + +async def _run_from_db(): + """Run pipeline using async stream from the database.""" + urls = fetch_image_urls( + table=config.DB_TABLE, + id_column=config.DB_ID_COLUMN, + url_column=config.DB_URL_COLUMN, + where_clause=config.DB_QUERY_WHERE, + ) + await run_pipeline(urls, dry_run=False) + + +def main(): + parser = argparse.ArgumentParser(description="Async Florence-2 captioning pipeline") + parser.add_argument("--dry-run", action="store_true", help="Run without DB, using sample URLs") + parser.add_argument("--log-level", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)") + + # Model overrides + parser.add_argument("--model-id", help="HuggingFace model ID") + parser.add_argument("--model-device", help="Device to run model on (cuda/cpu)") + parser.add_argument("--prompt-token", help="Task token to use for captioning") + + # Pipeline sizing + parser.add_argument("--gpu-batch-size", type=int, help="Batch size for GPU inference") + parser.add_argument("--download-concurrency", type=int, help="Max concurrent downloads") + parser.add_argument("--image-queue-max", type=int, help="Max size for image queue") + parser.add_argument("--result-queue-max", type=int, help="Max size for result queue") + + # DB schema configuration + parser.add_argument("--db-table", help="Database table containing images") + parser.add_argument("--db-id-col", help="ID column name") + parser.add_argument("--db-url-col", help="URL column name") + parser.add_argument("--db-caption-col", help="Caption column name") + parser.add_argument( + "--db-where", + help="WHERE clause to filter rows (e.g., \"caption IS NULL\")", + ) + + args = parser.parse_args() + + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + _configure_logging(log_level) + + config.apply_overrides( + MODEL_ID=args.model_id, + MODEL_DEVICE=args.model_device, + PROMPT_TOKEN=args.prompt_token, + GPU_BATCH_SIZE=args.gpu_batch_size, + DOWNLOAD_CONCURRENCY=args.download_concurrency, + IMAGE_QUEUE_MAX_SIZE=args.image_queue_max, + RESULT_QUEUE_MAX_SIZE=args.result_queue_max, + DB_TABLE=args.db_table, + DB_ID_COLUMN=args.db_id_col, + DB_URL_COLUMN=args.db_url_col, + DB_CAPTION_COLUMN=args.db_caption_col, + DB_QUERY_WHERE=args.db_where, + ) + + _LOGGER.debug("Effective config: %s", {k: getattr(config, k) for k in [ + "MODEL_ID", + "MODEL_DEVICE", + "PROMPT_TOKEN", + "GPU_BATCH_SIZE", + "DOWNLOAD_CONCURRENCY", + "IMAGE_QUEUE_MAX_SIZE", + "RESULT_QUEUE_MAX_SIZE", + "DB_TABLE", + "DB_ID_COLUMN", + "DB_URL_COLUMN", + "DB_CAPTION_COLUMN", + "DB_QUERY_WHERE", + ]}) + + try: + if args.dry_run: + urls = asyncio.run(_dry_run_urls()) + asyncio.run(run_pipeline(urls, dry_run=True)) + else: + asyncio.run(_run_from_db()) + except KeyboardInterrupt: + _LOGGER.warning("Received interrupt signal, shutting down...") + except Exception: + _LOGGER.exception("Pipeline failed") + + +if __name__ == "__main__": + main() diff --git a/src/preprocess.py b/src/preprocess.py new file mode 100644 index 0000000..8a55538 --- /dev/null +++ b/src/preprocess.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import numpy as np +import torch +from PIL import Image + +from . import config + + +def resize_preserve_aspect(img: Image.Image, max_side: int | None = None) -> Image.Image: + """Resize so the longest side is `max_side`, preserving aspect ratio.""" + if max_side is None: + max_side = config.IMAGE_MAX_SIDE + + img = img.convert("RGB") + img.thumbnail((max_side, max_side), Image.LANCZOS) + return img + + +def image_to_tensor(img: Image.Image) -> torch.Tensor: + """Convert a PIL image to a CHW float tensor scaled to [0, 1].""" + arr = torch.from_numpy(np.array(img)) + tensor = arr.permute(2, 0, 1).float() / 255.0 + return tensor + + +def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Normalize tensor with ImageNet/vision mean/std.""" + mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device)[..., None, None] + std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device)[..., None, None] + return (tensor - mean) / std + + +def preprocess_image(img: Image.Image, device: str = "cpu") -> torch.Tensor: + """Resize, convert, normalize, and move to device.""" + resized = resize_preserve_aspect(img) + tensor = image_to_tensor(resized).to(device) + norm = normalize_tensor(tensor) + return norm diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..cf4574b --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,57 @@ +import asyncio + +from src.preprocess import resize_preserve_aspect + + +def test_resize_preserve_aspect(): + from PIL import Image + + img = Image.new("RGB", (1024, 512), color="white") + resized = resize_preserve_aspect(img, max_side=256) + assert max(resized.size) == 256 + + +def test_pipeline_writes_captions_to_db(monkeypatch): + """Ensure pipeline calls write_captions with non-empty captions.""" + + # Stub out download to avoid network + async def fake_download_image(url, session, retries=2): + return object() + + monkeypatch.setattr("src.download.download_image", fake_download_image) + + # Stub out preprocess to return a dummy tensor-like object (skip Pillow / torch entirely) + def fake_preprocess_image(img, device="cpu"): + return object() + + monkeypatch.setattr("src.preprocess.preprocess_image", fake_preprocess_image) + + # Stub out model load + generate + def fake_load(self): + self.model = True + self.processor = True + + def fake_generate_captions(self, images, prompt, **kwargs): + return [f"caption-for-{i}" for i in range(len(images))] + + monkeypatch.setattr("src.model.Florence2CaptionModel.load", fake_load) + monkeypatch.setattr( + "src.model.Florence2CaptionModel.generate_captions", fake_generate_captions + ) + + # Capture writes + written = [] + + async def fake_write_captions(updates, batch_size=64): + written.extend(updates) + + monkeypatch.setattr("src.db.write_captions", fake_write_captions) + + # Run pipeline with a small dummy URL list + from src import pipeline + + urls = [(1, "http://example.com/image1.jpg"), (2, "http://example.com/image2.jpg")] + asyncio.run(pipeline.run_pipeline(urls, dry_run=True)) + + assert len(written) == 2 + assert all(isinstance(c, str) and len(c) > 0 for _, c in written)