Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Changelog

All notable changes to this project will be documented in this file.
## [v2603.0.0]
## Fixed
- an issue regading passing the storage_options to aiohttp

## [v2602.1.0]
### Added
Expand Down
20 changes: 17 additions & 3 deletions src/xarray_prism/_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from .utils import _strip_chaining_options, _strip_compression_suffix

Detector = Callable[[str], Optional[str]]

_custom_detectors: List[Tuple[int, Detector]] = []
Expand Down Expand Up @@ -145,6 +147,9 @@ def looks_like_opendap_url(uri: str) -> bool:

def _detect_from_uri_pattern(lower_uri: str) -> Optional[str]:
"""Detect engine from URI patterns without I/O."""

lower_uri = _strip_compression_suffix(lower_uri)

# Reference URIs -> zarr (Kerchunk)
if is_reference_uri(lower_uri):
return "zarr"
Expand All @@ -153,6 +158,10 @@ def _detect_from_uri_pattern(lower_uri: str) -> Optional[str]:
if lower_uri.endswith(".zarr") or ".zarr/" in lower_uri:
return "zarr"

# GRIB detection by extension
if lower_uri.endswith((".grib", ".grb", ".grb2", ".grib2")):
return "cfgrib"

# THREDDS NCSS with explicit accept format (overrides file extension)
if "/ncss/" in lower_uri or "/ncss?" in lower_uri:
if "accept=netcdf3" in lower_uri:
Expand Down Expand Up @@ -202,8 +211,10 @@ def _read_magic_bytes(fs: Any, path: str) -> Any:

def _detect_from_magic_bytes(header: bytes, lower_path: str) -> Engine:
"""Detect engine from magic bytes and file extension."""
bare_path = _strip_compression_suffix(lower_path)

# GRIB detection
if b"GRIB" in header or lower_path.endswith((".grib", ".grb", ".grb2", ".grib2")):
if b"GRIB" in header or bare_path.endswith((".grib", ".grb", ".grb2", ".grib2")):
return "cfgrib"

# NetCDF3 (Classic)
Expand All @@ -217,7 +228,7 @@ def _detect_from_magic_bytes(header: bytes, lower_path: str) -> Engine:
# GeoTIFF
if header.startswith((b"II*\x00", b"MM\x00*")):
return "rasterio"
if lower_path.endswith((".tif", ".tiff")):
if bare_path.endswith((".tif", ".tiff")):
return "rasterio"

return "unknown"
Expand Down Expand Up @@ -259,7 +270,10 @@ def _detect_engine_impl(uri: str, storage_options: Optional[Dict]) -> str:
# 3. Filesystem-based detection — lazy import fsspec only when needed
import fsspec # noqa: PLC0415

fs, path = fsspec.core.url_to_fs(uri, **(storage_options or {}))
fs, path = fsspec.core.url_to_fs(
uri, **_strip_chaining_options(storage_options or {})
)

lower_path = path.lower()

