Initial Commit

This commit is contained in:
MSVstudios 2026-03-16 11:43:26 +01:00
commit 14db6bcd61
16 changed files with 1003 additions and 0 deletions

223
.gitignore vendored Normal file
View File

@ -0,0 +1,223 @@
# project specifics
config.toml
user-docs/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml

5
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
"github.copilot.chat.reasoningEffort": "high",
"github.copilot.chat.responsesApiReasoningEffort": "high",
"github.copilot.selectedModel": ""
}

37
README.md Normal file
View File

@ -0,0 +1,37 @@
# Florence-2 Captioning Pipeline
High-throughput asynchronous captioning pipeline using **Florence-2 Base PromptGen**.
## Goals
- Download images from S3/HTTP concurrently
- Preprocess (resize/normalize)
- Run batched caption generation on GPU
- Persist captions back to a database (async)
## Project structure
- `src/`: implementation code
- `tests/`: unit/integration tests
- `todo.md`: tasks list
- `implementationPlanV2.md`: architecture + design notes
## Quickstart
1. Install dependencies:
```bash
pip install -r requirements.txt
```
2. Configure environment variables (see `src/config.py` for expected vars).
3. Run the pipeline (example):
```bash
python -m src.pipeline --dry-run
```
## Notes
This repo is intended as a foundation for building a fast, async dataset captioning tool.

36
config.toml.example Normal file
View File

@ -0,0 +1,36 @@
# Configuration for Florence-2 caption pipeline
[model]
# HuggingFace model id (replace with your model)
id = "your-org/your-model"
# Device to run model on (cuda/cpu)
device = "cuda"
# Prompt token used for captioning (replace with your token)
prompt_token = "<PROMPT_TOKEN>"
[preprocessing]
# Max side for resizing (longest dimension)
image_max_side = 768
[pipeline]
# Batch sizes and queue sizing
gpu_batch_size = 8
download_concurrency = 16
image_queue_max_size = 64
result_queue_max_size = 128
db_write_batch_size = 64
[database]
# Async DB connection string
dsn = ""
# Table and column names
table = "images"
id_column = "id"
url_column = "url"
caption_column = "caption"
# Default WHERE clause used to fetch rows to caption
query_where = "caption IS NULL"
[debug]
# Dry run disables writing to DB for quick local tests
dry_run = true

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
torch>=2.0.0
transformers>=4.40.0
aiohttp>=3.9.0
psycopg[binary]>=3.2
Pillow>=10.0.0
python-dotenv>=1.0.0
pydantic-settings>=2.2.0
pytest>=8.0.0

3
src/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""Florence-2 captioning pipeline package."""
__all__ = []

56
src/config.py Normal file
View File

@ -0,0 +1,56 @@
from .config.settings import settings
# Legacy compatibility helpers (so existing code can use config.<NAME>)
MODEL_ID = settings.model_id
MODEL_DEVICE = settings.model_device
PROMPT_TOKEN = settings.prompt_token
IMAGE_MAX_SIDE = settings.image_max_side
GPU_BATCH_SIZE = settings.gpu_batch_size
DOWNLOAD_CONCURRENCY = settings.download_concurrency
IMAGE_QUEUE_MAX_SIZE = settings.image_queue_max_size
RESULT_QUEUE_MAX_SIZE = settings.result_queue_max_size
DB_WRITE_BATCH_SIZE = settings.db_write_batch_size
DB_DSN = settings.db_dsn
DRY_RUN = settings.dry_run
DB_TABLE = settings.db_table
DB_ID_COLUMN = settings.db_id_column
DB_URL_COLUMN = settings.db_url_column
DB_CAPTION_COLUMN = settings.db_caption_column
DB_QUERY_WHERE = settings.db_query_where
def apply_overrides(**kwargs):
"""Override config values at runtime (for CLI overrides)."""
# Update pydantic settings object
for key, value in kwargs.items():
if value is None:
continue
if hasattr(settings, key):
setattr(settings, key, value)
# Sync legacy constants
globals().update({
"MODEL_ID": settings.model_id,
"MODEL_DEVICE": settings.model_device,
"PROMPT_TOKEN": settings.prompt_token,
"IMAGE_MAX_SIDE": settings.image_max_side,
"GPU_BATCH_SIZE": settings.gpu_batch_size,
"DOWNLOAD_CONCURRENCY": settings.download_concurrency,
"IMAGE_QUEUE_MAX_SIZE": settings.image_queue_max_size,
"RESULT_QUEUE_MAX_SIZE": settings.result_queue_max_size,
"DB_WRITE_BATCH_SIZE": settings.db_write_batch_size,
"DB_DSN": settings.db_dsn,
"DRY_RUN": settings.dry_run,
"DB_TABLE": settings.db_table,
"DB_ID_COLUMN": settings.db_id_column,
"DB_URL_COLUMN": settings.db_url_column,
"DB_CAPTION_COLUMN": settings.db_caption_column,
"DB_QUERY_WHERE": settings.db_query_where,
})

1
src/config/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Package for configuration loaded from TOML via pydantic-settings."""

