Initial Commit
This commit is contained in:
commit
14db6bcd61
|
|
@ -0,0 +1,223 @@
|
||||||
|
|
||||||
|
# project specifics
|
||||||
|
|
||||||
|
config.toml
|
||||||
|
user-docs/
|
||||||
|
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"github.copilot.chat.reasoningEffort": "high",
|
||||||
|
"github.copilot.chat.responsesApiReasoningEffort": "high",
|
||||||
|
"github.copilot.selectedModel": ""
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Florence-2 Captioning Pipeline
|
||||||
|
|
||||||
|
High-throughput asynchronous captioning pipeline using **Florence-2 Base PromptGen**.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Download images from S3/HTTP concurrently
|
||||||
|
- Preprocess (resize/normalize)
|
||||||
|
- Run batched caption generation on GPU
|
||||||
|
- Persist captions back to a database (async)
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
- `src/`: implementation code
|
||||||
|
- `tests/`: unit/integration tests
|
||||||
|
- `todo.md`: tasks list
|
||||||
|
- `implementationPlanV2.md`: architecture + design notes
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1. Install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Configure environment variables (see `src/config.py` for expected vars).
|
||||||
|
|
||||||
|
3. Run the pipeline (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.pipeline --dry-run
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
This repo is intended as a foundation for building a fast, async dataset captioning tool.
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Configuration for Florence-2 caption pipeline
|
||||||
|
|
||||||
|
[model]
|
||||||
|
# HuggingFace model id (replace with your model)
|
||||||
|
id = "your-org/your-model"
|
||||||
|
# Device to run model on (cuda/cpu)
|
||||||
|
device = "cuda"
|
||||||
|
# Prompt token used for captioning (replace with your token)
|
||||||
|
prompt_token = "<PROMPT_TOKEN>"
|
||||||
|
|
||||||
|
[preprocessing]
|
||||||
|
# Max side for resizing (longest dimension)
|
||||||
|
image_max_side = 768
|
||||||
|
|
||||||
|
[pipeline]
|
||||||
|
# Batch sizes and queue sizing
|
||||||
|
gpu_batch_size = 8
|
||||||
|
download_concurrency = 16
|
||||||
|
image_queue_max_size = 64
|
||||||
|
result_queue_max_size = 128
|
||||||
|
db_write_batch_size = 64
|
||||||
|
|
||||||
|
[database]
|
||||||
|
# Async DB connection string
|
||||||
|
dsn = ""
|
||||||
|
# Table and column names
|
||||||
|
table = "images"
|
||||||
|
id_column = "id"
|
||||||
|
url_column = "url"
|
||||||
|
caption_column = "caption"
|
||||||
|
# Default WHERE clause used to fetch rows to caption
|
||||||
|
query_where = "caption IS NULL"
|
||||||
|
|
||||||
|
[debug]
|
||||||
|
# Dry run disables writing to DB for quick local tests
|
||||||
|
dry_run = true
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.40.0
|
||||||
|
aiohttp>=3.9.0
|
||||||
|
psycopg[binary]>=3.2
|
||||||
|
Pillow>=10.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pydantic-settings>=2.2.0
|
||||||
|
pytest>=8.0.0
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
"""Florence-2 captioning pipeline package."""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
from .config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy compatibility helpers (so existing code can use config.<NAME>)
|
||||||
|
MODEL_ID = settings.model_id
|
||||||
|
MODEL_DEVICE = settings.model_device
|
||||||
|
PROMPT_TOKEN = settings.prompt_token
|
||||||
|
|
||||||
|
IMAGE_MAX_SIDE = settings.image_max_side
|
||||||
|
|
||||||
|
GPU_BATCH_SIZE = settings.gpu_batch_size
|
||||||
|
DOWNLOAD_CONCURRENCY = settings.download_concurrency
|
||||||
|
IMAGE_QUEUE_MAX_SIZE = settings.image_queue_max_size
|
||||||
|
RESULT_QUEUE_MAX_SIZE = settings.result_queue_max_size
|
||||||
|
DB_WRITE_BATCH_SIZE = settings.db_write_batch_size
|
||||||
|
|
||||||
|
DB_DSN = settings.db_dsn
|
||||||
|
|
||||||
|
DRY_RUN = settings.dry_run
|
||||||
|
|
||||||
|
DB_TABLE = settings.db_table
|
||||||
|
DB_ID_COLUMN = settings.db_id_column
|
||||||
|
DB_URL_COLUMN = settings.db_url_column
|
||||||
|
DB_CAPTION_COLUMN = settings.db_caption_column
|
||||||
|
DB_QUERY_WHERE = settings.db_query_where
|
||||||
|
|
||||||
|
|
||||||
|
def apply_overrides(**kwargs):
|
||||||
|
"""Override config values at runtime (for CLI overrides)."""
|
||||||
|
|
||||||
|
# Update pydantic settings object
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if hasattr(settings, key):
|
||||||
|
setattr(settings, key, value)
|
||||||
|
|
||||||
|
# Sync legacy constants
|
||||||
|
globals().update({
|
||||||
|
"MODEL_ID": settings.model_id,
|
||||||
|
"MODEL_DEVICE": settings.model_device,
|
||||||
|
"PROMPT_TOKEN": settings.prompt_token,
|
||||||
|
"IMAGE_MAX_SIDE": settings.image_max_side,
|
||||||
|
"GPU_BATCH_SIZE": settings.gpu_batch_size,
|
||||||
|
"DOWNLOAD_CONCURRENCY": settings.download_concurrency,
|
||||||
|
"IMAGE_QUEUE_MAX_SIZE": settings.image_queue_max_size,
|
||||||
|
"RESULT_QUEUE_MAX_SIZE": settings.result_queue_max_size,
|
||||||
|
"DB_WRITE_BATCH_SIZE": settings.db_write_batch_size,
|
||||||
|
"DB_DSN": settings.db_dsn,
|
||||||
|
"DRY_RUN": settings.dry_run,
|
||||||
|
"DB_TABLE": settings.db_table,
|
||||||
|
"DB_ID_COLUMN": settings.db_id_column,
|
||||||
|
"DB_URL_COLUMN": settings.db_url_column,
|
||||||
|
"DB_CAPTION_COLUMN": settings.db_caption_column,
|
||||||
|
"DB_QUERY_WHERE": settings.db_query_where,
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Package for configuration loaded from TOML via pydantic-settings."""
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from config.toml (or environment variables)."""
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model_id: str = "MiaoshouAI/Florence-2-base-PromptGen-v2"
|
||||||
|
model_device: str = "cuda"
|
||||||
|
prompt_token: str = "<MORE_DETAILED_CAPTION>"
|
||||||
|
|
||||||
|
# Preprocessing
|
||||||
|
image_max_side: int = 768
|
||||||
|
|
||||||
|
# Pipeline sizing
|
||||||
|
gpu_batch_size: int = 8
|
||||||
|
download_concurrency: int = 16
|
||||||
|
image_queue_max_size: int = 64
|
||||||
|
result_queue_max_size: int = 128
|
||||||
|
db_write_batch_size: int = 64
|
||||||
|
|
||||||
|
# Database
|
||||||
|
db_dsn: str = ""
|
||||||
|
db_table: str = "images"
|
||||||
|
db_id_column: str = "id"
|
||||||
|
db_url_column: str = "url"
|
||||||
|
db_caption_column: str = "caption"
|
||||||
|
db_query_where: str = "caption IS NULL"
|
||||||
|
|
||||||
|
# Other
|
||||||
|
dry_run: bool = True
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file="config.toml",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
from .config import DB_DSN
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_connection():
|
||||||
|
if not DB_DSN:
|
||||||
|
raise ValueError("DB_DSN is not set. Set the DB_DSN environment variable.")
|
||||||
|
return await psycopg.AsyncConnection.connect(DB_DSN)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_image_urls(
|
||||||
|
table: str = None,
|
||||||
|
id_column: str = None,
|
||||||
|
url_column: str = None,
|
||||||
|
where_clause: str = None,
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> AsyncGenerator[Tuple[int, str], None]:
|
||||||
|
"""Yield (image_id, image_url) tuples from the database in batches."""
|
||||||
|
# Prepare query components
|
||||||
|
table = table or "images"
|
||||||
|
id_column = id_column or "id"
|
||||||
|
url_column = url_column or "url"
|
||||||
|
|
||||||
|
query = f"SELECT {id_column}, {url_column} FROM {table}"
|
||||||
|
if where_clause:
|
||||||
|
query += " WHERE " + where_clause
|
||||||
|
|
||||||
|
async with await get_db_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute(query)
|
||||||
|
while True:
|
||||||
|
rows = await cur.fetchmany(batch_size)
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
for row in rows:
|
||||||
|
yield row[0], row[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def write_captions(
|
||||||
|
updates: Iterable[Tuple[int, str]],
|
||||||
|
batch_size: int = 64,
|
||||||
|
):
|
||||||
|
"""Write captions back to the DB in batches."""
|
||||||
|
async with await get_db_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
batch: List[Tuple[str, int]] = []
|
||||||
|
for image_id, caption in updates:
|
||||||
|
batch.append((caption, image_id))
|
||||||
|
if len(batch) >= batch_size:
|
||||||
|
await _execute_update(cur, batch)
|
||||||
|
batch.clear()
|
||||||
|
if batch:
|
||||||
|
await _execute_update(cur, batch)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_update(cur, batch: List[Tuple[str, int]]):
|
||||||
|
await cur.executemany(
|
||||||
|
"""
|
||||||
|
UPDATE images
|
||||||
|
SET caption = %s
|
||||||
|
WHERE id = %s
|
||||||
|
""",
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
await cur.connection.commit()
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
import io
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from . import config
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_image_bytes(session: aiohttp.ClientSession, url: str, timeout: int = 30) -> bytes:
|
||||||
|
async with session.get(url, timeout=timeout) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.read()
|
||||||
|
|
||||||
|
|
||||||
|
async def download_image(
|
||||||
|
url: str,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
retries: int = 2,
|
||||||
|
) -> Optional[Image.Image]:
|
||||||
|
"""Download an image and return a PIL Image (RGB)."""
|
||||||
|
for attempt in range(1, retries + 1):
|
||||||
|
try:
|
||||||
|
data = await fetch_image_bytes(session, url)
|
||||||
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||||
|
return img
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.warning("Download failed (attempt %s/%s) for %s: %s", attempt, retries, url, e)
|
||||||
|
if attempt == retries:
|
||||||
|
return None
|
||||||
|
await asyncio.sleep(0.5 * attempt)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadWorker:
|
||||||
|
"""Worker that downloads images and pushes them into a queue."""
|
||||||
|
|
||||||
|
def __init__(self, queue, session: aiohttp.ClientSession):
|
||||||
|
self.queue = queue
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def run(self, image_id: int, url: str):
|
||||||
|
img = await download_image(url, self.session)
|
||||||
|
if img is None:
|
||||||
|
return
|
||||||
|
await self.queue.put((image_id, img))
|
||||||
|
|
||||||
|
|
||||||
|
def create_aiohttp_session() -> aiohttp.ClientSession:
|
||||||
|
connector = aiohttp.TCPConnector(limit=config.DOWNLOAD_CONCURRENCY)
|
||||||
|
return aiohttp.ClientSession(connector=connector)
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics:
|
||||||
|
def __init__(self):
|
||||||
|
self.start_ts = time.time()
|
||||||
|
self.images_processed = 0
|
||||||
|
self.download_failures = 0
|
||||||
|
self.caption_failures = 0
|
||||||
|
|
||||||
|
def mark_image_processed(self):
|
||||||
|
self.images_processed += 1
|
||||||
|
|
||||||
|
def mark_download_failure(self):
|
||||||
|
self.download_failures += 1
|
||||||
|
|
||||||
|
def mark_caption_failure(self):
|
||||||
|
self.caption_failures += 1
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
elapsed = max(1.0, time.time() - self.start_ts)
|
||||||
|
return {
|
||||||
|
"images_per_second": self.images_processed / elapsed,
|
||||||
|
"download_failures": self.download_failures,
|
||||||
|
"caption_failures": self.caption_failures,
|
||||||
|
"elapsed_seconds": elapsed,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||||
|
|
||||||
|
from .config import MODEL_DEVICE, MODEL_ID
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2CaptionModel:
|
||||||
|
"""Wrapper around Florence-2 PromptGen for batched caption generation."""
|
||||||
|
|
||||||
|
def __init__(self, model_id: str = MODEL_ID, device: str = MODEL_DEVICE):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.device = device
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
def load(self, torch_dtype: torch.dtype = torch.float16):
|
||||||
|
"""Load model and processor into memory (once)."""
|
||||||
|
_LOGGER.info("Loading model %s on %s", self.model_id, self.device)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
|
||||||
|
self.model = (
|
||||||
|
AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
.to(self.device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_captions(
|
||||||
|
self,
|
||||||
|
images: List[torch.Tensor],
|
||||||
|
prompt: str,
|
||||||
|
max_new_tokens: int = 128,
|
||||||
|
do_sample: bool = False,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate captions for a batch of preprocessed image tensors."""
|
||||||
|
if self.model is None or self.processor is None:
|
||||||
|
raise RuntimeError("Model is not loaded. Call load() first.")
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[prompt] * len(images),
|
||||||
|
images=images,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
generated = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
do_sample=do_sample,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
captions = self.processor.batch_decode(generated, skip_special_tokens=True)
|
||||||
|
return captions
|
||||||
|
|
@ -0,0 +1,283 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import config
|
||||||
|
from .db import fetch_image_urls, write_captions
|
||||||
|
from .download import create_aiohttp_session, DownloadWorker
|
||||||
|
from .metrics import Metrics
|
||||||
|
from .model import Florence2CaptionModel
|
||||||
|
from .preprocess import preprocess_image
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_logging(level: int = logging.INFO) -> None:
|
||||||
|
"""Configure root logger for the pipeline."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _producer(queue: asyncio.Queue, urls: Iterable[Tuple[int, str]]):
|
||||||
|
for image_id, url in urls:
|
||||||
|
await queue.put((image_id, url))
|
||||||
|
await queue.put(None) # sentinel to indicate completion
|
||||||
|
|
||||||
|
|
||||||
|
async def _producer_async(queue: asyncio.Queue, urls: AsyncIterable[Tuple[int, str]]):
|
||||||
|
async for image_id, url in urls:
|
||||||
|
await queue.put((image_id, url))
|
||||||
|
await queue.put(None) # sentinel to indicate completion
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_consumer(
|
||||||
|
input_queue: asyncio.Queue,
|
||||||
|
image_queue: asyncio.Queue,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
metrics: Metrics,
|
||||||
|
):
|
||||||
|
worker = DownloadWorker(image_queue, session)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = await input_queue.get()
|
||||||
|
input_queue.task_done()
|
||||||
|
if item is None:
|
||||||
|
await image_queue.put(None)
|
||||||
|
break
|
||||||
|
|
||||||
|
image_id, url = item
|
||||||
|
await worker.run(image_id, url)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_LOGGER.info("Download consumer cancelled")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.exception("Download consumer error: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _caption_worker(
|
||||||
|
image_queue: asyncio.Queue,
|
||||||
|
result_queue: asyncio.Queue,
|
||||||
|
model: Florence2CaptionModel,
|
||||||
|
metrics: Metrics,
|
||||||
|
prompt: str,
|
||||||
|
):
|
||||||
|
buffer: List[Tuple[int, torch.Tensor]] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = await image_queue.get()
|
||||||
|
image_queue.task_done()
|
||||||
|
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
image_id, img = item
|
||||||
|
try:
|
||||||
|
tensor = preprocess_image(img, device=config.MODEL_DEVICE)
|
||||||
|
buffer.append((image_id, tensor))
|
||||||
|
if len(buffer) >= config.GPU_BATCH_SIZE:
|
||||||
|
await _flush_batch(buffer, model, result_queue, metrics, prompt)
|
||||||
|
except Exception as e:
|
||||||
|
metrics.mark_caption_failure()
|
||||||
|
_LOGGER.exception("Preprocess/inference error for %s: %s", image_id, e)
|
||||||
|
|
||||||
|
if buffer:
|
||||||
|
await _flush_batch(buffer, model, result_queue, metrics, prompt)
|
||||||
|
|
||||||
|
await result_queue.put(None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_LOGGER.info("Caption worker cancelled")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.exception("Caption worker failure: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _flush_batch(
|
||||||
|
buffer: List[Tuple[int, torch.Tensor]],
|
||||||
|
model: Florence2CaptionModel,
|
||||||
|
result_queue: asyncio.Queue,
|
||||||
|
metrics: Metrics,
|
||||||
|
prompt: str,
|
||||||
|
):
|
||||||
|
ids, tensors = zip(*buffer)
|
||||||
|
captions = model.generate_captions(list(tensors), prompt=prompt)
|
||||||
|
for image_id, caption in zip(ids, captions):
|
||||||
|
await result_queue.put((image_id, caption))
|
||||||
|
metrics.mark_image_processed()
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def _db_writer(result_queue: asyncio.Queue, metrics: Metrics):
|
||||||
|
batch: List[Tuple[int, str]] = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
item = await result_queue.get()
|
||||||
|
result_queue.task_done()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch.append(item)
|
||||||
|
if len(batch) >= config.DB_WRITE_BATCH_SIZE:
|
||||||
|
await write_captions(batch)
|
||||||
|
batch.clear()
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
await write_captions(batch)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_LOGGER.info("DB writer cancelled")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.exception("DB writer error: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def run_pipeline(
|
||||||
|
urls: Union[Iterable[Tuple[int, str]], AsyncIterable[Tuple[int, str]]],
|
||||||
|
dry_run: bool = config.DRY_RUN,
|
||||||
|
prompt_token: str = config.PROMPT_TOKEN,
|
||||||
|
):
|
||||||
|
_LOGGER.info("Starting pipeline (dry_run=%s)", dry_run)
|
||||||
|
metrics = Metrics()
|
||||||
|
model = Florence2CaptionModel()
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
download_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
|
||||||
|
image_queue: asyncio.Queue = asyncio.Queue(maxsize=config.IMAGE_QUEUE_MAX_SIZE)
|
||||||
|
result_queue: asyncio.Queue = asyncio.Queue(maxsize=config.RESULT_QUEUE_MAX_SIZE)
|
||||||
|
|
||||||
|
async with create_aiohttp_session() as session:
|
||||||
|
if isinstance(urls, AsyncIterable):
|
||||||
|
producer_task = asyncio.create_task(_producer_async(download_queue, urls))
|
||||||
|
else:
|
||||||
|
producer_task = asyncio.create_task(_producer(download_queue, urls))
|
||||||
|
|
||||||
|
downloader_task = asyncio.create_task(
|
||||||
|
_download_consumer(download_queue, image_queue, session, metrics)
|
||||||
|
)
|
||||||
|
caption_task = asyncio.create_task(
|
||||||
|
_caption_worker(image_queue, result_queue, model, metrics, prompt_token)
|
||||||
|
)
|
||||||
|
writer_task = asyncio.create_task(_db_writer(result_queue, metrics))
|
||||||
|
|
||||||
|
tasks = [producer_task, downloader_task, caption_task, writer_task]
|
||||||
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||||||
|
|
||||||
|
# Cancel any remaining tasks if something failed
|
||||||
|
for p in pending:
|
||||||
|
p.cancel()
|
||||||
|
|
||||||
|
# Wait for cancellation to complete
|
||||||
|
await asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
|
||||||
|
# If any task raised an exception, propagate it
|
||||||
|
for t in done:
|
||||||
|
if t.exception():
|
||||||
|
raise t.exception()
|
||||||
|
|
||||||
|
_LOGGER.info("Pipeline completed: %s", metrics.summary())
|
||||||
|
|
||||||
|
|
||||||
|
async def _dry_run_urls() -> List[Tuple[int, str]]:
|
||||||
|
# Example placeholder list for dry run.
|
||||||
|
return [
|
||||||
|
(1, "https://example.com/image1.jpg"),
|
||||||
|
(2, "https://example.com/image2.jpg"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_from_db():
|
||||||
|
"""Run pipeline using async stream from the database."""
|
||||||
|
urls = fetch_image_urls(
|
||||||
|
table=config.DB_TABLE,
|
||||||
|
id_column=config.DB_ID_COLUMN,
|
||||||
|
url_column=config.DB_URL_COLUMN,
|
||||||
|
where_clause=config.DB_QUERY_WHERE,
|
||||||
|
)
|
||||||
|
await run_pipeline(urls, dry_run=False)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Async Florence-2 captioning pipeline")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Run without DB, using sample URLs")
|
||||||
|
parser.add_argument("--log-level", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)")
|
||||||
|
|
||||||
|
# Model overrides
|
||||||
|
parser.add_argument("--model-id", help="HuggingFace model ID")
|
||||||
|
parser.add_argument("--model-device", help="Device to run model on (cuda/cpu)")
|
||||||
|
parser.add_argument("--prompt-token", help="Task token to use for captioning")
|
||||||
|
|
||||||
|
# Pipeline sizing
|
||||||
|
parser.add_argument("--gpu-batch-size", type=int, help="Batch size for GPU inference")
|
||||||
|
parser.add_argument("--download-concurrency", type=int, help="Max concurrent downloads")
|
||||||
|
parser.add_argument("--image-queue-max", type=int, help="Max size for image queue")
|
||||||
|
parser.add_argument("--result-queue-max", type=int, help="Max size for result queue")
|
||||||
|
|
||||||
|
# DB schema configuration
|
||||||
|
parser.add_argument("--db-table", help="Database table containing images")
|
||||||
|
parser.add_argument("--db-id-col", help="ID column name")
|
||||||
|
parser.add_argument("--db-url-col", help="URL column name")
|
||||||
|
parser.add_argument("--db-caption-col", help="Caption column name")
|
||||||
|
parser.add_argument(
|
||||||
|
"--db-where",
|
||||||
|
help="WHERE clause to filter rows (e.g., \"caption IS NULL\")",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
log_level = getattr(logging, args.log_level.upper(), logging.INFO)
|
||||||
|
_configure_logging(log_level)
|
||||||
|
|
||||||
|
config.apply_overrides(
|
||||||
|
MODEL_ID=args.model_id,
|
||||||
|
MODEL_DEVICE=args.model_device,
|
||||||
|
PROMPT_TOKEN=args.prompt_token,
|
||||||
|
GPU_BATCH_SIZE=args.gpu_batch_size,
|
||||||
|
DOWNLOAD_CONCURRENCY=args.download_concurrency,
|
||||||
|
IMAGE_QUEUE_MAX_SIZE=args.image_queue_max,
|
||||||
|
RESULT_QUEUE_MAX_SIZE=args.result_queue_max,
|
||||||
|
DB_TABLE=args.db_table,
|
||||||
|
DB_ID_COLUMN=args.db_id_col,
|
||||||
|
DB_URL_COLUMN=args.db_url_col,
|
||||||
|
DB_CAPTION_COLUMN=args.db_caption_col,
|
||||||
|
DB_QUERY_WHERE=args.db_where,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("Effective config: %s", {k: getattr(config, k) for k in [
|
||||||
|
"MODEL_ID",
|
||||||
|
"MODEL_DEVICE",
|
||||||
|
"PROMPT_TOKEN",
|
||||||
|
"GPU_BATCH_SIZE",
|
||||||
|
"DOWNLOAD_CONCURRENCY",
|
||||||
|
"IMAGE_QUEUE_MAX_SIZE",
|
||||||
|
"RESULT_QUEUE_MAX_SIZE",
|
||||||
|
"DB_TABLE",
|
||||||
|
"DB_ID_COLUMN",
|
||||||
|
"DB_URL_COLUMN",
|
||||||
|
"DB_CAPTION_COLUMN",
|
||||||
|
"DB_QUERY_WHERE",
|
||||||
|
]})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.dry_run:
|
||||||
|
urls = asyncio.run(_dry_run_urls())
|
||||||
|
asyncio.run(run_pipeline(urls, dry_run=True))
|
||||||
|
else:
|
||||||
|
asyncio.run(_run_from_db())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
_LOGGER.warning("Received interrupt signal, shutting down...")
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Pipeline failed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from . import config
|
||||||
|
|
||||||
|
|
||||||
|
def resize_preserve_aspect(img: Image.Image, max_side: int | None = None) -> Image.Image:
|
||||||
|
"""Resize so the longest side is `max_side`, preserving aspect ratio."""
|
||||||
|
if max_side is None:
|
||||||
|
max_side = config.IMAGE_MAX_SIDE
|
||||||
|
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.thumbnail((max_side, max_side), Image.LANCZOS)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||||
|
"""Convert a PIL image to a CHW float tensor scaled to [0, 1]."""
|
||||||
|
arr = torch.from_numpy(np.array(img))
|
||||||
|
tensor = arr.permute(2, 0, 1).float() / 255.0
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Normalize tensor with ImageNet/vision mean/std."""
|
||||||
|
mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device)[..., None, None]
|
||||||
|
std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device)[..., None, None]
|
||||||
|
return (tensor - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(img: Image.Image, device: str = "cpu") -> torch.Tensor:
|
||||||
|
"""Resize, convert, normalize, and move to device."""
|
||||||
|
resized = resize_preserve_aspect(img)
|
||||||
|
tensor = image_to_tensor(resized).to(device)
|
||||||
|
norm = normalize_tensor(tensor)
|
||||||
|
return norm
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from src.preprocess import resize_preserve_aspect
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_preserve_aspect():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.new("RGB", (1024, 512), color="white")
|
||||||
|
resized = resize_preserve_aspect(img, max_side=256)
|
||||||
|
assert max(resized.size) == 256
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_writes_captions_to_db(monkeypatch):
|
||||||
|
"""Ensure pipeline calls write_captions with non-empty captions."""
|
||||||
|
|
||||||
|
# Stub out download to avoid network
|
||||||
|
async def fake_download_image(url, session, retries=2):
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.download.download_image", fake_download_image)
|
||||||
|
|
||||||
|
# Stub out preprocess to return a dummy tensor-like object (skip Pillow / torch entirely)
|
||||||
|
def fake_preprocess_image(img, device="cpu"):
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.preprocess.preprocess_image", fake_preprocess_image)
|
||||||
|
|
||||||
|
# Stub out model load + generate
|
||||||
|
def fake_load(self):
|
||||||
|
self.model = True
|
||||||
|
self.processor = True
|
||||||
|
|
||||||
|
def fake_generate_captions(self, images, prompt, **kwargs):
|
||||||
|
return [f"caption-for-{i}" for i in range(len(images))]
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.model.Florence2CaptionModel.load", fake_load)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"src.model.Florence2CaptionModel.generate_captions", fake_generate_captions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture writes
|
||||||
|
written = []
|
||||||
|
|
||||||
|
async def fake_write_captions(updates, batch_size=64):
|
||||||
|
written.extend(updates)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.db.write_captions", fake_write_captions)
|
||||||
|
|
||||||
|
# Run pipeline with a small dummy URL list
|
||||||
|
from src import pipeline
|
||||||
|
|
||||||
|
urls = [(1, "http://example.com/image1.jpg"), (2, "http://example.com/image2.jpg")]
|
||||||
|
asyncio.run(pipeline.run_pipeline(urls, dry_run=True))
|
||||||
|
|
||||||
|
assert len(written) == 2
|
||||||
|
assert all(isinstance(c, str) and len(c) > 0 for _, c in written)
|
||||||
Loading…
Reference in New Issue