-
Notifications
You must be signed in to change notification settings - Fork 7
added cuda support for all pytorch mlg operations #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9d55745
180360d
157b0b4
cfdc173
95c57a2
8cc3b9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,7 @@ class SE23(MatrixLieGroupTorch): | |||||||
| matrix_size = 5 | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def random(N=1): | ||||||||
| def random(N=1, device='cpu'): | ||||||||
| """ | ||||||||
| Generates a random batch of SE_2(3) matricies. | ||||||||
|
|
||||||||
|
|
@@ -39,7 +39,7 @@ def random(N=1): | |||||||
|
|
||||||||
| C = SO3.Exp(phi) | ||||||||
|
|
||||||||
| return SE23.from_components(C, v, r) | ||||||||
| return SE23.from_components(C, v, r).to(device) | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def from_components(C: torch.Tensor, v: torch.Tensor, r: torch.Tensor): | ||||||||
|
|
@@ -65,8 +65,11 @@ def from_components(C: torch.Tensor, v: torch.Tensor, r: torch.Tensor): | |||||||
| # firstly, check that batch dimension for all 3 components matches | ||||||||
| if not (C.shape[0] == v.shape[0] == r.shape[0]): | ||||||||
| raise ValueError("Batch dimension for SE_2(3) components don't match.") | ||||||||
|
|
||||||||
| # check that all components are on the same device | ||||||||
| assert C.device == v.device == r.device, "Components must be on the same device for SE23.from_components." | ||||||||
|
||||||||
| assert C.device == v.device == r.device, "Components must be on the same device for SE23.from_components." | |
| if not (C.device == v.device == r.device): | |
| raise ValueError("Components must be on the same device for SE23.from_components.") |
Copilot
AI
Apr 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SE23.identity now requires a "device" argument, which breaks the MatrixLieGroupTorch.identity() interface and existing call sites that expect identity() with no args. Consider making device optional (defaulting to 'cpu' or inferred) to preserve compatibility.
| def identity(device, N=1, dtype=torch.float64): | |
| def identity(device='cpu', N=1, dtype=torch.float64): |
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -50,7 +50,7 @@ def _left_jacobian_Q_matrix(xi_phi, xi_rho): | |||||||
| return Q.squeeze_() | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def random(N=1): | ||||||||
| def random(N=1, device='cpu'): | ||||||||
| """ | ||||||||
| Generates a random batch of SE_(3) matricies. | ||||||||
|
|
||||||||
|
|
@@ -70,7 +70,7 @@ def random(N=1): | |||||||
|
|
||||||||
| C = SO3.Exp(phi) | ||||||||
|
|
||||||||
| return SE3.from_components(C, r) | ||||||||
| return SE3.from_components(C, r).to(device) | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def from_components(C: torch.Tensor, r: torch.Tensor): | ||||||||
|
|
@@ -94,8 +94,11 @@ def from_components(C: torch.Tensor, r: torch.Tensor): | |||||||
| # firstly, check that batch dimension for all 3 components matches | ||||||||
| if not (C.shape[0] == r.shape[0]): | ||||||||
| raise ValueError("Batch dimension for SE(3) components don't match.") | ||||||||
|
|
||||||||
| # then, check that both components are on the same device | ||||||||
| assert C.device == r.device, "Components must be on the same device for SE3.from_components." | ||||||||
|
||||||||
| assert C.device == r.device, "Components must be on the same device for SE3.from_components." | |
| if C.device != r.device: | |
| raise ValueError("Components must be on the same device for SE3.from_components.") |
Copilot
AI
Apr 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
b1 is created without specifying dtype, so it will default to an integer/float32 tensor and can cause dtype mismatch when concatenating with Xi (often float64). On CUDA this can raise an error. Construct b1 with the same dtype/device as X (or Xi) to keep concatenation valid.
| b1 = torch.tensor([0, 0, 0, 0], device=X.device).reshape(1, 1, 4) | |
| b1 = torch.tensor([0, 0, 0, 0], dtype=Xi.dtype, device=Xi.device).reshape(1, 1, 4) |
Copilot
AI
Apr 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In SE3.odot, b[0:3] slices the batch dimension, not the first three vector components. This gives incorrect results when batch size > 1 (and also passes a (N,4,1) tensor into SO3.odot). Use component slicing on the second dimension (e.g., b[:, 0:3]) so the odot operator is computed from the rotational part of each batch element.
Copilot
AI
Apr 30, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SE3.identity now requires a "device" argument, which breaks the MatrixLieGroupTorch.identity() interface and existing call sites that expect identity() to be callable with no args. Consider making device optional (defaulting to 'cpu' or inferred) to preserve compatibility.
| def identity(device, N=1, dtype=torch.float64): | |
| def identity(device=None, N=1, dtype=torch.float64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using assert for input validation is unsafe (asserts can be stripped with python -O). Prefer raising a ValueError/RuntimeError for device mismatch so this remains enforced in production.