# Check for Zarr directory
Expand Down
2 changes: 1 addition & 1 deletion src/xarray_prism/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version information for xarray-prism."""

__version__ = "2602.1.0"
__version__ = "2603.0.0"
10 changes: 7 additions & 3 deletions src/xarray_prism/backends/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any, Dict, Optional
from urllib.parse import urlparse

from ..utils import _decompress_if_needed, _strip_chaining_options

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -60,15 +62,17 @@ def _cache_remote_file(
sys.stdout.write("\033[A")
sys.stdout.write("\033[K")
sys.stdout.flush()
return str(local_path)
return _decompress_if_needed(str(local_path))

extra_lines = 0
if show_progress:
fmt = "GRIB" if engine == "cfgrib" else "NetCDF3"
logger.warning(f"Remote {fmt} requires full file download")
extra_lines = 2

fs, path = fsspec.core.url_to_fs(uri, **(storage_options or {}))
fs, path = fsspec.core.url_to_fs(
uri, **_strip_chaining_options(storage_options or {})
)

if show_progress:
size = 0
Expand All @@ -94,7 +98,7 @@ def _cache_remote_file(
else:
fs.get(path, str(local_path))

return str(local_path)
return _decompress_if_needed(str(local_path))


def open_cloud(
Expand Down
5 changes: 3 additions & 2 deletions src/xarray_prism/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray.backends import BackendEntrypoint

from ._detection import (
_strip_compression_suffix,
detect_engine,
detect_uri_type,
is_http_url,
Expand Down Expand Up @@ -75,7 +76,7 @@ class PrismBackendEntrypoint(BackendEntrypoint):
)

url = "https://github.com/freva-org/xarray-prism"
open_dataset_parameters = ("filename_or_obj", "drop_variables")
open_dataset_parameters = ("filename_or_obj", "drop_variables", "storage_options")

ENGINE_MAP: Dict[str, str] = {
"zarr": "zarr",
Expand Down Expand Up @@ -242,7 +243,7 @@ def guess_can_open(self, filename_or_obj: Any) -> bool:
if not isinstance(filename_or_obj, (str, os.PathLike)):
return False

u = str(filename_or_obj).lower()
u = _strip_compression_suffix(str(filename_or_obj).lower())

# Zarr
if u.endswith(".zarr") or ".zarr/" in u:
Expand Down
54 changes: 54 additions & 0 deletions src/xarray_prism/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import logging
import os
import sys
import tempfile
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional

logger = logging.getLogger(__name__)

_COMPRESSION_SUFFIXES = (".bz2", ".gz", ".xz", ".zst", ".lz4")

STORAGE_OPTIONS_TO_GDAL: Dict[str, str] = {
"key": "AWS_ACCESS_KEY_ID",
"secret": "AWS_SECRET_ACCESS_KEY",
Expand Down Expand Up @@ -241,3 +244,54 @@ def sanitize_dataset_attrs(ds):
var.attrs = _clean_attr_obj(dict(var.attrs))

return ds


def _strip_chaining_options(storage_options: dict) -> dict:
"""Strip fsspec chaining/wrapper protocol keys from storage_options.
These are keys like 'simplecache', 'blockcache', 'filecache' that are
fsspec protocol names used for URL chaining, not valid HTTP/remote FS kwargs.
"""
if not storage_options:
return {}
from fsspec.registry import known_implementations

return {k: v for k, v in storage_options.items() if k not in known_implementations}


def _strip_compression_suffix(uri: str) -> str:
"""Remove common compression suffixes from URI for
more accurate pattern detection."""
for suffix in _COMPRESSION_SUFFIXES:
if uri.endswith(suffix):
return uri[: -len(suffix)]
return uri


def _decompress_if_needed(path: str, output_dir: Optional[str] = None) -> str:
"""Decompress a file if it has a known compression suffix.
Returns the path to the decompressed file (or original if not compressed).
Uses output_dir if provided, otherwise decompresses alongside the source file.
"""
import bz2
import gzip
import lzma

# Dict[str, Any] avoids mypy errors from overloaded open() signatures
_DECOMPRESSORS: Dict[str, Any] = {
".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}

for suffix, opener in _DECOMPRESSORS.items():
if path.endswith(suffix):
bare_name = os.path.basename(path)[: -len(suffix)]
out_dir = output_dir or os.path.dirname(path) or tempfile.gettempdir()
decompressed = os.path.join(out_dir, bare_name)
if not os.path.exists(decompressed):
with opener(path, "rb") as src, open(decompressed, "wb") as dst:
while chunk := src.read(512 * 1024):
dst.write(chunk)
return decompressed

return path
142 changes: 142 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@

import os
from pathlib import Path
import bz2
import gzip
import tempfile

import pytest
import xarray as xr

from xarray_prism.backends import open_cloud, open_posix
from xarray_prism import PrismBackendEntrypoint
from xarray_prism._detection import _detect_from_magic_bytes, _detect_from_uri_pattern
from xarray_prism.utils import (
_decompress_if_needed,
_strip_chaining_options,
_strip_compression_suffix,
)


class TestPosixBackend:
Expand Down Expand Up @@ -201,3 +211,135 @@ def test_cache_dir_default(self):
cache_dir = _get_cache_dir()
assert cache_dir.parent == Path(tempfile.gettempdir())
assert "xarray-prism-cache" in str(cache_dir)


class TestStripCompressionSuffix:
def test_strips_bz2(self):
assert _strip_compression_suffix("file.grib2.bz2") == "file.grib2"

def test_strips_gz(self):
assert _strip_compression_suffix("file.nc.gz") == "file.nc"

def test_strips_xz(self):
assert _strip_compression_suffix("file.nc.xz") == "file.nc"

def test_no_suffix_unchanged(self):
assert _strip_compression_suffix("file.grib2") == "file.grib2"

def test_only_strips_last_suffix(self):
assert _strip_compression_suffix("file.grib2.bz2.bz2") == "file.grib2.bz2"


class TestStripChainingOptions:
def test_strips_simplecache(self):
opts = {"simplecache": {"cache_storage": "/tmp"}, "anon": True}
result = _strip_chaining_options(opts)
assert "simplecache" not in result
assert result["anon"] is True

def test_strips_blockcache(self):
opts = {"blockcache": {}, "key": "abc"}
result = _strip_chaining_options(opts)
assert "blockcache" not in result
assert "key" in result

def test_strips_filecache(self):
opts = {"filecache": {"cache_storage": "/tmp"}}
result = _strip_chaining_options(opts)
assert "filecache" not in result

def test_empty_input(self):
assert _strip_chaining_options({}) == {}

def test_no_chaining_keys_unchanged(self):
opts = {"anon": True, "key": "abc", "secret": "xyz"}
assert _strip_chaining_options(opts) == opts


class TestDetectionWithCompression:
def test_uri_pattern_grib2_bz2(self):
assert _detect_from_uri_pattern("file.grib2.bz2") == "cfgrib"

def test_uri_pattern_nc_gz(self):
assert _detect_from_uri_pattern("file.nc.gz") is None

def test_magic_bytes_grib_with_bz2_path(self):
assert _detect_from_magic_bytes(b"GRIB...", "file.grib2.bz2") == "cfgrib"

def test_magic_bytes_geotiff_with_gz_path(self):
assert _detect_from_magic_bytes(b"II*\x00...", "file.tif.gz") == "rasterio"


class TestDecompressIfNeeded:
def test_bz2_decompressed(self):
content = b"GRIB test content"
with tempfile.TemporaryDirectory() as tmpdir:
compressed = os.path.join(tmpdir, "test.grib2.bz2")
with bz2.open(compressed, "wb") as f:
f.write(content)

result = _decompress_if_needed(compressed)
assert result == os.path.join(tmpdir, "test.grib2")
assert Path(result).read_bytes() == content

def test_gz_decompressed(self):
content = b"CDF netcdf3 content"
with tempfile.TemporaryDirectory() as tmpdir:
compressed = os.path.join(tmpdir, "test.nc.gz")
with gzip.open(compressed, "wb") as f:
f.write(content)

result = _decompress_if_needed(compressed)
assert result == os.path.join(tmpdir, "test.nc")
assert Path(result).read_bytes() == content

def test_no_compression_unchanged(self):
with tempfile.NamedTemporaryFile(suffix=".grib2", delete=False) as f:
f.write(b"GRIB content")
path = f.name
try:
assert _decompress_if_needed(path) == path
finally:
os.unlink(path)

def test_idempotent_second_call(self):
"""Second call should not re-decompress."""
content = b"GRIB test"
with tempfile.TemporaryDirectory() as tmpdir:
compressed = os.path.join(tmpdir, "test.grib2.bz2")
with bz2.open(compressed, "wb") as f:
f.write(content)

result1 = _decompress_if_needed(compressed)
mtime1 = os.path.getmtime(result1)
result2 = _decompress_if_needed(compressed)
assert os.path.getmtime(result2) == mtime1 # file not rewritten

def test_custom_output_dir(self):
content = b"GRIB test"
with tempfile.TemporaryDirectory() as src_dir:
with tempfile.TemporaryDirectory() as out_dir:
compressed = os.path.join(src_dir, "test.grib2.bz2")
with bz2.open(compressed, "wb") as f:
f.write(content)

result = _decompress_if_needed(compressed, output_dir=out_dir)
assert result.startswith(out_dir)
assert Path(result).read_bytes() == content


class TestOpenDatasetParametersIncludesStorageOptions:
def test_storage_options_in_parameters(self):
"""storage_options must be in open_dataset_parameters so xarray forwards it."""
entrypoint = PrismBackendEntrypoint()
assert "storage_options" in entrypoint.open_dataset_parameters


class TestGuessCanOpenCompressed:
def test_grib2_bz2(self):
ep = PrismBackendEntrypoint()
assert ep.guess_can_open("forecast.grib2.bz2") is True

def test_nc_gz(self):
ep = PrismBackendEntrypoint()
assert ep.guess_can_open("data.nc.gz") is True
Loading