From a6ad8db6d98a04efab451a76bf5f78aa8e1360ad Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Tue, 28 Apr 2026 20:37:01 -0600 Subject: [PATCH 1/2] improve cache --- CHANGELOG.md | 27 ++ docs/docs/bulk-downloads.md | 23 ++ docs/docs/caching.md | 67 ++++- pyproject.toml | 7 +- src/oda_reader/__init__.py | 84 +++++- src/oda_reader/_cache/config.py | 45 ++- src/oda_reader/_cache/dataframe.py | 36 ++- src/oda_reader/_cache/legacy.py | 13 +- src/oda_reader/_cache/manager.py | 155 ++++++---- src/oda_reader/common.py | 26 +- src/oda_reader/crs.py | 8 + src/oda_reader/dac2a.py | 4 + src/oda_reader/download/download_tools.py | 280 +++++++++--------- src/oda_reader/download/query_builder.py | 2 +- src/oda_reader/download/version_discovery.py | 4 +- src/oda_reader/exceptions.py | 85 ++++++ src/oda_reader/multisystem.py | 5 +- src/oda_reader/tools.py | 7 +- tests/cache/__init__.py | 0 tests/cache/conftest.py | 92 ++++++ tests/cache/test_cache_manager_activated.py | 216 ++++++++++++++ tests/cache/test_dataframe_atomic.py | 81 +++++ tests/cache/test_deprecation_shims.py | 93 ++++++ tests/cache/test_freezegun_ttl.py | 69 +++++ tests/cache/test_lru_eviction.py | 112 +++++++ tests/common/unit/test_cache.py | 2 +- tests/conftest.py | 6 +- .../datasets/crs/integration/test_crs_e2e.py | 8 +- tests/datasets/dac2a/unit/test_dac2a_bulk.py | 7 +- tests/download/unit/test_deflate64.py | 6 +- tests/download/unit/test_download_tools.py | 69 +++-- tests/download/unit/test_version_discovery.py | 6 +- tests/schemas/unit/test_dac1_translation.py | 6 +- tests/utils.py | 6 +- uv.lock | 16 +- 35 files changed, 1369 insertions(+), 304 deletions(-) create mode 100644 src/oda_reader/exceptions.py create mode 100644 tests/cache/__init__.py create mode 100644 tests/cache/conftest.py create mode 100644 tests/cache/test_cache_manager_activated.py create mode 100644 tests/cache/test_dataframe_atomic.py create mode 100644 tests/cache/test_deprecation_shims.py create mode 100644 tests/cache/test_freezegun_ttl.py create mode 100644 tests/cache/test_lru_eviction.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cab7fde..b564989 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,32 @@ # Changelog for oda_reader +## 1.6.0 (2026-04-28) +- Adds `use_raw_cache=False` to `bulk_download_crs`, `download_crs_file`, `bulk_download_dac2a` + and `bulk_download_multisystem` for the cases where you want to bypass the bulk cache and + re-download fresh on every call. Caching remains on by default. +- Adds a typed `BulkPayloadCorruptError` (importable from `oda_reader`) for the rare case + where a freshly downloaded zip arrives corrupt. The corrupt entry is removed before the + exception is raised, so the next call cleanly re-downloads. The error message tells you + what to do next. +- Strengthens corruption detection by validating freshly downloaded zips end-to-end (full + member CRC check), not just the central directory. Cached files are trusted on hit so + this doesn't slow normal use. +- The bulk cache now self-maintains: it keeps the two most recent downloads per dataset + and evicts older ones, and sweeps stale temp files left behind by interrupted downloads + (older than 24h) on startup. No more manually clearing the cache directory to free space. +- Temp-file naming now includes the hostname alongside the PID, preventing collisions when + the cache directory lives on a shared / NFS mount used by multiple machines. +- `clear_cache`, `set_cache_dir`, `enable_cache` and `disable_cache` now emit a + `DeprecationWarning` for users who also import `oda_data`, pointing at the umbrella + `oda_data.cache.*` API. Standalone `oda_reader` users see no warning. The shims continue + to work through the `1.x` series and will be removed in `2.0`. + +## 1.5.2 (2026-04-28) +- Cache directory is now versioned by the installed package version (via `importlib.metadata`) rather than a hardcoded string, so upgrades automatically invalidate old caches that may contain partial or corrupt downloads from prior versions. +- Bulk-download cache writes are now atomic: downloads stream into a sibling temp file and are only renamed over the destination on success, so partial downloads no longer pollute the cache on interruption or error. +- On `BadZipFile`, the corrupt cached archive is removed so the next call cleanly re-downloads instead of looping on the same poisoned entry. +- Cached archives are validated with `zipfile.is_zipfile` before reuse; a corrupt entry is removed and re-downloaded transparently. + ## 1.5.1 (2026-04-15) - Adds support for Deflate64-compressed ZIP files in bulk downloads. The OECD switched the full CRS bulk file to Deflate64 compression, which Python's standard library does not support. This release patches `zipfile` at runtime using the `inflate64` library to handle Deflate64 transparently. - Adds `inflate64` as a dependency. diff --git a/docs/docs/bulk-downloads.md b/docs/docs/bulk-downloads.md index 6892632..a780714 100644 --- a/docs/docs/bulk-downloads.md +++ b/docs/docs/bulk-downloads.md @@ -110,6 +110,24 @@ print(f"Education projects: {education_count}") print(f"Total commitments: ${education_amount/1e9:.1f}B") ``` +## Forcing a Fresh Download + +By default, bulk downloads are cached on disk so a second call returns +instantly. If you need to bypass that cache (for example, in a CI job that +should always pull the latest file), pass `use_raw_cache=False`: + +```python +# Always download fresh; the zip is extracted to a temp dir and discarded +crs = bulk_download_crs(use_raw_cache=False) +``` + +The integrity check on the freshly downloaded zip still runs; only the +on-disk caching is skipped. This flag is available on `bulk_download_crs`, +`download_crs_file`, `bulk_download_dac2a` and `bulk_download_multisystem`. + +See [Caching & Performance](caching.md#bulk-file-cache) for how the bulk +cache is managed (LRU eviction, TTL, integrity validation). + ## Year-Specific CRS Files OECD also provides individual files for specific years: @@ -215,6 +233,11 @@ See [Schema Translation](schema-translation.md) for detailed comparison. **File not found errors**: Older CRS year-specific files use grouped years (e.g., "1995-99"). Check which grouping includes your target year. +**`BulkPayloadCorruptError`**: The OECD's bulk endpoint occasionally serves a +truncated or malformed zip. The corrupt entry is removed automatically before +the exception is raised, so the next call cleanly re-downloads. Retry the +call, or pass `use_raw_cache=False` to skip the cache for that invocation. + ## Next Steps - **[Caching & Performance](caching.md)** - Understand how bulk downloads are cached diff --git a/docs/docs/caching.md b/docs/docs/caching.md index bfa04c3..5d66438 100644 --- a/docs/docs/caching.md +++ b/docs/docs/caching.md @@ -4,12 +4,14 @@ ODA Reader uses caching to make repeated queries fast and reduce dependency on O ## How Caching Works -ODA Reader caches two types of data: +ODA Reader caches three types of data: 1. **HTTP responses**: Raw API responses before processing 2. **DataFrames**: Processed pandas DataFrames after schema translation +3. **Bulk files**: Large parquet/zip files downloaded by `bulk_download_crs`, + `download_crs_file`, `bulk_download_dac2a` and `bulk_download_multisystem` -Both caches are automatic and transparent - you don't need to change your code to benefit from caching. +All three caches are automatic and transparent - you don't need to change your code to benefit from caching. **Example of caching in action**: @@ -103,9 +105,16 @@ This removes all cached API responses and DataFrames. Your next query will hit t - Cache has grown too large - You're troubleshooting unexpected results +**Using `oda_reader` alongside `oda_data`?** `clear_cache`, `set_cache_dir`, +`enable_cache` and `disable_cache` are deprecated under the umbrella package +and emit a `DeprecationWarning` pointing at the `oda_data.cache.*` API +(e.g. `oda_data.cache.clear("all")`). Standalone `oda_reader` users see no +warning. The shims continue to work through the `1.x` series and will be +removed in `2.0`. + ### Automatic Cache Cleanup -ODA Reader automatically enforces cache limits: +ODA Reader automatically enforces cache limits across the cache root: - **Max size**: 2.5 GB - **Max age**: 7 days @@ -116,6 +125,58 @@ When you import oda_reader, it checks cache limits: This happens automatically - you don't need to do anything. +### Bulk File Cache + +The bulk file cache (used by `bulk_download_crs`, `download_crs_file`, +`bulk_download_dac2a` and `bulk_download_multisystem`) is governed separately +because the files are large (~1 GB each): + +- **LRU eviction**: only the two most recent bulk files are kept; older + entries are removed automatically the next time you import oda_reader. +- **Per-entry TTL**: an entry is considered stale after 30 days and refetched + on next use. +- **Integrity validation**: every freshly downloaded zip is end-to-end checked + before being trusted. A corrupt download is removed from the cache and + raises `BulkPayloadCorruptError` so you can simply retry. Cached files are + trusted on hit (no recheck on every call). +- **Self-healing**: temp files left behind by interrupted downloads (older + than 24 hours) are swept on startup, so an aborted download can't pollute + the cache directory indefinitely. + +#### Bypassing the Bulk File Cache + +If you need a fresh download every call (e.g. for a CI job that should always +hit the source), pass `use_raw_cache=False`: + +```python +from oda_reader import bulk_download_crs + +# Download to a temp directory and discard the zip after extraction +crs = bulk_download_crs(use_raw_cache=False) +``` + +Validation still runs in this mode; only the on-disk caching is skipped. The +flag is available on `bulk_download_crs`, `download_crs_file`, +`bulk_download_dac2a` and `bulk_download_multisystem`. `download_aiddata` +takes a different code path and is not affected. + +#### Handling Corrupt Downloads + +The OECD's bulk endpoint occasionally serves a truncated or malformed file. +When that happens, a `BulkPayloadCorruptError` is raised and the bad entry is +already removed from disk by the time you see it, so the next call cleanly +re-downloads: + +```python +from oda_reader import bulk_download_crs, BulkPayloadCorruptError + +try: + crs = bulk_download_crs() +except BulkPayloadCorruptError: + # Bad entry already removed — just retry + crs = bulk_download_crs() +``` + ## HTTP Caching (Separate from DataFrame Cache) ODA Reader also caches raw HTTP responses using `requests-cache`: diff --git a/pyproject.toml b/pyproject.toml index cac7b43..f275846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "oda_reader" -version = "1.5.1" +version = "1.6.0" description = "A simple package to import ODA data from the OECD's API and AidData's database" readme = "README.md" license = "MIT" @@ -53,6 +53,7 @@ docs = [ "mkdocstrings[python]>=0.24.0", ] test = [ + "freezegun>=1.4.0", "pytest>=9.0.3", "pytest-mock>=3.12", "pytest-cov>=4.1", @@ -120,9 +121,10 @@ unfixable = [] # Allow longer lines in tests and examples "tests/**/*.py" = ["E501"] "docs/examples/**/*.py" = ["E501"] -# Allow global variable usage in cache and common modules (architectural decision) +# Allow global variable usage in singleton-managing modules (architectural decision) "src/oda_reader/_cache/*.py" = ["PLW0603"] "src/oda_reader/common.py" = ["PLW0603"] +"src/oda_reader/_http_primitives.py" = ["PLW0603"] # Allow try-except without from (external libraries may not support cause chaining) "src/oda_reader/download/download_tools.py" = ["B904"] # Allow camelcase import in tools (QueryBuilder is the standard name) @@ -149,6 +151,7 @@ python_functions = ["test_*"] markers = [ "unit: Fast unit tests (no external dependencies)", "integration: Tests that call real OECD API", + "network: Tests that require live network access (opt-in via RUN_NETWORK_TESTS=1)", "slow: Long-running tests (bulk downloads)", "cache: Tests that verify cache behavior", ] diff --git a/src/oda_reader/__init__.py b/src/oda_reader/__init__.py index 55364ce..4d97c88 100644 --- a/src/oda_reader/__init__.py +++ b/src/oda_reader/__init__.py @@ -3,21 +3,23 @@ specifically designed to work with OECD DAC data. """ -# Core data download functions -# Cache management (new system) +import sys +import warnings +from collections.abc import Callable +from typing import Any + from oda_reader._cache import ( bulk_cache_manager, cache_dir, # Deprecated alias - clear_cache, dataframe_cache, - disable_cache, - # Legacy functions (for backward compatibility) - enable_cache, enforce_cache_limits, get_cache_dir, reset_cache_dir, - set_cache_dir, ) +from oda_reader._cache.config import set_cache_dir as _impl_set_cache_dir +from oda_reader._cache.legacy import clear_cache as _impl_clear_cache +from oda_reader._cache.legacy import disable_cache as _impl_disable_cache +from oda_reader._cache.legacy import enable_cache as _impl_enable_cache from oda_reader.aiddata import download_aiddata from oda_reader.common import ( API_RATE_LIMITER, @@ -26,15 +28,79 @@ enable_http_cache, get_http_cache_info, ) -from oda_reader.download.version_discovery import clear_version_cache from oda_reader.crs import bulk_download_crs, download_crs, download_crs_file from oda_reader.dac1 import download_dac1 from oda_reader.dac2a import bulk_download_dac2a, download_dac2a from oda_reader.download.query_builder import QueryBuilder +from oda_reader.download.version_discovery import clear_version_cache +from oda_reader.exceptions import BulkDownloadHTTPError, BulkPayloadCorruptError from oda_reader.multisystem import bulk_download_multisystem, download_multisystem from oda_reader.tools import get_available_filters +# Each shim emits a one-time-per-session DeprecationWarning when oda_data is +# also imported (umbrella users should migrate to oda_data.cache.*); standalone +# oda_reader users see no warning. +_WARNED_SHIMS: set[str] = set() + + +def _warn_once_if_oda_data_imported(name: str, replacement: str) -> None: + if name in _WARNED_SHIMS or "oda_data" not in sys.modules: + return + warnings.warn( + f"oda_reader.{name} is deprecated for users who also import oda_data; " + f"use {replacement} for the umbrella API. This shim is preserved for " + "standalone oda_reader users through 1.x and removed in 2.0.", + DeprecationWarning, + stacklevel=3, + ) + _WARNED_SHIMS.add(name) + + +def _make_deprecation_shim( + name: str, replacement: str, impl: Callable[..., Any], one_liner: str +) -> Callable[..., Any]: + def shim(*args: Any, **kwargs: Any) -> Any: + _warn_once_if_oda_data_imported(name, replacement) + return impl(*args, **kwargs) + + shim.__name__ = name + shim.__qualname__ = name + shim.__doc__ = ( + f"{one_liner} Deprecated under the oda_data umbrella; use {replacement}." + ) + return shim + + +clear_cache = _make_deprecation_shim( + "clear_cache", + "oda_data.cache.clear('all')", + _impl_clear_cache, + "Clear the cache directory.", +) +set_cache_dir = _make_deprecation_shim( + "set_cache_dir", + "oda_data.set_cache_root() or the ODA_DATA_CACHE_DIR env var", + _impl_set_cache_dir, + "Set a custom cache directory path.", +) +enable_cache = _make_deprecation_shim( + "enable_cache", + "oda_data.cache.enable_cache('all')", + _impl_enable_cache, + "Enable caching globally.", +) +disable_cache = _make_deprecation_shim( + "disable_cache", + "oda_data.cache.disable_cache('all')", + _impl_disable_cache, + "Disable caching globally.", +) + + __all__ = [ + # Boundary contract + "BulkPayloadCorruptError", + "BulkDownloadHTTPError", # Data download "QueryBuilder", "download_dac1", @@ -63,7 +129,7 @@ "bulk_cache_manager", # Rate limiting "API_RATE_LIMITER", - # Legacy (backward compatibility) + # Legacy (backward compatibility - deprecated for oda_data users) "enable_cache", "disable_cache", "clear_cache", diff --git a/src/oda_reader/_cache/config.py b/src/oda_reader/_cache/config.py index 7827567..b11bb8d 100644 --- a/src/oda_reader/_cache/config.py +++ b/src/oda_reader/_cache/config.py @@ -5,17 +5,25 @@ """ import os +import socket +from collections.abc import Callable +from importlib.metadata import version from pathlib import Path from platformdirs import user_cache_dir -# Version for cache versioning (hardcoded to avoid circular import) -# This should match the version in __init__.py -__version__ = "1.3.1" +# Resolved at import time so the cache path tracks the installed package version. +_CACHE_VERSION = version("oda_reader") + +_HOSTNAME = socket.gethostname() -# Global override for cache directory (set via set_cache_dir) _CACHE_DIR_OVERRIDE: Path | None = None +# Callbacks invoked when the cache root changes, so module-level singletons +# (CacheManager, DataFrameCache) can rebuild against the new path. Modules +# register on import via ``register_cache_dir_change_listener``. +_CACHE_DIR_LISTENERS: list[Callable[[], None]] = [] + def get_cache_dir() -> Path: """Get the cache directory path. @@ -30,25 +38,22 @@ def get_cache_dir() -> Path: Returns: Path: The cache directory path. """ - # Priority 1: Programmatic override if _CACHE_DIR_OVERRIDE is not None: return _CACHE_DIR_OVERRIDE - # Priority 2: Environment variable if env_dir := os.getenv("ODA_READER_CACHE_DIR"): return Path(env_dir).expanduser().resolve() - # Priority 3: Platform default with version - # This ensures cache is invalidated on package upgrades base = Path(user_cache_dir("oda-reader", "oda-reader")) - return base / __version__ + return base / _CACHE_VERSION def set_cache_dir(path: str | Path) -> None: """Set a custom cache directory path. This takes precedence over environment variables and platform defaults. - Changes affect all future cache operations. + Changes affect all future cache operations and reset any module-level + cache singletons so they pick up the new directory. Args: path: The directory path to use for caching. @@ -59,16 +64,34 @@ def set_cache_dir(path: str | Path) -> None: """ global _CACHE_DIR_OVERRIDE _CACHE_DIR_OVERRIDE = Path(path).expanduser().resolve() + _notify_cache_dir_changed() def reset_cache_dir() -> None: """Reset cache directory to default (remove override). After calling this, cache directory will be determined by environment - variable or platform default. + variable or platform default. Resets module-level cache singletons. """ global _CACHE_DIR_OVERRIDE _CACHE_DIR_OVERRIDE = None + _notify_cache_dir_changed() + + +def register_cache_dir_change_listener(callback: Callable[[], None]) -> None: + """Register a callback to fire when the cache directory changes. + + Args: + callback: Zero-argument callable invoked after set_cache_dir or + reset_cache_dir mutates the override. Used by cache singletons + to rebuild against the new directory. + """ + _CACHE_DIR_LISTENERS.append(callback) + + +def _notify_cache_dir_changed() -> None: + for callback in _CACHE_DIR_LISTENERS: + callback() def get_http_cache_path() -> Path: diff --git a/src/oda_reader/_cache/dataframe.py b/src/oda_reader/_cache/dataframe.py index 1136551..a3d2b10 100644 --- a/src/oda_reader/_cache/dataframe.py +++ b/src/oda_reader/_cache/dataframe.py @@ -15,12 +15,17 @@ import hashlib import json import logging +import os from pathlib import Path from typing import Any import pandas as pd -from oda_reader._cache.config import get_dataframe_cache_dir +from oda_reader._cache.config import ( + _HOSTNAME, + get_dataframe_cache_dir, + register_cache_dir_change_listener, +) logger = logging.getLogger("oda_reader") @@ -63,12 +68,12 @@ def _make_cache_key( **kwargs, } - # Create a deterministic JSON string (sorted keys) - params_json = json.dumps(params, sort_keys=True) + # default=str so non-JSON-serializable kwargs (Path, datetime, ...) hash + # by their str() rather than crashing inside the cache layer. + params_json = json.dumps(params, sort_keys=True, default=str) - # Hash it to create a short key - hash_obj = hashlib.sha256(params_json.encode()) - return hash_obj.hexdigest()[:16] # First 16 chars of hash + # 16 hex chars = 64 bits of namespace; ample for any realistic workload. + return hashlib.sha256(params_json.encode()).hexdigest()[:16] class DataFrameCache: @@ -170,10 +175,17 @@ def set( cache_file = self.cache_dir / f"{cache_key}.parquet" + # Cache writes are best-effort: a full disk or unwritable cache dir + # must not turn a successful download into a user-visible failure. + # We still write atomically (tmp + rename) so a crash mid-write can't + # leave a half-baked parquet at the destination. + tmp_path = Path(f"{cache_file}.tmp-{_HOSTNAME}-{os.getpid()}") try: - df.to_parquet(cache_file) + df.to_parquet(tmp_path) + tmp_path.replace(cache_file) logger.info(f"Cached DataFrame (key: {cache_key})") except Exception as e: + tmp_path.unlink(missing_ok=True) logger.warning(f"Failed to cache DataFrame: {e}") def clear(self) -> None: @@ -205,7 +217,6 @@ def disable(self) -> None: self._enabled = False -# Global singleton _DATAFRAME_CACHE: DataFrameCache | None = None @@ -219,3 +230,12 @@ def dataframe_cache() -> DataFrameCache: if _DATAFRAME_CACHE is None: _DATAFRAME_CACHE = DataFrameCache() return _DATAFRAME_CACHE + + +def _reset_dataframe_cache() -> None: + """Reset the singleton so the next access rebuilds against the current cache dir.""" + global _DATAFRAME_CACHE + _DATAFRAME_CACHE = None + + +register_cache_dir_change_listener(_reset_dataframe_cache) diff --git a/src/oda_reader/_cache/legacy.py b/src/oda_reader/_cache/legacy.py index 361e383..e402c72 100644 --- a/src/oda_reader/_cache/legacy.py +++ b/src/oda_reader/_cache/legacy.py @@ -14,6 +14,9 @@ from pathlib import Path from oda_reader._cache.config import get_cache_dir +from oda_reader._cache.config import set_cache_dir as _impl_set_cache_dir +from oda_reader._cache.dataframe import dataframe_cache +from oda_reader.common import disable_http_cache, enable_http_cache logger = logging.getLogger("oda_reader") @@ -58,9 +61,7 @@ def set_cache_dir(path) -> None: Use oda_reader._cache.config.set_cache_dir() instead. """ - from oda_reader._cache.config import set_cache_dir as new_set_cache_dir - - new_set_cache_dir(path) + _impl_set_cache_dir(path) def _human_mb(byte_count: float) -> float: @@ -195,9 +196,6 @@ def disable_cache() -> None: This disables HTTP caching and DataFrame caching. """ - from oda_reader._cache.dataframe import dataframe_cache - from oda_reader.common import disable_http_cache - disable_http_cache() dataframe_cache().disable() logger.info("Caching disabled globally.") @@ -208,9 +206,6 @@ def enable_cache() -> None: This enables HTTP caching and DataFrame caching. """ - from oda_reader._cache.dataframe import dataframe_cache - from oda_reader.common import enable_http_cache - enable_http_cache() dataframe_cache().enable() logger.info("Caching enabled globally.") diff --git a/src/oda_reader/_cache/manager.py b/src/oda_reader/_cache/manager.py index 9fc762d..aaca18b 100644 --- a/src/oda_reader/_cache/manager.py +++ b/src/oda_reader/_cache/manager.py @@ -15,12 +15,17 @@ from filelock import FileLock -from oda_reader._cache.config import get_bulk_cache_dir +from oda_reader._cache.config import ( + _HOSTNAME, + get_bulk_cache_dir, + register_cache_dir_change_listener, +) +from oda_reader.exceptions import validate_zip_or_raise logger = logging.getLogger("oda_reader") -# ISO format for datetime serialization ISO_FORMAT = "%Y-%m-%dT%H:%M:%S%z" +_TMP_MAX_AGE_SECONDS = 86_400 # 24 hours @dataclass(frozen=True) @@ -46,59 +51,72 @@ class CacheManager: """Manages cached bulk files with a JSON manifest. Provides: - - Atomic writes (temp file + rename) + - Atomic writes (temp file + rename, host+pid suffix for NFS safety) - Cross-process locking (FileLock) - TTL-based expiration - Version-based invalidation + - LRU eviction per dataset key prefix (keep_n most-recent) - Observable cache state (list_records, stats) + - Startup sweep of stale tmp files Storage layout: {base_dir}/ ├── manifest.json # Metadata tracking ├── .cache.lock # FileLock for coordination - └── *.parquet # Cached datasets + └── *.zip # Cached datasets """ - def __init__(self, base_dir: Path | None = None): + def __init__(self, base_dir: Path | None = None, *, keep_n: int = 2) -> None: """Initialize cache manager. Args: base_dir: Directory for cache storage. If None, uses get_bulk_cache_dir(). + keep_n: Maximum number of entries to retain per dataset key prefix. + Oldest entries beyond this limit are evicted on init. """ self.base_dir = base_dir or get_bulk_cache_dir() + self.keep_n = keep_n self.base_dir.mkdir(parents=True, exist_ok=True) self.manifest_path = self.base_dir / "manifest.json" self.lock_path = self.base_dir / ".cache.lock" self._lock = FileLock(str(self.lock_path), timeout=1200) + # Best-effort startup maintenance. If another process holds the lock + # (e.g., mid-fetch), skip — it'll run on the next clean startup. + try: + with self._lock.acquire(timeout=0): + self._sweep_stale_tmp_files() + self._evict_lru() + except Exception as e: + logger.debug(f"Skipping startup cache maintenance: {e}") + def ensure(self, entry: CacheEntry, refresh: bool = False) -> Path: """Ensure cached file exists, fetching if necessary. + On a cache-miss or refresh, the downloaded file is validated with + ``is_zipfile`` and ``testzip``. On a cache-hit the manifest record is + trusted (no CRC walk on every call). + Args: - entry: CacheEntry describing the dataset - refresh: If True, bypass cache and fetch fresh data + entry: CacheEntry describing the dataset. + refresh: If True, bypass cache and fetch fresh data. Returns: - Path: Path to the cached file - - Example: - >>> manager = bulk_cache_manager() - >>> entry = CacheEntry( - ... key="crs_full", - ... filename="crs_full.parquet", - ... fetcher=lambda p: download_crs_to_path(p), - ... ttl_days=90 - ... ) - >>> path = manager.ensure(entry) - >>> df = pd.read_parquet(path) + Path: Path to the cached file. On cache-miss / refetch the file is + guaranteed to satisfy ``is_zipfile(path) and + ZipFile(path).testzip() is None``. On cache-hit the manifest + record is trusted. + + Raises: + BulkPayloadCorruptError: If a freshly downloaded file fails + integrity validation. The cache entry is removed before raising. """ with self._lock: manifest = self._load_manifest() record = manifest.get(entry.key) path = self.base_dir / entry.filename - # Check if we need to fetch needs_fetch = ( refresh or record is None @@ -110,11 +128,13 @@ def ensure(self, entry: CacheEntry, refresh: bool = False) -> Path: logger.info(f"Loading {entry.key} from cache") return path - # Fetch fresh data logger.info(f"Fetching {entry.key} (refresh={refresh})") self._fetch_and_cache(entry, path) - # Update manifest + # Validate only on fetch - trusting the manifest on hit avoids a + # 10-30 s CRC walk per call on cached 976 MB zips. + validate_zip_or_raise(path) + manifest[entry.key] = { "filename": entry.filename, "downloaded_at": datetime.now(timezone.utc).strftime(ISO_FORMAT), @@ -130,22 +150,16 @@ def clear(self, key: str | None = None) -> None: Args: key: If provided, clear only this entry. If None, clear all. - - Example: - >>> manager.clear("crs_full") # Clear specific entry - >>> manager.clear() # Clear all """ with self._lock: manifest = self._load_manifest() if key is None: - # Clear all for record in manifest.values(): file_path = self.base_dir / record["filename"] file_path.unlink(missing_ok=True) manifest.clear() logger.info("Cleared all bulk cache entries") - # Clear specific entry elif key in manifest: file_path = self.base_dir / manifest[key]["filename"] file_path.unlink(missing_ok=True) @@ -161,11 +175,6 @@ def list_records(self) -> list[dict[str, Any]]: Returns: List of dicts containing cache record information. - - Example: - >>> for record in manager.list_records(): - ... print(f"{record['key']}: {record['size_mb']:.1f} MB, " - ... f"age: {record['age_days']:.1f} days") """ with self._lock: manifest = self._load_manifest() @@ -200,10 +209,6 @@ def stats(self) -> dict[str, Any]: Returns: Dict with total_entries, total_size_mb, stale_entries. - - Example: - >>> stats = manager.stats() - >>> print(f"Cache size: {stats['total_size_mb']:.1f} MB") """ records = self.list_records() total_size = sum(r["size_mb"] for r in records) @@ -216,14 +221,11 @@ def stats(self) -> dict[str, Any]: } def _fetch_and_cache(self, entry: CacheEntry, path: Path) -> None: - """Fetch data and write atomically. - - Uses temp-file-then-rename pattern to prevent corruption. - """ - tmp_path = Path(f"{path}.tmp-{os.getpid()}") + """Fetch data and write atomically using host+pid tmp suffix.""" + tmp_path = Path(f"{path}.tmp-{_HOSTNAME}-{os.getpid()}") try: entry.fetcher(tmp_path) - tmp_path.replace(path) # Atomic rename + tmp_path.replace(path) finally: tmp_path.unlink(missing_ok=True) @@ -234,11 +236,9 @@ def _is_stale(self, record: dict, entry: CacheEntry) -> bool: 1. Version changed (explicit cache bust) 2. Age exceeds TTL """ - # Check version mismatch if entry.version is not None and entry.version != record.get("version"): return True - # Check TTL expiration downloaded = datetime.strptime(record["downloaded_at"], ISO_FORMAT) age = datetime.now(timezone.utc) - downloaded ttl = timedelta(days=entry.ttl_days) @@ -258,8 +258,8 @@ def _load_manifest(self) -> dict: return {} def _save_manifest(self, manifest: dict) -> None: - """Save manifest to disk atomically.""" - tmp_path = Path(f"{self.manifest_path}.tmp-{os.getpid()}") + """Save manifest to disk atomically using host+pid tmp suffix.""" + tmp_path = Path(f"{self.manifest_path}.tmp-{_HOSTNAME}-{os.getpid()}") try: with tmp_path.open("w") as f: json.dump(manifest, f, indent=2) @@ -267,8 +267,53 @@ def _save_manifest(self, manifest: dict) -> None: finally: tmp_path.unlink(missing_ok=True) + def _sweep_stale_tmp_files(self) -> None: + """Remove *.tmp-* files in base_dir that are older than 24 hours.""" + now = datetime.now(timezone.utc).timestamp() + for tmp_file in self.base_dir.glob("*.tmp-*"): + try: + age = now - tmp_file.stat().st_mtime + if age > _TMP_MAX_AGE_SECONDS: + tmp_file.unlink(missing_ok=True) + logger.info(f"Swept stale tmp file: {tmp_file}") + except OSError as e: + logger.warning(f"Could not inspect/remove tmp file {tmp_file}: {e}") + + def _evict_lru(self) -> None: + """Evict oldest entries beyond keep_n per dataset key prefix. + + A "key prefix" is the full manifest key (e.g., ``sha1_hexdigest``). + Since all entries in the bulk cache share the same key space and each + URL produces a unique sha1 key, LRU is applied across *all* entries + collectively: keep the ``keep_n`` most-recently-downloaded, evict the + rest. + """ + manifest = self._load_manifest() + if len(manifest) <= self.keep_n: + return + + try: + ordered = sorted( + manifest.items(), + key=lambda kv: datetime.strptime(kv[1]["downloaded_at"], ISO_FORMAT), + ) + except (KeyError, ValueError) as e: + logger.warning(f"LRU eviction skipped due to manifest parse error: {e}") + return + + to_evict = ordered[: len(ordered) - self.keep_n] + for key, record in to_evict: + file_path = self.base_dir / record["filename"] + try: + file_path.unlink(missing_ok=True) + logger.info(f"LRU eviction: removed {file_path}") + except OSError as e: + logger.warning(f"LRU eviction: could not remove {file_path}: {e}") + del manifest[key] + + self._save_manifest(manifest) + -# Global singleton _BULK_CACHE_MANAGER: CacheManager | None = None @@ -277,13 +322,17 @@ def bulk_cache_manager() -> CacheManager: Returns: CacheManager: The global cache manager instance. - - Example: - >>> from oda_reader.cache_manager import bulk_cache_manager - >>> manager = bulk_cache_manager() - >>> stats = manager.stats() """ global _BULK_CACHE_MANAGER if _BULK_CACHE_MANAGER is None: _BULK_CACHE_MANAGER = CacheManager() return _BULK_CACHE_MANAGER + + +def _reset_bulk_cache_manager() -> None: + """Reset the singleton so the next access rebuilds against the current cache dir.""" + global _BULK_CACHE_MANAGER + _BULK_CACHE_MANAGER = None + + +register_cache_dir_change_listener(_reset_bulk_cache_manager) diff --git a/src/oda_reader/common.py b/src/oda_reader/common.py index e3a4e59..0ec67e8 100644 --- a/src/oda_reader/common.py +++ b/src/oda_reader/common.py @@ -6,12 +6,16 @@ import pandas as pd -import oda_reader._http_primitives as _http_primitives +from oda_reader import _http_primitives from oda_reader._http_primitives import ( API_RATE_LIMITER, RateLimiter, _get_http_session, - get_response_content as _get_response_content, +) +from oda_reader._http_primitives import ( + get_response_content as _get_response_content, # noqa: F401 # re-exported +) +from oda_reader._http_primitives import ( get_response_text as _get_response_text, ) from oda_reader.download.version_discovery import ( @@ -19,6 +23,22 @@ get_dimension_count, ) +# Re-exports of rate-limiting primitives. They live in _http_primitives +# to break a circular import with version_discovery; this module is the +# stable public re-export surface. +__all__ = [ + "API_RATE_LIMITER", + "RateLimiter", + "api_response_to_df", + "clear_http_cache", + "disable_http_cache", + "enable_http_cache", + "get_data_from_api", + "get_http_cache_info", + "logger", + "text_to_stringio", +] + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger("oda_importer") @@ -276,7 +296,7 @@ def get_data_from_api(url: str, compressed: bool = True) -> str: def api_response_to_df( - url: str, read_csv_options: dict = None, compressed: bool = True + url: str, read_csv_options: dict | None = None, compressed: bool = True ) -> pd.DataFrame: """Download a CSV file from an API endpoint and return it as a DataFrame. diff --git a/src/oda_reader/crs.py b/src/oda_reader/crs.py index 87a4e57..b1386a4 100644 --- a/src/oda_reader/crs.py +++ b/src/oda_reader/crs.py @@ -42,6 +42,7 @@ def download_crs_file( save_to_path: Path | str | None = None, *, as_iterator: bool = False, + use_raw_cache: bool = True, ) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: """ Download a year of CRS data from the bulk download service. The file is large. @@ -52,6 +53,8 @@ def download_crs_file( year: The year of CRS data to download. save_to_path: The path to save the file to. Optional. If not provided, a DataFrame is returned. as_iterator: If ``True`` yields ``DataFrame`` chunks instead of a single ``DataFrame``. + use_raw_cache: If False, the raw zip is downloaded to a temporary directory and + deleted after extraction. Each call hits the network. Validation still runs. Returns: pd.DataFrame | Iterator[pd.DataFrame] | None @@ -64,6 +67,7 @@ def download_crs_file( file_id=file_id, save_to_path=save_to_path, as_iterator=as_iterator, + use_raw_cache=use_raw_cache, ) @@ -72,6 +76,7 @@ def bulk_download_crs( reduced_version: bool = False, *, as_iterator: bool = False, + use_raw_cache: bool = True, ) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: """ Bulk download the CRS data from the bulk download service. The file is very large. @@ -82,6 +87,8 @@ def bulk_download_crs( save_to_path: The path to save the file to. Optional. If not provided, a DataFrame is returned. reduced_version: Whether to download the reduced version of the CRS data. as_iterator: If ``True`` yields ``DataFrame`` chunks instead of a single ``DataFrame``. + use_raw_cache: If False, the raw zip is downloaded to a temporary directory and + deleted after extraction. Each call hits the network. Validation still runs. Returns: pd.DataFrame | Iterator[pd.DataFrame] | None @@ -97,6 +104,7 @@ def bulk_download_crs( file_id=file_id, save_to_path=save_to_path, as_iterator=as_iterator, + use_raw_cache=use_raw_cache, ) diff --git a/src/oda_reader/dac2a.py b/src/oda_reader/dac2a.py index f41fd67..0bf9990 100644 --- a/src/oda_reader/dac2a.py +++ b/src/oda_reader/dac2a.py @@ -84,6 +84,7 @@ def bulk_download_dac2a( save_to_path: Path | str | None = None, *, as_iterator: bool = False, + use_raw_cache: bool = True, ) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: """ Bulk download the DAC2a data from the bulk download service. The file is very large. @@ -93,6 +94,8 @@ def bulk_download_dac2a( Args: save_to_path: The path to save the file to. Optional. If not provided, a DataFrame is returned. as_iterator: If ``True`` yields ``DataFrame`` chunks instead of a single ``DataFrame``. + use_raw_cache: If False, the raw zip is downloaded to a temporary directory and + deleted after extraction. Each call hits the network. Validation still runs. Returns: pd.DataFrame | Iterator[pd.DataFrame] | None @@ -104,4 +107,5 @@ def bulk_download_dac2a( file_id=file_id, save_to_path=save_to_path, as_iterator=as_iterator, + use_raw_cache=use_raw_cache, ) diff --git a/src/oda_reader/download/download_tools.py b/src/oda_reader/download/download_tools.py index c6cd7e0..05b6898 100644 --- a/src/oda_reader/download/download_tools.py +++ b/src/oda_reader/download/download_tools.py @@ -10,14 +10,13 @@ import zipfile from pathlib import Path -import oda_reader.download._deflate64 # noqa: F401 — adds Deflate64 support - import pandas as pd import pyarrow.parquet as pq import requests -from oda_reader._cache.config import get_bulk_cache_dir +import oda_reader.download._deflate64 # noqa: F401 # adds Deflate64 support from oda_reader._cache.dataframe import dataframe_cache +from oda_reader._cache.manager import CacheEntry, bulk_cache_manager from oda_reader.common import ( API_RATE_LIMITER, _get_response_content, @@ -27,6 +26,11 @@ ) from oda_reader.download.query_builder import QueryBuilder from oda_reader.download.version_discovery import discover_latest_version +from oda_reader.exceptions import ( + BulkDownloadHTTPError, + BulkPayloadCorruptError, + validate_zip_or_raise, +) from oda_reader.schemas.crs_translation import convert_crs_to_dotstat_codes from oda_reader.schemas.dac1_translation import convert_dac1_to_dotstat_codes from oda_reader.schemas.dac2_translation import convert_dac2a_to_dotstat_codes @@ -52,7 +56,6 @@ ) - def _detect_delimiter(file_obj, sample_size: int = 8192) -> str: """Detect the delimiter used in a CSV/text file. @@ -105,7 +108,7 @@ def _iter_frames(response_content: bytes | Path) -> typing.Iterator[pd.DataFrame def download( version: str, dataflow_id: str, - dataflow_version: str = None, + dataflow_version: str | None = None, start_year: int | None = None, end_year: int | None = None, filters: dict | None = None, @@ -328,76 +331,6 @@ def _save_or_return_parquet_files_from_content( raise ValueError("No parquet, csv, or txt files found in the zip archive.") -def _save_or_return_parquet_files_from_txt_in_zip( - response_content: bytes | Path, - save_to_path: Path | str | None = None, -) -> list[pd.DataFrame] | None: - """Extract csv or txt files from a zipped archive supplied as bytes or a file path. - - The file is read as CSV (with auto-detected delimiter) and optionally saved - as a parquet file. - - Args: - response_content: Bytes or ``Path`` pointing to the zipped archive. - save_to_path (Path | str | None): The path to save the file to. Optional. If - not provided, a list of DataFrames is returned. - - Returns: - list[pd.DataFrame]: The extracted DataFrames if save_to_path is not provided. - """ - # Convert the save_to_path to a Path object - save_to_path = Path(save_to_path).expanduser().resolve() if save_to_path else None - - with _open_zip(response_content=response_content) as z: - # Find all csv/txt files in the zip archive - files = [ - name - for name in z.namelist() - if name.endswith(".txt") or name.endswith(".csv") - ] - - # If save_to_path is provided, save the files to the path - if save_to_path: - save_to_path.mkdir(parents=True, exist_ok=True) - for file_name in files: - clean_name = ( - file_name.replace(".txt", ".parquet") - .replace(".csv", ".parquet") - .lower() - .replace(" ", "_") - ) - logger.info(f"Saving {clean_name}") - with z.open(file_name) as f_in: - delimiter = _detect_delimiter(f_in) - logger.info(f"Detected delimiter: '{delimiter}'") - pd.read_csv( - f_in, - delimiter=delimiter, - encoding="utf-8", - quotechar='"', - low_memory=False, - ).to_parquet(save_to_path / clean_name) - return None - - # If save_to_path is not provided, return the DataFrames - logger.info(f"Reading {len(files)} files.") - dfs = [] - for file_name in files: - with z.open(file_name) as f_in: - delimiter = _detect_delimiter(f_in) - logger.info(f"Detected delimiter for {file_name}: '{delimiter}'") - dfs.append( - pd.read_csv( - f_in, - delimiter=delimiter, - encoding="utf-8", - quotechar='"', - low_memory=False, - ) - ) - return dfs - - def _save_or_return_excel_files_from_content( response_content: bytes, save_to_path: Path | str | None = None, @@ -455,7 +388,7 @@ def _stream_to_file(url: str, headers: dict, path: Path) -> None: API_RATE_LIMITER.wait() with requests.get(url, headers=headers, stream=True) as r: if r.status_code > 299: - raise ConnectionError(f"Error {r.status_code}: {r.text}") + raise BulkDownloadHTTPError(status_code=r.status_code, url=url, body=r.text) with path.open("wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -464,46 +397,95 @@ def _stream_to_file(url: str, headers: dict, path: Path) -> None: def _stream_to_tempfile(url: str, headers: dict) -> Path: - """Download content to a temporary file using streaming.""" + """Download content to a temporary file using streaming. - with tempfile.NamedTemporaryFile(delete=False) as tmp: - _stream_to_file(url, headers, Path(tmp.name)) - return Path(tmp.name) + On stream failure the partial temp file is removed so callers don't have + to track a partial download to clean up. + """ + fd, name = tempfile.mkstemp() + os.close(fd) + path = Path(name) + try: + _stream_to_file(url, headers, path) + except BaseException: + path.unlink(missing_ok=True) + raise + return path -def _cached_stream_to_file(url: str, headers: dict) -> Path: - """Stream a URL to a cached file and return its path.""" +def _drain_then_unlink( + iterable: typing.Iterable[pd.DataFrame], path: Path +) -> typing.Iterator[pd.DataFrame]: + """Wrap an iterable so the source temp file is deleted on completion or close. - downloads = get_bulk_cache_dir() - downloads.mkdir(parents=True, exist_ok=True) - file_name = hashlib.sha1(url.encode()).hexdigest() + ".zip" - destination = downloads / file_name - if destination.exists(): - logger.info(f"Loading {url} from bulk file cache") - return destination + Cleanup runs on normal exhaustion, on caller-side exceptions during + iteration, and on generator close (CPython guarantees close() during + garbage collection of the wrapping generator). + """ + try: + yield from iterable + finally: + path.unlink(missing_ok=True) - _stream_to_file(url, headers, destination) - return destination +def _consume_bulk_zip( + *, + zip_path: Path, + save_to_path: Path | str | None, + as_iterator: bool, + use_raw_cache: bool, + manager: typing.Any, # CacheManager | None — typed loosely to avoid an import cycle here + url_key: str, +) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: + """Run extraction with cleanup tied to the no-cache temp file lifecycle. -def _get_temp_file(file_url: str, use_cache: bool = True) -> tuple[Path, bool]: - """Download file to a temporary location and return the path and a cleanup flag. + Three exit paths leak the temp file unless they're handled explicitly: + BadZipFile, any other extraction error, and the lazy-iterator early + return. This helper covers all three. + """ + try: + files = _save_or_return_parquet_files_from_content( + response_content=zip_path, + save_to_path=save_to_path, + as_iterator=as_iterator, + ) + except zipfile.BadZipFile: + if manager is not None: + manager.clear(url_key) + else: + zip_path.unlink(missing_ok=True) + raise BulkPayloadCorruptError( + zip_path, + reason="zipfile.BadZipFile raised when reading members", + ) + except BaseException: + if not use_raw_cache: + zip_path.unlink(missing_ok=True) + raise + + if as_iterator: + if files is None: + # save_to_path was provided; files were written to disk and there + # is no iteration to wrap. Clean up immediately. + if not use_raw_cache: + zip_path.unlink(missing_ok=True) + return None + # Iterator construction is lazy. For the no-cache path, defer unlink + # until the wrapping generator completes or is closed (CPython + # guarantees close() on garbage collection). + if not use_raw_cache: + return _drain_then_unlink(files, zip_path) + return files - Args: - file_url: URL to download - use_cache: If True, cache the file. If False, use a temp file. + if not use_raw_cache: + zip_path.unlink(missing_ok=True) - Returns: - tuple[Path, bool]: Path to file and whether it should be cleaned up - """ - headers = {"Accept-Encoding": "gzip"} - if use_cache: - temp_zip = _cached_stream_to_file(file_url, headers) - cleanup = False - else: - temp_zip = _stream_to_tempfile(file_url, headers) - cleanup = True - return temp_zip, cleanup + if files: + combined_df = pd.concat(files, ignore_index=True) + logger.info("File downloaded / retrieved correctly.") + return combined_df + + return None def bulk_download_parquet( @@ -512,6 +494,7 @@ def bulk_download_parquet( is_txt: bool | None = None, *, as_iterator: bool = False, + use_raw_cache: bool = True, ) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: """Download data from the stats.oecd.org file download service. @@ -520,17 +503,25 @@ def bulk_download_parquet( The file type is auto-detected from the zip contents. Args: - file_id (str): The ID of the file to download. - save_to_path (Path | str | None): The path to save the file to. Optional. + file_id: The ID of the file to download. + save_to_path: The path to save the file to. Optional. If not provided, the contents are returned. - is_txt (bool | None): Deprecated. File type is now auto-detected. + is_txt: Deprecated. File type is now auto-detected. This parameter is ignored and will be removed in a future version. - as_iterator (bool): When ``True`` return an iterator over ``DataFrame`` + as_iterator: When ``True`` return an iterator over ``DataFrame`` chunks instead of a single ``DataFrame``. Useful for large files. Only supported for parquet files. + use_raw_cache: If True (default), the raw zip is cached on disk and + reused across calls. If False, the zip is downloaded to a + temporary directory and deleted after extraction; each call hits + the network. Integrity validation (is_zipfile + testzip) still + runs in both modes. Returns: pd.DataFrame | Iterator[pd.DataFrame] | None + + Raises: + BulkPayloadCorruptError: If the downloaded zip fails integrity validation. """ if is_txt is not None: warnings.warn( @@ -539,44 +530,38 @@ def bulk_download_parquet( DeprecationWarning, stacklevel=2, ) - # Construct the URL + file_url = BULK_DOWNLOAD_URL + file_id + headers = {"Accept-Encoding": "gzip"} - # Inform the user about what the function will do (save or return) if save_to_path: logger.info(f"The file will be saved to {save_to_path}.") else: - logger.info("The file will be returned as a DataFrame. ") + logger.info("The file will be returned as a DataFrame.") - # Download the zip file to avoid loading it fully in memory - temp_zip_path, cleanup = _get_temp_file(file_url) + url_key = hashlib.sha1(file_url.encode()).hexdigest() - try: - # Auto-detect file type (parquet or txt) and process - files = _save_or_return_parquet_files_from_content( - response_content=temp_zip_path, - save_to_path=save_to_path, - as_iterator=as_iterator, - ) - if as_iterator: - return files - except zipfile.BadZipFile: - if cleanup: - os.unlink(temp_zip_path) - raise Exception( - f"Failed to read parquet files from {temp_zip_path}. " - "Ensure the file is a valid zip archive containing parquet files." + if use_raw_cache: + entry = CacheEntry( + key=url_key, + filename=f"{url_key}.zip", + fetcher=lambda p: _stream_to_file(file_url, headers, p), ) + manager = bulk_cache_manager() + zip_path = manager.ensure(entry) + else: + manager = None + zip_path = _stream_to_tempfile(file_url, headers) + validate_zip_or_raise(zip_path) - if cleanup: - os.unlink(temp_zip_path) - - if files: - combined_df = pd.concat(files, ignore_index=True) - logger.info("File downloaded / retrieved correctly.") - return combined_df - - return None + return _consume_bulk_zip( + zip_path=zip_path, + save_to_path=save_to_path, + as_iterator=as_iterator, + use_raw_cache=use_raw_cache, + manager=manager, + url_key=url_key, + ) def _download_aiddata_response() -> bytes: @@ -587,11 +572,18 @@ def _download_aiddata_response() -> bytes: """ logger.info("Downloading AidData. This may take a while...") headers = {"Accept-Encoding": "gzip"} - status, response, from_cache = _get_response_content( + status, response, _from_cache = _get_response_content( AIDDATA_DOWNLOAD_URL, headers=headers ) if status > 299: - raise ConnectionError(f"Error {status}: {response}") + body = ( + response.decode("utf-8", errors="replace") + if isinstance(response, bytes) + else str(response) + ) + raise BulkDownloadHTTPError( + status_code=status, url=AIDDATA_DOWNLOAD_URL, body=body + ) return response @@ -711,9 +703,7 @@ def _try_version(version: float | int | str) -> str | None: if result is not None: return result except (ConnectionError, ValueError): - logger.info( - "Version discovery failed; falling back to version scan." - ) + logger.info("Version discovery failed; falling back to version scan.") # --- Step 3: fall back to a decrement scan from 2.0 --- start = latest_flow if latest_flow is not None else 2.0 diff --git a/src/oda_reader/download/query_builder.py b/src/oda_reader/download/query_builder.py index a2af202..e314e67 100644 --- a/src/oda_reader/download/query_builder.py +++ b/src/oda_reader/download/query_builder.py @@ -24,7 +24,7 @@ class QueryBuilder: def __init__( self, dataflow_id: str, - dataflow_version: str = None, + dataflow_version: str | None = None, api_version: int = 1, ) -> None: """ diff --git a/src/oda_reader/download/version_discovery.py b/src/oda_reader/download/version_discovery.py index 809868b..3d538e9 100644 --- a/src/oda_reader/download/version_discovery.py +++ b/src/oda_reader/download/version_discovery.py @@ -137,9 +137,7 @@ def get_dimension_count(dataflow_id: str, version: str) -> int: if local_name == "Dimension": count += 1 if count == 0: - raise ValueError( - f"No dimensions found in DSD '{dsd_id}' version {version}." - ) + raise ValueError(f"No dimensions found in DSD '{dsd_id}' version {version}.") return count diff --git a/src/oda_reader/exceptions.py b/src/oda_reader/exceptions.py new file mode 100644 index 0000000..963a009 --- /dev/null +++ b/src/oda_reader/exceptions.py @@ -0,0 +1,85 @@ +"""Typed exceptions for the oda_reader boundary contract.""" + +import zipfile +import zlib +from pathlib import Path + +BULK_PAYLOAD_CORRUPT_HINT = ( + "Call the bulk_download function again to refetch (the corrupt entry " + "has been removed), run oda_data.cache.clear('raw') to wipe the raw " + "cache, or call with use_raw_cache=False to bypass." +) + +_HTTP_BODY_PREVIEW = 500 + + +class BulkPayloadCorruptError(Exception): + """Raised when a downloaded bulk payload fails integrity validation. + + Attributes: + path: The path of the failed cache entry. The entry has already + been removed from disk by the time this exception is raised. + reason: A short human-readable description of which check failed + (e.g. "is_zipfile() returned False", + "testzip() reported member 'crs.parquet'"). + """ + + def __init__(self, path: Path, *, reason: str) -> None: + self.path: Path = path + self.reason: str = reason + super().__init__( + f"Cached payload at {path} failed integrity validation " + f"({reason}). {BULK_PAYLOAD_CORRUPT_HINT}" + ) + + +class BulkDownloadHTTPError(ConnectionError): + """Raised when a bulk download HTTP request returns a non-2xx status. + + Subclasses ``ConnectionError`` for backward compatibility with callers + that catch the previous untyped exception. + + Attributes: + status_code: The HTTP status code returned by the server. + url: The URL that was requested. + body: A truncated preview of the response body (max 500 chars). + """ + + def __init__(self, *, status_code: int, url: str, body: str) -> None: + self.status_code = status_code + self.url = url + self.body = body[:_HTTP_BODY_PREVIEW] + super().__init__(f"HTTP {status_code} from {url}: {self.body}") + + +def validate_zip_or_raise(path: Path) -> None: + """Validate a zip file with is_zipfile + testzip; on failure unlink and raise. + + Any exception that ``testzip()`` itself raises (BadZipFile from a damaged + central directory, zlib.error from a corrupt compressed member) is + converted into BulkPayloadCorruptError so callers see a single boundary + exception and the corrupt file is always removed. + + Args: + path: Path to the zip file to validate. + + Raises: + BulkPayloadCorruptError: If the file fails either check. The file is + unlinked before raising so callers can simply retry. + """ + if not zipfile.is_zipfile(path): + path.unlink(missing_ok=True) + raise BulkPayloadCorruptError(path, reason="is_zipfile() returned False") + try: + with zipfile.ZipFile(path) as zf: + bad_member = zf.testzip() + except (zipfile.BadZipFile, zlib.error) as e: + path.unlink(missing_ok=True) + raise BulkPayloadCorruptError( + path, reason=f"testzip() raised {type(e).__name__}: {e}" + ) from e + if bad_member is not None: + path.unlink(missing_ok=True) + raise BulkPayloadCorruptError( + path, reason=f"testzip() reported member {bad_member!r}" + ) diff --git a/src/oda_reader/multisystem.py b/src/oda_reader/multisystem.py index cd1a730..a81444b 100644 --- a/src/oda_reader/multisystem.py +++ b/src/oda_reader/multisystem.py @@ -27,6 +27,7 @@ def bulk_download_multisystem( save_to_path: Path | str | None = None, *, as_iterator: bool = False, + use_raw_cache: bool = True, ) -> pd.DataFrame | None | typing.Iterator[pd.DataFrame]: """ Download the Multisystem data from the bulk download service. The file is very large. @@ -36,7 +37,8 @@ def bulk_download_multisystem( Args: save_to_path: The path to save the file to. Optional. If not provided, a DataFrame is returned. as_iterator: If `True` yields `DataFrame` chunks instead of a single `DataFrame`. - + use_raw_cache: If False, the raw zip is downloaded to a temporary directory and + deleted after extraction. Each call hits the network. Validation still runs. Returns: pd.DataFrame | Iterator[pd.DataFrame] | None @@ -49,6 +51,7 @@ def bulk_download_multisystem( file_id=file_id, save_to_path=save_to_path, as_iterator=as_iterator, + use_raw_cache=use_raw_cache, ) diff --git a/src/oda_reader/tools.py b/src/oda_reader/tools.py index 7492cb9..aef0932 100644 --- a/src/oda_reader/tools.py +++ b/src/oda_reader/tools.py @@ -1,6 +1,7 @@ """Additional tools for the API wrapper""" from collections import OrderedDict +from pprint import pprint def get_available_filters(source: str, quiet: bool = False) -> dict: @@ -15,9 +16,9 @@ def get_available_filters(source: str, quiet: bool = False) -> dict: Returns: dict: The available filters. """ - from pprint import pprint - - from oda_reader import QueryBuilder as qb + # Local import: oda_reader/__init__.py imports from this module, so a + # top-level import would form a cycle. + from oda_reader import QueryBuilder as qb # noqa: PLC0415 match source: case "dac1": diff --git a/tests/cache/__init__.py b/tests/cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cache/conftest.py b/tests/cache/conftest.py new file mode 100644 index 0000000..18a0d46 --- /dev/null +++ b/tests/cache/conftest.py @@ -0,0 +1,92 @@ +"""Shared fixtures for oda_reader cache tests.""" + +import io +import os +import zipfile +from pathlib import Path + +import pytest + + +@pytest.fixture() +def tmp_cache_root(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: + """Temporary cache root isolated from real cache. + + Sets ODA_READER_CACHE_DIR to a unique tmpdir so tests never touch + the user's real cache. Yields the root path; env var is restored on + teardown by monkeypatch automatically. + + Args: + monkeypatch: pytest monkeypatch fixture. + tmp_path: pytest temporary directory. + + Yields: + Path: Temporary cache root directory. + """ + cache_dir = tmp_path / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("ODA_READER_CACHE_DIR", str(cache_dir)) + yield cache_dir + + +def _make_valid_zip(target: Path) -> None: + """Write a ~1 KB valid zip containing one parquet-like entry.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_STORED) as zf: + # Minimal PAR1 magic + padding to reach ~1 KB + zf.writestr("data.parquet", b"PAR1" + b"\x00" * 1000) + target.write_bytes(buf.getvalue()) + + +@pytest.fixture() +def valid_tiny_zip(tmp_path: Path) -> Path: + """Path to a ~1 KB valid zip file with one parquet-like entry. + + Args: + tmp_path: pytest temporary directory. + + Returns: + Path: Path to the valid zip file. + """ + path = tmp_path / "valid.zip" + _make_valid_zip(path) + return path + + +@pytest.fixture() +def corrupt_zip_file(tmp_path: Path) -> Path: + """Path to a non-zip 1 KB file for corruption tests. + + Args: + tmp_path: pytest temporary directory. + + Returns: + Path: Path to the corrupt (non-zip) file. + """ + path = tmp_path / "corrupt.zip" + path.write_bytes(b"NOT A ZIP" + b"\xff" * 1000) + return path + + +@pytest.fixture() +def monkeypatched_fetcher(): + """Return a callable writing a 1 KB valid zip to any target path. + + The returned callable matches the ``fetcher(target_path: Path)`` + signature expected by ``CacheManager.ensure``. + + Returns: + Callable[[Path], None]: Fake fetcher writing a valid zip. + """ + + def _fetcher(target_path: Path) -> None: + _make_valid_zip(target_path) + + return _fetcher + + +@pytest.fixture() +def skip_if_no_network() -> None: + """Skip the current test when RUN_NETWORK_TESTS is not set to '1'.""" + if os.environ.get("RUN_NETWORK_TESTS") != "1": + pytest.skip("set RUN_NETWORK_TESTS=1 to run network tests") diff --git a/tests/cache/test_cache_manager_activated.py b/tests/cache/test_cache_manager_activated.py new file mode 100644 index 0000000..b49ac41 --- /dev/null +++ b/tests/cache/test_cache_manager_activated.py @@ -0,0 +1,216 @@ +"""Tests for the activated CacheManager.ensure() path and related behavior.""" + +import contextlib +import hashlib +import inspect +import io +import zipfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from oda_reader._cache.config import get_bulk_cache_dir +from oda_reader._cache.manager import CacheEntry, CacheManager +from oda_reader.aiddata import download_aiddata +from oda_reader.download import download_tools +from oda_reader.download.download_tools import bulk_download_aiddata +from oda_reader.exceptions import BulkPayloadCorruptError, validate_zip_or_raise + + +def _make_valid_zip(target: Path) -> None: + """Write a ~1 KB valid zip with one parquet-like entry to target.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_STORED) as zf: + zf.writestr("data.parquet", b"PAR1" + b"\x00" * 1000) + target.write_bytes(buf.getvalue()) + + +def _make_entry(key: str, fetcher) -> CacheEntry: + return CacheEntry(key=key, filename=f"{key}.zip", fetcher=fetcher) + + +def test_bulk_payload_corrupt_attributes() -> None: + """BulkPayloadCorruptError carries .path, .reason, and an actionable str().""" + path = Path("/tmp/x.zip") + exc = BulkPayloadCorruptError(path, reason="testzip failure") + assert exc.path == path + assert exc.reason == "testzip failure" + msg = str(exc) + assert str(path) in msg + assert "testzip failure" in msg + assert "use_raw_cache=False" in msg + + +def test_bulk_download_parquet_uses_cache_manager(tmp_path, monkeypatch) -> None: + """bulk_download_parquet routes through CacheManager.ensure(). + + The cached zip must live at base_dir/.zip and the manifest + must contain the entry. + """ + file_id = "TESTID123" + url = download_tools.BULK_DOWNLOAD_URL + file_id + expected_key = hashlib.sha1(url.encode()).hexdigest() + + def fake_stream(url_: str, headers: dict, path: Path) -> None: + _make_valid_zip(path) + + monkeypatch.setattr(download_tools, "_stream_to_file", fake_stream) + monkeypatch.setenv("ODA_READER_CACHE_DIR", str(tmp_path / "cache")) + + # Extraction of fake zip may fail; we only check the zip presence. + with contextlib.suppress(Exception): + download_tools.bulk_download_parquet(file_id, use_raw_cache=True) + + bulk_dir = get_bulk_cache_dir() + cached_zip = bulk_dir / f"{expected_key}.zip" + assert cached_zip.exists(), f"Expected cached zip at {cached_zip}" + + +def test_testzip_post_condition_raises_on_miss(tmp_path) -> None: + """Cache-miss branch: BulkPayloadCorruptError raised for a non-zip payload, entry removed.""" + manager = CacheManager(base_dir=tmp_path) + + def bad_fetcher(p: Path) -> None: + p.write_bytes(b"NOT A ZIP" * 50) + + entry = _make_entry("crs_full", bad_fetcher) + + with pytest.raises(BulkPayloadCorruptError) as exc_info: + manager.ensure(entry) + + exc = exc_info.value + assert exc.reason != "" + assert not (tmp_path / "crs_full.zip").exists() + + +def test_validate_zip_wraps_testzip_exceptions(tmp_path, monkeypatch) -> None: + """If testzip() itself raises (corrupt member), it is converted to + BulkPayloadCorruptError and the file is unlinked — not propagated raw.""" + target = tmp_path / "valid_outer_corrupt_inner.zip" + _make_valid_zip(target) + + def boom(self): + raise zipfile.BadZipFile("synthetic corruption inside testzip") + + monkeypatch.setattr(zipfile.ZipFile, "testzip", boom) + + with pytest.raises(BulkPayloadCorruptError) as exc_info: + validate_zip_or_raise(target) + + assert "testzip() raised BadZipFile" in exc_info.value.reason + assert not target.exists(), "corrupt zip must be removed" + + +def test_testzip_skipped_on_cache_hit(tmp_path) -> None: + """Phase-2 fix #10: testzip must NOT be called on cache-hit paths.""" + manager = CacheManager(base_dir=tmp_path) + + def fetcher(p: Path) -> None: + _make_valid_zip(p) + + entry = _make_entry("crs_full", fetcher) + manager.ensure(entry) + + with patch("zipfile.ZipFile.testzip") as mock_testzip: + manager.ensure(entry) + mock_testzip.assert_not_called() + + +def test_use_raw_cache_false_skips_cache(tmp_path, monkeypatch) -> None: + """use_raw_cache=False: no zip ends up in the cache dir after the call.""" + + def fake_stream(url_: str, headers: dict, path: Path) -> None: + _make_valid_zip(path) + + monkeypatch.setattr(download_tools, "_stream_to_file", fake_stream) + monkeypatch.setenv("ODA_READER_CACHE_DIR", str(tmp_path / "cache")) + + with contextlib.suppress(Exception): + download_tools.bulk_download_parquet("FAKEID", use_raw_cache=False) + + cache_root = tmp_path / "cache" + cache_zips = list(cache_root.rglob("*.zip")) if cache_root.exists() else [] + assert cache_zips == [], f"Found unexpected cached zips: {cache_zips}" + + +def test_use_raw_cache_false_validates_corrupt(tmp_path, monkeypatch) -> None: + """use_raw_cache=False: validation still runs and raises BulkPayloadCorruptError.""" + + def bad_stream(url_: str, headers: dict, path: Path) -> None: + path.write_bytes(b"NOT A ZIP" * 50) + + monkeypatch.setattr(download_tools, "_stream_to_file", bad_stream) + monkeypatch.setenv("ODA_READER_CACHE_DIR", str(tmp_path / "cache")) + + with pytest.raises(BulkPayloadCorruptError) as exc_info: + download_tools.bulk_download_parquet("FAKEID", use_raw_cache=False) + + assert exc_info.value.reason != "" + + +def test_use_raw_cache_false_iterator_cleans_temp_on_exhaustion( + monkeypatch, +) -> None: + """Iterator + use_raw_cache=False must delete the temp zip on completion.""" + written_paths: list[Path] = [] + + def fake_stream(url_: str, headers: dict, path: Path) -> None: + written_paths.append(path) + _make_valid_zip(path) + + monkeypatch.setattr(download_tools, "_stream_to_file", fake_stream) + + it = download_tools.bulk_download_parquet( + "FAKEID", as_iterator=True, use_raw_cache=False + ) + assert it is not None + # Iterator is lazy; temp file still exists. + assert written_paths and written_paths[0].exists() + + # Exhaust the iterator (the fake zip has no real parquet inside, so + # iteration may raise — that's still a path to cleanup). + with contextlib.suppress(Exception): + list(it) + + assert not written_paths[0].exists(), "temp zip should be deleted" + + +def test_use_raw_cache_false_extraction_error_cleans_temp(monkeypatch) -> None: + """A non-BadZipFile error during extraction must still delete the temp zip.""" + written_paths: list[Path] = [] + + def fake_stream(url_: str, headers: dict, path: Path) -> None: + written_paths.append(path) + _make_valid_zip(path) + + def boom(*args, **kwargs): + raise RuntimeError("extraction broke") + + monkeypatch.setattr(download_tools, "_stream_to_file", fake_stream) + monkeypatch.setattr( + download_tools, "_save_or_return_parquet_files_from_content", boom + ) + + with pytest.raises(RuntimeError, match="extraction broke"): + download_tools.bulk_download_parquet("FAKEID", use_raw_cache=False) + + assert written_paths and not written_paths[0].exists() + + +def test_is_txt_still_accepts_positional() -> None: + """is_txt remains a positional parameter for backward-compat with existing + callers that pass it as the third positional arg. + + We only check the signature shape; behavior is covered by the unit tests + that mock the download pipeline. + """ + sig = inspect.signature(download_tools.bulk_download_parquet) + is_txt = sig.parameters["is_txt"] + assert is_txt.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + + +def test_aiddata_unaffected() -> None: + """download_aiddata and bulk_download_aiddata must not have use_raw_cache.""" + assert "use_raw_cache" not in inspect.signature(download_aiddata).parameters + assert "use_raw_cache" not in inspect.signature(bulk_download_aiddata).parameters diff --git a/tests/cache/test_dataframe_atomic.py b/tests/cache/test_dataframe_atomic.py new file mode 100644 index 0000000..eed24ae --- /dev/null +++ b/tests/cache/test_dataframe_atomic.py @@ -0,0 +1,81 @@ +"""Tests for atomic DataFrameCache.set writes.""" + +from pathlib import Path +from unittest.mock import patch + +import pandas as pd + +import oda_reader._cache.dataframe as df_module +from oda_reader._cache.dataframe import DataFrameCache + + +def test_set_is_atomic(tmp_path: Path, caplog) -> None: + """A failed write must not leave a partial file at the destination, must + not leak a tmp sibling, and must not propagate as an exception (cache + writes are best-effort).""" + cache = DataFrameCache(cache_dir=tmp_path) + df = pd.DataFrame({"a": [1, 2, 3]}) + + with patch.object(pd.DataFrame, "to_parquet", side_effect=OSError("disk error")): + cache.set( + df, + dataflow_id="DSD_CRS@DF_CRS", + dataflow_version="1.0", + url="http://example.com", + pre_process=True, + dotstat_codes=True, + ) + + assert list(tmp_path.glob("*.parquet")) == [] + assert list(tmp_path.glob("*.tmp-*")) == [] + assert any("Failed to cache DataFrame" in r.message for r in caplog.records) + + +def test_set_no_tmp_file_left_on_success(tmp_path: Path) -> None: + """DataFrameCache.set leaves no *.tmp-* sibling after a successful write.""" + cache = DataFrameCache(cache_dir=tmp_path) + df = pd.DataFrame({"a": [1, 2, 3]}) + + cache.set( + df, + dataflow_id="DSD_CRS@DF_CRS", + dataflow_version="1.0", + url="http://example.com", + pre_process=True, + dotstat_codes=True, + ) + + tmp_files = list(tmp_path.glob("*.tmp-*")) + assert tmp_files == [], f"Stale tmp files found: {tmp_files}" + + parquet_files = list(tmp_path.glob("*.parquet")) + assert len(parquet_files) == 1 + + +def test_set_uses_host_pid_suffix(tmp_path: Path, monkeypatch) -> None: + """DataFrameCache.set temp file uses host+pid suffix.""" + recorded_tmp: list[Path] = [] + original_to_parquet = pd.DataFrame.to_parquet + + def capturing_to_parquet(self, path, *args, **kwargs): + recorded_tmp.append(Path(path)) + return original_to_parquet(self, path, *args, **kwargs) + + monkeypatch.setattr(pd.DataFrame, "to_parquet", capturing_to_parquet) + monkeypatch.setattr(df_module, "_HOSTNAME", "testhost") + + cache = DataFrameCache(cache_dir=tmp_path) + df = pd.DataFrame({"x": [1]}) + + cache.set( + df, + dataflow_id="DSD", + dataflow_version="1", + url="http://x.com", + pre_process=False, + dotstat_codes=False, + ) + + assert len(recorded_tmp) == 1 + tmp_name = recorded_tmp[0].name + assert "tmp-testhost-" in tmp_name diff --git a/tests/cache/test_deprecation_shims.py b/tests/cache/test_deprecation_shims.py new file mode 100644 index 0000000..016e5a3 --- /dev/null +++ b/tests/cache/test_deprecation_shims.py @@ -0,0 +1,93 @@ +"""Tests for oda_reader top-level deprecation shims (Phase-2 fix #25).""" + +import inspect +import sys +import types +import warnings + +import pytest + +# The oda_reader deprecation shims are gated on ``"oda_data" in sys.modules`` +# so they stay silent for standalone oda_reader users. In oda-importer's test +# suite oda_data is not installed, so we insert a synthetic placeholder to +# satisfy the gate without requiring the real package. +if "oda_data" not in sys.modules: + sys.modules["oda_data"] = types.ModuleType("oda_data") + +import oda_reader +from oda_reader._cache.config import get_cache_dir +from oda_reader.aiddata import download_aiddata +from oda_reader.download.download_tools import bulk_download_aiddata + + +@pytest.fixture(autouse=True) +def reset_shim_flags(): + """Ensure each test starts with fresh warned flags.""" + oda_reader._WARNED_SHIMS.clear() + yield + oda_reader._WARNED_SHIMS.clear() + + +def test_clear_cache_emits_deprecation_warning_once() -> None: + """oda_reader.clear_cache emits exactly one DeprecationWarning across two calls.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + oda_reader.clear_cache() + oda_reader.clear_cache() + + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert "oda_data.cache.clear" in str(deprecations[0].message) + + +def test_set_cache_dir_emits_deprecation_warning_once(tmp_path) -> None: + """oda_reader.set_cache_dir emits exactly one DeprecationWarning across two calls.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + oda_reader.set_cache_dir(str(tmp_path)) + oda_reader.set_cache_dir(str(tmp_path)) + + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert "oda_data.set_cache_root" in str(deprecations[0].message) + + +def test_enable_cache_emits_deprecation_warning_once() -> None: + """oda_reader.enable_cache emits exactly one DeprecationWarning across two calls.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + oda_reader.enable_cache() + oda_reader.enable_cache() + + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert "oda_data.cache.enable_cache" in str(deprecations[0].message) + + +def test_disable_cache_emits_deprecation_warning_once() -> None: + """oda_reader.disable_cache emits exactly one DeprecationWarning across two calls.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + oda_reader.disable_cache() + oda_reader.disable_cache() + + deprecations = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecations) == 1 + assert "oda_data.cache.disable_cache" in str(deprecations[0].message) + + +def test_aiddata_unaffected() -> None: + """download_aiddata and bulk_download_aiddata must not have use_raw_cache (Phase-2 fix #18).""" + assert "use_raw_cache" not in inspect.signature(download_aiddata).parameters + assert "use_raw_cache" not in inspect.signature(bulk_download_aiddata).parameters + + +def test_shims_forward_correctly(tmp_path) -> None: + """After the warning, the shims still forward to the underlying implementation.""" + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + oda_reader.set_cache_dir(str(tmp_path)) + + # set_cache_dir resolves symlinks (e.g. macOS /tmp -> /private/tmp), so + # compare against the resolved form. + assert get_cache_dir() == tmp_path.resolve() diff --git a/tests/cache/test_freezegun_ttl.py b/tests/cache/test_freezegun_ttl.py new file mode 100644 index 0000000..5c3c5ce --- /dev/null +++ b/tests/cache/test_freezegun_ttl.py @@ -0,0 +1,69 @@ +"""Tests for CacheManager TTL behavior using frozen clocks.""" + +import io +import zipfile +from pathlib import Path + +import pytest + +freeze_time = pytest.importorskip("freezegun").freeze_time + +from oda_reader._cache.manager import CacheEntry, CacheManager # noqa: E402 + + +def _make_valid_zip(target: Path) -> None: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_STORED) as zf: + zf.writestr("data.parquet", b"PAR1" + b"\x00" * 100) + target.write_bytes(buf.getvalue()) + + +def _make_entry(key: str, fetcher, ttl_days: int = 30) -> CacheEntry: + return CacheEntry( + key=key, filename=f"{key}.zip", fetcher=fetcher, ttl_days=ttl_days + ) + + +@pytest.fixture() +def counting_fetcher(): + calls = [] + + def fetcher(p: Path) -> None: + calls.append(p) + _make_valid_zip(p) + + return fetcher, calls + + +def test_ensure_serves_fresh_within_ttl(tmp_path: Path, counting_fetcher) -> None: + """Within the TTL window, ensure returns the cached path without re-fetching.""" + fetcher, calls = counting_fetcher + entry = _make_entry("crs", fetcher, ttl_days=30) + + with freeze_time("2024-01-01"): + manager = CacheManager(base_dir=tmp_path) + manager.ensure(entry) + assert len(calls) == 1 + + # Advance to day 20 (within TTL of 30 days) + with freeze_time("2024-01-21"): + manager2 = CacheManager(base_dir=tmp_path) + manager2.ensure(entry) + assert len(calls) == 1, "Should NOT re-fetch within TTL" + + +def test_ensure_refetches_after_ttl(tmp_path: Path, counting_fetcher) -> None: + """After the TTL window, ensure re-fetches the data.""" + fetcher, calls = counting_fetcher + entry = _make_entry("crs", fetcher, ttl_days=30) + + with freeze_time("2024-01-01"): + manager = CacheManager(base_dir=tmp_path) + manager.ensure(entry) + assert len(calls) == 1 + + # Advance past TTL (31 days later) + with freeze_time("2024-02-01"): + manager2 = CacheManager(base_dir=tmp_path) + manager2.ensure(entry) + assert len(calls) == 2, "Should re-fetch after TTL expires" diff --git a/tests/cache/test_lru_eviction.py b/tests/cache/test_lru_eviction.py new file mode 100644 index 0000000..55accf3 --- /dev/null +++ b/tests/cache/test_lru_eviction.py @@ -0,0 +1,112 @@ +"""Tests for CacheManager LRU eviction on __init__.""" + +import io +import json +import zipfile +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import patch + +from oda_reader._cache.manager import CacheManager + +ISO_FORMAT = "%Y-%m-%dT%H:%M:%S%z" + + +def _make_valid_zip(target: Path) -> None: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_STORED) as zf: + zf.writestr("data.parquet", b"PAR1" + b"\x00" * 100) + target.write_bytes(buf.getvalue()) + + +def _write_manifest(base_dir: Path, entries: dict) -> None: + """Write a manifest.json with the given entries dict.""" + manifest_path = base_dir / "manifest.json" + with manifest_path.open("w") as f: + json.dump(entries, f, indent=2) + + +def _make_manifest_entry(filename: str, downloaded_at: str) -> dict: + return { + "filename": filename, + "downloaded_at": downloaded_at, + "ttl_days": 30, + "version": None, + } + + +def _ts(year: int, month: int, day: int) -> str: + return datetime(year, month, day, tzinfo=timezone.utc).strftime(ISO_FORMAT) + + +def test_keeps_n_most_recent(tmp_path: Path) -> None: + """Pre-populate manifest with 4 entries; CacheManager(keep_n=2) evicts 2 oldest.""" + base = tmp_path / "cache" + base.mkdir() + + entries = { + "key_a": _make_manifest_entry("key_a.zip", _ts(2024, 1, 1)), + "key_b": _make_manifest_entry("key_b.zip", _ts(2024, 2, 1)), + "key_c": _make_manifest_entry("key_c.zip", _ts(2024, 3, 1)), + "key_d": _make_manifest_entry("key_d.zip", _ts(2024, 4, 1)), + } + _write_manifest(base, entries) + + # Write the actual zip files so unlink can succeed. + for key in entries: + _make_valid_zip(base / entries[key]["filename"]) + + CacheManager(base_dir=base, keep_n=2) + + # Should keep the 2 most recent (key_c, key_d); evict key_a and key_b. + remaining = list((base).glob("*.zip")) + remaining_names = {p.name for p in remaining} + + assert "key_c.zip" in remaining_names + assert "key_d.zip" in remaining_names + assert "key_a.zip" not in remaining_names + assert "key_b.zip" not in remaining_names + + +def test_no_eviction_when_within_limit(tmp_path: Path) -> None: + """With 2 entries and keep_n=2, no eviction occurs.""" + base = tmp_path / "cache" + base.mkdir() + + entries = { + "key_a": _make_manifest_entry("key_a.zip", _ts(2024, 1, 1)), + "key_b": _make_manifest_entry("key_b.zip", _ts(2024, 2, 1)), + } + _write_manifest(base, entries) + for key in entries: + _make_valid_zip(base / entries[key]["filename"]) + + CacheManager(base_dir=base, keep_n=2) + + remaining = {p.name for p in base.glob("*.zip")} + assert remaining == {"key_a.zip", "key_b.zip"} + + +def test_unlink_failure_logs_but_does_not_raise(tmp_path: Path) -> None: + """LRU eviction: unlink failure logs a warning but does not propagate.""" + base = tmp_path / "cache" + base.mkdir() + + entries = { + "key_a": _make_manifest_entry("key_a.zip", _ts(2024, 1, 1)), + "key_b": _make_manifest_entry("key_b.zip", _ts(2024, 2, 1)), + "key_c": _make_manifest_entry("key_c.zip", _ts(2024, 3, 1)), + } + _write_manifest(base, entries) + # Do not write the actual files — unlink(missing_ok=True) should be silent. + # We patch Path.unlink to raise OSError. + original_unlink = Path.unlink + + def raising_unlink(self, missing_ok=False): + if self.name.endswith(".zip"): + raise OSError("simulated failure") + return original_unlink(self, missing_ok=missing_ok) + + with patch.object(Path, "unlink", raising_unlink): + # Should not raise despite unlink errors. + CacheManager(base_dir=base, keep_n=2) diff --git a/tests/common/unit/test_cache.py b/tests/common/unit/test_cache.py index 3c18d55..eb4cc12 100644 --- a/tests/common/unit/test_cache.py +++ b/tests/common/unit/test_cache.py @@ -2,8 +2,8 @@ import pytest -import oda_reader._http_primitives as _http_primitives from oda_reader import ( + _http_primitives, clear_http_cache, disable_http_cache, enable_http_cache, diff --git a/tests/conftest.py b/tests/conftest.py index 7863ac5..6758bbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,15 +4,13 @@ import pytest -from oda_reader import disable_http_cache, enable_http_cache +from oda_reader import _http_primitives, disable_http_cache, enable_http_cache from oda_reader.common import RateLimiter @pytest.fixture(autouse=True) def disable_cache_for_tests(): """Disable HTTP cache for all tests by default.""" - import oda_reader._http_primitives as _http_primitives - disable_http_cache() yield # Reset session before re-enabling to avoid SQLite contention @@ -32,8 +30,6 @@ def temp_cache_dir(tmp_path, monkeypatch): Yields: Path: Path to the temporary cache directory """ - import oda_reader._http_primitives as _http_primitives - cache_dir = tmp_path / "test_cache" cache_dir.mkdir() monkeypatch.setenv("ODA_READER_CACHE_DIR", str(cache_dir)) diff --git a/tests/datasets/crs/integration/test_crs_e2e.py b/tests/datasets/crs/integration/test_crs_e2e.py index 4199666..6723485 100644 --- a/tests/datasets/crs/integration/test_crs_e2e.py +++ b/tests/datasets/crs/integration/test_crs_e2e.py @@ -14,12 +14,14 @@ def test_crs_microdata_query(self): """Test CRS microdata query returns project-level data.""" enable_http_cache() - # Small query: US education projects (microdata) - # Using pre_process=False and dotstat_codes=False to test raw API + # Small query: US Primary Education projects (microdata). + # CRS microdata (MD_DIM=DD) is keyed on *leaf* DAC purpose codes + # (11220 = Primary Education), not the parent category 110. + # Using pre_process=False and dotstat_codes=False to test raw API. df = download_crs( start_year=2023, end_year=2023, - filters={"donor": "USA", "sector": "110"}, # US, Education sector + filters={"donor": "USA", "sector": "11220"}, pre_process=False, dotstat_codes=False, ) diff --git a/tests/datasets/dac2a/unit/test_dac2a_bulk.py b/tests/datasets/dac2a/unit/test_dac2a_bulk.py index 0f1e979..d4b8954 100644 --- a/tests/datasets/dac2a/unit/test_dac2a_bulk.py +++ b/tests/datasets/dac2a/unit/test_dac2a_bulk.py @@ -1,5 +1,6 @@ """Unit tests for DAC2a bulk download functionality.""" +import pandas as pd import pytest from oda_reader.dac2a import bulk_download_dac2a, get_full_dac2a_parquet_id @@ -28,8 +29,6 @@ def test_get_full_dac2a_parquet_id_calls_correct_function(self, mocker): def test_bulk_download_dac2a_returns_dataframe(self, mocker): """Test that bulk_download_dac2a returns DataFrame when no save path.""" - import pandas as pd - mock_df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) mocker.patch( @@ -48,6 +47,7 @@ def test_bulk_download_dac2a_returns_dataframe(self, mocker): file_id="mock-file-id", save_to_path=None, as_iterator=False, + use_raw_cache=True, ) def test_bulk_download_dac2a_saves_to_path(self, mocker, tmp_path): @@ -68,11 +68,11 @@ def test_bulk_download_dac2a_saves_to_path(self, mocker, tmp_path): file_id="mock-file-id", save_to_path=tmp_path, as_iterator=False, + use_raw_cache=True, ) def test_bulk_download_dac2a_as_iterator(self, mocker): """Test that bulk_download_dac2a passes as_iterator flag correctly.""" - import pandas as pd def mock_iterator(): yield pd.DataFrame({"col1": [1]}) @@ -95,4 +95,5 @@ def mock_iterator(): file_id="mock-file-id", save_to_path=None, as_iterator=True, + use_raw_cache=True, ) diff --git a/tests/download/unit/test_deflate64.py b/tests/download/unit/test_deflate64.py index 1450bae..63a1c3f 100644 --- a/tests/download/unit/test_deflate64.py +++ b/tests/download/unit/test_deflate64.py @@ -15,10 +15,12 @@ import pytest import oda_reader.download._deflate64 # noqa: F401 — ensure patch is active -from oda_reader.download.download_tools import _save_or_return_parquet_files_from_content +from oda_reader.download.download_tools import ( + _save_or_return_parquet_files_from_content, +) -def _create_deflate64_zip(files: dict[str, bytes]) -> bytes: +def _create_deflate64_zip(files: dict[str, bytes]) -> bytes: # noqa: PLR0915 # binary ZIP layout is inherently statement-heavy """Create a ZIP archive using Deflate64 (type 9) compression. Manually constructs the ZIP binary format since Python's ``zipfile`` diff --git a/tests/download/unit/test_download_tools.py b/tests/download/unit/test_download_tools.py index bc3e5ed..9af91ca 100644 --- a/tests/download/unit/test_download_tools.py +++ b/tests/download/unit/test_download_tools.py @@ -53,7 +53,9 @@ def test_get_data_from_api_404_triggers_version_discovery(self, mocker): return_value=7, ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) result = get_data_from_api(url) assert mock_get_response.call_count == 2 @@ -70,7 +72,9 @@ def test_get_data_from_api_discovered_version_matches_raises(self, mocker): return_value="2.0", # same as URL version ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) with pytest.raises(ConnectionError, match="matches the attempted version"): get_data_from_api(url) @@ -92,7 +96,9 @@ def test_get_data_from_api_incompatible_dsd_raises(self, mocker): side_effect=[7, 8], # old has 7, new has 8 — breaking change ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) with pytest.raises(ConnectionError, match="breaking schema change"): get_data_from_api(url) @@ -114,7 +120,9 @@ def test_get_data_from_api_compatible_upgrade_succeeds(self, mocker): return_value=7, # same count — compatible ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) result = get_data_from_api(url) assert result == "DONOR,VALUE\n1,100" @@ -136,7 +144,9 @@ def test_get_data_from_api_dsd_check_fails_gracefully(self, mocker): side_effect=ConnectionError("DSD endpoint down"), ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) result = get_data_from_api(url) assert result == "DONOR,VALUE\n1,100" @@ -155,7 +165,9 @@ def test_get_data_from_api_retry_also_fails_raises(self, mocker): return_value=7, ) - url = "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + url = ( + "https://sdmx.oecd.org/public/rest/data/OECD.DCD.FSD,DSD_DAC1@DF_DAC1,2.0/" + ) with pytest.raises(ConnectionError, match="even after version discovery"): get_data_from_api(url) @@ -397,49 +409,40 @@ def test_txt_iterator_raises(self): class TestDeprecationWarnings: """Test deprecation warnings for backward compatibility.""" - def test_is_txt_parameter_emits_deprecation_warning(self, mocker): - """Test that using is_txt parameter emits a deprecation warning.""" - # Mock the internal functions to avoid actual downloads + @staticmethod + def _mock_download_pipeline(mocker): + """Stub the cache manager + content extractor to avoid real downloads.""" + fake_manager = mocker.Mock() + fake_manager.ensure.return_value = "/fake/path" mocker.patch( - "oda_reader.download.download_tools._get_temp_file", - return_value=("/fake/path", False), + "oda_reader.download.download_tools.bulk_cache_manager", + return_value=fake_manager, ) mocker.patch( "oda_reader.download.download_tools._save_or_return_parquet_files_from_content", return_value=[pd.DataFrame({"col": [1, 2]})], ) + def test_is_txt_parameter_emits_deprecation_warning(self, mocker): + """Test that using is_txt parameter emits a deprecation warning.""" + self._mock_download_pipeline(mocker) + with pytest.warns(DeprecationWarning, match="is_txt.*deprecated"): bulk_download_parquet("fake-id", is_txt=True) def test_is_txt_false_also_emits_warning(self, mocker): """Test that is_txt=False also emits deprecation warning.""" - mocker.patch( - "oda_reader.download.download_tools._get_temp_file", - return_value=("/fake/path", False), - ) - mocker.patch( - "oda_reader.download.download_tools._save_or_return_parquet_files_from_content", - return_value=[pd.DataFrame({"col": [1, 2]})], - ) + self._mock_download_pipeline(mocker) with pytest.warns(DeprecationWarning, match="is_txt.*deprecated"): bulk_download_parquet("fake-id", is_txt=False) def test_no_warning_when_is_txt_not_provided(self, mocker): """Test that no warning is emitted when is_txt is not provided.""" - mocker.patch( - "oda_reader.download.download_tools._get_temp_file", - return_value=("/fake/path", False), - ) - mocker.patch( - "oda_reader.download.download_tools._save_or_return_parquet_files_from_content", - return_value=[pd.DataFrame({"col": [1, 2]})], - ) + self._mock_download_pipeline(mocker) with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) - # Should not raise any DeprecationWarning bulk_download_parquet("fake-id") @@ -473,7 +476,9 @@ def test_recognized_urls(self, url, expected): assert _extract_dataflow_id_from_flow_url(url) == expected def test_unrecognized_url_returns_none(self): - assert _extract_dataflow_id_from_flow_url("https://example.com/other/path") is None + assert ( + _extract_dataflow_id_from_flow_url("https://example.com/other/path") is None + ) FLOW_URL = "https://sdmx.oecd.org/public/rest/dataflow/OECD.DCD.FSD/DSD_CRS@DF_CRS/" @@ -519,7 +524,11 @@ def test_explicit_version_fails_then_discovery_rescues(self, mocker): "oda_reader.download.download_tools._get_response_text", side_effect=[ (404, "Not found", False), # explicit version fails - (200, f"{SEARCH_STRING}rescued", False), # discovered version works + ( + 200, + f"{SEARCH_STRING}rescued", + False, + ), # discovered version works ], ) mocker.patch( diff --git a/tests/download/unit/test_version_discovery.py b/tests/download/unit/test_version_discovery.py index 021c273..86e57b7 100644 --- a/tests/download/unit/test_version_discovery.py +++ b/tests/download/unit/test_version_discovery.py @@ -1,5 +1,7 @@ """Unit tests for the version_discovery module.""" +from xml.etree.ElementTree import ParseError + import pytest from oda_reader.download.version_discovery import ( @@ -115,7 +117,7 @@ def test_missing_version_raises(self): _parse_version_from_xml(_MISSING_VERSION_XML) def test_malformed_xml_raises(self): - with pytest.raises(Exception): + with pytest.raises(ParseError): _parse_version_from_xml(">>") @@ -332,5 +334,5 @@ def test_bad_xml_in_200_raises_valueerror(self, _mock_http): """200 with unparseable XML should raise ValueError.""" _mock_http.return_value = (200, " Any: diff --git a/uv.lock b/uv.lock index 0fb0f6c..9fea140 100644 --- a/uv.lock +++ b/uv.lock @@ -340,6 +340,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, ] +[[package]] +name = "freezegun" +version = "1.5.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/dd/23e2f4e357f8fd3bdff613c1fe4466d21bfb00a6177f238079b17f7b1c84/freezegun-1.5.5.tar.gz", hash = "sha256:ac7742a6cc6c25a2c35e9292dfd554b897b517d2dec26891a2e8debf205cb94a", size = 35914, upload-time = "2025-08-09T10:39:08.338Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/2e/b41d8a1a917d6581fc27a35d05561037b048e47df50f27f8ac9c7e27a710/freezegun-1.5.5-py3-none-any.whl", hash = "sha256:cd557f4a75cf074e84bc374249b9dd491eaeacd61376b9eb3c423282211619d2", size = 19266, upload-time = "2025-08-09T10:39:06.636Z" }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -864,7 +876,7 @@ wheels = [ [[package]] name = "oda-reader" -version = "1.5.1" +version = "1.6.0" source = { editable = "." } dependencies = [ { name = "filelock" }, @@ -891,6 +903,7 @@ docs = [ { name = "mkdocstrings", extra = ["python"] }, ] test = [ + { name = "freezegun" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mock" }, @@ -923,6 +936,7 @@ docs = [ { name = "mkdocstrings", extras = ["python"], specifier = ">=0.24.0" }, ] test = [ + { name = "freezegun", specifier = ">=1.4.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-cov", specifier = ">=4.1" }, { name = "pytest-mock", specifier = ">=3.12" }, From 8b19cadf40cb7b9ff97286037aa4168c7be52994 Mon Sep 17 00:00:00 2001 From: Jorge Rivera Date: Tue, 28 Apr 2026 20:38:30 -0600 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b564989..bf56b27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,6 @@ `DeprecationWarning` for users who also import `oda_data`, pointing at the umbrella `oda_data.cache.*` API. Standalone `oda_reader` users see no warning. The shims continue to work through the `1.x` series and will be removed in `2.0`. - -## 1.5.2 (2026-04-28) - Cache directory is now versioned by the installed package version (via `importlib.metadata`) rather than a hardcoded string, so upgrades automatically invalidate old caches that may contain partial or corrupt downloads from prior versions. - Bulk-download cache writes are now atomic: downloads stream into a sibling temp file and are only renamed over the destination on success, so partial downloads no longer pollute the cache on interruption or error. - On `BadZipFile`, the corrupt cached archive is removed so the next call cleanly re-downloads instead of looping on the same poisoned entry.