NATSU-FlorenceCaptioner/tests/test_pipeline.py

58 lines
1.8 KiB
Python

import asyncio
from src.preprocess import resize_preserve_aspect
def test_resize_preserve_aspect():
from PIL import Image
img = Image.new("RGB", (1024, 512), color="white")
resized = resize_preserve_aspect(img, max_side=256)
assert max(resized.size) == 256
def test_pipeline_writes_captions_to_db(monkeypatch):
"""Ensure pipeline calls write_captions with non-empty captions."""
# Stub out download to avoid network
async def fake_download_image(url, session, retries=2):
return object()
monkeypatch.setattr("src.download.download_image", fake_download_image)
# Stub out preprocess to return a dummy tensor-like object (skip Pillow / torch entirely)
def fake_preprocess_image(img, device="cpu"):
return object()
monkeypatch.setattr("src.preprocess.preprocess_image", fake_preprocess_image)
# Stub out model load + generate
def fake_load(self):
self.model = True
self.processor = True
def fake_generate_captions(self, images, prompt, **kwargs):
return [f"caption-for-{i}" for i in range(len(images))]
monkeypatch.setattr("src.model.Florence2CaptionModel.load", fake_load)
monkeypatch.setattr(
"src.model.Florence2CaptionModel.generate_captions", fake_generate_captions
)
# Capture writes
written = []
async def fake_write_captions(updates, batch_size=64):
written.extend(updates)
monkeypatch.setattr("src.db.write_captions", fake_write_captions)
# Run pipeline with a small dummy URL list
from src import pipeline
urls = [(1, "http://example.com/image1.jpg"), (2, "http://example.com/image2.jpg")]
asyncio.run(pipeline.run_pipeline(urls, dry_run=True))
assert len(written) == 2
assert all(isinstance(c, str) and len(c) > 0 for _, c in written)