From a3b3d5f38178200942d1273a98fea7e29077dbd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Cerveau?= Date: Thu, 19 Feb 2026 11:33:36 +0100 Subject: [PATCH] feat: deduplicate shared URL downloads across test suites Introduce a centralized DownloadManager so each URL is downloaded at most once, both within and across selected suites. Saves re-fetching multi-GB archives like AV1-ARGON shared by 12 suites. DownloadManager (fluster/utils.py): - Thread-safe per-URL caching at resources/.cache/; concurrent get() calls on the same URL block on the in-flight download. - BoundedSemaphore caps HTTP concurrency at 8. - Per-URL retry budget; ChecksumMismatchError poisons immediately. - invalidate(url) lets consumers drop a corrupt cached archive. - Context manager: cleanup() runs via __exit__, honoring keep_file. - filename_from_url() strips query strings for safe on-disk names. TestSuite.download() (fluster/test_suite.py): - Requires a DownloadManager (keyword-only). All three download paths consume pre-downloaded archives. - Multi-TV branch pre-downloads unique URLs in parallel before the multiprocessing extraction pool. - Raw source files are moved out of the cache (no double storage). CLI (fluster/fluster.py): - Three-phase: collect URLs across selected suites, parallel pre-download, per-suite extraction. Cross-suite parallelism is the main user-visible win. - All callers (CLI + 7 scripts/gen_*.py) use the with-statement form. --- README.md | 10 ++ fluster/fluster.py | 99 +++++++++++- fluster/main.py | 4 +- fluster/test_suite.py | 306 +++++++++++++++++++++++++----------- fluster/utils.py | 286 +++++++++++++++++++++++++++++++-- scripts/gen_aac.py | 8 +- scripts/gen_av1_aom.py | 8 +- scripts/gen_av1_chromium.py | 8 +- scripts/gen_jct_vc.py | 8 +- scripts/gen_jvet.py | 8 +- scripts/gen_jvt.py | 8 +- scripts/gen_mpeg2_video.py | 8 +- scripts/gen_mpeg4_video.py | 8 +- 13 files changed, 602 insertions(+), 167 deletions(-) diff --git a/README.md b/README.md index 6f04989c..893da8b6 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,16 @@ For complete setup, usage examples, hardware acceleration configuration, and tro files. You can change the number of parallel processes used with `-j`. It defaults to 2x number of logical cores. + Downloaded archives are cached in `resources/.cache/` during the run, + deduplicated across test suites, and downloaded in parallel. HTTP concurrency is + capped at 8 by default to avoid hammering remote mirrors. With `--keep`, + the cache is preserved after the run. + + Note: running multiple `fluster.py download` instances against the same + `resources/` directory concurrently is not supported — one instance's + cleanup may remove files another is still extracting. Use `--keep` or + point each instance at a separate resource dir if you need parallel runs. + Use the `-c/--codec` option to download test suites for specific codecs: - `./fluster.py download -c H.264,H.265` downloads all H.264 and H.265 test suites - `./fluster.py download AV1-TEST-VECTORS VP9-TEST-VECTORS` downloads the specific AV1-TEST-VECTORS and VP9-TEST-VECTORS test suite diff --git a/fluster/fluster.py b/fluster/fluster.py index 216c825e..06329a3b 100644 --- a/fluster/fluster.py +++ b/fluster/fluster.py @@ -21,11 +21,13 @@ import os import os.path import sys +from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum from functools import lru_cache from shutil import rmtree from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from fluster import utils from fluster.codec import Codec, Profile from fluster.decoder import DECODERS, Decoder @@ -883,11 +885,92 @@ def download_test_suites( download_test_suites = self.test_suites print(f"Test suites: {[ts.name for ts in download_test_suites]}") - for test_suite in download_test_suites: - test_suite.download( - jobs, - self.resources_dir, - verify=True, - keep_file=keep_file, - retries=retries, - ) + if not download_test_suites: + print("No test suites to download.") + return + + cache_dir = os.path.join(self.resources_dir, ".cache") + with utils.DownloadManager(cache_dir=cache_dir, verify=True, keep_file=keep_file, retries=retries) as manager: + # Phase 1: collect every (url, checksum) across all selected suites, + # deduplicated. Different suites can share a URL (e.g. AV1-ARGON + # archive). + url_checksums: Dict[str, str] = {} + checksum_conflicts: List[Tuple[str, str, str]] = [] + for ts in download_test_suites: + for tv in ts.test_vectors.values(): + existing = url_checksums.get(tv.source) + if existing is None or existing == "__skip__": + # Prefer a real checksum over an unset/__skip__ one. + url_checksums[tv.source] = tv.source_checksum + elif tv.source_checksum not in (existing, "__skip__"): + checksum_conflicts.append((tv.source, existing, tv.source_checksum)) + if checksum_conflicts: + for src, kept, other in checksum_conflicts: + print( + f"ERROR: conflicting checksums for {src}: " + f"{kept} vs {other} — the test-suite definitions disagree." + ) + sys.exit( + f"{len(checksum_conflicts)} URL(s) have conflicting checksums across " + f"selected suites; refusing to download (fix the test-suite JSON)." + ) + + # Phase 2: parallel pre-download. The manager's BoundedSemaphore + # caps actual HTTP concurrency; this pool just feeds it work. + if url_checksums: + max_workers = max(1, min(jobs, len(url_checksums), utils.MAX_PREDOWNLOAD_POOL_WORKERS)) + print( + f"Pre-downloading {len(url_checksums)} unique source(s) across " + f"{len(download_test_suites)} suite(s) using {max_workers} workers" + ) + pre_errors: List[Tuple[str, Exception]] = [] + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(manager.get, url, checksum): url for url, checksum in url_checksums.items()} + for fut in as_completed(futures): + url = futures[fut] + try: + fut.result() + except Exception as exc: # noqa: BLE001 - report all + pre_errors.append((url, exc)) + if pre_errors: + failed_urls = {err_url for err_url, _ in pre_errors} + for err_url, err_exc in pre_errors: + print(f"Error pre-downloading {err_url}: {type(err_exc).__name__}: {err_exc}") + print(f"{len(pre_errors)} URL(s) failed to pre-download — skipping affected suites.") + else: + failed_urls = set() + else: + failed_urls = set() + + # Phase 3: extract per suite. Cache is now warm — TestSuite.download + # hits manager.get() which short-circuits to the cached path. + # Suites whose URLs intersect with failed_urls are skipped so the + # rest of the batch still extracts. + skipped_suites: List[str] = [] + failed_extractions: List[str] = [] + for test_suite in download_test_suites: + suite_urls = {tv.source for tv in test_suite.test_vectors.values()} + if suite_urls & failed_urls: + skipped_suites.append(test_suite.name) + continue + try: + test_suite.download( + jobs, + self.resources_dir, + download_manager=manager, + ) + except utils.BadArchiveError as exc: + # The cache entry was already invalidated inside download(); + # report cleanly and keep going so the rest of the batch + # still extracts. + print(f"\n{test_suite.name}: {exc}") + failed_extractions.append(test_suite.name) + if skipped_suites: + print( + f"\nSkipped {len(skipped_suites)} suite(s) due to pre-download failures: " + f"{skipped_suites}" + ) + if failed_extractions: + print(f"Corrupt archive(s) invalidated for: {failed_extractions} (re-run to retry)") + if skipped_suites or failed_extractions: + sys.exit(1) diff --git a/fluster/main.py b/fluster/main.py index adefd47e..88ce71a8 100644 --- a/fluster/main.py +++ b/fluster/main.py @@ -335,8 +335,8 @@ def _add_download_cmd(self, subparsers: Any) -> None: subparser.add_argument( "-k", "--keep", - help="keep original downloaded file after extracting. Only applicable to compressed " - "files such as .zip, .tar.gz, etc", + help="keep original downloaded file after extracting. Archives are stored in " + "/.cache/. Only applicable to compressed files (.zip, .tar.gz, etc)", action="store_true", ) subparser.add_argument( diff --git a/fluster/test_suite.py b/fluster/test_suite.py index bf8f370b..c2d9587d 100644 --- a/fluster/test_suite.py +++ b/fluster/test_suite.py @@ -19,14 +19,16 @@ import fnmatch import json import os.path +import subprocess import sys import zipfile +from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum from functools import lru_cache from multiprocessing import Pool -from shutil import rmtree +from shutil import move, rmtree from time import perf_counter -from typing import Any, Dict, List, Optional, Set, Type, cast +from typing import Any, Dict, List, Optional, Set, Tuple, Type, cast from unittest.result import TestResult from fluster import utils @@ -36,29 +38,40 @@ from fluster.test_vector import TestVector, TestVectorResult +class _CorruptCacheError(Exception): + """Raised by an extract worker when the cached archive is unusable. + + The worker runs in a multiprocessing.Pool subprocess and cannot mutate + the parent's DownloadManager directly; it raises this instead so the + parent can invalidate the URL and re-download on the next run. + """ + + def __init__(self, source_url: str, original: Exception): + # Pass both args to Exception so __reduce__ pickles them and the + # exception round-trips through multiprocessing.Pool's result queue. + super().__init__(source_url, original) + self.source_url = source_url + self.original = original + + def __str__(self) -> str: + return f"corrupt cache for {self.source_url}: {self.original}" + + class DownloadWork: """Context to pass to download worker""" def __init__( self, out_dir: str, - verify: bool, extract_all: bool, - keep_file: bool, test_suite_name: str, - retries: int, + archive_path: str, + test_vector: Optional[TestVector] = None, ): self.out_dir = out_dir - self.verify = verify self.extract_all = extract_all - self.keep_file = keep_file self.test_suite_name = test_suite_name - self.retries = retries - - # This is added to avoid having to create an extra ancestor class - def set_test_vector(self, test_vector: TestVector) -> None: - """Setter function for member variable test vector""" - + self.archive_path = archive_path self.test_vector = test_vector @@ -68,15 +81,17 @@ class DownloadWorkSingleArchive(DownloadWork): def __init__( self, out_dir: str, - verify: bool, extract_all: bool, - keep_file: bool, test_suite_name: str, test_vectors: Dict[str, TestVector], - retries: int, + archive_path: str, + source_url: str, + download_manager: "utils.DownloadManager", ): - super().__init__(out_dir, verify, extract_all, keep_file, test_suite_name, retries) + super().__init__(out_dir, extract_all, test_suite_name, archive_path) self.test_vectors = test_vectors + self.source_url = source_url + self.download_manager = download_manager class Context: @@ -209,95 +224,100 @@ def to_json_file(self, filename: str) -> None: @staticmethod def _download_single_test_vector(ctx: DownloadWork) -> None: - """Download and extract a single test vector""" + """Extract a single test vector from a pre-downloaded archive. + + The DownloadManager always provides ctx.archive_path for extractable + sources; non-extractable single-file sources skip this worker entirely + (see TestSuite.download). + + Concurrency note: in the multi-TV branch this runs inside a + multiprocessing.Pool subprocess. Subprocesses get a fork-time copy of + the parent's DownloadManager, so any state mutation here (cache + bookkeeping, invalidation, etc.) does NOT propagate back to the + parent. Communicate failures by raising — the parent handles them via + the Pool's error_callback (e.g. _CorruptCacheError → manager + invalidation in TestSuite.download).""" + if ctx.test_vector is None: + raise ValueError("per-TV worker requires a test_vector") + if not ctx.archive_path: + raise ValueError("DownloadManager must provide archive_path") dest_dir = os.path.join(ctx.out_dir, ctx.test_suite_name, ctx.test_vector.name) - dest_path = os.path.join(dest_dir, os.path.basename(ctx.test_vector.source)) os.makedirs(dest_dir, exist_ok=True) - if ( - ctx.verify - and os.path.exists(dest_path) - and ctx.test_vector.source_checksum == utils.file_checksum(dest_path) - ): - # Remove file only in case the input file was extractable. - # Otherwise, we'd be removing the original file we want to work - # with every even time we execute the download subcommand. - if utils.is_extractable(dest_path) and not ctx.keep_file: - os.remove(dest_path) - return - - print(f"\tDownloading test vector {ctx.test_vector.name} from {ctx.test_vector.source}") - utils.download(ctx.test_vector.source, dest_dir, ctx.retries**ctx.retries) - - if ctx.test_vector.source_checksum != "__skip__": - checksum = utils.file_checksum(dest_path) - if ctx.test_vector.source_checksum != checksum: - raise Exception( - f"Checksum mismatch for {ctx.test_vector.name}: {checksum} instead of " - f"{ctx.test_vector.source_checksum}" - ) - - if utils.is_extractable(dest_path): + if utils.is_extractable(ctx.archive_path): + # Skip extraction if the target file is already on disk from a + # previous run. Trusts presence as proof of content; users can + # force re-extraction by removing the file (or the suite dir). + extracted_path = os.path.join(dest_dir, ctx.test_vector.input_file) + if not ctx.extract_all and os.path.exists(extracted_path): + print(f"\tSkipping extraction of {ctx.test_vector.name} (already extracted)") + return print(f"\tExtracting test vector {ctx.test_vector.name} to {dest_dir}") - utils.extract(dest_path, dest_dir, file=ctx.test_vector.input_file if not ctx.extract_all else None) - if not ctx.keep_file: - os.remove(dest_path) + try: + utils.extract( + ctx.archive_path, + dest_dir, + file=ctx.test_vector.input_file if not ctx.extract_all else None, + ) + except (zipfile.BadZipFile, subprocess.CalledProcessError, OSError) as exc: + # Worker runs in a multiprocessing.Pool subprocess (or here in + # the main thread for the single-TV branch); raise so the + # parent can call manager.invalidate(source_url). + raise _CorruptCacheError(ctx.test_vector.source, exc) from exc + else: + # Raw (non-extractable) source file: move from the manager's cache + # into the suite dir so it's stored only once. Non-extractable + # files aren't shared across suites, so there's no dedup value in + # keeping the cache copy. + dest_path = os.path.join(dest_dir, os.path.basename(ctx.archive_path)) + if os.path.exists(dest_path): + print(f"\tSkipping placement of {ctx.test_vector.name} (already exists)") + else: + print(f"\tPlacing test vector {ctx.test_vector.name} at {dest_path}") + move(ctx.archive_path, dest_path) @staticmethod def _download_single_archive(ctx: DownloadWorkSingleArchive) -> None: - """Download a single archive containing many test vectors and extract them""" + """Extract many test vectors from a pre-downloaded single archive. + + The DownloadManager always provides ctx.archive_path.""" first_tv = ctx.test_vectors[next(iter(ctx.test_vectors))] dest_dir = os.path.join(ctx.out_dir, ctx.test_suite_name) - dest_path = os.path.join(dest_dir, os.path.basename(first_tv.source)) os.makedirs(dest_dir, exist_ok=True) - - # Clean up existing corrupt source file - if ( - ctx.verify - and os.path.exists(dest_path) - and utils.is_extractable(dest_path) - and first_tv.source_checksum != utils.file_checksum(dest_path) - ): - os.remove(dest_path) - - print(f"\tDownloading source file from {first_tv.source}") - utils.download(first_tv.source, dest_dir, ctx.retries**ctx.retries) - - # Check that source file was downloaded correctly - if first_tv.source_checksum != "__skip__": - checksum = utils.file_checksum(dest_path) - if first_tv.source_checksum != checksum: - raise Exception( - f"Checksum mismatch for source file {os.path.basename(first_tv.source)}: {checksum} " - f"instead of '{first_tv.source_checksum}'" - ) + if not ctx.archive_path: + raise ValueError("DownloadManager must provide archive_path") try: - with zipfile.ZipFile(dest_path, "r") as zip_file: - print(f"\tExtracting test vectors from {os.path.basename(first_tv.source)}") + with zipfile.ZipFile(ctx.archive_path, "r") as zip_file: + print(f"\tExtracting test vectors from {utils.filename_from_url(first_tv.source)}") + namelist = zip_file.namelist() for tv in ctx.test_vectors.values(): - if tv.input_file in zip_file.namelist(): + # Skip extraction if the target file is already on disk + # from a previous run. Trusts presence as proof of content. + if os.path.exists(os.path.join(dest_dir, tv.input_file)): + continue + if tv.input_file in namelist: zip_file.extract(tv.input_file, dest_dir) else: print( - f"WARNING: test vector {tv.input_file} not found inside {os.path.basename(first_tv.source)}" + f"WARNING: test vector {tv.input_file} not found inside " + f"{utils.filename_from_url(first_tv.source)}" ) except zipfile.BadZipFile as bad_zip_error: - os.remove(dest_path) - raise Exception(f"{dest_path} could not be opened as zip file. File was deleted") from bad_zip_error - - # Remove source file, if applicable - if not ctx.keep_file: - os.remove(dest_path) + # Corrupt archive: ask the DownloadManager to invalidate its + # cache entry so the next run re-downloads from scratch. + ctx.download_manager.invalidate(ctx.source_url) + raise utils.BadArchiveError( + f"{ctx.archive_path} could not be opened as zip file (invalidated)" + ) from bad_zip_error def download( self, jobs: int, out_dir: str, - verify: bool, + *, + download_manager: utils.DownloadManager, extract_all: bool = False, - keep_file: bool = False, - retries: int = 2, ) -> None: """Download the test suite""" os.makedirs(out_dir, exist_ok=True) @@ -306,37 +326,126 @@ def download( if ( len(unique_sources) == 1 and len(self.test_vectors) > 1 - and utils.is_extractable(os.path.basename(next(iter(unique_sources)))) + and utils.is_extractable(utils.filename_from_url(next(iter(unique_sources)))) ): # Download test suite of multiple test vectors from a single archive print(f"Downloading test suite {self.name} using 1 job (single archive)") + first_tv = next(iter(self.test_vectors.values())) + shared_archive_path = download_manager.get(first_tv.source, first_tv.source_checksum) dwork_single = DownloadWorkSingleArchive( - out_dir, verify, extract_all, keep_file, self.name, self.test_vectors, retries + out_dir, + extract_all, + self.name, + self.test_vectors, + shared_archive_path, + first_tv.source, + download_manager, ) self._download_single_archive(dwork_single) elif len(unique_sources) == 1 and len(self.test_vectors) == 1: - # Download test suite of single test vector + # Download test suite of single test vector (extractable or raw). + # The worker handles both cases: extract from archive, or copy + # the raw file from the cache into the suite dir. print(f"Downloading test suite {self.name} using 1 job (single file)") single_tv = next(iter(self.test_vectors.values())) - dwork = DownloadWork(out_dir, verify, extract_all, keep_file, self.name, retries) - dwork.set_test_vector(single_tv) - self._download_single_test_vector(dwork) + single_tv_archive_path = download_manager.get(single_tv.source, single_tv.source_checksum) + dwork = DownloadWork( + out_dir, + extract_all, + self.name, + single_tv_archive_path, + single_tv, + ) + try: + self._download_single_test_vector(dwork) + except _CorruptCacheError as exc: + download_manager.invalidate(exc.source_url) + raise utils.BadArchiveError( + f"corrupt cache for {exc.source_url} (invalidated, re-run to retry)" + ) from exc.original + # The worker move()s non-extractable raw sources out of the cache + # into the suite dir. Tell the manager so cleanup() doesn't later + # try to remove a path it no longer owns. + if not utils.is_extractable(single_tv_archive_path): + download_manager.release(single_tv.source) else: - # Download test suite of multiple test vectors + # Download test suite of multiple test vectors. + # Pre-download all unique source URLs in parallel (deduplicating + # via the thread-safe manager), then dispatch parallel workers that + # only extract from the pre-downloaded archives. + source_paths: Dict[str, str] = {} + unique_source_list = list(unique_sources) + # Map each unique URL to a representative checksum once (O(N)), + # so the per-URL pre-download lookup is O(1) instead of O(N). + url_checksum: Dict[str, str] = {} + for tv in self.test_vectors.values(): + url_checksum.setdefault(tv.source, tv.source_checksum) + + # Fast path: if every URL is already cached (e.g. cross-suite + # phase-2 in fluster.py pre-downloaded everything), skip the + # thread pool and short-circuit to the cached paths. + cached_paths = {url: download_manager.cached_path(url) for url in unique_source_list} + if all(p is not None for p in cached_paths.values()): + source_paths = {url: p for url, p in cached_paths.items() if p is not None} + else: + + def _pre_download(url: str) -> Tuple[str, str]: + # get() owns the entire per-URL retry budget: it counts + # attempts and raises a terminal RuntimeError once the budget + # is exhausted (or immediately for non-retryable errors like + # ChecksumMismatchError). Calling it once is enough — an outer + # retry loop here would re-drive those attempts and busy-spin + # on fast-failing errors. + return (url, download_manager.get(url, url_checksum[url])) + + max_workers = max(1, min(jobs, len(unique_source_list), utils.MAX_PREDOWNLOAD_POOL_WORKERS)) + persistent_errors: List[Tuple[str, Exception]] = [] + with ThreadPoolExecutor(max_workers=max_workers) as dl_pool: + futures = {dl_pool.submit(_pre_download, url): url for url in unique_source_list} + for future in as_completed(futures): + url_in_flight = futures[future] + try: + url, local_path = future.result() + source_paths[url] = local_path + except Exception as exc: + persistent_errors.append((url_in_flight, exc)) + if persistent_errors: + for err_url, err_exc in persistent_errors: + print(f"Error pre-downloading {err_url}: {err_exc}") + raise RuntimeError(f"{len(persistent_errors)} URL(s) failed pre-download for suite {self.name}") + print(f"Downloading test suite {self.name} using {jobs} parallel jobs") error_occurred = False + # Defer manager.invalidate() out of the Pool's result-handler + # thread; do the actual disk work after pool.join() to avoid + # serializing all failures behind a single lock + os.remove. + corrupted_urls: List[str] = [] with Pool(jobs) as pool: def _callback_error(err: Any) -> None: nonlocal error_occurred error_occurred = True - print(f"\nError downloading -> {err}\n") + if isinstance(err, _CorruptCacheError): + corrupted_urls.append(err.source_url) + print( + f"\nCorrupt cached archive {err.source_url} " + f"(will invalidate after job drain). " + f"({err.original})\n" + ) + else: + print(f"\nError downloading -> {err}\n") pool.terminate() downloads = [] for tv in self.test_vectors.values(): - dwork = DownloadWork(out_dir, verify, extract_all, keep_file, self.name, retries) - dwork.set_test_vector(tv) + archive_path = source_paths[tv.source] + dwork = DownloadWork( + out_dir, + extract_all, + self.name, + archive_path, + tv, + ) downloads.append( pool.apply_async( self._download_single_test_vector, @@ -348,6 +457,9 @@ def _callback_error(err: Any) -> None: pool.close() pool.join() + for corrupt_url in corrupted_urls: + download_manager.invalidate(corrupt_url) + if error_occurred: sys.exit("Some download failed") else: @@ -357,6 +469,16 @@ def _callback_error(err: Any) -> None: print("All downloads finished") + def download_with_default_manager(self, jobs: int, *, extract_all: bool = False) -> None: + """Download via a fresh DownloadManager configured for generator scripts. + + Verify off, keep archives, default retries — the common settings shared + by the scripts/gen_*.py. Convenience wrapper so they don't each repeat + the manager boilerplate.""" + cache_dir = os.path.join(self.resources_dir, ".cache") + with utils.DownloadManager(cache_dir=cache_dir, verify=False, keep_file=True, retries=2) as manager: + self.download(jobs, self.resources_dir, extract_all=extract_all, download_manager=manager) + @staticmethod def _rename_test(test: Test, module: str, qualname: str) -> None: test_cls = type(test) diff --git a/fluster/utils.py b/fluster/utils.py index 44b9206f..383530ce 100644 --- a/fluster/utils.py +++ b/fluster/utils.py @@ -28,6 +28,7 @@ import shutil import subprocess import sys +import threading import time import urllib.error import urllib.parse @@ -35,12 +36,55 @@ import wave import zipfile from functools import partial -from threading import Lock -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type TARBALL_EXTS = ("tar.gz", "tgz", "tar.bz2", "tbz2", "tar.xz") -download_lock = Lock() +# Pre-download thread-pool ceiling. HTTP concurrency is capped lower by +# DownloadManager's BoundedSemaphore (default 8); the surplus (2x) lets a +# waiting thread grab a freed slot without spinning up a new worker. +MAX_PREDOWNLOAD_POOL_WORKERS = 16 + +# Serialize concurrent progress-bar prints so lines don't garble under the +# ThreadPoolExecutor pre-download. +_print_lock = threading.Lock() + + +def _locked_print(msg: str) -> None: + """Print *msg* under the shared print lock.""" + with _print_lock: + print(msg) + + +def filename_from_url(url: str) -> str: + """Return a safe filename from *url*, stripping query string and fragment. + + Plain os.path.basename on a URL keeps the query string (e.g. a signed + GCS/S3 URL ending in "?X-Amz-Signature=..."), which then wrecks suffix + checks like is_extractable() and leaves odd filenames on disk. + + Raises ValueError if the URL has no usable filename component (e.g. + "https://host/" or "https://host"). Catching this here surfaces a + clear error instead of an opaque IsADirectoryError further downstream. + """ + filename = os.path.basename(urllib.parse.urlsplit(url).path) + if not filename: + raise ValueError(f"URL {url!r} has no filename component") + return filename + + +class ChecksumMismatchError(Exception): + """Downloaded file's checksum does not match the expected value.""" + + +class BadArchiveError(Exception): + """Downloaded archive is corrupt or unreadable.""" + + +# Errors that make further retries pointless: the remote content genuinely +# differs from the expected checksum, so retrying just re-downloads the same +# wrong file. +_NON_RETRYABLE_DOWNLOAD_ERRORS = (ChecksumMismatchError,) def create_enhanced_opener() -> urllib.request.OpenerDirector: @@ -119,7 +163,7 @@ def _update_progress_bar( else: progress_bar = f"{size_info} | {_format_bytes(int(rate))}/s" - print(f"\t{filename:<40} {progress_bar}") + _locked_print(f"\t{filename:<40} {progress_bar}") return current_time return last_update_time @@ -159,15 +203,17 @@ def download( timeout: int = 300, chunk_size: int = 2048 * 2048, # 4MB ) -> None: - """Downloads a file to a directory with a mutex lock - to avoid conflicts and retries with exponential backoff.""" + """Downloads a file to a directory with retries and full-jitter backoff. + + Between attempts it sleeps a uniform random delay in [1, 2**attempt) + seconds, so the backoff window grows exponentially while the actual + wait is randomized (AWS-style "full jitter").""" os.makedirs(dest_dir, exist_ok=True) - filename = os.path.basename(url) + filename = filename_from_url(url) dest_path = os.path.join(dest_dir, filename) for attempt in range(max_retries): try: - with download_lock: - _download_simple(url, dest_path, filename, timeout, chunk_size) + _download_simple(url, dest_path, filename, timeout, chunk_size) break except ( urllib.error.URLError, @@ -188,6 +234,228 @@ def download( raise RuntimeError(f"Failed to download {url} after {max_retries} attempts: {e}") from e +class DownloadManager: + """Centralized download manager that ensures each URL is downloaded at most once. + + Thread-safe: multiple threads may call get() concurrently. If the same URL + is requested by multiple threads, only one performs the download while + the others wait for the result. Archives managed by this class are cleaned + up via cleanup() unless keep_file is set. + """ + + def __init__( + self, + cache_dir: str, + verify: bool, + keep_file: bool, + retries: int, + max_concurrent_downloads: int = 8, + ): + self._cache_dir = cache_dir + self._cache: Dict[str, str] = {} + self._verify = verify + self._keep_file = keep_file + self._retries = retries + self._managed_files: List[str] = [] + self._lock = threading.Lock() + self._in_progress: Dict[str, threading.Event] = {} + self._errors: Dict[str, Exception] = {} + self._attempts: Dict[str, int] = {} + # Cap concurrent HTTP downloads across all callers. Browsers typically + # use 6-8 connections per host; keep that order of magnitude regardless + # of how many extraction workers the CLI spins up. + self._download_slots = threading.BoundedSemaphore(max(1, max_concurrent_downloads)) + + def get(self, url: str, source_checksum: str) -> str: + """Download *url* once into the session cache dir and return the local path. + + Thread-safe. Concurrent calls for the same URL will block until + the first caller finishes, then reuse the cached result. + """ + while True: + with self._lock: + # Poison the URL only after exceeding the per-URL retry budget. + if url in self._errors and self._attempts.get(url, 0) >= self._retries: + raise RuntimeError( + f"Download of {url} failed after {self._attempts[url]} attempts: {self._errors[url]}" + ) from self._errors[url] + + if url in self._cache and os.path.exists(self._cache[url]): + _locked_print(f"\tReusing cached download for {filename_from_url(url)}") + return self._cache[url] + + if url in self._in_progress: + event = self._in_progress[url] + else: + event = threading.Event() + self._in_progress[url] = event + # Clear any previous error so this new attempt can run. + self._errors.pop(url, None) + break + + event.wait() + + with self._lock: + # A concurrent attempt completed. If it failed and the budget + # is exhausted, poison; otherwise loop and try again ourselves. + if url in self._errors and self._attempts.get(url, 0) >= self._retries: + raise RuntimeError( + f"Download of {url} failed after {self._attempts[url]} attempts: {self._errors[url]}" + ) from self._errors[url] + + done_event: Optional[threading.Event] = None + try: + result, downloaded_now = self._do_download(url, source_checksum) + with self._lock: + self._cache[url] = result + self._errors.pop(url, None) + self._attempts.pop(url, None) + if downloaded_now: + self._managed_files.append(result) + done_event = self._in_progress.pop(url, None) + return result + except Exception as exc: + with self._lock: + self._errors[url] = exc + if isinstance(exc, _NON_RETRYABLE_DOWNLOAD_ERRORS): + # Poison immediately — retrying won't help. + self._attempts[url] = self._retries + else: + self._attempts[url] = self._attempts.get(url, 0) + 1 + done_event = self._in_progress.pop(url, None) + raise + finally: + if done_event is not None: + done_event.set() + + def _do_download(self, url: str, source_checksum: str) -> Tuple[str, bool]: + """Perform the actual download/skip logic. Returns (path, downloaded_now).""" + dest_path = os.path.join(self._cache_dir, filename_from_url(url)) + os.makedirs(self._cache_dir, exist_ok=True) + + # When the cached file's checksum matches the expected one, skip + # redownload regardless of `verify`. For the CLI path, cleanup() wipes + # the cache dir at end-of-run; for scripts (keep_file=True), the + # checksum match is itself the safety gate. + skip = False + if os.path.exists(dest_path): + if source_checksum == "__skip__": + skip = True + elif source_checksum == file_checksum(dest_path): + skip = True + elif self._verify and is_extractable(dest_path): + os.remove(dest_path) + + if skip: + _locked_print(f"\tSkipping download of {filename_from_url(url)} (already exists)") + else: + _locked_print(f"\tDownloading {filename_from_url(url)} from {url}") + # Hold an HTTP slot only for the network transfer; the checksum + # below runs unslotted (it needs no connection). + with self._download_slots: + # download() retries internally with backoff. Pass retries + # straight through; the per-URL budget in get() is the outer + # envelope. (Avoids the old retries**retries blow-up: -r 5 + # used to mean 3125 attempts per URL.) + download(url, self._cache_dir, max_retries=self._retries) + + if source_checksum != "__skip__": + checksum = file_checksum(dest_path) + if source_checksum != checksum: + raise ChecksumMismatchError( + f"Checksum mismatch for {filename_from_url(url)}: {checksum} instead of '{source_checksum}'" + ) + + return dest_path, not skip + + def is_poisoned(self, url: str) -> bool: + """True if this URL has exhausted its retry budget and will fail on next get().""" + with self._lock: + return url in self._errors and self._attempts.get(url, 0) >= self._retries + + def cached_path(self, url: str) -> Optional[str]: + """Return the cached path for *url* if present and on disk, else None. + + Read-only; never triggers a download or alters manager state. Useful + for callers that want to skip a redundant get() round-trip when an + earlier phase has already warmed the cache. + """ + with self._lock: + path = self._cache.get(url) + if path is not None and os.path.exists(path): + return path + return None + + def invalidate(self, url: str) -> None: + """Drop the cached download for *url* and delete the on-disk file. + + Use when a consumer detects the cached archive is unusable (e.g. + a corrupt zip that passed the checksum). The next get() call will + re-download from scratch. + """ + with self._lock: + path = self._cache.pop(url, None) + # Forget tracking so cleanup() won't later try to remove the + # re-downloaded file twice. + if path and path in self._managed_files: + self._managed_files.remove(path) + if path and os.path.exists(path): + with contextlib.suppress(OSError): + os.remove(path) + + def release(self, url: str) -> None: + """Forget the cached path for *url* without deleting the file. + + Use when a consumer has taken ownership of the cached file (e.g. + moved it elsewhere). Differs from invalidate() in that no on-disk + removal happens — the caller now owns the file and the manager + forgets it. Only meaningful in the parent process; subprocess + workers can't mutate the parent's manager state. + """ + with self._lock: + path = self._cache.pop(url, None) + if path and path in self._managed_files: + self._managed_files.remove(path) + + def __enter__(self) -> "DownloadManager": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Any, + ) -> None: + # Skip cleanup when an exception is propagating (especially + # KeyboardInterrupt mid-download), so the user can resume on the + # next run instead of starting over. + if exc_type is None: + self.cleanup() + + def cleanup(self) -> None: + """Remove downloaded archives unless keep_file was requested. + + Not safe against concurrent DownloadManager instances sharing the + same cache_dir (across fluster processes). Each DownloadManager only + tracks files it downloaded itself, so cross-process corruption is + bounded, but callers running concurrent fluster sessions should + either use --keep or point at separate resource dirs. + """ + if self._keep_file: + return + for path in self._managed_files: + # Best-effort: a missing file (already removed) or one we can't + # delete shouldn't abort cleanup of the rest. Suppressing OSError + # also avoids a TOCTOU race against an exists() check. + with contextlib.suppress(OSError): + os.remove(path) + self._managed_files.clear() + # Best-effort: remove the cache dir if it is now empty. Fails quietly + # if the dir still contains files (e.g., concurrent fluster instance). + with contextlib.suppress(OSError): + os.rmdir(self._cache_dir) + + def file_checksum(path: str) -> str: """Calculates the checksum of a file reading chunks of 64KiB""" md5 = hashlib.md5() diff --git a/scripts/gen_aac.py b/scripts/gen_aac.py index fa11c99a..f9283cb9 100755 --- a/scripts/gen_aac.py +++ b/scripts/gen_aac.py @@ -127,13 +127,7 @@ def generate(self, download: bool, jobs: int) -> None: print(f"Download list of compressed bitstreams from {self.url_test_vectors}") if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) # MP4 test suites audio validation if test_suite.name in ["MPEG4_AAC-MP4", "MPEG4_AAC-MP4-ER"]: diff --git a/scripts/gen_av1_aom.py b/scripts/gen_av1_aom.py index 0059fbbf..c2c6b2b8 100755 --- a/scripts/gen_av1_aom.py +++ b/scripts/gen_av1_aom.py @@ -92,13 +92,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[test_vector_name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) for test_vector in test_suite.test_vectors.values(): dest_dir = os.path.join(test_suite.resources_dir, test_suite.name, test_vector.name) diff --git a/scripts/gen_av1_chromium.py b/scripts/gen_av1_chromium.py index 04181c5a..c978a617 100755 --- a/scripts/gen_av1_chromium.py +++ b/scripts/gen_av1_chromium.py @@ -120,13 +120,7 @@ def generate(self, download: bool, jobs: int) -> Any: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) for test_vector in test_suite.test_vectors.values(): dest_dir = os.path.join(test_suite.resources_dir, test_suite.name, test_vector.name) diff --git a/scripts/gen_jct_vc.py b/scripts/gen_jct_vc.py index ae2408e7..0e5a0298 100755 --- a/scripts/gen_jct_vc.py +++ b/scripts/gen_jct_vc.py @@ -106,13 +106,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) if "SHVC" in test_suite.name: for test_vector in test_suite.test_vectors.values(): diff --git a/scripts/gen_jvet.py b/scripts/gen_jvet.py index 34fa102f..2c527468 100755 --- a/scripts/gen_jvet.py +++ b/scripts/gen_jvet.py @@ -103,13 +103,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) for test_vector in test_suite.test_vectors.values(): dest_dir = os.path.join(test_suite.resources_dir, test_suite.name, test_vector.name) diff --git a/scripts/gen_jvt.py b/scripts/gen_jvt.py index 856899b5..ed5b23db 100755 --- a/scripts/gen_jvt.py +++ b/scripts/gen_jvt.py @@ -116,13 +116,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) for test_vector in test_suite.test_vectors.values(): dest_dir = os.path.join(test_suite.resources_dir, test_suite.name, test_vector.name) diff --git a/scripts/gen_mpeg2_video.py b/scripts/gen_mpeg2_video.py index 3a201506..a8ee807a 100755 --- a/scripts/gen_mpeg2_video.py +++ b/scripts/gen_mpeg2_video.py @@ -118,13 +118,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) for test_vector in test_suite.test_vectors.values(): dest_dir = os.path.join(test_suite.resources_dir, test_suite.name, test_vector.name) diff --git a/scripts/gen_mpeg4_video.py b/scripts/gen_mpeg4_video.py index 862fee18..b4f6295a 100755 --- a/scripts/gen_mpeg4_video.py +++ b/scripts/gen_mpeg4_video.py @@ -130,13 +130,7 @@ def generate(self, download: bool, jobs: int) -> None: test_suite.test_vectors[name] = test_vector if download: - test_suite.download( - jobs=jobs, - out_dir=test_suite.resources_dir, - verify=False, - extract_all=True, - keep_file=True, - ) + test_suite.download_with_default_manager(jobs, extract_all=True) original_vectors = { name: {"source": vector.source, "source_checksum": vector.source_checksum}