-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtag_existing.py
More file actions
164 lines (133 loc) · 5.7 KB
/
tag_existing.py
File metadata and controls
164 lines (133 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
"""
Script to add tags to existing photos using stored CLIP embeddings.
This is much faster than rescanning as it doesn't require GPU inference.
"""
import logging
import sqlite3
import argparse
from pathlib import Path
logger = logging.getLogger("facet.tag_existing")
from models.tagger import CLIPTagger
from utils import tags_to_string, get_tag_params
from db.tags import migrate_tags_to_lookup
def tag_untagged_photos(db_path, tagger, threshold=0.22, max_tags=5, verbose=False, force=False):
"""
Tag photos that have CLIP embeddings but no tags.
Args:
db_path: Path to the SQLite database
tagger: Initialized CLIPTagger instance
threshold: Minimum similarity threshold for tags
max_tags: Maximum tags per image
verbose: Print each tagged photo
force: Re-tag all photos, not just untagged ones
Returns:
Number of photos tagged
"""
with sqlite3.connect(db_path) as conn:
conn.row_factory = sqlite3.Row
if force:
rows = conn.execute('''
SELECT path, filename, clip_embedding
FROM photos
WHERE clip_embedding IS NOT NULL
''').fetchall()
else:
rows = conn.execute('''
SELECT path, filename, clip_embedding
FROM photos
WHERE clip_embedding IS NOT NULL
AND (tags IS NULL OR tags = '')
''').fetchall()
if not rows:
return 0
tagged_count = 0
for row in rows:
tag_list = tagger.get_tags_from_embedding(
row['clip_embedding'], threshold=threshold, max_tags=max_tags
)
if tag_list:
tags_str = tags_to_string(tag_list)
conn.execute('UPDATE photos SET tags = ? WHERE path = ?', (tags_str, row['path']))
tagged_count += 1
if verbose:
logger.info(" %s: %s", row['filename'], tags_str)
conn.commit()
return tagged_count
def run_tagging(db_path, tagger, config):
"""
Run tagging using settings from config.
Args:
db_path: Path to SQLite database
tagger: Initialized CLIPTagger instance
config: ScoringConfig instance
Returns:
Number of photos tagged, or None if disabled
"""
tag_settings = config.get_tagging_settings()
if not tag_settings.get('enabled', True):
return None
threshold, max_tags = get_tag_params(config)
return tag_untagged_photos(db_path, tagger, threshold, max_tags)
def main():
parser = argparse.ArgumentParser(description='Tag existing photos using stored CLIP embeddings')
parser.add_argument('--db', default='photo_scores_pro.db', help='Database path')
parser.add_argument('--config', default='scoring_config.json', help='Config file path')
parser.add_argument('--threshold', type=float, default=None, help='Override similarity threshold')
parser.add_argument('--max-tags', type=int, default=None, help='Override max tags per image')
parser.add_argument('--dry-run', action='store_true', help='Show what would be tagged without saving')
parser.add_argument('--force', action='store_true', help='Re-tag all photos, not just untagged ones')
args = parser.parse_args()
# Load config
from config import ScoringConfig
config = ScoringConfig(args.config)
clip_settings = config.get_clip_settings()
tag_settings = config.get_tagging_settings()
threshold = args.threshold or clip_settings.get('similarity_threshold_percent', 22) / 100
max_tags = args.max_tags or tag_settings.get('max_tags', 5)
logger.info("Loading CLIP model for tagging...")
logger.info(" Threshold: %s", threshold)
logger.info(" Max tags: %s", max_tags)
# Load CLIP model and tagger
import torch
clip_cfg = config.get_clip_config()
clip_model_name = clip_cfg.get('model_name', 'ViT-L-14')
clip_backend = clip_cfg.get('backend', 'open_clip')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if clip_backend == 'transformers':
from transformers import AutoModel
model = AutoModel.from_pretrained(clip_model_name, trust_remote_code=True)
model = model.to(device).eval()
else:
import open_clip
clip_pretrained = clip_cfg.get('pretrained', 'laion2b_s32b_b82k')
model, _, _ = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_pretrained)
model = model.to(device).eval()
tagger = CLIPTagger(model, device, config=config, model_name=clip_model_name,
backend=clip_backend)
logger.info("Tagger initialized with %d tag categories", len(tagger.tag_vocabulary))
# Count photos to tag
conn = sqlite3.connect(args.db)
if args.force:
count = conn.execute('''
SELECT COUNT(*) FROM photos WHERE clip_embedding IS NOT NULL
''').fetchone()[0]
else:
count = conn.execute('''
SELECT COUNT(*) FROM photos
WHERE clip_embedding IS NOT NULL AND (tags IS NULL OR tags = '')
''').fetchone()[0]
conn.close()
mode = "[FORCE] Re-tagging all" if args.force else "Found"
logger.info("%s %d photos to tag", mode, count)
if args.dry_run:
logger.info("[DRY RUN] Would tag up to %d photos", count)
else:
tagged = tag_untagged_photos(args.db, tagger, threshold, max_tags, verbose=True, force=args.force)
logger.info("Tagged %d photos", tagged)
# Update photo_tags lookup table for fast queries
if tagged > 0:
logger.info("Updating photo_tags lookup table...")
migrate_tags_to_lookup(args.db)
if __name__ == '__main__':
main()