Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ with one *slight* but **important** difference:

## [Unreleased](https://github.com/jeertmans/DiffeRT/compare/v0.6.2...HEAD)

### Added

- Added `polarization` parameter to {func}`deepmimo.export<differt.plugins.deepmimo.export>` (by <gh-user:jeertmans>, in <gh-pr:356>).

### Chore

- Removed PyOpenGL from macOS dependencies as it is no longer needed to fix VisPy not finding DLL files (by <gh-user:jeertmans>, in <gh-pr:345>).
- Fix anchor link to JAX's documentation (by <gh-user:jeertmans>, in <gh-pr:346>).
- Simplified {func}`deepmimo.export<differt.plugins.deepmimo.export>` to reduce redundant code (by <gh-user:jeertmans>, in <gh-pr:356>).

### Fixed

Expand Down
175 changes: 70 additions & 105 deletions differt/src/differt/plugins/deepmimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from collections.abc import Iterable, Iterator, Mapping
from dataclasses import KW_ONLY, asdict
from typing import Any, Generic
from typing import Any, Generic, Literal

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Float, Int
from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped

from differt.em import (
InteractionType,
Expand All @@ -34,6 +34,41 @@
from ._deepmimo_types import ArrayType


def _pad_and_concat(
left: Shaped[Array, "num_tx num_rx num_paths_left num_interactions_left ..."],
right: Shaped[Array, "num_tx num_rx num_paths_right num_interactions_right ..."],
fill_value: Any,
) -> Shaped[
Array,
"num_tx num_rx num_paths_left+num_paths_right max(num_interactions_left,num_interactions_right) ...",
]:
max_num_interactions = max(left.shape[3], right.shape[3])
extra_dims_pad = [(0, 0)] * (left.ndim - 4)
left = jnp.pad(
left,
(
(0, 0),
(0, 0),
(0, 0),
(0, max_num_interactions - left.shape[3]),
*extra_dims_pad,
),
constant_values=fill_value,
)
right = jnp.pad(
right,
(
(0, 0),
(0, 0),
(0, 0),
(0, max_num_interactions - right.shape[3]),
*extra_dims_pad,
),
constant_values=fill_value,
)
return jnp.concatenate((left, right), axis=2)


class DeepMIMO(eqx.Module, Generic[ArrayType]):
"""DeepMIMO format data structure.

Expand Down Expand Up @@ -120,7 +155,10 @@ def _sort(
)

if vertices.shape != self.inter_pos.shape: # pragma: no cover
msg = f"Cannot sort based on provided paths: shape mismatch, got {vertices.shape!r} but expected {self.inter_pos.shape!r}."
msg = (
"Cannot sort based on provided paths: shape mismatch, got "
f"{vertices.shape!r} but expected {self.inter_pos.shape!r}."
)
raise ValueError(msg)

max_num_interactions = self.inter.shape[-1]
Expand Down Expand Up @@ -278,15 +316,17 @@ def export(
radio_materials: Mapping[str, Material] | None = None,
frequency: Float[ArrayLike, " "],
include_primitives: bool = False,
polarization: Literal["V", "H"] | Float[ArrayLike, "3"] = "V",
) -> DeepMIMO[Array]:
"""
Export a Ray Tracing simulation to the DeepMIMO format.

.. note::
The current implementation assumes far-field propagation in free space, and isotropic antennas with vertical polarization.
The current implementation assumes far-field propagation in free space, and isotropic antennas with a single polarization.

While tests show a good match with Sionna's :class:`sionna.rt.PathSolver` for most attributes,
:attr:`DeepMIMO.power` and :attr:`DeepMIMO.phase` are not exactly equal, and we don't know yet if our implementation is 100% correct. If you know how to improve this, please open an issue or a pull-request on GitHub!
:attr:`DeepMIMO.power` and :attr:`DeepMIMO.phase` are not exactly equal, and we don't know yet if our implementation is 100% correct.
If you know how to improve this, please open an issue or a pull-request on GitHub!

Args:
paths: The geometrical paths.
Expand All @@ -301,6 +341,8 @@ def export(
If not provided, :data:`materials<differt.em._material.materials>` will be used.
frequency: The operating frequency (in Hz).
include_primitives: If :data:`True`, include the primitive indices in the output.
polarization: The antenna polarization.
Can be either ``"V"`` (vertical, z-axis up), ``"H"`` (horizontal, x-axis up), or a 3D unit vector.

Returns:
The exported DeepMIMO data as JAX arrays.
Expand Down Expand Up @@ -332,7 +374,12 @@ def export(

# Fields array
fields = jnp.zeros((num_tx, num_rx, 0, 3), dtype=complex)
polarization = jnp.array([0.0, 0.0, 1.0], dtype=float)
if polarization == "V":
polarization = jnp.array([0.0, 0.0, 1.0], dtype=float)
elif polarization == "H":
polarization = jnp.array([1.0, 0.0, 0.0], dtype=float)
else: # pragma: no cover
polarization = jnp.asarray(polarization)
# Direction of departure (DoD) and direction of arrival (DoA) segments
k_d = jnp.zeros((num_tx, num_rx, 0, 3), dtype=float)
k_a = jnp.zeros_like(k_d)
Expand All @@ -357,112 +404,30 @@ def export(
# [num_tx num_rx num_path_candidates order+1 3]
path_segments = jnp.diff(paths.vertices, axis=-2)

max_num_interactions = max(paths.order, inter.shape[-1])
# [num_tx num_rx num_paths max_num_interactions]
if primitives is not None:
primitives = jnp.concatenate(
(
jnp.concatenate(
(
primitives,
jnp.full(
(
*primitives.shape[:-1],
max_num_interactions - primitives.shape[-1],
),
no_interaction,
dtype=primitives.dtype,
),
),
axis=-1,
),
jnp.concatenate(
(
paths.objects[..., 1:-1],
jnp.full(
(
*paths.objects.shape[:-1],
max_num_interactions - paths.order,
),
no_interaction,
dtype=primitives.dtype,
),
),
axis=-1,
),
),
axis=-2,
primitives = _pad_and_concat(
primitives,
paths.objects[..., 1:-1],
fill_value=no_interaction,
)
# [num_tx num_rx num_paths max_num_interactions]
inter = jnp.concatenate(
(
jnp.concatenate(
(
inter,
jnp.full(
(*inter.shape[:-1], max_num_interactions - inter.shape[-1]),
no_interaction,
dtype=inter.dtype,
),
),
axis=-1,
),
jnp.concatenate(
(
paths.interaction_types
if paths.interaction_types is not None
else jnp.full_like(
paths.objects[..., 1:-1],
InteractionType.REFLECTION,
dtype=inter.dtype,
),
jnp.full(
(
*paths.objects.shape[:-1],
max_num_interactions - paths.order,
),
no_interaction,
dtype=inter.dtype,
),
),
axis=-1,
),
inter = _pad_and_concat(
inter,
paths.interaction_types
if paths.interaction_types is not None
else jnp.full_like(
paths.objects[..., 1:-1],
InteractionType.REFLECTION,
dtype=inter.dtype,
),
axis=-2,
fill_value=no_interaction,
)
# [num_tx num_rx num_paths max_num_interactions 3]
inter_pos = jnp.concatenate(
(
jnp.concatenate(
(
inter_pos,
jnp.zeros(
(
*inter_pos.shape[:-2],
max_num_interactions - inter_pos.shape[-2],
3,
),
dtype=inter_pos.dtype,
),
),
axis=-2,
),
jnp.concatenate(
(
paths.vertices[..., 1:-1, :],
jnp.zeros(
(
*paths.vertices.shape[:-2],
max_num_interactions - paths.order,
3,
),
dtype=inter_pos.dtype,
),
),
axis=-2,
),
),
axis=-3,
inter_pos = _pad_and_concat(
inter_pos,
paths.vertices[..., 1:-1, :],
fill_value=0.0,
)
# [num_tx num_rx num_path_candidates order+1 3],
# [num_tx num_rx num_path_candidates order+1 1]
Expand Down
6 changes: 4 additions & 2 deletions differt/tests/plugins/test_deepmimo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: ERA001
from dataclasses import asdict
from itertools import chain
from typing import Literal

import chex
import equinox as eqx
Expand Down Expand Up @@ -88,7 +89,8 @@ def test_export(key: PRNGKeyArray) -> None:


@pytest.mark.slow
def test_match_sionna_on_simple_street_canyon() -> None:
@pytest.mark.parametrize("polarization", ["V", "H"])
def test_match_sionna_on_simple_street_canyon(polarization: Literal["V", "H"]) -> None:
mi = pytest.importorskip("mitsuba", reason="mitsuba not installed")
try:
mi.set_variant("llvm_ad_mono_polarized")
Expand All @@ -106,7 +108,7 @@ def test_match_sionna_on_simple_street_canyon() -> None:
vertical_spacing=0.5,
horizontal_spacing=0.5,
pattern="iso",
polarization="V",
polarization=polarization,
)

sionna_scene.rx_array = sionna.rt.PlanarArray(
Expand Down
Loading