Initial Commit
This commit is contained in:
commit
14db6bcd61
|
|
@ -0,0 +1,223 @@
|
|||
|
||||
# project specifics
|
||||
|
||||
config.toml
|
||||
user-docs/
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
# .idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"github.copilot.chat.reasoningEffort": "high",
|
||||
"github.copilot.chat.responsesApiReasoningEffort": "high",
|
||||
"github.copilot.selectedModel": ""
|
||||
}
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# Florence-2 Captioning Pipeline
|
||||
|
||||
High-throughput asynchronous captioning pipeline using **Florence-2 Base PromptGen**.
|
||||
|
||||
## Goals
|
||||
|
||||
- Download images from S3/HTTP concurrently
|
||||
- Preprocess (resize/normalize)
|
||||
- Run batched caption generation on GPU
|
||||
- Persist captions back to a database (async)
|
||||
|
||||
## Project structure
|
||||
|
||||
- `src/`: implementation code
|
||||
- `tests/`: unit/integration tests
|
||||
- `todo.md`: tasks list
|
||||
- `implementationPlanV2.md`: architecture + design notes
|
||||
|
||||
## Quickstart
|
||||
|
||||
1. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Configure environment variables (see `src/config.py` for expected vars).
|
||||
|
||||
3. Run the pipeline (example):
|
||||
|
||||
```bash
|
||||
python -m src.pipeline --dry-run
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
This repo is intended as a foundation for building a fast, async dataset captioning tool.
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Configuration for Florence-2 caption pipeline
|
||||
|
||||
[model]
|
||||
# HuggingFace model id (replace with your model)
|
||||
id = "your-org/your-model"
|
||||
# Device to run model on (cuda/cpu)
|
||||
device = "cuda"
|
||||
# Prompt token used for captioning (replace with your token)
|
||||
prompt_token = "<PROMPT_TOKEN>"
|
||||
|
||||
[preprocessing]
|
||||
# Max side for resizing (longest dimension)
|
||||
image_max_side = 768
|
||||
|
||||
[pipeline]
|
||||
# Batch sizes and queue sizing
|
||||
gpu_batch_size = 8
|
||||
download_concurrency = 16
|
||||
image_queue_max_size = 64
|
||||
result_queue_max_size = 128
|
||||
db_write_batch_size = 64
|
||||
|
||||
[database]
|
||||
# Async DB connection string
|
||||
dsn = ""
|
||||
# Table and column names
|
||||
table = "images"
|
||||
id_column = "id"
|
||||
url_column = "url"
|
||||
caption_column = "caption"
|
||||
# Default WHERE clause used to fetch rows to caption
|
||||
query_where = "caption IS NULL"
|
||||
|
||||
[debug]
|
||||
# Dry run disables writing to DB for quick local tests
|
||||
dry_run = true
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
torch>=2.0.0
|
||||
transformers>=4.40.0
|
||||
aiohttp>=3.9.0
|
||||
psycopg[binary]>=3.2
|
||||
Pillow>=10.0.0
|
||||
python-dotenv>=1.0.0
|
||||
pydantic-settings>=2.2.0
|
||||
pytest>=8.0.0
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Florence-2 captioning pipeline package."""
|
||||
|
||||
__all__ = []
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from .config.settings import settings
|
||||
|
||||
|
||||
# Legacy compatibility helpers (so existing code can use config.<NAME>)
|
||||
MODEL_ID = settings.model_id
|
||||
MODEL_DEVICE = settings.model_device
|
||||
PROMPT_TOKEN = settings.prompt_token
|
||||
|
||||
IMAGE_MAX_SIDE = settings.image_max_side
|
||||
|
||||
GPU_BATCH_SIZE = settings.gpu_batch_size
|
||||
DOWNLOAD_CONCURRENCY = settings.download_concurrency
|
||||
IMAGE_QUEUE_MAX_SIZE = settings.image_queue_max_size
|
||||
RESULT_QUEUE_MAX_SIZE = settings.result_queue_max_size
|
||||
DB_WRITE_BATCH_SIZE = settings.db_write_batch_size
|
||||
|
||||
DB_DSN = settings.db_dsn
|
||||
|
||||
DRY_RUN = settings.dry_run
|
||||
|
||||
DB_TABLE = settings.db_table
|
||||
DB_ID_COLUMN = settings.db_id_column
|
||||
DB_URL_COLUMN = settings.db_url_column
|
||||
DB_CAPTION_COLUMN = settings.db_caption_column
|
||||
DB_QUERY_WHERE = settings.db_query_where
|
||||
|
||||
|
||||
def apply_overrides(**kwargs):
|
||||
"""Override config values at runtime (for CLI overrides)."""
|
||||
|
||||
# Update pydantic settings object
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
|
||||
# Sync legacy constants
|
||||
globals().update({
|
||||
"MODEL_ID": settings.model_id,
|
||||
"MODEL_DEVICE": settings.model_device,
|
||||
"PROMPT_TOKEN": settings.prompt_token,
|
||||
"IMAGE_MAX_SIDE": settings.image_max_side,
|
||||
"GPU_BATCH_SIZE": settings.gpu_batch_size,
|
||||
"DOWNLOAD_CONCURRENCY": settings.download_concurrency,
|
||||
"IMAGE_QUEUE_MAX_SIZE": settings.image_queue_max_size,
|
||||
"RESULT_QUEUE_MAX_SIZE": settings.result_queue_max_size,
|
||||
"DB_WRITE_BATCH_SIZE": settings.db_write_batch_size,
|
||||
"DB_DSN": settings.db_dsn,
|
||||
"DRY_RUN": settings.dry_run,
|
||||
"DB_TABLE": settings.db_table,
|
||||
"DB_ID_COLUMN": settings.db_id_column,
|
||||
"DB_URL_COLUMN": settings.db_url_column,
|
||||
"DB_CAPTION_COLUMN": settings.db_caption_column,
|
||||
"DB_QUERY_WHERE": settings.db_query_where,
|
||||
})
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Package for configuration loaded from TOML via pydantic-settings."""
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from config.toml (or environment variables)."""
|
||||
|
||||
# Model
|
||||
model_id: str = "MiaoshouAI/Florence-2-base-PromptGen-v2"
|
||||
model_device: str = "cuda"
|
||||
prompt_token: str = "<MORE_DETAILED_CAPTION>"
|
||||
|
||||
# Preprocessing
|
||||
image_max_side: int = 768
|
||||
|
||||
# Pipeline sizing
|
||||
gpu_batch_size: int = 8
|
||||
download_concurrency: int = 16
|
||||
image_queue_max_size: int = 64
|
||||
result_queue_max_size: int = 128
|
||||
db_write_batch_size: int = 64
|
||||
|
||||
# Database
|
||||
db_dsn: str = ""
|
||||
db_table: str = "images"
|
||||
db_id_column: str = "id"
|
||||
db_url_column: str = "url"
|
||||
db_caption_column: str = "caption"
|
||||
db_query_where: str = "caption IS NULL"
|
||||
|
||||
# Other
|
||||
dry_run: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file="config.toml",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import logging
|
||||
from typing import AsyncGenerator, Iterable, List, Tuple
|
||||
|
||||
import psycopg
|
||||
|
||||
from .config import DB_DSN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_db_connection():
|
||||
if not DB_DSN:
|
||||
raise ValueError("DB_DSN is not set. Set the DB_DSN environment variable.")
|
||||
return await psycopg.AsyncConnection.connect(DB_DSN)
|
||||
|
||||
|
||||
async def fetch_image_urls(
|
||||
table: str = None,
|
||||
id_column: str = None,
|
||||
url_column: str = None,
|
||||
where_clause: str = None,
|
||||
batch_size: int = 100,
|
||||
) -> AsyncGenerator[Tuple[int, str], None]:
|
||||
"""Yield (image_id, image_url) tuples from the database in batches."""
|
||||
# Prepare query components
|
||||
table = table or "images"
|
||||
id_column = id_column or "id"
|
||||
url_column = url_column or "url"
|
||||
|
||||
query = f"SELECT {id_column}, {url_column} FROM {table}"
|
||||
if where_clause:
|
||||
query += " WHERE " + where_clause
|
||||
|
||||
async with await get_db_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query)
|
||||
while True:
|
||||
rows = await cur.fetchmany(batch_size)
|
||||
if not rows:
|
||||
break
|
||||
for row in rows:
|
||||
yield row[0], row[1]
|
||||
|
||||
|
||||
async def write_captions(
|
||||
updates: Iterable[Tuple[int, str]],
|
||||
batch_size: int = 64,
|
||||
):
|
||||
"""Write captions back to the DB in batches."""
|
||||
async with await get_db_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
batch: List[Tuple[str, int]] = []
|
||||
for image_id, caption in updates:
|
||||
batch.append((caption, image_id))
|
||||
if len(batch) >= batch_size:
|
||||
await _execute_update(cur, batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
await _execute_update(cur, batch)
|
||||
|
||||
|
||||
async def _execute_update(cur, batch: List[Tuple[str, int]]):
|
||||
await cur.executemany(
|
||||
"""
|
||||
UPDATE images
|
||||
SET caption = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
batch,
|
||||
)
|
||||
await cur.connection.commit()
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import io
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
|
||||
from . import config
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, timeout: int = 30) -> bytes:
|
||||
async with session.get(url, timeout=timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.read()
|
||||
|
||||
|
||||
async def download_image(
|
||||
url: str,
|
||||
session: aiohttp.ClientSession,
|
||||
retries: int = 2,
|
||||
) -> Optional[Image.Image]:
|
||||
"""Download an image and return a PIL Image (RGB)."""
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
data = await fetch_image_bytes(session, url)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
return img
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Download failed (attempt %s/%s) for %s: %s", attempt, retries, url, e)
|
||||
if attempt == retries:
|
||||
return None
|
||||
await asyncio.sleep(0.5 * attempt)
|
||||
|
||||
|
||||
class DownloadWorker:
|
||||
"""Worker that downloads images and pushes them into a queue."""
|
||||
|
||||
def __init__(self, queue, session: aiohttp.ClientSession):
|
||||
self.queue = queue
|
||||
self.session = session
|
||||
|
||||
async def run(self, image_id: int, url: str):
|
||||
img = await download_image(url, self.session)
|
||||
if img is None:
|
||||
return
|
||||
await self.queue.put((image_id, img))
|
||||
|
||||
|
||||
def create_aiohttp_session() -> aiohttp.ClientSession:
|
||||
connector = aiohttp.TCPConnector(limit=config.DOWNLOAD_CONCURRENCY)
|
||||
return aiohttp.ClientSession(connector=connector)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import time
|
||||
|
||||
|
||||
class Metrics:
|
||||
def __init__(self):
|
||||
self.start_ts = time.time()
|
||||
self.images_processed = 0
|
||||
self.download_failures = 0
|
||||
self.caption_failures = 0
|
||||
|
||||
def mark_image_processed(self):
|
||||
self.images_processed += 1
|
||||
|
||||
def mark_download_failure(self):
|
||||
self.download_failures += 1
|
||||
|
||||
def mark_caption_failure(self):
|
||||
self.caption_failures += 1
|
||||
|
||||
def summary(self):
|
||||
elapsed = max(1.0, time.time() - self.start_ts)
|
||||
return {
|
||||
"images_per_second": self.images_processed / elapsed,
|
||||
"download_failures": self.download_failures,
|
||||
"caption_failures": self.caption_failures,
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from .config import MODEL_DEVICE, MODEL_ID
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Florence2CaptionModel:
|
||||
"""Wrapper around Florence-2 PromptGen for batched caption generation."""
|
||||
|
||||
def __init__(self, model_id: str = MODEL_ID, device: str = MODEL_DEVICE):
|
||||
self.model_id = model_id
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
def load(self, torch_dtype: torch.dtype = torch.float16):
|
||||
"""Load model and processor into memory (once)."""
|
||||
_LOGGER.info("Loading model %s on %s", self.model_id, self.device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
def generate_captions(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
prompt: str,
|
||||
max_new_tokens: int = 128,
|
||||
do_sample: bool = False,
|
||||
temperature: float = 0.0,
|
||||
) -> List[str]:
|
||||
"""Generate captions for a batch of preprocessed image tensors."""
|
||||
if self.model is None or self.processor is None:
|
||||
raise RuntimeError("Model is not loaded. Call load() first.")
|
||||
|
||||
inputs = self.processor(
|
||||
text=[prompt] * len(images),
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
generated = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
captions = self.processor.batch_decode(generated, skip_special_tokens=True)
|
||||
return captions
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
|
||||
from . import config
|
||||
from .db import fetch_image_urls, write_captions
|
||||
from .download import create_aiohttp_session, DownloadWorker
|
||||
from .metrics import Metrics
|
||||
from .model import Florence2CaptionModel
|
||||
from .preprocess import preprocess_image
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _configure_logging(level: int = logging.INFO) -> None:
|
||||
"""Configure root logger for the pipeline."""
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
async def _producer(queue: asyncio.Queue, urls: Iterable[Tuple[int, str]]):
|
||||
for image_id, url in urls:
|
||||
await queue.put((image_id, url))
|
||||
await queue.put(None) # sentinel to indicate completion
|
||||
|
||||
|
||||
async def _producer_async(queue: asyncio.Queue, urls: AsyncIterable[Tuple[int, str]]):
|
||||
async for image_id, url in urls:
|
||||
await queue.put((image_id, url))
|
||||
await queue.put(None) # sentinel to indicate completion
|
||||
|
||||
|
||||
async def _download_consumer(
|
||||
input_queue: asyncio.Queue,
|
||||
image_queue: asyncio.Queue,
|
||||
session: aiohttp.ClientSession,
|
||||
metrics: Metrics,
|
||||
):
|
||||
worker = DownloadWorker(image_queue, session)
|
||||
try:
|
||||
while True:
|
||||
item = await input_queue.get()
|
||||
input_queue.task_done()
|
||||
if item is None:
|
||||
await image_queue.put(None)
|
||||
break
|
||||
|
||||
image_id, url = item
|
||||
await worker.run(image_id, url)
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.info("Download consumer cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Download consumer error: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
async def _caption_worker(
|
||||
image_queue: asyncio.Queue,
|
||||
result_queue: asyncio.Queue,
|
||||
model: Florence2CaptionModel,
|
||||
metrics: Metrics,
|
||||
prompt: str,
|
||||
):
|
||||
buffer: List[Tuple[int, torch.Tensor]] = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await image_queue.get()
|
||||
image_queue.task_done()
|
||||
|
||||
if item is None:
|
||||
break
|
||||
|
||||
image_id, img = item
|
||||
try:
|
||||
tensor = preprocess_image(img, device=config.MODEL_DEVICE)
|
||||
buffer.append((image_id, tensor))
|
||||
if len(buffer) >= config.GPU_BATCH_SIZE:
|
||||
await _flush_batch(buffer, model, result_queue, metrics, prompt)
|
||||
except Exception as e:
|
||||
metrics.mark_caption_failure()
|
||||
_LOGGER.exception("Preprocess/inference error for %s: %s", image_id, e)
|
||||
|
||||
if buffer:
|
||||
await _flush_batch(buffer, model, result_queue, metrics, prompt)
|
||||
|
||||
await result_queue.put(None)
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.info("Caption worker cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Caption worker failure: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
async def _flush_batch(
|
||||
buffer: List[Tuple[int, torch.Tensor]],
|
||||
model: Florence2CaptionModel,
|
||||
result_queue: asyncio.Queue,
|
||||
metrics: Metrics,
|
||||
prompt: str,
|
||||
):
|
||||
ids, tensors = zip(*buffer)
|
||||
captions = model.generate_captions(list(tensors), prompt=prompt)
|
||||
for image_id, caption in zip(ids, captions):
|
||||
await result_queue.put((image_id, caption))
|
||||
metrics.mark_image_processed()
|
||||
buffer.clear()
|
||||
|
||||
|
||||
async def _db_writer(result_queue: asyncio.Queue, metrics: Metrics):
|
||||
batch: List[Tuple[int, str]] = []
|
||||
try:
|
||||
while True:
|
||||
item = await result_queue.get()
|
||||
result_queue.task_done()
|
||||
if item is None:
|
||||
break
|
||||
|
||||
batch.append(item)
|
||||
if len(batch) >= config.DB_WRITE_BATCH_SIZE:
|
||||
await write_captions(batch)
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
await write_captions(batch)
|
||||
except asyncio.CancelledError:
|
||||
_LOGGER.info("DB writer cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
_LOGGER.exception("DB writer error: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
async def run_pipeline(
|
||||
urls: Union[Iterable[Tuple[int, str]], AsyncIterable[Tuple[int, str]]],
|
||||
dry_run: bool = config.DRY_RUN,
|
||||
prompt_token: str = config.PROMPT_TOKEN,
|
||||
):
|
||||
_LOGGER.info("Starting pipeline (dry_run=%s)", dry_run)
|
||||
metrics = Metrics()
|
||||
model = Florence2CaptionModel()
|
||||
model.load()
|
||||
|
||||
download_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
|
||||
image_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
|
||||
result_queue: asyncio.Queue = asyncio.Queue(maxsize=config.RESULT_QUEUE_MAX_SIZE)
|
||||
|
||||
async with create_aiohttp_session() as session:
|
||||
if isinstance(urls, AsyncIterable):
|
||||
producer_task = asyncio.create_task(_producer_async(download_queue, urls))
|
||||
else:
|
||||
producer_task = asyncio.create_task(_producer(download_queue, urls))
|
||||
|
||||
downloader_task = asyncio.create_task(
|
||||
_download_consumer(download_queue, image_queue, session, metrics)
|
||||
)
|
||||
caption_task = asyncio.create_task(
|
||||
_caption_worker(image_queue, result_queue, model, metrics, prompt_token)
|
||||
)
|
||||
writer_task = asyncio.create_task(_db_writer(result_queue, metrics))
|
||||
|
||||
tasks = [producer_task, downloader_task, caption_task, writer_task]
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||||
|
||||
# Cancel any remaining tasks if something failed
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
|
||||
# Wait for cancellation to complete
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
# If any task raised an exception, propagate it
|
||||
for t in done:
|
||||
if t.exception():
|
||||
raise t.exception()
|
||||
|
||||
_LOGGER.info("Pipeline completed: %s", metrics.summary())
|
||||
|
||||
|
||||
async def _dry_run_urls() -> List[Tuple[int, str]]:
|
||||
# Example placeholder list for dry run.
|
||||
return [
|
||||
(1, "https://example.com/image1.jpg"),
|
||||
(2, "https://example.com/image2.jpg"),
|
||||
]
|
||||
|
||||
|
||||
async def _run_from_db():
|
||||
"""Run pipeline using async stream from the database."""
|
||||
urls = fetch_image_urls(
|
||||
table=config.DB_TABLE,
|
||||
id_column=config.DB_ID_COLUMN,
|
||||
url_column=config.DB_URL_COLUMN,
|
||||
where_clause=config.DB_QUERY_WHERE,
|
||||
)
|
||||
await run_pipeline(urls, dry_run=False)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Async Florence-2 captioning pipeline")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Run without DB, using sample URLs")
|
||||
parser.add_argument("--log-level", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)")
|
||||
|
||||
# Model overrides
|
||||
parser.add_argument("--model-id", help="HuggingFace model ID")
|
||||
parser.add_argument("--model-device", help="Device to run model on (cuda/cpu)")
|
||||
parser.add_argument("--prompt-token", help="Task token to use for captioning")
|
||||
|
||||
# Pipeline sizing
|
||||
parser.add_argument("--gpu-batch-size", type=int, help="Batch size for GPU inference")
|
||||
parser.add_argument("--download-concurrency", type=int, help="Max concurrent downloads")
|
||||
parser.add_argument("--image-queue-max", type=int, help="Max size for image queue")
|
||||
parser.add_argument("--result-queue-max", type=int, help="Max size for result queue")
|
||||
|
||||
# DB schema configuration
|
||||
parser.add_argument("--db-table", help="Database table containing images")
|
||||
parser.add_argument("--db-id-col", help="ID column name")
|
||||
parser.add_argument("--db-url-col", help="URL column name")
|
||||
parser.add_argument("--db-caption-col", help="Caption column name")
|
||||
parser.add_argument(
|
||||
"--db-where",
|
||||
help="WHERE clause to filter rows (e.g., \"caption IS NULL\")",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log_level = getattr(logging, args.log_level.upper(), logging.INFO)
|
||||
_configure_logging(log_level)
|
||||
|
||||
config.apply_overrides(
|
||||
MODEL_ID=args.model_id,
|
||||
MODEL_DEVICE=args.model_device,
|
||||
PROMPT_TOKEN=args.prompt_token,
|
||||
GPU_BATCH_SIZE=args.gpu_batch_size,
|
||||
DOWNLOAD_CONCURRENCY=args.download_concurrency,
|
||||
IMAGE_QUEUE_MAX_SIZE=args.image_queue_max,
|
||||
RESULT_QUEUE_MAX_SIZE=args.result_queue_max,
|
||||
DB_TABLE=args.db_table,
|
||||
DB_ID_COLUMN=args.db_id_col,
|
||||
DB_URL_COLUMN=args.db_url_col,
|
||||
DB_CAPTION_COLUMN=args.db_caption_col,
|
||||
DB_QUERY_WHERE=args.db_where,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Effective config: %s", {k: getattr(config, k) for k in [
|
||||
"MODEL_ID",
|
||||
"MODEL_DEVICE",
|
||||
"PROMPT_TOKEN",
|
||||
"GPU_BATCH_SIZE",
|
||||
"DOWNLOAD_CONCURRENCY",
|
||||
"IMAGE_QUEUE_MAX_SIZE",
|
||||
"RESULT_QUEUE_MAX_SIZE",
|
||||
"DB_TABLE",
|
||||
"DB_ID_COLUMN",
|
||||
"DB_URL_COLUMN",
|
||||
"DB_CAPTION_COLUMN",
|
||||
"DB_QUERY_WHERE",
|
||||
]})
|
||||
|
||||
try:
|
||||
if args.dry_run:
|
||||
urls = asyncio.run(_dry_run_urls())
|
||||
asyncio.run(run_pipeline(urls, dry_run=True))
|
||||
else:
|
||||
asyncio.run(_run_from_db())
|
||||
except KeyboardInterrupt:
|
||||
_LOGGER.warning("Received interrupt signal, shutting down...")
|
||||
except Exception:
|
||||
_LOGGER.exception("Pipeline failed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
def resize_preserve_aspect(img: Image.Image, max_side: int | None = None) -> Image.Image:
|
||||
"""Resize so the longest side is `max_side`, preserving aspect ratio."""
|
||||
if max_side is None:
|
||||
max_side = config.IMAGE_MAX_SIDE
|
||||
|
||||
img = img.convert("RGB")
|
||||
img.thumbnail((max_side, max_side), Image.LANCZOS)
|
||||
return img
|
||||
|
||||
|
||||
def image_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||
"""Convert a PIL image to a CHW float tensor scaled to [0, 1]."""
|
||||
arr = torch.from_numpy(np.array(img))
|
||||
tensor = arr.permute(2, 0, 1).float() / 255.0
|
||||
return tensor
|
||||
|
||||
|
||||
def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize tensor with ImageNet/vision mean/std."""
|
||||
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device)[..., None, None]
|
||||
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device)[..., None, None]
|
||||
return (tensor - mean) / std
|
||||
|
||||
|
||||
def preprocess_image(img: Image.Image, device: str = "cpu") -> torch.Tensor:
|
||||
"""Resize, convert, normalize, and move to device."""
|
||||
resized = resize_preserve_aspect(img)
|
||||
tensor = image_to_tensor(resized).to(device)
|
||||
norm = normalize_tensor(tensor)
|
||||
return norm
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import asyncio
|
||||
|
||||
from src.preprocess import resize_preserve_aspect
|
||||
|
||||
|
||||
def test_resize_preserve_aspect():
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (1024, 512), color="white")
|
||||
resized = resize_preserve_aspect(img, max_side=256)
|
||||
assert max(resized.size) == 256
|
||||
|
||||
|
||||
def test_pipeline_writes_captions_to_db(monkeypatch):
|
||||
"""Ensure pipeline calls write_captions with non-empty captions."""
|
||||
|
||||
# Stub out download to avoid network
|
||||
async def fake_download_image(url, session, retries=2):
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr("src.download.download_image", fake_download_image)
|
||||
|
||||
# Stub out preprocess to return a dummy tensor-like object (skip Pillow / torch entirely)
|
||||
def fake_preprocess_image(img, device="cpu"):
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr("src.preprocess.preprocess_image", fake_preprocess_image)
|
||||
|
||||
# Stub out model load + generate
|
||||
def fake_load(self):
|
||||
self.model = True
|
||||
self.processor = True
|
||||
|
||||
def fake_generate_captions(self, images, prompt, **kwargs):
|
||||
return [f"caption-for-{i}" for i in range(len(images))]
|
||||
|
||||
monkeypatch.setattr("src.model.Florence2CaptionModel.load", fake_load)
|
||||
monkeypatch.setattr(
|
||||
"src.model.Florence2CaptionModel.generate_captions", fake_generate_captions
|
||||
)
|
||||
|
||||
# Capture writes
|
||||
written = []
|
||||
|
||||
async def fake_write_captions(updates, batch_size=64):
|
||||
written.extend(updates)
|
||||
|
||||
monkeypatch.setattr("src.db.write_captions", fake_write_captions)
|
||||
|
||||
# Run pipeline with a small dummy URL list
|
||||
from src import pipeline
|
||||
|
||||
urls = [(1, "http://example.com/image1.jpg"), (2, "http://example.com/image2.jpg")]
|
||||
asyncio.run(pipeline.run_pipeline(urls, dry_run=True))
|
||||
|
||||
assert len(written) == 2
|
||||
assert all(isinstance(c, str) and len(c) > 0 for _, c in written)
|
||||
Loading…
Reference in New Issue