Skip to content

Commit e301994

Browse files
Enforce MetPy signature parity lazily
1 parent 71f30cc commit e301994

2 files changed

Lines changed: 144 additions & 3 deletions

File tree

python/metrust/calc/__init__.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
"""metrust.calc -- Drop-in replacement for metpy.calc
1+
"""metrust.calc -- MetPy-compatible calculation layer
22
3-
Every public function accepts and returns Pint Quantity objects, matching
4-
the MetPy API exactly. Internally, units are stripped to the convention
3+
Every public function accepts and returns Pint Quantity objects with a
4+
MetPy-compatible API surface. Internally, units are stripped to the convention
55
expected by the Rust engine (hPa for pressure, Celsius for temperature,
66
m/s for wind, m for height, etc.), the Rust function is called, and
77
appropriate units are attached to the result.
@@ -23,6 +23,8 @@
2323
"""
2424

2525
import importlib
26+
import inspect
27+
import sys
2628
from contextlib import contextmanager
2729
import numpy as np
2830
try:
@@ -6261,3 +6263,89 @@ def __getattr__(name):
62616263

62626264
def __dir__():
62636265
return sorted(set(globals()).union(__all__))
6266+
6267+
6268+
_METPY_SIGNATURE_HOOK = None
6269+
6270+
6271+
def _apply_metpy_signatures():
6272+
"""Mirror MetPy signatures for shared public wrappers when MetPy is available."""
6273+
_metpy_calc = sys.modules.get("metpy.calc")
6274+
if _metpy_calc is None:
6275+
return
6276+
6277+
for name in __all__:
6278+
metpy_obj = getattr(_metpy_calc, name, None)
6279+
if metpy_obj is None or not callable(metpy_obj):
6280+
continue
6281+
6282+
target_name = _COMPAT_ALIASES.get(name, name)
6283+
wrapper = globals().get(target_name)
6284+
if wrapper is None or not callable(wrapper):
6285+
continue
6286+
6287+
try:
6288+
wrapper.__signature__ = inspect.signature(metpy_obj)
6289+
except (TypeError, ValueError):
6290+
continue
6291+
6292+
6293+
class _MetPyCalcSignatureHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
6294+
"""Apply MetPy signature mirroring when metpy.calc imports after metrust.calc."""
6295+
6296+
def __init__(self):
6297+
self._wrapped_loader = None
6298+
6299+
def find_spec(self, fullname, path=None, target=None):
6300+
if fullname != "metpy.calc":
6301+
return None
6302+
6303+
for finder in sys.meta_path:
6304+
if finder is self:
6305+
continue
6306+
find_spec = getattr(finder, "find_spec", None)
6307+
if find_spec is None:
6308+
continue
6309+
spec = find_spec(fullname, path, target)
6310+
if spec is None:
6311+
continue
6312+
self._wrapped_loader = spec.loader
6313+
spec.loader = self
6314+
return spec
6315+
return None
6316+
6317+
def create_module(self, spec):
6318+
if self._wrapped_loader is not None and hasattr(self._wrapped_loader, "create_module"):
6319+
return self._wrapped_loader.create_module(spec)
6320+
return None
6321+
6322+
def exec_module(self, module):
6323+
if self._wrapped_loader is None:
6324+
raise ImportError("metpy.calc signature hook missing wrapped loader")
6325+
self._wrapped_loader.exec_module(module)
6326+
_apply_metpy_signatures()
6327+
_remove_metpy_signature_hook()
6328+
6329+
6330+
def _install_metpy_signature_hook():
6331+
global _METPY_SIGNATURE_HOOK
6332+
if _METPY_SIGNATURE_HOOK is not None or "metpy.calc" in sys.modules:
6333+
return
6334+
_METPY_SIGNATURE_HOOK = _MetPyCalcSignatureHook()
6335+
sys.meta_path.insert(0, _METPY_SIGNATURE_HOOK)
6336+
6337+
6338+
def _remove_metpy_signature_hook():
6339+
global _METPY_SIGNATURE_HOOK
6340+
hook = _METPY_SIGNATURE_HOOK
6341+
if hook is None:
6342+
return
6343+
try:
6344+
sys.meta_path.remove(hook)
6345+
except ValueError:
6346+
pass
6347+
_METPY_SIGNATURE_HOOK = None
6348+
6349+
6350+
_apply_metpy_signatures()
6351+
_install_metpy_signature_hook()

tests/test_signature_parity.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import subprocess
2+
import sys
3+
import textwrap
4+
5+
6+
def test_public_signatures_match_metpy_even_when_metrust_imports_first():
7+
script = textwrap.dedent(
8+
"""
9+
import inspect
10+
import metrust.calc as mrcalc
11+
import metpy.calc as mpcalc
12+
13+
ignore = {"set_module"}
14+
mp = {}
15+
mr = {}
16+
17+
for name in dir(mpcalc):
18+
if name.startswith("_") or name in ignore:
19+
continue
20+
obj = getattr(mpcalc, name)
21+
if callable(obj):
22+
mp[name] = obj
23+
24+
for name in dir(mrcalc):
25+
if name.startswith("_") or name in ignore:
26+
continue
27+
obj = getattr(mrcalc, name)
28+
if callable(obj):
29+
mr[name] = obj
30+
31+
mismatches = []
32+
for name in sorted(set(mp) & set(mr)):
33+
try:
34+
mp_sig = str(inspect.signature(mp[name]))
35+
mr_sig = str(inspect.signature(mr[name]))
36+
except Exception:
37+
continue
38+
if mp_sig != mr_sig:
39+
mismatches.append((name, mp_sig, mr_sig))
40+
41+
if mismatches:
42+
for name, mp_sig, mr_sig in mismatches:
43+
print(f"{name}: {mp_sig} != {mr_sig}")
44+
raise SystemExit(1)
45+
"""
46+
)
47+
result = subprocess.run(
48+
[sys.executable, "-c", script],
49+
capture_output=True,
50+
text=True,
51+
check=False,
52+
)
53+
assert result.returncode == 0, result.stdout + result.stderr

0 commit comments

Comments
 (0)