Skip to content

Commit aa16e4f

Browse files
ShutterSiftclaude
andcommitted
feat: richer analysis progress (ETA, per-photo decision), CPU warning, GPU priority fix
- Progress bar now shows filename + decision color, photo count, elapsed, and dynamic ETA (rolling-window speed estimate via Rich TimeRemainingColumn) - Yellow warning printed before analysis when no GPU detected, explaining MUSIQ→BRISQUE fallback and how to enable GPU - Post-analysis summary lists every photo grouped by Keep / Review / Reject with score and rejection reasons, plus total elapsed time and avg ms/photo - Fix GPU device priority: CUDA → Metal (MPS) → CPU (was: MPS could override CUDA) - capabilities.py gains gpu_device field ('cuda'/'mps'/'cpu'); summary() and info command display the specific backend Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7192084 commit aa16e4f

File tree

3 files changed

+117
-34
lines changed

3 files changed

+117
-34
lines changed

src/shuttersift/cli/main.py

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import typer
99
from rich.console import Console
10-
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
10+
from rich.progress import (
11+
Progress, SpinnerColumn, BarColumn, TextColumn,
12+
TimeElapsedColumn, TimeRemainingColumn, MofNCompleteColumn,
13+
)
1114
from rich.table import Table
1215
from rich.rule import Rule
1316

