From ee9266faa708a69412c7ab7272c7cc95aa82700f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Fri, 19 Dec 2025 10:12:56 +0100 Subject: [PATCH] chore(lib): simplify `export`'s implementation and add `polarization` Import changes made in #355. --- CHANGELOG.md | 5 + differt/src/differt/plugins/deepmimo.py | 175 ++++++++++-------------- differt/tests/plugins/test_deepmimo.py | 6 +- 3 files changed, 79 insertions(+), 107 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe56bec4..3b1ae819 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,10 +20,15 @@ with one *slight* but **important** difference: ## [Unreleased](https://github.com/jeertmans/DiffeRT/compare/v0.6.2...HEAD) +### Added + +- Added `polarization` parameter to {func}`deepmimo.export` (by , in ). + ### Chore - Removed PyOpenGL from macOS dependencies as it is no longer needed to fix VisPy not finding DLL files (by , in ). - Fix anchor link to JAX's documentation (by , in ). +- Simplified {func}`deepmimo.export` to reduce redundant code (by , in ). ### Fixed diff --git a/differt/src/differt/plugins/deepmimo.py b/differt/src/differt/plugins/deepmimo.py index e608cda4..2b6ffd66 100644 --- a/differt/src/differt/plugins/deepmimo.py +++ b/differt/src/differt/plugins/deepmimo.py @@ -4,13 +4,13 @@ from collections.abc import Iterable, Iterator, Mapping from dataclasses import KW_ONLY, asdict -from typing import Any, Generic +from typing import Any, Generic, Literal import equinox as eqx import jax import jax.numpy as jnp import numpy as np -from jaxtyping import Array, ArrayLike, Bool, Float, Int +from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped from differt.em import ( InteractionType, @@ -34,6 +34,41 @@ from ._deepmimo_types import ArrayType +def _pad_and_concat( + left: Shaped[Array, "num_tx num_rx num_paths_left num_interactions_left ..."], + right: Shaped[Array, "num_tx num_rx num_paths_right num_interactions_right ..."], + fill_value: Any, +) -> Shaped[ + Array, + "num_tx num_rx num_paths_left+num_paths_right max(num_interactions_left,num_interactions_right) ...", +]: + max_num_interactions = max(left.shape[3], right.shape[3]) + extra_dims_pad = [(0, 0)] * (left.ndim - 4) + left = jnp.pad( + left, + ( + (0, 0), + (0, 0), + (0, 0), + (0, max_num_interactions - left.shape[3]), + *extra_dims_pad, + ), + constant_values=fill_value, + ) + right = jnp.pad( + right, + ( + (0, 0), + (0, 0), + (0, 0), + (0, max_num_interactions - right.shape[3]), + *extra_dims_pad, + ), + constant_values=fill_value, + ) + return jnp.concatenate((left, right), axis=2) + + class DeepMIMO(eqx.Module, Generic[ArrayType]): """DeepMIMO format data structure. @@ -120,7 +155,10 @@ def _sort( ) if vertices.shape != self.inter_pos.shape: # pragma: no cover - msg = f"Cannot sort based on provided paths: shape mismatch, got {vertices.shape!r} but expected {self.inter_pos.shape!r}." + msg = ( + "Cannot sort based on provided paths: shape mismatch, got " + f"{vertices.shape!r} but expected {self.inter_pos.shape!r}." + ) raise ValueError(msg) max_num_interactions = self.inter.shape[-1] @@ -278,15 +316,17 @@ def export( radio_materials: Mapping[str, Material] | None = None, frequency: Float[ArrayLike, " "], include_primitives: bool = False, + polarization: Literal["V", "H"] | Float[ArrayLike, "3"] = "V", ) -> DeepMIMO[Array]: """ Export a Ray Tracing simulation to the DeepMIMO format. .. note:: - The current implementation assumes far-field propagation in free space, and isotropic antennas with vertical polarization. + The current implementation assumes far-field propagation in free space, and isotropic antennas with a single polarization. While tests show a good match with Sionna's :class:`sionna.rt.PathSolver` for most attributes, - :attr:`DeepMIMO.power` and :attr:`DeepMIMO.phase` are not exactly equal, and we don't know yet if our implementation is 100% correct. If you know how to improve this, please open an issue or a pull-request on GitHub! + :attr:`DeepMIMO.power` and :attr:`DeepMIMO.phase` are not exactly equal, and we don't know yet if our implementation is 100% correct. + If you know how to improve this, please open an issue or a pull-request on GitHub! Args: paths: The geometrical paths. @@ -301,6 +341,8 @@ def export( If not provided, :data:`materials` will be used. frequency: The operating frequency (in Hz). include_primitives: If :data:`True`, include the primitive indices in the output. + polarization: The antenna polarization. + Can be either ``"V"`` (vertical, z-axis up), ``"H"`` (horizontal, x-axis up), or a 3D unit vector. Returns: The exported DeepMIMO data as JAX arrays. @@ -332,7 +374,12 @@ def export( # Fields array fields = jnp.zeros((num_tx, num_rx, 0, 3), dtype=complex) - polarization = jnp.array([0.0, 0.0, 1.0], dtype=float) + if polarization == "V": + polarization = jnp.array([0.0, 0.0, 1.0], dtype=float) + elif polarization == "H": + polarization = jnp.array([1.0, 0.0, 0.0], dtype=float) + else: # pragma: no cover + polarization = jnp.asarray(polarization) # Direction of departure (DoD) and direction of arrival (DoA) segments k_d = jnp.zeros((num_tx, num_rx, 0, 3), dtype=float) k_a = jnp.zeros_like(k_d) @@ -357,112 +404,30 @@ def export( # [num_tx num_rx num_path_candidates order+1 3] path_segments = jnp.diff(paths.vertices, axis=-2) - max_num_interactions = max(paths.order, inter.shape[-1]) # [num_tx num_rx num_paths max_num_interactions] if primitives is not None: - primitives = jnp.concatenate( - ( - jnp.concatenate( - ( - primitives, - jnp.full( - ( - *primitives.shape[:-1], - max_num_interactions - primitives.shape[-1], - ), - no_interaction, - dtype=primitives.dtype, - ), - ), - axis=-1, - ), - jnp.concatenate( - ( - paths.objects[..., 1:-1], - jnp.full( - ( - *paths.objects.shape[:-1], - max_num_interactions - paths.order, - ), - no_interaction, - dtype=primitives.dtype, - ), - ), - axis=-1, - ), - ), - axis=-2, + primitives = _pad_and_concat( + primitives, + paths.objects[..., 1:-1], + fill_value=no_interaction, ) # [num_tx num_rx num_paths max_num_interactions] - inter = jnp.concatenate( - ( - jnp.concatenate( - ( - inter, - jnp.full( - (*inter.shape[:-1], max_num_interactions - inter.shape[-1]), - no_interaction, - dtype=inter.dtype, - ), - ), - axis=-1, - ), - jnp.concatenate( - ( - paths.interaction_types - if paths.interaction_types is not None - else jnp.full_like( - paths.objects[..., 1:-1], - InteractionType.REFLECTION, - dtype=inter.dtype, - ), - jnp.full( - ( - *paths.objects.shape[:-1], - max_num_interactions - paths.order, - ), - no_interaction, - dtype=inter.dtype, - ), - ), - axis=-1, - ), + inter = _pad_and_concat( + inter, + paths.interaction_types + if paths.interaction_types is not None + else jnp.full_like( + paths.objects[..., 1:-1], + InteractionType.REFLECTION, + dtype=inter.dtype, ), - axis=-2, + fill_value=no_interaction, ) # [num_tx num_rx num_paths max_num_interactions 3] - inter_pos = jnp.concatenate( - ( - jnp.concatenate( - ( - inter_pos, - jnp.zeros( - ( - *inter_pos.shape[:-2], - max_num_interactions - inter_pos.shape[-2], - 3, - ), - dtype=inter_pos.dtype, - ), - ), - axis=-2, - ), - jnp.concatenate( - ( - paths.vertices[..., 1:-1, :], - jnp.zeros( - ( - *paths.vertices.shape[:-2], - max_num_interactions - paths.order, - 3, - ), - dtype=inter_pos.dtype, - ), - ), - axis=-2, - ), - ), - axis=-3, + inter_pos = _pad_and_concat( + inter_pos, + paths.vertices[..., 1:-1, :], + fill_value=0.0, ) # [num_tx num_rx num_path_candidates order+1 3], # [num_tx num_rx num_path_candidates order+1 1] diff --git a/differt/tests/plugins/test_deepmimo.py b/differt/tests/plugins/test_deepmimo.py index 6e79368f..3153f208 100644 --- a/differt/tests/plugins/test_deepmimo.py +++ b/differt/tests/plugins/test_deepmimo.py @@ -1,6 +1,7 @@ # ruff: noqa: ERA001 from dataclasses import asdict from itertools import chain +from typing import Literal import chex import equinox as eqx @@ -88,7 +89,8 @@ def test_export(key: PRNGKeyArray) -> None: @pytest.mark.slow -def test_match_sionna_on_simple_street_canyon() -> None: +@pytest.mark.parametrize("polarization", ["V", "H"]) +def test_match_sionna_on_simple_street_canyon(polarization: Literal["V", "H"]) -> None: mi = pytest.importorskip("mitsuba", reason="mitsuba not installed") try: mi.set_variant("llvm_ad_mono_polarized") @@ -106,7 +108,7 @@ def test_match_sionna_on_simple_street_canyon() -> None: vertical_spacing=0.5, horizontal_spacing=0.5, pattern="iso", - polarization="V", + polarization=polarization, ) sionna_scene.rx_array = sionna.rt.PlanarArray(