Skip to content
Draft
284 changes: 169 additions & 115 deletions differt/src/differt/plugins/deepmimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from collections.abc import Iterable, Iterator, Mapping
from dataclasses import KW_ONLY, asdict
from typing import Any, Generic
from typing import Any, Generic, Literal

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, ArrayLike, Bool, Float, Int
from jaxtyping import Array, ArrayLike, Bool, Float, Int, Shaped

from differt.em import (
InteractionType,
Expand All @@ -19,6 +19,7 @@
materials,
reflection_coefficients,
sp_directions,
sp_rotation_matrix,
z_0,
)
from differt.geometry import (
Expand All @@ -34,6 +35,41 @@
from ._deepmimo_types import ArrayType


def _pad_and_concat(
left: Shaped[Array, "num_tx num_rx num_paths_left num_interactions_left ..."],
right: Shaped[Array, "num_tx num_rx num_paths_right num_interactions_right ..."],
fill_value: float,
) -> Shaped[
Array,
"num_tx num_rx num_paths_left+num_paths_right max(num_interactions_left, num_interactions_right) ...",
]:
max_num_interactions = max(left.shape[3], right.shape[3])
extra_dims_pad = [(0, 0)] * (left.ndim - 4)
left = jnp.pad(
left,
(
(0, 0),
(0, 0),
(0, 0),
(0, max_num_interactions - left.shape[3]),
*extra_dims_pad,
),
constant_values=fill_value,
)
right = jnp.pad(
right,
(
(0, 0),
(0, 0),
(0, 0),
(0, max_num_interactions - right.shape[3]),
*extra_dims_pad,
),
constant_values=fill_value,
)
return jnp.concatenate((left, right), axis=2)


class DeepMIMO(eqx.Module, Generic[ArrayType]):
"""DeepMIMO format data structure.

Expand Down Expand Up @@ -120,7 +156,10 @@ def _sort(
)

if vertices.shape != self.inter_pos.shape: # pragma: no cover
msg = f"Cannot sort based on provided paths: shape mismatch, got {vertices.shape!r} but expected {self.inter_pos.shape!r}."
msg = (
"Cannot sort based on provided paths: shape mismatch, got "
f"{vertices.shape!r} but expected {self.inter_pos.shape!r}."
)
raise ValueError(msg)

max_num_interactions = self.inter.shape[-1]
Expand Down Expand Up @@ -278,6 +317,7 @@ def export(
radio_materials: Mapping[str, Material] | None = None,
frequency: Float[ArrayLike, " "],
include_primitives: bool = False,
polarization: Literal["V", "H"] | Float[ArrayLike, "3"] = "V",
) -> DeepMIMO[Array]:
"""
Export a Ray Tracing simulation to the DeepMIMO format.
Expand All @@ -301,6 +341,8 @@ def export(
If not provided, :data:`materials<differt.em._material.materials>` will be used.
frequency: The operating frequency (in Hz).
include_primitives: If :data:`True`, include the primitive indices in the output.
polarization: The antenna polarization.
Can be either ``"V"`` (vertical), ``"H"`` (horizontal), or a 3D unit vector.

Returns:
The exported DeepMIMO data as JAX arrays.
Expand Down Expand Up @@ -332,7 +374,12 @@ def export(

# Fields array
fields = jnp.zeros((num_tx, num_rx, 0, 3), dtype=complex)
polarization = jnp.array([0.0, 0.0, 1.0], dtype=float)
if polarization == "V":
polarization = jnp.array([0.0, 0.0, 1.0], dtype=float)
elif polarization == "H":
polarization = jnp.array([1.0, 0.0, 0.0], dtype=float)
else:
polarization = jnp.asarray(polarization)
# Direction of departure (DoD) and direction of arrival (DoA) segments
k_d = jnp.zeros((num_tx, num_rx, 0, 3), dtype=float)
k_a = jnp.zeros_like(k_d)
Expand All @@ -357,112 +404,30 @@ def export(
# [num_tx num_rx num_path_candidates order+1 3]
path_segments = jnp.diff(paths.vertices, axis=-2)

max_num_interactions = max(paths.order, inter.shape[-1])
# [num_tx num_rx num_paths max_num_interactions]
if primitives is not None:
primitives = jnp.concatenate(
(
jnp.concatenate(
(
primitives,
jnp.full(
(
*primitives.shape[:-1],
max_num_interactions - primitives.shape[-1],
),
no_interaction,
dtype=primitives.dtype,
),
),
axis=-1,
),
jnp.concatenate(
(
paths.objects[..., 1:-1],
jnp.full(
(
*paths.objects.shape[:-1],
max_num_interactions - paths.order,
),
no_interaction,
dtype=primitives.dtype,
),
),
axis=-1,
),
),
axis=-2,
primitives = _pad_and_concat(
primitives,
paths.objects[..., 1:-1],
fill_value=float(no_interaction),
)
# [num_tx num_rx num_paths max_num_interactions]
inter = jnp.concatenate(
(
jnp.concatenate(
(
inter,
jnp.full(
(*inter.shape[:-1], max_num_interactions - inter.shape[-1]),
no_interaction,
dtype=inter.dtype,
),
),
axis=-1,
),
jnp.concatenate(
(
paths.interaction_types
if paths.interaction_types is not None
else jnp.full_like(
paths.objects[..., 1:-1],
InteractionType.REFLECTION,
dtype=inter.dtype,
),
jnp.full(
(
*paths.objects.shape[:-1],
max_num_interactions - paths.order,
),
no_interaction,
dtype=inter.dtype,
),
),
axis=-1,
),
inter = _pad_and_concat(
inter,
paths.interaction_types
if paths.interaction_types is not None
else jnp.full_like(
paths.objects[..., 1:-1],
InteractionType.REFLECTION,
dtype=inter.dtype,
),
axis=-2,
fill_value=float(no_interaction),
)
# [num_tx num_rx num_paths max_num_interactions 3]
inter_pos = jnp.concatenate(
(
jnp.concatenate(
(
inter_pos,
jnp.zeros(
(
*inter_pos.shape[:-2],
max_num_interactions - inter_pos.shape[-2],
3,
),
dtype=inter_pos.dtype,
),
),
axis=-2,
),
jnp.concatenate(
(
paths.vertices[..., 1:-1, :],
jnp.zeros(
(
*paths.vertices.shape[:-2],
max_num_interactions - paths.order,
3,
),
dtype=inter_pos.dtype,
),
),
axis=-2,
),
),
axis=-3,
inter_pos = _pad_and_concat(
inter_pos,
paths.vertices[..., 1:-1, :],
fill_value=0.0,
)
# [num_tx num_rx num_path_candidates order+1 3],
# [num_tx num_rx num_path_candidates order+1 1]
Expand Down Expand Up @@ -493,19 +458,108 @@ def export(
cos_theta = jnp.sum(obj_normals * -k_i, axis=-1, keepdims=True)
# [num_tx num_rx num_path_candidates order 1]
r_s, r_p = reflection_coefficients(obj_n_r[..., None], cos_theta)
# [num_tx num_rx num_path_candidates 1]
r_s = jnp.prod(r_s, axis=-2)
r_p = jnp.prod(r_p, axis=-2)
# [num_tx num_rx num_path_candidates order 3]
(e_i_s, e_i_p), (e_r_s, e_r_p) = sp_directions(k_i, k_r, obj_normals)
# [num_tx num_rx num_path_candidates 1]
fields_i_s = jnp.sum(fields_i * e_i_s[..., 0, :], axis=-1, keepdims=True)
fields_i_p = jnp.sum(fields_i * e_i_p[..., 0, :], axis=-1, keepdims=True)
# [num_tx num_rx num_path_candidates 1]
fields_r_s = r_s * fields_i_s
fields_r_p = r_p * fields_i_p

# Compute transmitter and receiver local s-p bases
# based on departure and arrival directions (matching Sionna's approach)
# For departure direction k_d = k[..., 0, :], compute local spherical basis
k_tx = k[..., 0, :] # [num_tx num_rx num_path_candidates 3]
# For arrival direction k_rx = -k[..., -1, :], compute local spherical basis
k_rx = -k[..., -1, :] # [num_tx num_rx num_path_candidates 3]

# Convert to spherical coordinates to get theta and phi
# [num_tx num_rx num_path_candidates 3] -> (r, theta, phi)
_, theta_tx, phi_tx = jnp.split(cartesian_to_spherical(k_tx), 3, axis=-1)
_, theta_rx, phi_rx = jnp.split(cartesian_to_spherical(k_rx), 3, axis=-1)

# Compute theta_hat and phi_hat for transmitter (s=theta_hat, p=phi_hat in Sionna)
# theta_hat = [cos(theta)*cos(phi), cos(theta)*sin(phi), -sin(theta)]
e_tx_s = jnp.concatenate([
jnp.cos(theta_tx) * jnp.cos(phi_tx),
jnp.cos(theta_tx) * jnp.sin(phi_tx),
-jnp.sin(theta_tx),
], axis=-1)
# phi_hat = [-sin(phi), cos(phi), 0]
e_tx_p = jnp.concatenate([
-jnp.sin(phi_tx),
jnp.cos(phi_tx),
jnp.zeros_like(phi_tx),
], axis=-1)

# Compute theta_hat and phi_hat for receiver
e_rx_s = jnp.concatenate([
jnp.cos(theta_rx) * jnp.cos(phi_rx),
jnp.cos(theta_rx) * jnp.sin(phi_rx),
-jnp.sin(theta_rx),
], axis=-1)
e_rx_p = jnp.concatenate([
-jnp.sin(phi_rx),
jnp.cos(phi_rx),
jnp.zeros_like(phi_rx),
], axis=-1)

# Initialize transfer matrix as identity [num_tx num_rx num_path_candidates 2 2]
num_tx = paths.vertices.shape[0]
num_rx = paths.vertices.shape[1]
num_candidates = paths.vertices.shape[2]
mat_t = jnp.eye(2, dtype=fields.dtype)
mat_t = jnp.broadcast_to(
mat_t, (num_tx, num_rx, num_candidates, 2, 2)
)

# Initialize last basis with transmitter basis
last_e_s = e_tx_s
last_e_p = e_tx_p

# Apply reflections sequentially through each interaction (matrix-based approach)
for i in range(paths.order):
# Change of basis: from last basis to current incident basis
# [num_tx num_rx num_path_candidates 2 2]
mat_cob = sp_rotation_matrix(
last_e_s,
last_e_p,
e_i_s[..., i, :],
e_i_p[..., i, :],
)
# Apply change of basis to transfer matrix
# [num_tx num_rx num_path_candidates 2 2]
mat_t = jnp.einsum("...ij,...jk->...ik", mat_cob, mat_t)

# Apply reflection coefficients as diagonal matrix
# [num_tx num_rx num_path_candidates 2]
r_diag = jnp.stack([r_s[..., i, 0], r_p[..., i, 0]], axis=-1)
# Broadcast to [num_tx num_rx num_path_candidates 2 2] diagonal
# [num_tx num_rx num_path_candidates 2 2]
mat_t = mat_t * r_diag[..., :, None]

# Update last basis to current reflected basis
last_e_s = e_r_s[..., i, :]
last_e_p = e_r_p[..., i, :]

# Final transformation to receiver basis
# [num_tx num_rx num_path_candidates 2 2]
mat_cob_rx = sp_rotation_matrix(
last_e_s,
last_e_p,
e_rx_s,
e_rx_p,
)
mat_t = jnp.einsum("...ij,...jk->...ik", mat_cob_rx, mat_t)

# Project transmitter field onto transmitter s-p basis
# [num_tx num_rx num_path_candidates 2]
field_tx_s = jnp.sum(fields_i * e_tx_s, axis=-1, keepdims=False)
field_tx_p = jnp.sum(fields_i * e_tx_p, axis=-1, keepdims=False)
field_tx_sp = jnp.stack([field_tx_s, field_tx_p], axis=-1)

# Apply transfer matrix to get field at receiver in receiver s-p basis
# [num_tx num_rx num_path_candidates 2]
field_rx_sp = jnp.einsum("...ij,...j->...i", mat_t, field_tx_sp)

# Project back to Cartesian coordinates using receiver basis
# [num_tx num_rx num_path_candidates 3]
fields_r = fields_r_s * e_r_s[..., -1, :] + fields_r_p * e_r_p[..., -1, :]
fields_r = (
field_rx_sp[..., 0:1] * e_rx_s + field_rx_sp[..., 1:2] * e_rx_p
)
else:
# [num_tx num_rx num_path_candidates 3]
fields_r = fields_i
Expand Down
23 changes: 9 additions & 14 deletions differt/tests/plugins/test_deepmimo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# ruff: noqa: ERA001
from dataclasses import asdict
from itertools import chain

Expand All @@ -10,7 +9,7 @@
import pytest
from jaxtyping import PRNGKeyArray

from differt.em import materials
from differt.em import materials, z_0
from differt.geometry import TriangleMesh
from differt.plugins import deepmimo
from differt.scene import TriangleScene
Expand Down Expand Up @@ -219,16 +218,12 @@ def test_match_sionna_on_simple_street_canyon() -> None:
a = a[:, 0, :, 0, :, :] # Take only the first TX and RX polarization
a = a[..., 0] # Take only the first time instant

# TODO: Understand why phase and power are not matching

del a

# chex.assert_trees_all_equal(
# dm.phase,
# jnp.angle(a, deg=True)
# )
chex.assert_trees_all_close(
dm.phase,
jnp.angle(a, deg=True),
)

# chex.assert_trees_all_equal(
# dm.power,
# 10.0 * jnp.log10(jnp.abs(a)**2 / z_0),
# )
chex.assert_trees_all_close(
dm.power,
10.0 * jnp.log10(jnp.abs(a) ** 2 / z_0),
)