42
src/config/settings.py Normal file
View File

@ -0,0 +1,42 @@
from __future__ import annotations
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from config.toml (or environment variables)."""
# Model
model_id: str = "MiaoshouAI/Florence-2-base-PromptGen-v2"
model_device: str = "cuda"
prompt_token: str = "<MORE_DETAILED_CAPTION>"
# Preprocessing
image_max_side: int = 768
# Pipeline sizing
gpu_batch_size: int = 8
download_concurrency: int = 16
image_queue_max_size: int = 64
result_queue_max_size: int = 128
db_write_batch_size: int = 64
# Database
db_dsn: str = ""
db_table: str = "images"
db_id_column: str = "id"
db_url_column: str = "url"
db_caption_column: str = "caption"
db_query_where: str = "caption IS NULL"
# Other
dry_run: bool = True
model_config = SettingsConfigDict(
env_file="config.toml",
env_file_encoding="utf-8",
case_sensitive=False,
)
settings = Settings()

71
src/db.py Normal file
View File

@ -0,0 +1,71 @@
import logging
from typing import AsyncGenerator, Iterable, List, Tuple
import psycopg
from .config import DB_DSN
_LOGGER = logging.getLogger(__name__)
async def get_db_connection():
if not DB_DSN:
raise ValueError("DB_DSN is not set. Set the DB_DSN environment variable.")
return await psycopg.AsyncConnection.connect(DB_DSN)
async def fetch_image_urls(
table: str = None,
id_column: str = None,
url_column: str = None,
where_clause: str = None,
batch_size: int = 100,
) -> AsyncGenerator[Tuple[int, str], None]:
"""Yield (image_id, image_url) tuples from the database in batches."""
# Prepare query components
table = table or "images"
id_column = id_column or "id"
url_column = url_column or "url"
query = f"SELECT {id_column}, {url_column} FROM {table}"
if where_clause:
query += " WHERE " + where_clause
async with await get_db_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query)
while True:
rows = await cur.fetchmany(batch_size)
if not rows:
break
for row in rows:
yield row[0], row[1]
async def write_captions(
updates: Iterable[Tuple[int, str]],
batch_size: int = 64,
):
"""Write captions back to the DB in batches."""
async with await get_db_connection() as conn:
async with conn.cursor() as cur:
batch: List[Tuple[str, int]] = []
for image_id, caption in updates:
batch.append((caption, image_id))
if len(batch) >= batch_size:
await _execute_update(cur, batch)
batch.clear()
if batch:
await _execute_update(cur, batch)
async def _execute_update(cur, batch: List[Tuple[str, int]]):
await cur.executemany(
"""
UPDATE images
SET caption = %s
WHERE id = %s
""",
batch,
)
await cur.connection.commit()

54
src/download.py Normal file
View File

@ -0,0 +1,54 @@
import io
import asyncio
import logging
from typing import Optional, Tuple
import aiohttp
from PIL import Image
from . import config
_LOGGER = logging.getLogger(__name__)
async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, timeout: int = 30) -> bytes:
async with session.get(url, timeout=timeout) as resp:
resp.raise_for_status()
return await resp.read()
async def download_image(
url: str,
session: aiohttp.ClientSession,
retries: int = 2,
) -> Optional[Image.Image]:
"""Download an image and return a PIL Image (RGB)."""
for attempt in range(1, retries + 1):
try:
data = await fetch_image_bytes(session, url)
img = Image.open(io.BytesIO(data)).convert("RGB")
return img
except Exception as e:
_LOGGER.warning("Download failed (attempt %s/%s) for %s: %s", attempt, retries, url, e)
if attempt == retries:
return None
await asyncio.sleep(0.5 * attempt)
class DownloadWorker:
"""Worker that downloads images and pushes them into a queue."""
def __init__(self, queue, session: aiohttp.ClientSession):
self.queue = queue
self.session = session
async def run(self, image_id: int, url: str):
img = await download_image(url, self.session)
if img is None:
return
await self.queue.put((image_id, img))
def create_aiohttp_session() -> aiohttp.ClientSession:
connector = aiohttp.TCPConnector(limit=config.DOWNLOAD_CONCURRENCY)
return aiohttp.ClientSession(connector=connector)

27
src/metrics.py Normal file
View File

@ -0,0 +1,27 @@
import time
class Metrics:
def __init__(self):
self.start_ts = time.time()
self.images_processed = 0
self.download_failures = 0
self.caption_failures = 0
def mark_image_processed(self):
self.images_processed += 1
def mark_download_failure(self):
self.download_failures += 1
def mark_caption_failure(self):
self.caption_failures += 1
def summary(self):
elapsed = max(1.0, time.time() - self.start_ts)
return {
"images_per_second": self.images_processed / elapsed,
"download_failures": self.download_failures,
"caption_failures": self.caption_failures,
"elapsed_seconds": elapsed,
}

61
src/model.py Normal file
View File

@ -0,0 +1,61 @@
import logging
from typing import Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from .config import MODEL_DEVICE, MODEL_ID
_LOGGER = logging.getLogger(__name__)
class Florence2CaptionModel:
"""Wrapper around Florence-2 PromptGen for batched caption generation."""
def __init__(self, model_id: str = MODEL_ID, device: str = MODEL_DEVICE):
self.model_id = model_id
self.device = device
self.model = None
self.processor = None
def load(self, torch_dtype: torch.dtype = torch.float16):
"""Load model and processor into memory (once)."""
_LOGGER.info("Loading model %s on %s", self.model_id, self.device)
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
self.model = (
AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
.to(self.device)
.eval()
)
def generate_captions(
self,
images: List[torch.Tensor],
prompt: str,
max_new_tokens: int = 128,
do_sample: bool = False,
temperature: float = 0.0,
) -> List[str]:
"""Generate captions for a batch of preprocessed image tensors."""
if self.model is None or self.processor is None:
raise RuntimeError("Model is not loaded. Call load() first.")
inputs = self.processor(
text=[prompt] * len(images),
images=images,
return_tensors="pt",
).to(self.device)
generated = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
)
captions = self.processor.batch_decode(generated, skip_special_tokens=True)
return captions

283
src/pipeline.py Normal file
View File

@ -0,0 +1,283 @@
import argparse
import asyncio
import logging
from collections.abc import AsyncIterable
from typing import Iterable, List, Optional, Tuple, Union
import aiohttp
import torch
from . import config
from .db import fetch_image_urls, write_captions
from .download import create_aiohttp_session, DownloadWorker
from .metrics import Metrics
from .model import Florence2CaptionModel
from .preprocess import preprocess_image
_LOGGER = logging.getLogger(__name__)
def _configure_logging(level: int = logging.INFO) -> None:
"""Configure root logger for the pipeline."""
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
async def _producer(queue: asyncio.Queue, urls: Iterable[Tuple[int, str]]):
for image_id, url in urls:
await queue.put((image_id, url))
await queue.put(None) # sentinel to indicate completion
async def _producer_async(queue: asyncio.Queue, urls: AsyncIterable[Tuple[int, str]]):
async for image_id, url in urls:
await queue.put((image_id, url))
await queue.put(None) # sentinel to indicate completion
async def _download_consumer(
input_queue: asyncio.Queue,
image_queue: asyncio.Queue,
session: aiohttp.ClientSession,
metrics: Metrics,
):
worker = DownloadWorker(image_queue, session)
try:
while True:
item = await input_queue.get()
input_queue.task_done()
if item is None:
await image_queue.put(None)
break
image_id, url = item
await worker.run(image_id, url)
except asyncio.CancelledError:
_LOGGER.info("Download consumer cancelled")
raise
except Exception as e:
_LOGGER.exception("Download consumer error: %s", e)
raise
async def _caption_worker(
image_queue: asyncio.Queue,
result_queue: asyncio.Queue,
model: Florence2CaptionModel,
metrics: Metrics,
prompt: str,
):
buffer: List[Tuple[int, torch.Tensor]] = []
try:
while True:
item = await image_queue.get()
image_queue.task_done()
if item is None:
break
image_id, img = item
try:
tensor = preprocess_image(img, device=config.MODEL_DEVICE)
buffer.append((image_id, tensor))
if len(buffer) >= config.GPU_BATCH_SIZE:
await _flush_batch(buffer, model, result_queue, metrics, prompt)
except Exception as e:
metrics.mark_caption_failure()
_LOGGER.exception("Preprocess/inference error for %s: %s", image_id, e)
if buffer:
await _flush_batch(buffer, model, result_queue, metrics, prompt)
await result_queue.put(None)
except asyncio.CancelledError:
_LOGGER.info("Caption worker cancelled")
raise
except Exception as e:
_LOGGER.exception("Caption worker failure: %s", e)
raise
async def _flush_batch(
buffer: List[Tuple[int, torch.Tensor]],
model: Florence2CaptionModel,
result_queue: asyncio.Queue,
metrics: Metrics,
prompt: str,
):
ids, tensors = zip(*buffer)
captions = model.generate_captions(list(tensors), prompt=prompt)
for image_id, caption in zip(ids, captions):
await result_queue.put((image_id, caption))
metrics.mark_image_processed()
buffer.clear()
async def _db_writer(result_queue: asyncio.Queue, metrics: Metrics):
batch: List[Tuple[int, str]] = []
try:
while True:
item = await result_queue.get()
result_queue.task_done()
if item is None:
break
batch.append(item)
if len(batch) >= config.DB_WRITE_BATCH_SIZE:
await write_captions(batch)
batch.clear()
if batch:
await write_captions(batch)
except asyncio.CancelledError:
_LOGGER.info("DB writer cancelled")
raise
except Exception as e:
_LOGGER.exception("DB writer error: %s", e)
raise
async def run_pipeline(
urls: Union[Iterable[Tuple[int, str]], AsyncIterable[Tuple[int, str]]],
dry_run: bool = config.DRY_RUN,
prompt_token: str = config.PROMPT_TOKEN,
):
_LOGGER.info("Starting pipeline (dry_run=%s)", dry_run)
metrics = Metrics()
model = Florence2CaptionModel()
model.load()
download_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
image_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
result_queue: asyncio.Queue = asyncio.Queue(maxsize=config.RESULT_QUEUE_MAX_SIZE)
async with create_aiohttp_session() as session:
if isinstance(urls, AsyncIterable):
producer_task = asyncio.create_task(_producer_async(download_queue, urls))
else:
producer_task = asyncio.create_task(_producer(download_queue, urls))
downloader_task = asyncio.create_task(
_download_consumer(download_queue, image_queue, session, metrics)
)
caption_task = asyncio.create_task(
_caption_worker(image_queue, result_queue, model, metrics, prompt_token)
)
writer_task = asyncio.create_task(_db_writer(result_queue, metrics))
tasks = [producer_task, downloader_task, caption_task, writer_task]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
# Cancel any remaining tasks if something failed
for p in pending:
p.cancel()
# Wait for cancellation to complete
await asyncio.gather(*pending, return_exceptions=True)
# If any task raised an exception, propagate it
for t in done:
if t.exception():
raise t.exception()
_LOGGER.info("Pipeline completed: %s", metrics.summary())
async def _dry_run_urls() -> List[Tuple[int, str]]:
# Example placeholder list for dry run.
return [
(1, "https://example.com/image1.jpg"),
(2, "https://example.com/image2.jpg"),
]
async def _run_from_db():
"""Run pipeline using async stream from the database."""
urls = fetch_image_urls(
table=config.DB_TABLE,
id_column=config.DB_ID_COLUMN,
url_column=config.DB_URL_COLUMN,
where_clause=config.DB_QUERY_WHERE,
)
await run_pipeline(urls, dry_run=False)
def main():
parser = argparse.ArgumentParser(description="Async Florence-2 captioning pipeline")
parser.add_argument("--dry-run", action="store_true", help="Run without DB, using sample URLs")
parser.add_argument("--log-level", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)")
# Model overrides
parser.add_argument("--model-id", help="HuggingFace model ID")
parser.add_argument("--model-device", help="Device to run model on (cuda/cpu)")
parser.add_argument("--prompt-token", help="Task token to use for captioning")
# Pipeline sizing
parser.add_argument("--gpu-batch-size", type=int, help="Batch size for GPU inference")
parser.add_argument("--download-concurrency", type=int, help="Max concurrent downloads")
parser.add_argument("--image-queue-max", type=int, help="Max size for image queue")
parser.add_argument("--result-queue-max", type=int, help="Max size for result queue")
# DB schema configuration
parser.add_argument("--db-table", help="Database table containing images")
parser.add_argument("--db-id-col", help="ID column name")
parser.add_argument("--db-url-col", help="URL column name")
parser.add_argument("--db-caption-col", help="Caption column name")
parser.add_argument(
"--db-where",
help="WHERE clause to filter rows (e.g., \"caption IS NULL\")",
)
args = parser.parse_args()
log_level = getattr(logging, args.log_level.upper(), logging.INFO)
_configure_logging(log_level)
config.apply_overrides(
MODEL_ID=args.model_id,
MODEL_DEVICE=args.model_device,
PROMPT_TOKEN=args.prompt_token,
GPU_BATCH_SIZE=args.gpu_batch_size,
DOWNLOAD_CONCURRENCY=args.download_concurrency,
IMAGE_QUEUE_MAX_SIZE=args.image_queue_max,
RESULT_QUEUE_MAX_SIZE=args.result_queue_max,
DB_TABLE=args.db_table,
DB_ID_COLUMN=args.db_id_col,
DB_URL_COLUMN=args.db_url_col,
DB_CAPTION_COLUMN=args.db_caption_col,
DB_QUERY_WHERE=args.db_where,
)
_LOGGER.debug("Effective config: %s", {k: getattr(config, k) for k in [
"MODEL_ID",
"MODEL_DEVICE",
"PROMPT_TOKEN",
"GPU_BATCH_SIZE",
"DOWNLOAD_CONCURRENCY",
"IMAGE_QUEUE_MAX_SIZE",
"RESULT_QUEUE_MAX_SIZE",
"DB_TABLE",
"DB_ID_COLUMN",
"DB_URL_COLUMN",
"DB_CAPTION_COLUMN",
"DB_QUERY_WHERE",
]})
try:
if args.dry_run:
urls = asyncio.run(_dry_run_urls())
asyncio.run(run_pipeline(urls, dry_run=True))
else:
asyncio.run(_run_from_db())
except KeyboardInterrupt:
_LOGGER.warning("Received interrupt signal, shutting down...")
except Exception:
_LOGGER.exception("Pipeline failed")
if __name__ == "__main__":
main()

39
src/preprocess.py Normal file
View File

@ -0,0 +1,39 @@
from typing import Tuple
import numpy as np
import torch
from PIL import Image
from . import config
def resize_preserve_aspect(img: Image.Image, max_side: int | None = None) -> Image.Image:
"""Resize so the longest side is `max_side`, preserving aspect ratio."""
if max_side is None:
max_side = config.IMAGE_MAX_SIDE
img = img.convert("RGB")
img.thumbnail((max_side, max_side), Image.LANCZOS)
return img
def image_to_tensor(img: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a CHW float tensor scaled to [0, 1]."""
arr = torch.from_numpy(np.array(img))
tensor = arr.permute(2, 0, 1).float() / 255.0
return tensor
def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Normalize tensor with ImageNet/vision mean/std."""
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device)[..., None, None]
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device)[..., None, None]
return (tensor - mean) / std
def preprocess_image(img: Image.Image, device: str = "cpu") -> torch.Tensor:
"""Resize, convert, normalize, and move to device."""
resized = resize_preserve_aspect(img)
tensor = image_to_tensor(resized).to(device)
norm = normalize_tensor(tensor)
return norm

57
tests/test_pipeline.py Normal file
View File

@ -0,0 +1,57 @@
import asyncio
from src.preprocess import resize_preserve_aspect
def test_resize_preserve_aspect():
from PIL import Image
img = Image.new("RGB", (1024, 512), color="white")
resized = resize_preserve_aspect(img, max_side=256)
assert max(resized.size) == 256
def test_pipeline_writes_captions_to_db(monkeypatch):
"""Ensure pipeline calls write_captions with non-empty captions."""
# Stub out download to avoid network
async def fake_download_image(url, session, retries=2):
return object()
monkeypatch.setattr("src.download.download_image", fake_download_image)
# Stub out preprocess to return a dummy tensor-like object (skip Pillow / torch entirely)
def fake_preprocess_image(img, device="cpu"):
return object()
monkeypatch.setattr("src.preprocess.preprocess_image", fake_preprocess_image)
# Stub out model load + generate
def fake_load(self):
self.model = True
self.processor = True
def fake_generate_captions(self, images, prompt, **kwargs):
return [f"caption-for-{i}" for i in range(len(images))]
monkeypatch.setattr("src.model.Florence2CaptionModel.load", fake_load)
monkeypatch.setattr(
"src.model.Florence2CaptionModel.generate_captions", fake_generate_captions
)
# Capture writes
written = []
async def fake_write_captions(updates, batch_size=64):
written.extend(updates)
monkeypatch.setattr("src.db.write_captions", fake_write_captions)
# Run pipeline with a small dummy URL list
from src import pipeline
urls = [(1, "http://example.com/image1.jpg"), (2, "http://example.com/image2.jpg")]
asyncio.run(pipeline.run_pipeline(urls, dry_run=True))
assert len(written) == 2
assert all(isinstance(c, str) and len(c) > 0 for _, c in written)