-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvision.py
More file actions
93 lines (76 loc) · 2.67 KB
/
vision.py
File metadata and controls
93 lines (76 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import base64
import io
import json
import re
from dataclasses import dataclass
import httpx
import pillow_heif
from PIL import Image
pillow_heif.register_heif_opener()
@dataclass
class TagResult:
tags_en: list[str]
tags_de: list[str]
caption_en: str
caption_de: str
def _prepare_image(path: str, max_size: int) -> str:
"""Open any supported format, resize if needed, return base64-encoded JPEG."""
img = Image.open(path).convert("RGB")
w, h = img.size
if max(w, h) > max_size:
ratio = max_size / max(w, h)
img = img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
def _extract_json(text: str) -> dict:
"""Parse JSON from model output, stripping markdown fences if present."""
# Strip ```json ... ``` or ``` ... ``` fences
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
if match:
text = match.group(1)
else:
# Fall back to first { ... } block in the text
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
text = match.group(0)
return json.loads(text)
async def get_tags_and_caption(image_path: str, config) -> TagResult:
b64 = _prepare_image(image_path, config.vision.image_max_size)
data_uri = f"data:image/jpeg;base64,{b64}"
payload = {
"model": config.vision.model,
"max_tokens": config.vision.max_tokens,
"messages": [
{
"role": "system",
"content": config.prompts.system,
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_uri}},
{"type": "text", "text": config.prompts.user},
],
},
],
}
headers = {"Content-Type": "application/json"}
if config.vision.api_key:
headers["Authorization"] = f"Bearer {config.vision.api_key}"
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{config.vision.base_url.rstrip('/')}/chat/completions",
json=payload,
headers=headers,
)
if not response.is_success:
raise RuntimeError(f"HTTP {response.status_code}: {response.text}")
raw = response.json()["choices"][0]["message"]["content"]
result = _extract_json(raw)
return TagResult(
tags_en=result.get("tags_en", []),
tags_de=result.get("tags_de", []),
caption_en=result.get("caption_en", ""),
caption_de=result.get("caption_de", ""),
)