@@ -144,6 +147,14 @@ def _do_scan(
144147
console.print(f"\n[bold]ShutterSift[/] v{__version__}")
145148
console.print(f"Detected: {caps.summary()}\n")
146149

150+
if not caps.gpu:
151+
console.print(
152+
"[yellow]⚠ No GPU detected — running on CPU.[/] "
153+
"MUSIQ aesthetic scoring will use BRISQUE fallback and analysis will be slower.\n"
154+
"[yellow] To enable GPU:[/] install a CUDA-enabled torch (Windows/Linux) "
155+
"or ensure Metal is available (macOS).\n"
156+
)
157+
147158
# Auto-calibration: run on first use or when --recalibrate is passed
148159
if not cfg.calibrated or recalibrate:
149160
console.print("[1/3] Detecting capabilities... ✓")
@@ -158,20 +169,32 @@ def _do_scan(
158169

159170
engine = Engine(cfg)
160171

172+
_DECISION_STYLE = {"keep": "green", "review": "yellow", "reject": "red"}
173+
161174
with Progress(
162175
SpinnerColumn(),
163-
TextColumn("[progress.description]{task.description}"),
164-
BarColumn(),
165-
TextColumn("{task.completed}/{task.total}"),
176+
TextColumn("{task.description}"),
177+
BarColumn(bar_width=28),
178+
MofNCompleteColumn(),
179+
TextColumn("·"),
166180
TimeElapsedColumn(),
181+
TextColumn("· ETA"),
182+
TimeRemainingColumn(),
167183
console=console,
184+
expand=False,
168185
) as progress:
169-
task_id = progress.add_task("Analyzing...", total=None)
186+
task_id = progress.add_task("[dim]waiting…[/]", total=None)
170187

171188
def on_progress(current: int, total: int, result: PhotoResult) -> None:
172-
progress.update(task_id, completed=current, total=total,
173-
description=f"[cyan]{result.path.name}[/]")
174-
189+
style = _DECISION_STYLE.get(result.decision, "white")
190+
name = result.path.name
191+
if len(name) > 28:
192+
name = "…" + name[-27:]
193+
desc = f"[cyan]{name}[/] [[{style}]{result.decision}[/]]"
194+
progress.update(task_id, completed=current, total=total, description=desc)
195+
196+
import time as _time
197+
_t0 = _time.perf_counter()
175198
try:
176199
result: AnalysisResult = engine.analyze(
177200
input_dir=input_dir,
@@ -185,28 +208,77 @@ def on_progress(current: int, total: int, result: PhotoResult) -> None:
185208
logging.getLogger(__name__).exception("Engine error")
186209
console.print(f"[red]Error:[/] {exc or type(exc).__name__}")
187210
raise typer.Exit(1)
211+
_elapsed = _time.perf_counter() - _t0
212+
213+
_print_summary(result, output_dir, dry_run, _elapsed)
214+
215+
216+
_MAX_LIST_ROWS = 25 # max filenames shown per bucket before truncating
217+
188218

189-
_print_summary(result, output_dir, dry_run)
219+
def _fmt_elapsed(seconds: float) -> str:
220+
m, s = divmod(int(seconds), 60)
221+
h, m = divmod(m, 60)
222+
if h:
223+
return f"{h}h {m:02d}m {s:02d}s"
224+
if m:
225+
return f"{m}m {s:02d}s"
226+
return f"{s}s"
190227

191228

192-
def _print_summary(result: AnalysisResult, output_dir: Path, dry_run: bool) -> None:
229+
def _print_bucket(title: str, style: str, photos: list) -> None:
230+
if not photos:
231+
return
232+
console.print(f"\n[bold {style}]{title}[/] ({len(photos)} photos)")
233+
t = Table(box=None, show_header=False, padding=(0, 1))
234+
shown = photos[:_MAX_LIST_ROWS]
235+
for p in shown:
236+
reasons = ", ".join(p.reasons) if p.reasons else ""
237+
reason_text = f"[dim]— {reasons}[/]" if reasons else ""
238+
t.add_row(
239+
f" [cyan]{p.path.name}[/]",
240+
f"[{style}]{p.score:.0f}[/]",
241+
reason_text,
242+
)
243+
if len(photos) > _MAX_LIST_ROWS:
244+
t.add_row(f" [dim]… and {len(photos) - _MAX_LIST_ROWS} more[/]", "", "")
245+
console.print(t)
246+
247+
248+
def _print_summary(result: AnalysisResult, output_dir: Path, dry_run: bool, elapsed_s: float = 0.0) -> None:
193249
total = len(result.photos)
194250
if total == 0:
195251
console.print("[yellow]No photos found.[/]")
196252
return
197253

254+
# ── Timing ────────────────────────────────────────────────────────────────
255+
avg_ms = (elapsed_s * 1000 / total) if total else 0.0
256+
elapsed_str = _fmt_elapsed(elapsed_s)
257+
258+
# ── Counts table ──────────────────────────────────────────────────────────
198259
console.rule()
199-
table = Table(show_header=False, box=None, padding=(0, 2))
200-
table.add_row("[green]✓ Keep[/]", str(len(result.keep)), f"({len(result.keep)/total:.0%})")
201-
table.add_row("[yellow]◎ Review[/]", str(len(result.review)), f"({len(result.review)/total:.0%})")
202-
table.add_row("[red]✗ Reject[/]", str(len(result.reject)), f"({len(result.reject)/total:.0%})")
203-
console.print(table)
260+
tbl = Table(show_header=False, box=None, padding=(0, 2))
261+
tbl.add_row("[green]✓ Keep[/]", str(len(result.keep)), f"({len(result.keep)/total:.0%})")
262+
tbl.add_row("[yellow]◎ Review[/]", str(len(result.review)), f"({len(result.review)/total:.0%})")
263+
tbl.add_row("[red]✗ Reject[/]", str(len(result.reject)), f"({len(result.reject)/total:.0%})")
264+
tbl.add_row("", "", "")
265+
tbl.add_row("[dim]⏱ Time[/]", f"[dim]{elapsed_str}[/]",
266+
f"[dim](avg {avg_ms:.0f} ms/photo)[/]")
267+
console.print(tbl)
204268
console.rule()
269+
270+
# ── Per-bucket file listings ───────────────────────────────────────────────
271+
_print_bucket("✓ Keep", "green", result.keep)
272+
_print_bucket("◎ Review", "yellow", result.review)
273+
_print_bucket("✗ Reject", "red", result.reject)
274+
275+
# ── Paths ─────────────────────────────────────────────────────────────────
276+
console.print()
205277
if not dry_run:
206-
console.print(f"\nOutput → [bold]{output_dir}[/]")
278+
console.print(f"Output → [bold]{output_dir}[/]")
207279
console.print(f"Report → [bold]{output_dir / 'report.html'}[/]\n")
208280
else:
209-
console.print("\n[yellow]Dry run — no files written[/]\n")
281+
console.print("[yellow]Dry run — no files written[/]\n")
210282

211283

212284
# ── Default command callback (show help when no args given) ───────────────────
@@ -308,7 +380,8 @@ def info() -> None:
308380
table.add_column("Status")
309381
table.add_column("Details")
310382

311-
table.add_row("GPU", "[green]✓[/]" if caps.gpu else "[red]✗[/]", "CUDA or Apple Metal")
383+
_gpu_detail = {"cuda": "CUDA", "mps": "Apple Metal (MPS)", "cpu": "none"}.get(caps.gpu_device, caps.gpu_device)
384+
table.add_row("GPU", "[green]✓[/]" if caps.gpu else "[red]✗[/]", _gpu_detail)
312385
table.add_row("RAW decode", "[green]✓[/]" if caps.rawpy else "[yellow]~[/]", "rawpy" if caps.rawpy else "Using Pillow fallback")
313386
table.add_row("MUSIQ", "[green]✓[/]" if caps.musiq else "[yellow]~[/]", "GPU aesthetic scoring" if caps.musiq else "BRISQUE fallback")
314387
table.add_row("Local VLM", "[green]✓[/]" if caps.gguf_vlm else "[red]✗[/]",

src/shuttersift/engine/analyzers/aesthetic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ def _load(self) -> None:
2626
self._loaded = True
2727
if _PYIQA_AVAILABLE:
2828
try:
29-
device = "cuda" if self._use_gpu else "cpu"
30-
# Try MPS on Apple Silicon
31-
try:
32-
import torch
33-
if torch.backends.mps.is_available() and self._use_gpu:
34-
device = "mps"
35-
except Exception:
36-
pass
29+
# Priority: CUDA → Metal (MPS) → CPU
30+
device = "cpu"
31+
if self._use_gpu:
32+
try:
33+
import torch
34+
if torch.cuda.is_available():
35+
device = "cuda"
36+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
37+
device = "mps"
38+
except Exception:
39+
pass
3740
self._model = pyiqa.create_metric("musiq", device=device)
3841
self._backend = "musiq"
3942
logger.info("Aesthetic backend: MUSIQ (%s)", device)

src/shuttersift/engine/capabilities.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,37 @@ def _try_import(module: str) -> bool:
1414
return False
1515

1616

17-
def _has_gpu() -> bool:
17+
def _detect_gpu_device() -> str:
18+
"""Returns 'cuda', 'mps', or 'cpu' — in priority order."""
1819
try:
1920
import torch
20-
return torch.cuda.is_available() or (
21-
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
22-
)
21+
if torch.cuda.is_available():
22+
return "cuda"
23+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
24+
return "mps"
2325
except Exception:
24-
return False
26+
pass
27+
return "cpu"
2528

2629

2730
@dataclass
2831
class Capabilities:
2932
gpu: bool
33+
gpu_device: str # 'cuda', 'mps', or 'cpu'
3034
rawpy: bool
3135
musiq: bool
32-
gguf_vlm: bool # True when a local Moondream .mf model + moondream package present
36+
gguf_vlm: bool # True when a local Moondream .mf model + moondream package present
3337
gguf_model_path: Path | None
3438
api_vlm: bool
3539

3640
@classmethod
3741
def detect(cls) -> "Capabilities":
42+
gpu_device = _detect_gpu_device()
3843
# Moondream models use the .mf (Moondream Format) extension
3944
mf_models = list(MODELS_DIR.glob("*.mf")) if MODELS_DIR.exists() else []
4045
return cls(
41-
gpu=_has_gpu(),
46+
gpu=gpu_device != "cpu",
47+
gpu_device=gpu_device,
4248
rawpy=_try_import("rawpy"),
4349
musiq=_try_import("pyiqa"),
4450
gguf_vlm=bool(mf_models) and _try_import("moondream"),
@@ -52,8 +58,9 @@ def summary(self) -> str:
5258
def flag(val: bool, label: str) -> str:
5359
return f"{label} {'✓' if val else '✗'}"
5460

61+
gpu_label = f"GPU ({self.gpu_device.upper()})" if self.gpu else "GPU"
5562
parts = [
56-
flag(self.gpu, "GPU"),
63+
flag(self.gpu, gpu_label),
5764
flag(self.rawpy, "RAW"),
5865
flag(self.musiq, "MUSIQ"),
5966
flag(self.gguf_vlm, "Local VLM"),

0 commit comments

Comments
 (0)