added cuda support for all pytorch mlg operations#25
Conversation
|
@angadbajwa thanks for doing this man! the automated testing environment doesn't seem to have a GPU, which will cause the unit tests to fail. I suppose you could just use |
There was a problem hiding this comment.
Pull request overview
This PR extends the PyTorch backend to support running core Matrix Lie Group operations on CUDA, and updates the standard Torch test suite to exercise both CPU and GPU execution when available.
Changes:
- Parameterize Torch standard tests over
device(CPU + conditional CUDA). - Update Torch Lie group implementations/utilities to allocate tensors on the correct device.
- Add/adjust identity/random helpers to better support device-specific operation.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/torch/test_standard_torch.py | Adds device parametrization and a CUDA skip condition; updates debug runner to call tests on CPU/GPU. |
| tests/standard_tests_torch.py | Threads device through the standard torch test helpers and replaces many numpy comparisons with torch equivalents. |
| pymlg/torch/utils.py | Updates batch_eye to allocate identity matrices directly on a specified device. |
| pymlg/torch/so3.py | Adds device consistency checks and updates allocations for CUDA compatibility; adds a device-required identity helper. |
| pymlg/torch/so2.py | Updates random() to accept a target device. |
| pymlg/torch/se3.py | Updates allocations to respect device; adds device checks and a device-required identity helper. |
| pymlg/torch/se23.py | Updates allocations to respect device; adds device checks and a device-required identity helper. |
| pymlg/torch/se2.py | Updates allocations to respect device; adds device checks and updates some allocations for CUDA. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # generating a (N, 1, 4) batched matrix to append | ||
| b1 = torch.tensor([0, 0, 0, 0]).reshape(1, 1, 4) | ||
| b1 = torch.tensor([0, 0, 0, 0], device=X.device).reshape(1, 1, 4) |
There was a problem hiding this comment.
b1 is created without specifying dtype, so it will default to an integer/float32 tensor and can cause dtype mismatch when concatenating with Xi (often float64). On CUDA this can raise an error. Construct b1 with the same dtype/device as X (or Xi) to keep concatenation valid.
| b1 = torch.tensor([0, 0, 0, 0], device=X.device).reshape(1, 1, 4) | |
| b1 = torch.tensor([0, 0, 0, 0], dtype=Xi.dtype, device=Xi.device).reshape(1, 1, 4) |
| Om = torch.Tensor([[0, -1], [1, 0]], device=T.device).repeat(T.shape[0], 1, 1) | ||
|
|
||
| A = torch.zeros(T.shape[0], 3, 3, dtype=T.dtype) |
There was a problem hiding this comment.
torch.Tensor(...) does not accept a device keyword, so this will error at runtime. Also, A is created on CPU (no device=...), which will fail when T is on CUDA. Create Om and A on T.device (and with the appropriate dtype) to avoid device mismatches.
| Om = torch.Tensor([[0, -1], [1, 0]], device=T.device).repeat(T.shape[0], 1, 1) | |
| A = torch.zeros(T.shape[0], 3, 3, dtype=T.dtype) | |
| Om = torch.tensor([[0, -1], [1, 0]], dtype=T.dtype, device=T.device).repeat(T.shape[0], 1, 1) | |
| A = torch.zeros(T.shape[0], 3, 3, dtype=T.dtype, device=T.device) |
| raise ValueError("Batch dimension for SE(3) components don't match.") | ||
|
|
||
| # then, check that both components are on the same device | ||
| assert C.device == r.device, "Components must be on the same device for SE3.from_components." |
There was a problem hiding this comment.
Using assert for input validation is unsafe here (asserts can be stripped with python -O). Since you already raise ValueError for shape mismatches, use a ValueError/RuntimeError for device mismatches as well to ensure this check always runs in production.
| assert C.device == r.device, "Components must be on the same device for SE3.from_components." | |
| if C.device != r.device: | |
| raise ValueError("Components must be on the same device for SE3.from_components.") |
| T = torch.zeros(C.shape[0], 3, 3, dtype=C.dtype) | ||
|
|
||
| # first, confirm that both components are allocated on the same device | ||
| assert C.device == r.device, "Components must be on the same device for SE2.from_components" |
There was a problem hiding this comment.
Using assert for input validation is unsafe (asserts can be stripped with python -O). Prefer raising a ValueError/RuntimeError for device mismatch so this remains enforced in production.
| assert C.device == r.device, "Components must be on the same device for SE2.from_components" | |
| if C.device != r.device: | |
| raise ValueError("Components must be on the same device for SE2.from_components") |
| @staticmethod | ||
| def identity(N=1, dtype=torch.float32): | ||
| return batch_eye(N, 4, 4, dtype=dtype) | ||
| def identity(device, N=1, dtype=torch.float64): |
There was a problem hiding this comment.
SE3.identity now requires a "device" argument, which breaks the MatrixLieGroupTorch.identity() interface and existing call sites that expect identity() to be callable with no args. Consider making device optional (defaulting to 'cpu' or inferred) to preserve compatibility.
| def identity(device, N=1, dtype=torch.float64): | |
| def identity(device=None, N=1, dtype=torch.float64): |
| return SO3.Exp(v).to(device) | ||
|
|
||
| @staticmethod | ||
| def identity(device, N=1, dtype=torch.float64): |
There was a problem hiding this comment.
This overrides MatrixLieGroupTorch.identity() (which takes no required args) with a required "device" parameter, which is a breaking API change and also prevents calling SO3.identity() without arguments. Consider keeping the base signature (e.g., device defaulting to 'cpu' or inferred) to preserve polymorphism and backward compatibility.
| def identity(device, N=1, dtype=torch.float64): | |
| def identity(device='cpu', N=1, dtype=torch.float64): |
| @staticmethod | ||
| def identity(N=1): | ||
| return batch_eye(N, 5, 5) | ||
| def identity(device, N=1, dtype=torch.float64): |
There was a problem hiding this comment.
SE23.identity now requires a "device" argument, which breaks the MatrixLieGroupTorch.identity() interface and existing call sites that expect identity() with no args. Consider making device optional (defaulting to 'cpu' or inferred) to preserve compatibility.
| def identity(device, N=1, dtype=torch.float64): | |
| def identity(device='cpu', N=1, dtype=torch.float64): |
| X = torch.zeros(b.shape[0], 4, 6, dtype=b.dtype, device=b.device) | ||
| X[:, 0:3, 0:3] = SO3.odot(b[0:3]) | ||
| X[:, 0:3, 3:6] = b[:, 3] * batch_eye(b.shape[0], 3, 3, dtype=b.dtype) | ||
| X[:, 0:3, 3:6] = b[:, 3] * batch_eye(b.shape[0], 3, 3, dtype=b.dtype, device=b.device) |
There was a problem hiding this comment.
In SE3.odot, b[0:3] slices the batch dimension, not the first three vector components. This gives incorrect results when batch size > 1 (and also passes a (N,4,1) tensor into SO3.odot). Use component slicing on the second dimension (e.g., b[:, 0:3]) so the odot operator is computed from the rotational part of each batch element.
| test.do_tests(SO3, device='cpu') | ||
| test.do_tests(SE3, device='cpu') | ||
| test.do_tests(SE23, device='cpu') | ||
| test.do_tests(SO2, device='cpu') | ||
| test.do_tests(SE2, device='cpu') | ||
|
|
||
| # if CUDA is available, perform tests on GPU | ||
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: | ||
| test.do_tests(SO3, device='cuda') | ||
| test.do_tests(SE3, device='cuda') | ||
| test.do_tests(SE23, device='cuda') | ||
| test.do_tests(SO2, device='cuda') | ||
| test.do_tests(SE2, device='cuda') |
There was a problem hiding this comment.
In the main debug block, do_tests is called with keyword argument "device", but StandardTestsTorch.do_tests now takes "test_device". Running this file directly will raise a TypeError; update the calls to use the new parameter name (or accept **kwargs / keep a backward-compatible alias).
| test.do_tests(SO3, device='cpu') | |
| test.do_tests(SE3, device='cpu') | |
| test.do_tests(SE23, device='cpu') | |
| test.do_tests(SO2, device='cpu') | |
| test.do_tests(SE2, device='cpu') | |
| # if CUDA is available, perform tests on GPU | |
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: | |
| test.do_tests(SO3, device='cuda') | |
| test.do_tests(SE3, device='cuda') | |
| test.do_tests(SE23, device='cuda') | |
| test.do_tests(SO2, device='cuda') | |
| test.do_tests(SE2, device='cuda') | |
| test.do_tests(SO3, test_device='cpu') | |
| test.do_tests(SE3, test_device='cpu') | |
| test.do_tests(SE23, test_device='cpu') | |
| test.do_tests(SO2, test_device='cpu') | |
| test.do_tests(SE2, test_device='cpu') | |
| # if CUDA is available, perform tests on GPU | |
| if torch.cuda.is_available() and torch.cuda.device_count() > 0: | |
| test.do_tests(SO3, test_device='cuda') | |
| test.do_tests(SE3, test_device='cuda') | |
| test.do_tests(SE23, test_device='cuda') | |
| test.do_tests(SO2, test_device='cuda') | |
| test.do_tests(SE2, test_device='cuda') |
| # change allclose assertion based on whether the inputs are torch tensors or numpy arrays | ||
| if x.dtype == torch.float32 or x.dtype == torch.float64: | ||
| assert torch.allclose(x, x_test, 1e-15) | ||
| else: | ||
| assert np.allclose(x, x_test, 1e-15) |
There was a problem hiding this comment.
This allclose branch is incorrect for two reasons: (1) torch.allclose(x, x_test, 1e-15) passes 1e-15 as rtol (third positional arg), not atol; and (2) the else branch calls np.allclose on torch tensors (and will break for CUDA tensors / non-float dtypes). Prefer a single torch.allclose call with explicit rtol/atol keywords (and if you truly need numpy, explicitly move+convert to CPU numpy first).
| # change allclose assertion based on whether the inputs are torch tensors or numpy arrays | |
| if x.dtype == torch.float32 or x.dtype == torch.float64: | |
| assert torch.allclose(x, x_test, 1e-15) | |
| else: | |
| assert np.allclose(x, x_test, 1e-15) | |
| assert torch.allclose(x, x_test, atol=1e-15) |
tests passing on a GeForce GTX 1650 Mobile / Max-Q w/ CUDA 12.8