From 175f4877f501206eabaf22871a8278459d6867d5 Mon Sep 17 00:00:00 2001 From: Mo Date: Mon, 9 Mar 2026 19:14:21 +0100 Subject: [PATCH 1/2] fix issue regarding the passing of storage options to the aiohttp --- CHANGELOG.md | 3 +++ src/xarray_prism/_detection.py | 7 ++++++- src/xarray_prism/_version.py | 2 +- src/xarray_prism/backends/cloud.py | 6 +++++- src/xarray_prism/entrypoint.py | 2 +- src/xarray_prism/utils.py | 12 ++++++++++++ 6 files changed, 28 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70cad17..416abd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Changelog All notable changes to this project will be documented in this file. +## [v2603.0.0] +## Fixed +- an issue regading passing the storage_options to aiohttp ## [v2602.1.0] ### Added diff --git a/src/xarray_prism/_detection.py b/src/xarray_prism/_detection.py index 3672639..852a1cb 100644 --- a/src/xarray_prism/_detection.py +++ b/src/xarray_prism/_detection.py @@ -5,6 +5,8 @@ from functools import lru_cache from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from .utils import _strip_chaining_options + Detector = Callable[[str], Optional[str]] _custom_detectors: List[Tuple[int, Detector]] = [] @@ -259,7 +261,10 @@ def _detect_engine_impl(uri: str, storage_options: Optional[Dict]) -> str: # 3. Filesystem-based detection — lazy import fsspec only when needed import fsspec # noqa: PLC0415 - fs, path = fsspec.core.url_to_fs(uri, **(storage_options or {})) + fs, path = fsspec.core.url_to_fs( + uri, **_strip_chaining_options(storage_options or {}) + ) + lower_path = path.lower() # Check for Zarr directory diff --git a/src/xarray_prism/_version.py b/src/xarray_prism/_version.py index 7fea308..d4d9e4a 100644 --- a/src/xarray_prism/_version.py +++ b/src/xarray_prism/_version.py @@ -1,3 +1,3 @@ """Version information for xarray-prism.""" -__version__ = "2602.1.0" +__version__ = "2603.0.0" diff --git a/src/xarray_prism/backends/cloud.py b/src/xarray_prism/backends/cloud.py index 86db2b7..a1e19bb 100644 --- a/src/xarray_prism/backends/cloud.py +++ b/src/xarray_prism/backends/cloud.py @@ -11,6 +11,8 @@ from typing import Any, Dict, Optional from urllib.parse import urlparse +from ..utils import _strip_chaining_options + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -68,7 +70,9 @@ def _cache_remote_file( logger.warning(f"Remote {fmt} requires full file download") extra_lines = 2 - fs, path = fsspec.core.url_to_fs(uri, **(storage_options or {})) + fs, path = fsspec.core.url_to_fs( + uri, **_strip_chaining_options(storage_options or {}) + ) if show_progress: size = 0 diff --git a/src/xarray_prism/entrypoint.py b/src/xarray_prism/entrypoint.py index 03bf27e..e9ba16a 100644 --- a/src/xarray_prism/entrypoint.py +++ b/src/xarray_prism/entrypoint.py @@ -75,7 +75,7 @@ class PrismBackendEntrypoint(BackendEntrypoint): ) url = "https://github.com/freva-org/xarray-prism" - open_dataset_parameters = ("filename_or_obj", "drop_variables") + open_dataset_parameters = ("filename_or_obj", "drop_variables", "storage_options") ENGINE_MAP: Dict[str, str] = { "zarr": "zarr", diff --git a/src/xarray_prism/utils.py b/src/xarray_prism/utils.py index cb387e3..fcb6dc5 100644 --- a/src/xarray_prism/utils.py +++ b/src/xarray_prism/utils.py @@ -241,3 +241,15 @@ def sanitize_dataset_attrs(ds): var.attrs = _clean_attr_obj(dict(var.attrs)) return ds + + +def _strip_chaining_options(storage_options: dict) -> dict: + """Strip fsspec chaining/wrapper protocol keys from storage_options. + These are keys like 'simplecache', 'blockcache', 'filecache' that are + fsspec protocol names used for URL chaining, not valid HTTP/remote FS kwargs. + """ + if not storage_options: + return {} + from fsspec.registry import known_implementations + + return {k: v for k, v in storage_options.items() if k not in known_implementations} From 4ab731135b87f4957adb41da36ee0e619e142d79 Mon Sep 17 00:00:00 2001 From: Mo Date: Mon, 9 Mar 2026 22:42:22 +0100 Subject: [PATCH 2/2] handle the zip data before reading --- src/xarray_prism/_detection.py | 15 ++- src/xarray_prism/backends/cloud.py | 6 +- src/xarray_prism/entrypoint.py | 3 +- src/xarray_prism/utils.py | 42 +++++++++ tests/test_backends.py | 142 +++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 7 deletions(-) diff --git a/src/xarray_prism/_detection.py b/src/xarray_prism/_detection.py index 852a1cb..249ccf6 100644 --- a/src/xarray_prism/_detection.py +++ b/src/xarray_prism/_detection.py @@ -5,7 +5,7 @@ from functools import lru_cache from typing import Any, Callable, Dict, List, Literal, Optional, Tuple -from .utils import _strip_chaining_options +from .utils import _strip_chaining_options, _strip_compression_suffix Detector = Callable[[str], Optional[str]] @@ -147,6 +147,9 @@ def looks_like_opendap_url(uri: str) -> bool: def _detect_from_uri_pattern(lower_uri: str) -> Optional[str]: """Detect engine from URI patterns without I/O.""" + + lower_uri = _strip_compression_suffix(lower_uri) + # Reference URIs -> zarr (Kerchunk) if is_reference_uri(lower_uri): return "zarr" @@ -155,6 +158,10 @@ def _detect_from_uri_pattern(lower_uri: str) -> Optional[str]: if lower_uri.endswith(".zarr") or ".zarr/" in lower_uri: return "zarr" + # GRIB detection by extension + if lower_uri.endswith((".grib", ".grb", ".grb2", ".grib2")): + return "cfgrib" + # THREDDS NCSS with explicit accept format (overrides file extension) if "/ncss/" in lower_uri or "/ncss?" in lower_uri: if "accept=netcdf3" in lower_uri: @@ -204,8 +211,10 @@ def _read_magic_bytes(fs: Any, path: str) -> Any: def _detect_from_magic_bytes(header: bytes, lower_path: str) -> Engine: """Detect engine from magic bytes and file extension.""" + bare_path = _strip_compression_suffix(lower_path) + # GRIB detection - if b"GRIB" in header or lower_path.endswith((".grib", ".grb", ".grb2", ".grib2")): + if b"GRIB" in header or bare_path.endswith((".grib", ".grb", ".grb2", ".grib2")): return "cfgrib" # NetCDF3 (Classic) @@ -219,7 +228,7 @@ def _detect_from_magic_bytes(header: bytes, lower_path: str) -> Engine: # GeoTIFF if header.startswith((b"II*\x00", b"MM\x00*")): return "rasterio" - if lower_path.endswith((".tif", ".tiff")): + if bare_path.endswith((".tif", ".tiff")): return "rasterio" return "unknown" diff --git a/src/xarray_prism/backends/cloud.py b/src/xarray_prism/backends/cloud.py index a1e19bb..01703d5 100644 --- a/src/xarray_prism/backends/cloud.py +++ b/src/xarray_prism/backends/cloud.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional from urllib.parse import urlparse -from ..utils import _strip_chaining_options +from ..utils import _decompress_if_needed, _strip_chaining_options logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -62,7 +62,7 @@ def _cache_remote_file( sys.stdout.write("\033[A") sys.stdout.write("\033[K") sys.stdout.flush() - return str(local_path) + return _decompress_if_needed(str(local_path)) extra_lines = 0 if show_progress: @@ -98,7 +98,7 @@ def _cache_remote_file( else: fs.get(path, str(local_path)) - return str(local_path) + return _decompress_if_needed(str(local_path)) def open_cloud( diff --git a/src/xarray_prism/entrypoint.py b/src/xarray_prism/entrypoint.py index e9ba16a..336e7f7 100644 --- a/src/xarray_prism/entrypoint.py +++ b/src/xarray_prism/entrypoint.py @@ -12,6 +12,7 @@ from xarray.backends import BackendEntrypoint from ._detection import ( + _strip_compression_suffix, detect_engine, detect_uri_type, is_http_url, @@ -242,7 +243,7 @@ def guess_can_open(self, filename_or_obj: Any) -> bool: if not isinstance(filename_or_obj, (str, os.PathLike)): return False - u = str(filename_or_obj).lower() + u = _strip_compression_suffix(str(filename_or_obj).lower()) # Zarr if u.endswith(".zarr") or ".zarr/" in u: diff --git a/src/xarray_prism/utils.py b/src/xarray_prism/utils.py index fcb6dc5..69db7f8 100644 --- a/src/xarray_prism/utils.py +++ b/src/xarray_prism/utils.py @@ -3,11 +3,14 @@ import logging import os import sys +import tempfile from contextlib import contextmanager from typing import Any, Dict, Iterator, Optional logger = logging.getLogger(__name__) +_COMPRESSION_SUFFIXES = (".bz2", ".gz", ".xz", ".zst", ".lz4") + STORAGE_OPTIONS_TO_GDAL: Dict[str, str] = { "key": "AWS_ACCESS_KEY_ID", "secret": "AWS_SECRET_ACCESS_KEY", @@ -253,3 +256,42 @@ def _strip_chaining_options(storage_options: dict) -> dict: from fsspec.registry import known_implementations return {k: v for k, v in storage_options.items() if k not in known_implementations} + + +def _strip_compression_suffix(uri: str) -> str: + """Remove common compression suffixes from URI for + more accurate pattern detection.""" + for suffix in _COMPRESSION_SUFFIXES: + if uri.endswith(suffix): + return uri[: -len(suffix)] + return uri + + +def _decompress_if_needed(path: str, output_dir: Optional[str] = None) -> str: + """Decompress a file if it has a known compression suffix. + Returns the path to the decompressed file (or original if not compressed). + Uses output_dir if provided, otherwise decompresses alongside the source file. + """ + import bz2 + import gzip + import lzma + + # Dict[str, Any] avoids mypy errors from overloaded open() signatures + _DECOMPRESSORS: Dict[str, Any] = { + ".bz2": bz2.open, + ".gz": gzip.open, + ".xz": lzma.open, + } + + for suffix, opener in _DECOMPRESSORS.items(): + if path.endswith(suffix): + bare_name = os.path.basename(path)[: -len(suffix)] + out_dir = output_dir or os.path.dirname(path) or tempfile.gettempdir() + decompressed = os.path.join(out_dir, bare_name) + if not os.path.exists(decompressed): + with opener(path, "rb") as src, open(decompressed, "wb") as dst: + while chunk := src.read(512 * 1024): + dst.write(chunk) + return decompressed + + return path diff --git a/tests/test_backends.py b/tests/test_backends.py index 9aa1beb..2e64bdb 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -4,11 +4,21 @@ import os from pathlib import Path +import bz2 +import gzip +import tempfile import pytest import xarray as xr from xarray_prism.backends import open_cloud, open_posix +from xarray_prism import PrismBackendEntrypoint +from xarray_prism._detection import _detect_from_magic_bytes, _detect_from_uri_pattern +from xarray_prism.utils import ( + _decompress_if_needed, + _strip_chaining_options, + _strip_compression_suffix, +) class TestPosixBackend: @@ -201,3 +211,135 @@ def test_cache_dir_default(self): cache_dir = _get_cache_dir() assert cache_dir.parent == Path(tempfile.gettempdir()) assert "xarray-prism-cache" in str(cache_dir) + + +class TestStripCompressionSuffix: + def test_strips_bz2(self): + assert _strip_compression_suffix("file.grib2.bz2") == "file.grib2" + + def test_strips_gz(self): + assert _strip_compression_suffix("file.nc.gz") == "file.nc" + + def test_strips_xz(self): + assert _strip_compression_suffix("file.nc.xz") == "file.nc" + + def test_no_suffix_unchanged(self): + assert _strip_compression_suffix("file.grib2") == "file.grib2" + + def test_only_strips_last_suffix(self): + assert _strip_compression_suffix("file.grib2.bz2.bz2") == "file.grib2.bz2" + + +class TestStripChainingOptions: + def test_strips_simplecache(self): + opts = {"simplecache": {"cache_storage": "/tmp"}, "anon": True} + result = _strip_chaining_options(opts) + assert "simplecache" not in result + assert result["anon"] is True + + def test_strips_blockcache(self): + opts = {"blockcache": {}, "key": "abc"} + result = _strip_chaining_options(opts) + assert "blockcache" not in result + assert "key" in result + + def test_strips_filecache(self): + opts = {"filecache": {"cache_storage": "/tmp"}} + result = _strip_chaining_options(opts) + assert "filecache" not in result + + def test_empty_input(self): + assert _strip_chaining_options({}) == {} + + def test_no_chaining_keys_unchanged(self): + opts = {"anon": True, "key": "abc", "secret": "xyz"} + assert _strip_chaining_options(opts) == opts + + +class TestDetectionWithCompression: + def test_uri_pattern_grib2_bz2(self): + assert _detect_from_uri_pattern("file.grib2.bz2") == "cfgrib" + + def test_uri_pattern_nc_gz(self): + assert _detect_from_uri_pattern("file.nc.gz") is None + + def test_magic_bytes_grib_with_bz2_path(self): + assert _detect_from_magic_bytes(b"GRIB...", "file.grib2.bz2") == "cfgrib" + + def test_magic_bytes_geotiff_with_gz_path(self): + assert _detect_from_magic_bytes(b"II*\x00...", "file.tif.gz") == "rasterio" + + +class TestDecompressIfNeeded: + def test_bz2_decompressed(self): + content = b"GRIB test content" + with tempfile.TemporaryDirectory() as tmpdir: + compressed = os.path.join(tmpdir, "test.grib2.bz2") + with bz2.open(compressed, "wb") as f: + f.write(content) + + result = _decompress_if_needed(compressed) + assert result == os.path.join(tmpdir, "test.grib2") + assert Path(result).read_bytes() == content + + def test_gz_decompressed(self): + content = b"CDF netcdf3 content" + with tempfile.TemporaryDirectory() as tmpdir: + compressed = os.path.join(tmpdir, "test.nc.gz") + with gzip.open(compressed, "wb") as f: + f.write(content) + + result = _decompress_if_needed(compressed) + assert result == os.path.join(tmpdir, "test.nc") + assert Path(result).read_bytes() == content + + def test_no_compression_unchanged(self): + with tempfile.NamedTemporaryFile(suffix=".grib2", delete=False) as f: + f.write(b"GRIB content") + path = f.name + try: + assert _decompress_if_needed(path) == path + finally: + os.unlink(path) + + def test_idempotent_second_call(self): + """Second call should not re-decompress.""" + content = b"GRIB test" + with tempfile.TemporaryDirectory() as tmpdir: + compressed = os.path.join(tmpdir, "test.grib2.bz2") + with bz2.open(compressed, "wb") as f: + f.write(content) + + result1 = _decompress_if_needed(compressed) + mtime1 = os.path.getmtime(result1) + result2 = _decompress_if_needed(compressed) + assert os.path.getmtime(result2) == mtime1 # file not rewritten + + def test_custom_output_dir(self): + content = b"GRIB test" + with tempfile.TemporaryDirectory() as src_dir: + with tempfile.TemporaryDirectory() as out_dir: + compressed = os.path.join(src_dir, "test.grib2.bz2") + with bz2.open(compressed, "wb") as f: + f.write(content) + + result = _decompress_if_needed(compressed, output_dir=out_dir) + assert result.startswith(out_dir) + assert Path(result).read_bytes() == content + + +class TestOpenDatasetParametersIncludesStorageOptions: + def test_storage_options_in_parameters(self): + """storage_options must be in open_dataset_parameters so xarray forwards it.""" + entrypoint = PrismBackendEntrypoint() + assert "storage_options" in entrypoint.open_dataset_parameters + + +class TestGuessCanOpenCompressed: + def test_grib2_bz2(self): + ep = PrismBackendEntrypoint() + assert ep.guess_can_open("forecast.grib2.bz2") is True + + def test_nc_gz(self): + ep = PrismBackendEntrypoint() + assert ep.guess_can_open("data.nc.gz") is True