diff --git a/differt/src/differt/plugins/deepmimo.py b/differt/src/differt/plugins/deepmimo.py index e608cda4..d1ba0aff 100644 --- a/differt/src/differt/plugins/deepmimo.py +++ b/differt/src/differt/plugins/deepmimo.py @@ -4,13 +4,13 @@ from collections.abc import Iterable, Iterator, Mapping from dataclasses import KW_ONLY, asdict -from typing import Any, Generic +from typing import Any, Generic, Literal import equinox as eqx import jax import jax.numpy as jnp import numpy as np -from jaxtyping import Array, ArrayLike, Bool, Float, Int +from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped from differt.em import ( InteractionType, @@ -19,6 +19,7 @@ materials, reflection_coefficients, sp_directions, + sp_rotation_matrix, z_0, ) from differt.geometry import ( @@ -34,6 +35,41 @@ from ._deepmimo_types import ArrayType +def _pad_and_concat( + left: Shaped[Array, "num_tx num_rx num_paths_left num_interactions_left ..."], + right: Shaped[Array, "num_tx num_rx num_paths_right num_interactions_right ..."], + fill_value: float, +) -> Shaped[ + Array, + "num_tx num_rx num_paths_left+num_paths_right max(num_interactions_left, num_interactions_right) ...", +]: + max_num_interactions = max(left.shape[3], right.shape[3]) + extra_dims_pad = [(0, 0)] * (left.ndim - 4) + left = jnp.pad( + left, + ( + (0, 0), + (0, 0), + (0, 0), + (0, max_num_interactions - left.shape[3]), + *extra_dims_pad, + ), + constant_values=fill_value, + ) + right = jnp.pad( + right, + ( + (0, 0), + (0, 0), + (0, 0), + (0, max_num_interactions - right.shape[3]), + *extra_dims_pad, + ), + constant_values=fill_value, + ) + return jnp.concatenate((left, right), axis=2) + + class DeepMIMO(eqx.Module, Generic[ArrayType]): """DeepMIMO format data structure. @@ -120,7 +156,10 @@ def _sort( ) if vertices.shape != self.inter_pos.shape: # pragma: no cover - msg = f"Cannot sort based on provided paths: shape mismatch, got {vertices.shape!r} but expected {self.inter_pos.shape!r}." + msg = ( + "Cannot sort based on provided paths: shape mismatch, got " + f"{vertices.shape!r} but expected {self.inter_pos.shape!r}." + ) raise ValueError(msg) max_num_interactions = self.inter.shape[-1] @@ -278,6 +317,7 @@ def export( radio_materials: Mapping[str, Material] | None = None, frequency: Float[ArrayLike, " "], include_primitives: bool = False, + polarization: Literal["V", "H"] | Float[ArrayLike, "3"] = "V", ) -> DeepMIMO[Array]: """ Export a Ray Tracing simulation to the DeepMIMO format. @@ -301,6 +341,8 @@ def export( If not provided, :data:`materials` will be used. frequency: The operating frequency (in Hz). include_primitives: If :data:`True`, include the primitive indices in the output. + polarization: The antenna polarization. + Can be either ``"V"`` (vertical), ``"H"`` (horizontal), or a 3D unit vector. Returns: The exported DeepMIMO data as JAX arrays. @@ -332,7 +374,12 @@ def export( # Fields array fields = jnp.zeros((num_tx, num_rx, 0, 3), dtype=complex) - polarization = jnp.array([0.0, 0.0, 1.0], dtype=float) + if polarization == "V": + polarization = jnp.array([0.0, 0.0, 1.0], dtype=float) + elif polarization == "H": + polarization = jnp.array([1.0, 0.0, 0.0], dtype=float) + else: + polarization = jnp.asarray(polarization) # Direction of departure (DoD) and direction of arrival (DoA) segments k_d = jnp.zeros((num_tx, num_rx, 0, 3), dtype=float) k_a = jnp.zeros_like(k_d) @@ -357,112 +404,30 @@ def export( # [num_tx num_rx num_path_candidates order+1 3] path_segments = jnp.diff(paths.vertices, axis=-2) - max_num_interactions = max(paths.order, inter.shape[-1]) # [num_tx num_rx num_paths max_num_interactions] if primitives is not None: - primitives = jnp.concatenate( - ( - jnp.concatenate( - ( - primitives, - jnp.full( - ( - *primitives.shape[:-1], - max_num_interactions - primitives.shape[-1], - ), - no_interaction, - dtype=primitives.dtype, - ), - ), - axis=-1, - ), - jnp.concatenate( - ( - paths.objects[..., 1:-1], - jnp.full( - ( - *paths.objects.shape[:-1], - max_num_interactions - paths.order, - ), - no_interaction, - dtype=primitives.dtype, - ), - ), - axis=-1, - ), - ), - axis=-2, + primitives = _pad_and_concat( + primitives, + paths.objects[..., 1:-1], + fill_value=float(no_interaction), ) # [num_tx num_rx num_paths max_num_interactions] - inter = jnp.concatenate( - ( - jnp.concatenate( - ( - inter, - jnp.full( - (*inter.shape[:-1], max_num_interactions - inter.shape[-1]), - no_interaction, - dtype=inter.dtype, - ), - ), - axis=-1, - ), - jnp.concatenate( - ( - paths.interaction_types - if paths.interaction_types is not None - else jnp.full_like( - paths.objects[..., 1:-1], - InteractionType.REFLECTION, - dtype=inter.dtype, - ), - jnp.full( - ( - *paths.objects.shape[:-1], - max_num_interactions - paths.order, - ), - no_interaction, - dtype=inter.dtype, - ), - ), - axis=-1, - ), + inter = _pad_and_concat( + inter, + paths.interaction_types + if paths.interaction_types is not None + else jnp.full_like( + paths.objects[..., 1:-1], + InteractionType.REFLECTION, + dtype=inter.dtype, ), - axis=-2, + fill_value=float(no_interaction), ) # [num_tx num_rx num_paths max_num_interactions 3] - inter_pos = jnp.concatenate( - ( - jnp.concatenate( - ( - inter_pos, - jnp.zeros( - ( - *inter_pos.shape[:-2], - max_num_interactions - inter_pos.shape[-2], - 3, - ), - dtype=inter_pos.dtype, - ), - ), - axis=-2, - ), - jnp.concatenate( - ( - paths.vertices[..., 1:-1, :], - jnp.zeros( - ( - *paths.vertices.shape[:-2], - max_num_interactions - paths.order, - 3, - ), - dtype=inter_pos.dtype, - ), - ), - axis=-2, - ), - ), - axis=-3, + inter_pos = _pad_and_concat( + inter_pos, + paths.vertices[..., 1:-1, :], + fill_value=0.0, ) # [num_tx num_rx num_path_candidates order+1 3], # [num_tx num_rx num_path_candidates order+1 1] @@ -493,19 +458,108 @@ def export( cos_theta = jnp.sum(obj_normals * -k_i, axis=-1, keepdims=True) # [num_tx num_rx num_path_candidates order 1] r_s, r_p = reflection_coefficients(obj_n_r[..., None], cos_theta) - # [num_tx num_rx num_path_candidates 1] - r_s = jnp.prod(r_s, axis=-2) - r_p = jnp.prod(r_p, axis=-2) - # [num_tx num_rx num_path_candidates order 3] - (e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(k_i, k_r, obj_normals) - # [num_tx num_rx num_path_candidates 1] - fields_i_s = jnp.sum(fields_i * e_i_s[..., 0, :], axis=-1, keepdims=True) - fields_i_p = jnp.sum(fields_i * e_i_p[..., 0, :], axis=-1, keepdims=True) - # [num_tx num_rx num_path_candidates 1] - fields_r_s = r_s * fields_i_s - fields_r_p = r_p * fields_i_p + + # Compute transmitter and receiver local s-p bases + # based on departure and arrival directions (matching Sionna's approach) + # For departure direction k_d = k[..., 0, :], compute local spherical basis + k_tx = k[..., 0, :] # [num_tx num_rx num_path_candidates 3] + # For arrival direction k_rx = -k[..., -1, :], compute local spherical basis + k_rx = -k[..., -1, :] # [num_tx num_rx num_path_candidates 3] + + # Convert to spherical coordinates to get theta and phi + # [num_tx num_rx num_path_candidates 3] -> (r, theta, phi) + _, theta_tx, phi_tx = jnp.split(cartesian_to_spherical(k_tx), 3, axis=-1) + _, theta_rx, phi_rx = jnp.split(cartesian_to_spherical(k_rx), 3, axis=-1) + + # Compute theta_hat and phi_hat for transmitter (s=theta_hat, p=phi_hat in Sionna) + # theta_hat = [cos(theta)*cos(phi), cos(theta)*sin(phi), -sin(theta)] + e_tx_s = jnp.concatenate([ + jnp.cos(theta_tx) * jnp.cos(phi_tx), + jnp.cos(theta_tx) * jnp.sin(phi_tx), + -jnp.sin(theta_tx), + ], axis=-1) + # phi_hat = [-sin(phi), cos(phi), 0] + e_tx_p = jnp.concatenate([ + -jnp.sin(phi_tx), + jnp.cos(phi_tx), + jnp.zeros_like(phi_tx), + ], axis=-1) + + # Compute theta_hat and phi_hat for receiver + e_rx_s = jnp.concatenate([ + jnp.cos(theta_rx) * jnp.cos(phi_rx), + jnp.cos(theta_rx) * jnp.sin(phi_rx), + -jnp.sin(theta_rx), + ], axis=-1) + e_rx_p = jnp.concatenate([ + -jnp.sin(phi_rx), + jnp.cos(phi_rx), + jnp.zeros_like(phi_rx), + ], axis=-1) + + # Initialize transfer matrix as identity [num_tx num_rx num_path_candidates 2 2] + num_tx = paths.vertices.shape[0] + num_rx = paths.vertices.shape[1] + num_candidates = paths.vertices.shape[2] + mat_t = jnp.eye(2, dtype=fields.dtype) + mat_t = jnp.broadcast_to( + mat_t, (num_tx, num_rx, num_candidates, 2, 2) + ) + + # Initialize last basis with transmitter basis + last_e_s = e_tx_s + last_e_p = e_tx_p + + # Apply reflections sequentially through each interaction (matrix-based approach) + for i in range(paths.order): + # Change of basis: from last basis to current incident basis + # [num_tx num_rx num_path_candidates 2 2] + mat_cob = sp_rotation_matrix( + last_e_s, + last_e_p, + e_i_s[..., i, :], + e_i_p[..., i, :], + ) + # Apply change of basis to transfer matrix + # [num_tx num_rx num_path_candidates 2 2] + mat_t = jnp.einsum("...ij,...jk->...ik", mat_cob, mat_t) + + # Apply reflection coefficients as diagonal matrix + # [num_tx num_rx num_path_candidates 2] + r_diag = jnp.stack([r_s[..., i, 0], r_p[..., i, 0]], axis=-1) + # Broadcast to [num_tx num_rx num_path_candidates 2 2] diagonal + # [num_tx num_rx num_path_candidates 2 2] + mat_t = mat_t * r_diag[..., :, None] + + # Update last basis to current reflected basis + last_e_s = e_r_s[..., i, :] + last_e_p = e_r_p[..., i, :] + + # Final transformation to receiver basis + # [num_tx num_rx num_path_candidates 2 2] + mat_cob_rx = sp_rotation_matrix( + last_e_s, + last_e_p, + e_rx_s, + e_rx_p, + ) + mat_t = jnp.einsum("...ij,...jk->...ik", mat_cob_rx, mat_t) + + # Project transmitter field onto transmitter s-p basis + # [num_tx num_rx num_path_candidates 2] + field_tx_s = jnp.sum(fields_i * e_tx_s, axis=-1, keepdims=False) + field_tx_p = jnp.sum(fields_i * e_tx_p, axis=-1, keepdims=False) + field_tx_sp = jnp.stack([field_tx_s, field_tx_p], axis=-1) + + # Apply transfer matrix to get field at receiver in receiver s-p basis + # [num_tx num_rx num_path_candidates 2] + field_rx_sp = jnp.einsum("...ij,...j->...i", mat_t, field_tx_sp) + + # Project back to Cartesian coordinates using receiver basis # [num_tx num_rx num_path_candidates 3] - fields_r = fields_r_s * e_r_s[..., -1, :] + fields_r_p * e_r_p[..., -1, :] + fields_r = ( + field_rx_sp[..., 0:1] * e_rx_s + field_rx_sp[..., 1:2] * e_rx_p + ) else: # [num_tx num_rx num_path_candidates 3] fields_r = fields_i diff --git a/differt/tests/plugins/test_deepmimo.py b/differt/tests/plugins/test_deepmimo.py index 6e79368f..2c3210ac 100644 --- a/differt/tests/plugins/test_deepmimo.py +++ b/differt/tests/plugins/test_deepmimo.py @@ -1,4 +1,3 @@ -# ruff: noqa: ERA001 from dataclasses import asdict from itertools import chain @@ -10,7 +9,7 @@ import pytest from jaxtyping import PRNGKeyArray -from differt.em import materials +from differt.em import materials, z_0 from differt.geometry import TriangleMesh from differt.plugins import deepmimo from differt.scene import TriangleScene @@ -219,16 +218,12 @@ def test_match_sionna_on_simple_street_canyon() -> None: a = a[:, 0, :, 0, :, :] # Take only the first TX and RX polarization a = a[..., 0] # Take only the first time instant - # TODO: Understand why phase and power are not matching - - del a - - # chex.assert_trees_all_equal( - # dm.phase, - # jnp.angle(a, deg=True) - # ) + chex.assert_trees_all_close( + dm.phase, + jnp.angle(a, deg=True), + ) - # chex.assert_trees_all_equal( - # dm.power, - # 10.0 * jnp.log10(jnp.abs(a)**2 / z_0), - # ) + chex.assert_trees_all_close( + dm.power, + 10.0 * jnp.log10(jnp.abs(a) ** 2 / z_0), + )