diff --git a/docs/metrics.md b/docs/metrics.md index ee94eb2..e0ae5a6 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -49,6 +49,30 @@ $$ATE = \frac{1}{L}\sum_{i=1}^{L} \|\mathbf{p}_i - \mathbf{p}_i^*\|_2$$ $$RTE = \frac{1}{L-\Delta}\sum_{i=1}^{L-\Delta} \|(\mathbf{p}_{i+\Delta} - \mathbf{p}_i) - (\mathbf{p}_{i+\Delta}^* - \mathbf{p}_i^*)\|_2$$ +#### Dynamic Time Warping (DTW) + +**DTW Distance** measures the minimum-cost temporal alignment between two trajectories[19]. Unlike MSE or ATE which compare trajectories timestep-by-timestep, DTW finds the optimal warping to align sequences that may differ in length or timing. This is critical for evaluating Vision-Language-Action (VLA) models and policies using action chunking (e.g., ACT, Diffusion Policy), where predicted trajectories may be temporally misaligned with demonstrations. + +DTW is computed via dynamic programming on an accumulated cost matrix: +$$D[i,j] = C[i,j] + \min(D[i-1,j], D[i,j-1], D[i-1,j-1])$$ + +where $C[i,j] = \|\mathbf{q}_i - \mathbf{r}_j\|_2$ is the Euclidean distance between predicted point $\mathbf{q}_i$ and reference point $\mathbf{r}_j$. The final DTW distance is $D[n-1, m-1]$. + +**nDTW (Normalized DTW)** maps the raw DTW distance to a [0, 1] score: +$$\text{nDTW} = \exp\left(-\frac{\text{DTW}}{|R| \cdot d}\right)$$ + +where $|R|$ is the reference trajectory length and $d$ is a normalization constant (typically the mean step distance of the reference). Higher nDTW indicates better similarity (1.0 = perfect match). + +**SDTW (Success-weighted DTW)** combines trajectory fidelity with task success: +$$\text{SDTW} = \text{nDTW} \times \text{Success}$$ + +This captures both "did you succeed?" and "did you follow the right path?" If the task failed, SDTW = 0 regardless of trajectory similarity. + +**Use cases:** +- Evaluating VLA models where predicted trajectories may be temporally misaligned +- Comparing policies that use action chunking (ACT, Diffusion Policy) +- Benchmarking across demonstrations with different execution speeds + ### 1.2.3 Vision-Language Alignment Metrics #### BLEU Score @@ -147,3 +171,5 @@ $$MU = \max_t(\text{RAM}_t + \text{VRAM}_t)$$ [17] J. Hartmanis and R. E. Stearns, "On the computational complexity of algorithms," Trans. Am. Math. Soc., vol. 117, p. 285, May 1965. [18] X.-H. Sun and D. Wang, "APC," Perform. Eval. Rev., vol. 40, pp. 125–130, Oct. 2012. + +[19] G. Ilharco, V. Jain, A. Ku, E. Ie, and J. Baldridge, "General Evaluation for Instruction Conditioned Navigation using Dynamic Time Warping," arXiv preprint arXiv:1907.05446, NeurIPS ViGIL Workshop, 2019. diff --git a/examples/dtw_example.py b/examples/dtw_example.py new file mode 100644 index 0000000..0041fef --- /dev/null +++ b/examples/dtw_example.py @@ -0,0 +1,234 @@ +"""Example demonstrating DTW metrics for trajectory evaluation. + +This example shows how to use DTWDistance, NormalizedDTW, and SuccessWeightedDTW +for evaluating trajectories that may have different lengths or temporal alignment. +""" + +import torch + +from robometric_frame import DTWDistance, NormalizedDTW, SuccessWeightedDTW + + +def main() -> None: + """Demonstrate DTW metrics usage.""" + print("=" * 70) + print("FRAME - Dynamic Time Warping (DTW) Metrics Example") + print("=" * 70) + + # Example 1: Identical trajectories + print("\nExample 1: Identical Trajectories") + print("-" * 70) + dtw = DTWDistance() + ndtw = NormalizedDTW() + sdtw = SuccessWeightedDTW() + + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + predicted = reference.clone() + + dtw.update(predicted, reference) + ndtw.update(predicted, reference) + sdtw.update(predicted, reference, success=torch.tensor(True)) + + print(f"Reference trajectory: {reference.shape[0]} points") + print(f"Predicted trajectory: {predicted.shape[0]} points (identical)") + print(f"DTW Distance: {dtw.compute():.4f} (lower = better)") + print(f"Normalized DTW: {ndtw.compute():.4f} (higher = better, range [0,1])") + print(f"Success-weighted DTW: {sdtw.compute():.4f}") + + # Example 2: Different lengths (core use case) + print("\nExample 2: Trajectories of Different Lengths") + print("-" * 70) + dtw.reset() + ndtw.reset() + + # Reference: 4 points along a straight line + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + # Predicted: 7 points along the same line (different density - e.g., action chunking) + predicted = torch.tensor( + [ + [0.0, 0.0], + [0.5, 0.0], + [1.0, 0.0], + [1.5, 0.0], + [2.0, 0.0], + [2.5, 0.0], + [3.0, 0.0], + ] + ) + + dtw.update(predicted, reference) + ndtw.update(predicted, reference) + + print(f"Reference: {reference.shape[0]} points") + print(f"Predicted: {predicted.shape[0]} points (same path, higher density)") + print(f"DTW Distance: {dtw.compute():.4f}") + print(f"Normalized DTW: {ndtw.compute():.4f}") + print("Note: High nDTW because the paths align well despite different lengths") + + # Example 3: Temporal shift (hesitation) + print("\nExample 3: Temporal Shift (Hesitation at Start)") + print("-" * 70) + dtw.reset() + ndtw.reset() + + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + # Predicted: hesitates at start, then catches up + predicted = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + dtw.update(predicted, reference) + ndtw.update(predicted, reference) + + print(f"Reference: {reference.tolist()}") + print(f"Predicted: {predicted.tolist()}") + print(f"DTW Distance: {dtw.compute():.4f}") + print(f"Normalized DTW: {ndtw.compute():.4f}") + print("Note: DTW tolerates hesitation - MSE would heavily penalize this!") + + # Example 4: Success-weighted DTW + print("\nExample 4: Success-weighted DTW") + print("-" * 70) + sdtw = SuccessWeightedDTW() + + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + # Success case + sdtw.update(predicted, reference, success=torch.tensor(True)) + print(f"Task succeeded: SDTW = {sdtw.compute():.4f}") + + # Failure case + sdtw.reset() + sdtw.update(predicted, reference, success=torch.tensor(False)) + print(f"Task failed: SDTW = {sdtw.compute():.4f}") + print("Note: SDTW=0 when task fails, regardless of trajectory quality") + + # Example 5: Comparing models + print("\nExample 5: Comparing Different Models") + print("-" * 70) + + # Demonstration trajectory (what human did) + demo = torch.tensor( + [ + [0.0, 0.0, 0.0], # Start + [0.5, 0.0, 0.0], # Reach phase + [1.0, 0.0, 0.0], + [1.0, 0.5, 0.0], # Adjust phase + [1.0, 1.0, 0.0], # Target position + ] + ) + + # Model A: Good trajectory (follows demonstration) + model_a = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.3, 0.0, 0.0], + [0.6, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.3, 0.0], + [1.0, 0.6, 0.0], + [1.0, 1.0, 0.0], + ] + ) + + # Model B: Poor trajectory (goes wrong direction) + model_b = torch.tensor( + [ + [0.0, 0.0, 0.0], + [-0.5, 0.0, 0.0], + [-1.0, 0.0, 0.0], + [-1.0, -0.5, 0.0], + [-1.0, -1.0, 0.0], + ] + ) + + ndtw_a = NormalizedDTW() + ndtw_b = NormalizedDTW() + + ndtw_a.update(model_a, demo) + ndtw_b.update(model_b, demo) + + print("Demonstration trajectory (3D end-effector position):") + print(f" Points: {demo.shape[0]}") + print("\nModel A (follows demonstration, different timing):") + print(f" Points: {model_a.shape[0]}") + print(f" nDTW: {ndtw_a.compute():.4f}") + print("\nModel B (wrong direction):") + print(f" Points: {model_b.shape[0]}") + print(f" nDTW: {ndtw_b.compute():.4f}") + + # Example 6: 7-DoF robotic arm + print("\nExample 6: 7-DoF Robotic Arm Evaluation") + print("-" * 70) + ndtw = NormalizedDTW() + + # Simulated 7-DoF joint trajectory (40 timesteps) + torch.manual_seed(42) + demo_7dof = torch.cumsum(torch.randn(40, 7) * 0.1, dim=0) + # Model prediction (47 timesteps - hesitated during execution) + pred_7dof = torch.cumsum(torch.randn(47, 7) * 0.1, dim=0) + # Make it similar to demo by adding demo values + pred_7dof[:40] = demo_7dof + torch.randn(40, 7) * 0.05 # Add noise + + ndtw.update(pred_7dof, demo_7dof) + print(f"Demo trajectory: {demo_7dof.shape} (timesteps, DoF)") + print(f"Predicted: {pred_7dof.shape} (different length)") + print(f"nDTW Score: {ndtw.compute():.4f}") + + # Example 7: Multiple trajectory evaluation + print("\nExample 7: Evaluating Multiple Episodes") + print("-" * 70) + sdtw = SuccessWeightedDTW() + + # Simulate 5 episodes with varying success + torch.manual_seed(123) + episodes = [ + (torch.randn(10, 3), torch.randn(10, 3), True), + (torch.randn(12, 3), torch.randn(10, 3), True), + (torch.randn(8, 3), torch.randn(10, 3), False), + (torch.randn(10, 3), torch.randn(10, 3), True), + (torch.randn(11, 3), torch.randn(10, 3), False), + ] + + for i, (pred, ref, success) in enumerate(episodes): + sdtw.update(pred, ref, success=torch.tensor(success)) + print( + f" Episode {i + 1}: pred={pred.shape[0]}pts, ref={ref.shape[0]}pts, success={success}" + ) + + print(f"\nAverage SDTW across all episodes: {sdtw.compute():.4f}") + print("(This combines trajectory quality with task success)") + + # Example 8: Custom normalization factor + print("\nExample 8: Custom Normalization Factor") + print("-" * 70) + + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + predicted = torch.tensor([[0.0, 0.1], [1.0, 0.1], [2.0, 0.1]]) # Small offset + + ndtw_auto = NormalizedDTW() + ndtw_custom = NormalizedDTW(normalization_factor=0.5) + + ndtw_auto.update(predicted, reference) + ndtw_custom.update(predicted, reference) + + print(f"Automatic normalization: nDTW = {ndtw_auto.compute():.4f}") + print(f"Custom normalization (d=0.5): nDTW = {ndtw_custom.compute():.4f}") + print("Note: Custom d allows tuning sensitivity to trajectory differences") + + print("\n" + "=" * 70) + print("Key Takeaways:") + print("-" * 70) + print("1. DTW Distance: Raw alignment cost (lower = more similar)") + print("2. nDTW: Normalized to [0,1] (higher = more similar)") + print("3. SDTW: nDTW weighted by task success") + print("") + print("When to use DTW over MSE/ATE:") + print(" - Trajectories have different lengths") + print(" - Temporal alignment varies (hesitation, speed differences)") + print(" - Using action chunking (ACT, Diffusion Policy)") + print(" - Comparing across demonstrations with different timing") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/src/robometric_frame/__init__.py b/src/robometric_frame/__init__.py index f338e21..d68b044 100644 --- a/src/robometric_frame/__init__.py +++ b/src/robometric_frame/__init__.py @@ -19,9 +19,12 @@ from robometric_frame.trajectory_quality import ( AbsoluteTrajectoryError, CurvatureChange, + DTWDistance, + NormalizedDTW, PathLength, PathSmoothness, RelativeTrajectoryError, + SuccessWeightedDTW, ) __all__ = [ @@ -29,13 +32,16 @@ "ActionAccuracy", "CollisionRate", "CurvatureChange", + "DTWDistance", "InferenceLatency", "MemoryUsage", + "NormalizedDTW", "ObstacleProximity", "PathLength", "PathSmoothness", "RelativeTrajectoryError", "RiskFactor", "SuccessRate", + "SuccessWeightedDTW", "TaskCompletionRate", ] diff --git a/src/robometric_frame/trajectory_quality/__init__.py b/src/robometric_frame/trajectory_quality/__init__.py index 52f60c7..4d9759e 100644 --- a/src/robometric_frame/trajectory_quality/__init__.py +++ b/src/robometric_frame/trajectory_quality/__init__.py @@ -1,11 +1,13 @@ """Trajectory quality metrics for robotics policy evaluation. This module provides metrics for evaluating the quality of robot trajectories, -including path length, smoothness, curvature change, and trajectory errors. +including path length, smoothness, curvature change, trajectory errors, and +dynamic time warping metrics for temporally misaligned trajectories. """ from robometric_frame.trajectory_quality.absolute_trajectory_error import AbsoluteTrajectoryError from robometric_frame.trajectory_quality.curvature_change import CurvatureChange +from robometric_frame.trajectory_quality.dtw import DTWDistance, NormalizedDTW, SuccessWeightedDTW from robometric_frame.trajectory_quality.path_length import PathLength from robometric_frame.trajectory_quality.path_smoothness import PathSmoothness from robometric_frame.trajectory_quality.relative_trajectory_error import RelativeTrajectoryError @@ -13,7 +15,10 @@ __all__ = [ "AbsoluteTrajectoryError", "CurvatureChange", + "DTWDistance", + "NormalizedDTW", "PathLength", "PathSmoothness", "RelativeTrajectoryError", + "SuccessWeightedDTW", ] diff --git a/src/robometric_frame/trajectory_quality/dtw.py b/src/robometric_frame/trajectory_quality/dtw.py new file mode 100644 index 0000000..dfa9fce --- /dev/null +++ b/src/robometric_frame/trajectory_quality/dtw.py @@ -0,0 +1,646 @@ +"""Dynamic Time Warping (DTW) metrics for robotics policy trajectory evaluation. + +DTW-based metrics measure trajectory similarity while allowing for temporal misalignment. +Unlike MSE which requires point-to-point correspondence, DTW finds the optimal warping +between sequences of different lengths or timing. + +Reference: + G. Ilharco, V. Jain, A. Ku, E. Ie, and J. Baldridge, "General Evaluation for + Instruction Conditioned Navigation using Dynamic Time Warping," arXiv:1907.05446, + NeurIPS ViGIL Workshop, 2019. +""" + +from typing import Any, Optional + +import torch +from torch import Tensor +from torchmetrics import Metric + + +def _compute_dtw(predicted: Tensor, reference: Tensor) -> Tensor: + """Compute DTW distance between two trajectories using dynamic programming. + + Args: + predicted: Predicted trajectory tensor of shape (T_pred, D). + reference: Reference trajectory tensor of shape (T_ref, D). + + Returns: + DTW distance as a scalar tensor. + + Note: + Memory complexity is O(T_pred * T_ref) for the cost matrices. + For very long trajectories, this may require significant memory. + """ + # Compute pairwise cost matrix using Euclidean distances + # Shape: (T_pred, T_ref) + cost_matrix = torch.cdist(predicted.unsqueeze(0), reference.unsqueeze(0), p=2.0).squeeze(0) + + t_pred, t_ref = cost_matrix.shape + + # Build accumulated cost matrix using dynamic programming + # D[i,j] = C[i,j] + min(D[i-1,j], D[i,j-1], D[i-1,j-1]) + accumulated = torch.zeros_like(cost_matrix) + + # Initialize first element + accumulated[0, 0] = cost_matrix[0, 0] + + # Initialize first column + for i in range(1, t_pred): + accumulated[i, 0] = accumulated[i - 1, 0] + cost_matrix[i, 0] + + # Initialize first row + for j in range(1, t_ref): + accumulated[0, j] = accumulated[0, j - 1] + cost_matrix[0, j] + + # Fill in the rest of the accumulated cost matrix + for i in range(1, t_pred): + for j in range(1, t_ref): + accumulated[i, j] = cost_matrix[i, j] + torch.minimum( + torch.minimum(accumulated[i - 1, j], accumulated[i, j - 1]), + accumulated[i - 1, j - 1], + ) + + # DTW distance is the final cell + return accumulated[t_pred - 1, t_ref - 1] + + +def _compute_path_length(trajectory: Tensor) -> Tensor: + """Compute the total path length of a trajectory. + + Args: + trajectory: Trajectory tensor of shape (L, D). + + Returns: + Path length as a scalar tensor. + """ + if trajectory.shape[0] < 2: + return torch.tensor(0.0, device=trajectory.device, dtype=trajectory.dtype) + + deltas = trajectory[1:, :] - trajectory[:-1, :] + distances = torch.norm(deltas, p=2, dim=-1) + return distances.sum() + + +class DTWDistance(Metric): + r"""Compute Dynamic Time Warping (DTW) distance for trajectory evaluation. + + DTW distance measures the minimum-cost temporal alignment between predicted + and reference trajectories. Unlike MSE which compares trajectories timestep-by- + timestep, DTW finds the optimal warping to align sequences that may differ in + length or timing. + + DTW is calculated by building an accumulated cost matrix D where: + D[0,0] = C[0,0] + D[i,0] = D[i-1,0] + C[i,0] for i > 0 + D[0,j] = D[0,j-1] + C[0,j] for j > 0 + D[i,j] = C[i,j] + min(D[i-1,j], D[i,j-1], D[i-1,j-1]) for i,j > 0 + + where C[i,j] is the Euclidean distance between predicted[i] and reference[j]. + The final DTW distance is D[n-1, m-1]. + + This metric is particularly useful for evaluating VLA models and policies using + action chunking (e.g., ACT, Diffusion Policy) where predicted trajectories may + be temporally misaligned with demonstrations. + + This metric accumulates DTW distances across multiple trajectory pairs and returns + the average DTW distance when compute() is called. + + Args: + **kwargs: Additional keyword arguments passed to the base Metric class. + + Attributes: + higher_is_better: False - lower DTW distance indicates better similarity. + is_differentiable: False - DTW computation is not differentiable. + full_state_update: False - incremental state updates. + + Note: + Memory complexity is O(T_pred * T_ref) for the cost matrices. + For very long trajectories, this may require significant memory. + + Example: + >>> from robometric_frame.trajectory_quality import DTWDistance + >>> import torch + >>> metric = DTWDistance() + >>> # Identical trajectories (zero distance) + >>> predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(predicted, reference) + >>> metric.compute() + tensor(0.0000) + + Example (different lengths): + >>> # Trajectories of different lengths (the core use case) + >>> metric = DTWDistance() + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + >>> predicted = torch.tensor([[0.0, 0.0], [0.5, 0.0], [1.0, 0.0], [1.5, 0.0], + ... [2.0, 0.0], [2.5, 0.0], [3.0, 0.0]]) + >>> metric.update(predicted, reference) + >>> result = metric.compute() # Small value (same path, different density) + + Example (temporal shift): + >>> # Hesitation at start (same actions, different timing) + >>> metric = DTWDistance() + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> predicted = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], + ... [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(predicted, reference) + >>> result = metric.compute() # Small value (DTW tolerates hesitation) + """ + + higher_is_better: bool = False + is_differentiable: bool = False + full_state_update: bool = False + + # Dynamically added by add_state() in __init__ + total_dtw_distance: Tensor + num_trajectory_pairs: Tensor + + def __init__( + self, + **kwargs: Any, + ) -> None: + """Initialize the DTWDistance metric.""" + super().__init__(**kwargs) + + # Add metric states for distributed computation + self.add_state("total_dtw_distance", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_trajectory_pairs", default=torch.tensor(0), dist_reduce_fx="sum") + + def update( # pylint: disable=arguments-differ + self, predicted: Tensor, reference: Tensor + ) -> None: + """Update metric state with new predicted and reference trajectory pair. + + Args: + predicted: Predicted trajectory tensor of shape (T_pred, D) where: + - T_pred is the number of timesteps (can differ from T_ref) + - D is the spatial dimensionality (e.g., 2 for 2D, 3 for 3D, 7 for 7-DoF) + + Points should be ordered chronologically. + + reference: Reference (ground truth) trajectory tensor of shape (T_ref, D). + T_ref can differ from T_pred - this is the core advantage of DTW. + D must match predicted trajectory dimensionality. + + Raises: + ValueError: If trajectories have invalid shape (< 2 dimensions), + mismatched dimensionality, or insufficient points. + """ + if predicted.ndim != 2: + raise ValueError( + f"Predicted trajectory must have 2 dimensions (T, D), " + f"got {predicted.ndim}D tensor with shape {predicted.shape}" + ) + + if reference.ndim != 2: + raise ValueError( + f"Reference trajectory must have 2 dimensions (T, D), " + f"got {reference.ndim}D tensor with shape {reference.shape}" + ) + + if predicted.shape[-1] != reference.shape[-1]: + raise ValueError( + f"Predicted and reference trajectories must have the same dimensionality D, " + f"got predicted D={predicted.shape[-1]}, reference D={reference.shape[-1]}" + ) + + if predicted.shape[0] < 1: + raise ValueError( + f"Predicted trajectory must have at least 1 point, " + f"got {predicted.shape[0]} point(s)" + ) + + if reference.shape[0] < 1: + raise ValueError( + f"Reference trajectory must have at least 1 point, " + f"got {reference.shape[0]} point(s)" + ) + + # Convert to float for numerical operations + predicted = predicted.float() + reference = reference.float() + + # Compute DTW distance + dtw_distance = _compute_dtw(predicted, reference) + + # Update states + self.total_dtw_distance += dtw_distance # pylint: disable=no-member + self.num_trajectory_pairs += 1 # pylint: disable=no-member + + def compute(self) -> Tensor: + """Compute the average DTW distance across all trajectory pairs. + + Returns: + Average DTW distance as a scalar tensor. Lower values indicate + better trajectory similarity. + + Raises: + RuntimeError: If no trajectory pairs have been recorded. + """ + if self.num_trajectory_pairs == 0: # pylint: disable=no-member + raise RuntimeError( + "Cannot compute DTW distance: no trajectory pairs have been recorded. " + "Call update() with trajectory data before compute()." + ) + + return self.total_dtw_distance / self.num_trajectory_pairs # pylint: disable=no-member + + +class NormalizedDTW(Metric): + r"""Compute Normalized DTW (nDTW) for trajectory evaluation. + + nDTW normalizes the raw DTW distance and maps it to a [0, 1] score: + nDTW = exp(-DTW / (|R| * d)) + + where: + - DTW is the raw DTW distance + - |R| is the length of the reference trajectory (number of points) + - d is a normalization constant (average step distance of the reference, + or a user-specified value) + + Higher nDTW scores indicate better trajectory similarity (1.0 = perfect match, + approaches 0.0 for very dissimilar trajectories). + + This metric is particularly useful for evaluating VLA models and policies using + action chunking where predicted trajectories may be temporally misaligned. + + Args: + normalization_factor: Optional user-specified normalization constant d. + If None (default), automatically computed as the mean step distance + of the reference trajectory: PathLength(reference) / (len(reference) - 1). + **kwargs: Additional keyword arguments passed to the base Metric class. + + Attributes: + higher_is_better: True - higher nDTW indicates better similarity. + is_differentiable: False - DTW computation is not differentiable. + full_state_update: False - incremental state updates. + + Note: + Memory complexity is O(T_pred * T_ref) for the cost matrices. + For very long trajectories, this may require significant memory. + + Example: + >>> from robometric_frame.trajectory_quality import NormalizedDTW + >>> import torch + >>> metric = NormalizedDTW() + >>> # Identical trajectories (perfect score) + >>> predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(predicted, reference) + >>> metric.compute() + tensor(1.0000) + + Example (different lengths): + >>> # Trajectories of different lengths + >>> metric = NormalizedDTW() + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + >>> predicted = torch.tensor([[0.0, 0.0], [0.5, 0.0], [1.0, 0.0], [1.5, 0.0], + ... [2.0, 0.0], [2.5, 0.0], [3.0, 0.0]]) + >>> metric.update(predicted, reference) + >>> result = metric.compute() # High value (same path) + + Example (custom normalization): + >>> # Use custom normalization factor + >>> metric = NormalizedDTW(normalization_factor=0.5) + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> predicted = torch.tensor([[0.0, 0.1], [1.0, 0.1], [2.0, 0.1]]) + >>> metric.update(predicted, reference) + >>> result = metric.compute() + """ + + higher_is_better: bool = True + is_differentiable: bool = False + full_state_update: bool = False + + # Dynamically added by add_state() in __init__ + total_ndtw: Tensor + num_trajectory_pairs: Tensor + + def __init__( + self, + normalization_factor: Optional[float] = None, + **kwargs: Any, + ) -> None: + """Initialize the NormalizedDTW metric. + + Args: + normalization_factor: Optional user-specified normalization constant d. + If None, automatically computed as the mean step distance of the + reference trajectory. + **kwargs: Additional keyword arguments passed to the base Metric class. + """ + super().__init__(**kwargs) + + if normalization_factor is not None and normalization_factor <= 0: + raise ValueError(f"normalization_factor must be positive, got {normalization_factor}") + + self.normalization_factor = normalization_factor + + # Add metric states for distributed computation + self.add_state("total_ndtw", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_trajectory_pairs", default=torch.tensor(0), dist_reduce_fx="sum") + + def update( # pylint: disable=arguments-differ + self, predicted: Tensor, reference: Tensor + ) -> None: + """Update metric state with new predicted and reference trajectory pair. + + Args: + predicted: Predicted trajectory tensor of shape (T_pred, D) where: + - T_pred is the number of timesteps (can differ from T_ref) + - D is the spatial dimensionality + + Points should be ordered chronologically. + + reference: Reference (ground truth) trajectory tensor of shape (T_ref, D). + T_ref can differ from T_pred. D must match predicted dimensionality. + + Raises: + ValueError: If trajectories have invalid shape, mismatched dimensionality, + or insufficient points. + """ + if predicted.ndim != 2: + raise ValueError( + f"Predicted trajectory must have 2 dimensions (T, D), " + f"got {predicted.ndim}D tensor with shape {predicted.shape}" + ) + + if reference.ndim != 2: + raise ValueError( + f"Reference trajectory must have 2 dimensions (T, D), " + f"got {reference.ndim}D tensor with shape {reference.shape}" + ) + + if predicted.shape[-1] != reference.shape[-1]: + raise ValueError( + f"Predicted and reference trajectories must have the same dimensionality D, " + f"got predicted D={predicted.shape[-1]}, reference D={reference.shape[-1]}" + ) + + if predicted.shape[0] < 1: + raise ValueError( + f"Predicted trajectory must have at least 1 point, " + f"got {predicted.shape[0]} point(s)" + ) + + if reference.shape[0] < 1: + raise ValueError( + f"Reference trajectory must have at least 1 point, " + f"got {reference.shape[0]} point(s)" + ) + + # Convert to float for numerical operations + predicted = predicted.float() + reference = reference.float() + + # Compute DTW distance + dtw_distance = _compute_dtw(predicted, reference) + + # Compute normalization factor d + ref_length = reference.shape[0] + if self.normalization_factor is not None: + d = self.normalization_factor + elif ref_length == 1: + # Edge case: single-point reference, use Euclidean distance + d = torch.norm(predicted[0] - reference[0], p=2).item() + if d == 0: + d = 1.0 # Avoid division by zero for identical single points + else: + # Compute mean step distance of reference trajectory + path_length = _compute_path_length(reference) + d = (path_length / (ref_length - 1)).item() + if d == 0: + d = 1.0 # Avoid division by zero for stationary trajectories + + # Compute nDTW = exp(-DTW / (|R| * d)) + ndtw = torch.exp(-dtw_distance / (ref_length * d)) + + # Update states + self.total_ndtw += ndtw # pylint: disable=no-member + self.num_trajectory_pairs += 1 # pylint: disable=no-member + + def compute(self) -> Tensor: + """Compute the average nDTW score across all trajectory pairs. + + Returns: + Average nDTW score as a scalar tensor in range [0, 1]. + Higher values indicate better trajectory similarity. + + Raises: + RuntimeError: If no trajectory pairs have been recorded. + """ + if self.num_trajectory_pairs == 0: # pylint: disable=no-member + raise RuntimeError( + "Cannot compute nDTW: no trajectory pairs have been recorded. " + "Call update() with trajectory data before compute()." + ) + + return self.total_ndtw / self.num_trajectory_pairs # pylint: disable=no-member + + +class SuccessWeightedDTW(Metric): + r"""Compute Success-weighted DTW (SDTW) for trajectory evaluation. + + SDTW combines trajectory fidelity with task success: + SDTW = nDTW * Success + + where: + - nDTW is the normalized DTW score (see NormalizedDTW) + - Success is a binary indicator (1 if task succeeded, 0 if not) + + If the task failed, SDTW = 0 regardless of trajectory similarity. This captures + both "did you succeed?" and "did you follow the right path?" + + This metric is particularly useful for benchmarking policies where both task + completion and trajectory quality matter. + + Args: + normalization_factor: Optional user-specified normalization constant d. + If None (default), automatically computed as the mean step distance + of the reference trajectory. + **kwargs: Additional keyword arguments passed to the base Metric class. + + Attributes: + higher_is_better: True - higher SDTW indicates better performance. + is_differentiable: False - DTW computation is not differentiable. + full_state_update: False - incremental state updates. + + Note: + Memory complexity is O(T_pred * T_ref) for the cost matrices. + For very long trajectories, this may require significant memory. + + Example: + >>> from robometric_frame.trajectory_quality import SuccessWeightedDTW + >>> import torch + >>> metric = SuccessWeightedDTW() + >>> # Successful task with good trajectory + >>> predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(predicted, reference, success=torch.tensor(True)) + >>> metric.compute() + tensor(1.0000) + + Example (failed task): + >>> # Failed task (SDTW = 0 regardless of trajectory quality) + >>> metric = SuccessWeightedDTW() + >>> predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(predicted, reference, success=torch.tensor(False)) + >>> metric.compute() + tensor(0.0000) + + Example (multiple updates): + >>> # Mix of successes and failures + >>> metric = SuccessWeightedDTW() + >>> ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> pred = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> metric.update(pred, ref, success=torch.tensor(True)) # SDTW = 1.0 + >>> metric.update(pred, ref, success=torch.tensor(False)) # SDTW = 0.0 + >>> metric.compute() # Average: 0.5 + tensor(0.5000) + """ + + higher_is_better: bool = True + is_differentiable: bool = False + full_state_update: bool = False + + # Dynamically added by add_state() in __init__ + total_sdtw: Tensor + num_trajectory_pairs: Tensor + + def __init__( + self, + normalization_factor: Optional[float] = None, + **kwargs: Any, + ) -> None: + """Initialize the SuccessWeightedDTW metric. + + Args: + normalization_factor: Optional user-specified normalization constant d. + If None, automatically computed as the mean step distance of the + reference trajectory. + **kwargs: Additional keyword arguments passed to the base Metric class. + """ + super().__init__(**kwargs) + + if normalization_factor is not None and normalization_factor <= 0: + raise ValueError(f"normalization_factor must be positive, got {normalization_factor}") + + self.normalization_factor = normalization_factor + + # Add metric states for distributed computation + self.add_state("total_sdtw", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_trajectory_pairs", default=torch.tensor(0), dist_reduce_fx="sum") + + def update( # pylint: disable=arguments-differ + self, predicted: Tensor, reference: Tensor, success: Tensor + ) -> None: + """Update metric state with new trajectory pair and success indicator. + + Args: + predicted: Predicted trajectory tensor of shape (T_pred, D) where: + - T_pred is the number of timesteps (can differ from T_ref) + - D is the spatial dimensionality + + Points should be ordered chronologically. + + reference: Reference (ground truth) trajectory tensor of shape (T_ref, D). + T_ref can differ from T_pred. D must match predicted dimensionality. + + success: Boolean or 0/1 tensor indicating task success. + If False/0, SDTW will be 0 regardless of trajectory similarity. + + Raises: + ValueError: If trajectories have invalid shape, mismatched dimensionality, + or insufficient points. + """ + if predicted.ndim != 2: + raise ValueError( + f"Predicted trajectory must have 2 dimensions (T, D), " + f"got {predicted.ndim}D tensor with shape {predicted.shape}" + ) + + if reference.ndim != 2: + raise ValueError( + f"Reference trajectory must have 2 dimensions (T, D), " + f"got {reference.ndim}D tensor with shape {reference.shape}" + ) + + if predicted.shape[-1] != reference.shape[-1]: + raise ValueError( + f"Predicted and reference trajectories must have the same dimensionality D, " + f"got predicted D={predicted.shape[-1]}, reference D={reference.shape[-1]}" + ) + + if predicted.shape[0] < 1: + raise ValueError( + f"Predicted trajectory must have at least 1 point, " + f"got {predicted.shape[0]} point(s)" + ) + + if reference.shape[0] < 1: + raise ValueError( + f"Reference trajectory must have at least 1 point, " + f"got {reference.shape[0]} point(s)" + ) + + # Convert to float for numerical operations + predicted = predicted.float() + reference = reference.float() + + # Convert success to float (0.0 or 1.0) + success_value = success.float() if isinstance(success, Tensor) else float(success) + if isinstance(success_value, Tensor): + success_value = success_value.item() + + # If task failed, SDTW is 0 + if success_value == 0: + sdtw = torch.tensor(0.0, device=predicted.device, dtype=predicted.dtype) + else: + # Compute DTW distance + dtw_distance = _compute_dtw(predicted, reference) + + # Compute normalization factor d + ref_length = reference.shape[0] + if self.normalization_factor is not None: + d = self.normalization_factor + elif ref_length == 1: + # Edge case: single-point reference, use Euclidean distance + d = torch.norm(predicted[0] - reference[0], p=2).item() + if d == 0: + d = 1.0 # Avoid division by zero for identical single points + else: + # Compute mean step distance of reference trajectory + path_length = _compute_path_length(reference) + d = (path_length / (ref_length - 1)).item() + if d == 0: + d = 1.0 # Avoid division by zero for stationary trajectories + + # Compute nDTW = exp(-DTW / (|R| * d)) + ndtw = torch.exp(-dtw_distance / (ref_length * d)) + + # SDTW = nDTW * success + sdtw = ndtw * success_value + + # Update states + self.total_sdtw += sdtw # pylint: disable=no-member + self.num_trajectory_pairs += 1 # pylint: disable=no-member + + def compute(self) -> Tensor: + """Compute the average SDTW score across all trajectory pairs. + + Returns: + Average SDTW score as a scalar tensor in range [0, 1]. + Higher values indicate better overall performance (trajectory + quality weighted by task success). + + Raises: + RuntimeError: If no trajectory pairs have been recorded. + """ + if self.num_trajectory_pairs == 0: # pylint: disable=no-member + raise RuntimeError( + "Cannot compute SDTW: no trajectory pairs have been recorded. " + "Call update() with trajectory data before compute()." + ) + + return self.total_sdtw / self.num_trajectory_pairs # pylint: disable=no-member diff --git a/tests/test_dtw.py b/tests/test_dtw.py new file mode 100644 index 0000000..dcbf5ac --- /dev/null +++ b/tests/test_dtw.py @@ -0,0 +1,607 @@ +"""Tests for Dynamic Time Warping (DTW) metrics.""" + +import pytest +import torch + +from robometric_frame.trajectory_quality import DTWDistance, NormalizedDTW, SuccessWeightedDTW + + +class TestDTWDistance: + """Test suite for DTWDistance metric.""" + + def test_identical_trajectories(self) -> None: + """Test DTW distance for identical trajectories (should be zero).""" + metric = DTWDistance() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_different_lengths_same_path(self) -> None: + """Test DTW with trajectories of different lengths covering same path.""" + metric = DTWDistance() + # Reference: 4 points along a straight line + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + # Predicted: 7 points along the same line (different density) + predicted = torch.tensor( + [ + [0.0, 0.0], + [0.5, 0.0], + [1.0, 0.0], + [1.5, 0.0], + [2.0, 0.0], + [2.5, 0.0], + [3.0, 0.0], + ] + ) + metric.update(predicted, reference) + result = metric.compute() + # DTW should be small since paths are aligned + assert result < 1.0 + + def test_temporal_shift_hesitation(self) -> None: + """Test DTW handles temporal shift (hesitation at start).""" + metric = DTWDistance() + # Reference: direct path + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + # Predicted: hesitates at start, then catches up + predicted = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + metric.update(predicted, reference) + result = metric.compute() + # DTW should be small (tolerates hesitation) + assert result < 1.0 + + def test_completely_different_trajectories(self) -> None: + """Test DTW for completely different trajectories.""" + metric = DTWDistance() + # Reference: along x-axis + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + # Predicted: along y-axis, far away + predicted = torch.tensor([[0.0, 10.0], [0.0, 11.0], [0.0, 12.0]]) + metric.update(predicted, reference) + result = metric.compute() + # DTW should be large + assert result > 10.0 + + def test_3d_trajectory(self) -> None: + """Test DTW with 3D trajectories.""" + metric = DTWDistance() + predicted = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_7dof_trajectory(self) -> None: + """Test DTW with 7-DoF action space.""" + metric = DTWDistance() + # Simulating 7-DoF robot arm action space + predicted = torch.randn(10, 7) + reference = predicted.clone() # Identical + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0), atol=1e-5) + + def test_1d_trajectory(self) -> None: + """Test DTW with 1D trajectories.""" + metric = DTWDistance() + predicted = torch.tensor([[0.0], [1.0], [2.0]]) + reference = torch.tensor([[0.0], [1.0], [2.0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_single_point_trajectories(self) -> None: + """Test DTW with single-point trajectories.""" + metric = DTWDistance() + predicted = torch.tensor([[1.0, 2.0]]) + reference = torch.tensor([[0.0, 0.0]]) + metric.update(predicted, reference) + result = metric.compute() + # Should be Euclidean distance between the two points + expected = torch.sqrt(torch.tensor(1.0**2 + 2.0**2)) + assert torch.isclose(result, expected) + + def test_multiple_updates(self) -> None: + """Test metric with multiple trajectory updates.""" + metric = DTWDistance() + # First pair: identical (DTW=0) + pred1 = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + ref1 = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + metric.update(pred1, ref1) + + # Second pair: offset by 1 unit + pred2 = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + ref2 = torch.tensor([[0.0, 1.0], [1.0, 1.0]]) + metric.update(pred2, ref2) + + result = metric.compute() + # Average of 0 and some positive value + assert result > 0 + + def test_reset(self) -> None: + """Test metric reset functionality.""" + metric = DTWDistance() + pred = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + ref = torch.tensor([[0.0, 1.0], [1.0, 1.0]]) + metric.update(pred, ref) + result1 = metric.compute() + + metric.reset() + + # After reset, compute identical trajectories + pred2 = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + ref2 = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + metric.update(pred2, ref2) + result2 = metric.compute() + + assert result2 < result1 # Second should be 0, first was positive + + def test_invalid_shape_1d_raises(self) -> None: + """Test that 1D tensor raises ValueError.""" + metric = DTWDistance() + with pytest.raises(ValueError, match="2 dimensions"): + metric.update(torch.tensor([0.0, 1.0, 2.0]), torch.tensor([[0.0, 0.0]])) + + def test_mismatched_dimensionality_raises(self) -> None: + """Test that mismatched D raises ValueError.""" + metric = DTWDistance() + with pytest.raises(ValueError, match="same dimensionality"): + metric.update( + torch.tensor([[0.0, 0.0]]), # D=2 + torch.tensor([[0.0, 0.0, 0.0]]), # D=3 + ) + + def test_empty_trajectory_raises(self) -> None: + """Test that empty trajectory raises ValueError.""" + metric = DTWDistance() + with pytest.raises(ValueError, match="at least 1 point"): + metric.update(torch.tensor([]).reshape(0, 2), torch.tensor([[0.0, 0.0]])) + + def test_compute_before_update_raises(self) -> None: + """Test that compute before update raises RuntimeError.""" + metric = DTWDistance() + with pytest.raises(RuntimeError, match="no trajectory pairs"): + metric.compute() + + def test_different_dtypes(self) -> None: + """Test with different tensor dtypes.""" + for dtype in [torch.float32, torch.float64]: + metric = DTWDistance() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=dtype) + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=dtype) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_int_dtype(self) -> None: + """Test with integer dtype (should work with automatic conversion).""" + metric = DTWDistance() + predicted = torch.tensor([[0, 0], [1, 0], [2, 0]]) + reference = torch.tensor([[0, 0], [1, 0], [2, 0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_gpu_if_available(self) -> None: + """Test metric on GPU if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + metric = DTWDistance().to("cuda") + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + metric.update(predicted, reference) + result = metric.compute() + assert result.device.type == "cuda" + assert torch.isclose(result, torch.tensor(0.0, device="cuda")) + + +class TestNormalizedDTW: + """Test suite for NormalizedDTW metric.""" + + def test_identical_trajectories(self) -> None: + """Test nDTW for identical trajectories (should be 1.0).""" + metric = NormalizedDTW() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(1.0)) + + def test_ndtw_range(self) -> None: + """Test that nDTW is always between 0 and 1.""" + metric = NormalizedDTW() + # Various trajectory pairs + pairs = [ + ( + torch.tensor([[0.0, 0.0], [1.0, 0.0]]), + torch.tensor([[0.0, 0.0], [1.0, 0.0]]), + ), + ( + torch.tensor([[0.0, 0.0], [1.0, 0.0]]), + torch.tensor([[0.0, 10.0], [1.0, 10.0]]), + ), + ( + torch.randn(5, 3), + torch.randn(7, 3), + ), + ] + for pred, ref in pairs: + metric.reset() + metric.update(pred, ref) + result = metric.compute() + assert 0.0 <= result <= 1.0, f"nDTW {result} out of range [0, 1]" + + def test_different_lengths_high_ndtw(self) -> None: + """Test that similar paths with different lengths have high nDTW.""" + metric = NormalizedDTW() + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + predicted = torch.tensor( + [ + [0.0, 0.0], + [0.5, 0.0], + [1.0, 0.0], + [1.5, 0.0], + [2.0, 0.0], + [2.5, 0.0], + [3.0, 0.0], + ] + ) + metric.update(predicted, reference) + result = metric.compute() + # Should be close to 1.0 since paths align well + assert result > 0.8 + + def test_dissimilar_trajectories_low_ndtw(self) -> None: + """Test that dissimilar trajectories have low nDTW.""" + metric = NormalizedDTW() + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + predicted = torch.tensor([[0.0, 100.0], [1.0, 100.0], [2.0, 100.0]]) + metric.update(predicted, reference) + result = metric.compute() + # Should be close to 0 for very dissimilar trajectories + assert result < 0.1 + + def test_custom_normalization_factor(self) -> None: + """Test nDTW with custom normalization factor.""" + metric_default = NormalizedDTW() + metric_custom = NormalizedDTW(normalization_factor=2.0) + + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.5], [1.0, 0.5], [2.0, 0.5]]) + + metric_default.update(predicted, reference) + metric_custom.update(predicted, reference) + + result_default = metric_default.compute() + result_custom = metric_custom.compute() + + # Results should differ due to different normalization + assert not torch.isclose(result_default, result_custom) + + def test_invalid_normalization_factor_raises(self) -> None: + """Test that non-positive normalization factor raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + NormalizedDTW(normalization_factor=0.0) + with pytest.raises(ValueError, match="must be positive"): + NormalizedDTW(normalization_factor=-1.0) + + def test_single_point_reference(self) -> None: + """Test nDTW with single-point reference (edge case).""" + metric = NormalizedDTW() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0]]) # Single point + metric.update(predicted, reference) + result = metric.compute() + # Should still be in valid range + assert 0.0 <= result <= 1.0 + + def test_identical_single_points(self) -> None: + """Test nDTW with identical single-point trajectories.""" + metric = NormalizedDTW() + predicted = torch.tensor([[1.0, 2.0]]) + reference = torch.tensor([[1.0, 2.0]]) + metric.update(predicted, reference) + result = metric.compute() + assert torch.isclose(result, torch.tensor(1.0)) + + def test_stationary_trajectory(self) -> None: + """Test nDTW with stationary (all same points) reference.""" + metric = NormalizedDTW() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) # Stationary + metric.update(predicted, reference) + result = metric.compute() + # Should be valid (d defaults to 1.0 to avoid division by zero) + assert 0.0 <= result <= 1.0 + + def test_multiple_updates(self) -> None: + """Test nDTW with multiple updates.""" + metric = NormalizedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + # Update with identical (nDTW=1.0) + metric.update(ref.clone(), ref) + # Update with different + metric.update(torch.tensor([[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]]), ref) + + result = metric.compute() + # Average should be between 0 and 1 + assert 0.0 < result < 1.0 + + def test_reset(self) -> None: + """Test nDTW reset functionality.""" + metric = NormalizedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + + # First: dissimilar + metric.update(torch.tensor([[0.0, 10.0], [1.0, 10.0]]), ref) + result1 = metric.compute() + + metric.reset() + + # After reset: identical + metric.update(ref.clone(), ref) + result2 = metric.compute() + + assert result2 > result1 + + def test_gpu_if_available(self) -> None: + """Test nDTW on GPU if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + metric = NormalizedDTW().to("cuda") + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + metric.update(predicted, reference) + result = metric.compute() + assert result.device.type == "cuda" + assert torch.isclose(result, torch.tensor(1.0, device="cuda")) + + +class TestSuccessWeightedDTW: + """Test suite for SuccessWeightedDTW metric.""" + + def test_identical_trajectories_success(self) -> None: + """Test SDTW for identical trajectories with success (should be 1.0).""" + metric = SuccessWeightedDTW() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + metric.update(predicted, reference, success=torch.tensor(True)) + result = metric.compute() + assert torch.isclose(result, torch.tensor(1.0)) + + def test_identical_trajectories_failure(self) -> None: + """Test SDTW for identical trajectories with failure (should be 0.0).""" + metric = SuccessWeightedDTW() + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + metric.update(predicted, reference, success=torch.tensor(False)) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_success_as_int(self) -> None: + """Test SDTW with success as integer tensor.""" + metric = SuccessWeightedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + metric.update(ref.clone(), ref, success=torch.tensor(1)) + result = metric.compute() + assert torch.isclose(result, torch.tensor(1.0)) + + metric.reset() + metric.update(ref.clone(), ref, success=torch.tensor(0)) + result = metric.compute() + assert torch.isclose(result, torch.tensor(0.0)) + + def test_success_as_float(self) -> None: + """Test SDTW with success as float tensor.""" + metric = SuccessWeightedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + metric.update(ref.clone(), ref, success=torch.tensor(1.0)) + result = metric.compute() + assert torch.isclose(result, torch.tensor(1.0)) + + def test_sdtw_range(self) -> None: + """Test that SDTW is always between 0 and 1.""" + metric = SuccessWeightedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + # Various scenarios + scenarios = [ + (ref.clone(), torch.tensor(True)), + (ref.clone(), torch.tensor(False)), + (torch.randn(5, 2), torch.tensor(True)), + (torch.randn(7, 2), torch.tensor(False)), + ] + + for pred, success in scenarios: + metric.reset() + metric.update(pred, ref, success=success) + result = metric.compute() + assert 0.0 <= result <= 1.0 + + def test_multiple_updates_mixed_success(self) -> None: + """Test SDTW with mixed success/failure updates.""" + metric = SuccessWeightedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + # Identical + success: SDTW = 1.0 + metric.update(ref.clone(), ref, success=torch.tensor(True)) + # Identical + failure: SDTW = 0.0 + metric.update(ref.clone(), ref, success=torch.tensor(False)) + + result = metric.compute() + # Average of 1.0 and 0.0 = 0.5 + assert torch.isclose(result, torch.tensor(0.5)) + + def test_custom_normalization_factor(self) -> None: + """Test SDTW with custom normalization factor.""" + metric = SuccessWeightedDTW(normalization_factor=1.0) + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + reference = torch.tensor([[0.0, 0.5], [1.0, 0.5], [2.0, 0.5]]) + metric.update(predicted, reference, success=torch.tensor(True)) + result = metric.compute() + assert 0.0 <= result <= 1.0 + + def test_invalid_normalization_factor_raises(self) -> None: + """Test that non-positive normalization factor raises ValueError.""" + with pytest.raises(ValueError, match="must be positive"): + SuccessWeightedDTW(normalization_factor=0.0) + + def test_reset(self) -> None: + """Test SDTW reset functionality.""" + metric = SuccessWeightedDTW() + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + + # First: failure + metric.update(ref.clone(), ref, success=torch.tensor(False)) + result1 = metric.compute() + assert torch.isclose(result1, torch.tensor(0.0)) + + metric.reset() + + # After reset: success + metric.update(ref.clone(), ref, success=torch.tensor(True)) + result2 = metric.compute() + assert torch.isclose(result2, torch.tensor(1.0)) + + def test_compute_before_update_raises(self) -> None: + """Test that compute before update raises RuntimeError.""" + metric = SuccessWeightedDTW() + with pytest.raises(RuntimeError, match="no trajectory pairs"): + metric.compute() + + def test_gpu_if_available(self) -> None: + """Test SDTW on GPU if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + metric = SuccessWeightedDTW().to("cuda") + predicted = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + reference = torch.tensor([[0.0, 0.0], [1.0, 0.0]], device="cuda") + metric.update(predicted, reference, success=torch.tensor(True, device="cuda")) + result = metric.compute() + assert result.device.type == "cuda" + assert torch.isclose(result, torch.tensor(1.0, device="cuda")) + + +class TestDTWIntegration: + """Integration tests for DTW metrics.""" + + def test_all_metrics_consistent(self) -> None: + """Test that DTW, nDTW, and SDTW are consistent.""" + dtw = DTWDistance() + ndtw = NormalizedDTW() + sdtw = SuccessWeightedDTW() + + pred = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + + dtw.update(pred, ref) + ndtw.update(pred, ref) + sdtw.update(pred, ref, success=torch.tensor(True)) + + dtw_result = dtw.compute() + ndtw_result = ndtw.compute() + sdtw_result = sdtw.compute() + + # Identical trajectories: DTW=0, nDTW=1, SDTW=1 + assert torch.isclose(dtw_result, torch.tensor(0.0)) + assert torch.isclose(ndtw_result, torch.tensor(1.0)) + assert torch.isclose(sdtw_result, torch.tensor(1.0)) + + def test_dtw_vs_mse_temporal_shift(self) -> None: + """Test that DTW handles temporal shifts better than MSE would.""" + dtw = DTWDistance() + + # Reference trajectory + ref = torch.tensor([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + + # Predicted: same path but with hesitation (pause at start) + pred_hesitation = torch.tensor([[0.0, 0.0], [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + + # Predicted: same path but faster + pred_fast = torch.tensor([[0.0, 0.0], [2.0, 0.0], [3.0, 0.0]]) + + dtw.update(pred_hesitation, ref) + dtw_hesitation = dtw.compute() + + dtw.reset() + dtw.update(pred_fast, ref) + dtw_fast = dtw.compute() + + # Both should have relatively low DTW (good alignment) + assert dtw_hesitation < 2.0 + assert dtw_fast < 2.0 + + def test_vla_action_chunking_scenario(self) -> None: + """Test realistic VLA model evaluation scenario with action chunking.""" + ndtw = NormalizedDTW() + + # Demonstration trajectory (what human did) + demo = torch.tensor( + [ + [0.0, 0.0, 0.0], # Start + [0.5, 0.0, 0.0], # Reach + [1.0, 0.0, 0.0], # Reach + [1.0, 0.5, 0.0], # Adjust + [1.0, 1.0, 0.0], # Grasp position + ] + ) + + # Model A: Same path, different chunk boundaries + model_a = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.3, 0.0, 0.0], + [0.6, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.3, 0.0], + [1.0, 0.6, 0.0], + [1.0, 1.0, 0.0], + ] + ) + + # Model B: Wrong direction + model_b = torch.tensor( + [ + [0.0, 0.0, 0.0], + [-0.5, 0.0, 0.0], + [-1.0, 0.0, 0.0], + [-1.0, -0.5, 0.0], + [-1.0, -1.0, 0.0], + ] + ) + + ndtw.update(model_a, demo) + score_a = ndtw.compute() + + ndtw.reset() + ndtw.update(model_b, demo) + score_b = ndtw.compute() + + # Model A should score higher than Model B + assert score_a > score_b + assert score_a > 0.7 # Good trajectory + assert score_b < 0.3 # Bad trajectory + + def test_large_trajectory(self) -> None: + """Test with large trajectories (memory/performance check).""" + dtw = DTWDistance() + # Create trajectories with 100 points each + pred = torch.cumsum(torch.randn(100, 3) * 0.1, dim=0) + ref = torch.cumsum(torch.randn(100, 3) * 0.1, dim=0) + + dtw.update(pred, ref) + result = dtw.compute() + + # Just verify it computes without error + assert result.ndim == 0 + assert result >= 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])