Add pure-PyTorch reference implementations#1189
Conversation
562e6c2 to
be21049
Compare
|
cc @viiccwen @vvvdwbvvv for review |
There was a problem hiding this comment.
Pull request overview
This PR introduces a pure-PyTorch reference implementation for QDP encoders plus backend detection/selection plumbing, enabling comparisons against the Rust+CUDA path and allowing parts of qumat_qdp to function when _qdp isn’t built.
Changes:
- Added PyTorch reference implementations for amplitude/angle/basis/IQP encodings with a string-based dispatcher.
- Introduced backend detection utilities and exposed backend info via
qumat_qdpexports; added explicit backend selection toQdpBenchmarkandQuantumDataLoader. - Added new tests for the PyTorch reference encoders and for behavior when
_qdpis unavailable; added a benchmark script supporting “encode-only” vs “end-to-end” modes.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
testing/qdp_python/test_torch_ref.py |
New unit tests for PyTorch reference encoders + optional cross-validation vs _qdp. |
testing/qdp_python/test_fallback.py |
Tests for backend detection and explicit PyTorch backend behavior when _qdp is missing. |
testing/conftest.py |
Adjusts skip logic so selected tests can run without _qdp. |
qdp/qdp-python/qumat_qdp/torch_ref.py |
Implements pure-PyTorch reference encoders and an encode() dispatcher. |
qdp/qdp-python/qumat_qdp/loader.py |
Adds explicit `.backend('rust' |
qdp/qdp-python/qumat_qdp/api.py |
Adds `.backend('rust' |
qdp/qdp-python/qumat_qdp/_backend.py |
Adds backend detection (Backend enum, get_backend, force_backend, get_qdp, get_torch). |
qdp/qdp-python/qumat_qdp/__init__.py |
Makes qumat_qdp importable without _qdp; exports BACKEND/Backend and safe _qdp symbols. |
qdp/qdp-python/benchmark/benchmark_pytorch_ref.py |
Adds benchmark script comparing PyTorch vs Mahout with --mode (encode-only/end-to-end). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
update at df90a6e |
|
I already implements the amplitude one and successfully increase the speed to slightly faster than |
- Implemented backend detection and selection logic in _backend.py, prioritizing Rust+CUDA, PyTorch, and fallback to None. - Added pure-PyTorch reference implementations for quantum data encoding methods in torch_ref.py, including amplitude, angle, basis, and IQP encoding. - Created comprehensive tests for fallback mechanisms and pure-PyTorch encodings in test_fallback.py and test_torch_ref.py, ensuring functionality without the Rust extension. - Enhanced error handling and validation across encoding methods to ensure robustness.
df90a6e to
d814d7a
Compare
rich7420
left a comment
There was a problem hiding this comment.
Thanks for the patch! @ryankert01
|
|
||
| # For each state index: amplitude = prod_k (sin if bit else cos) | ||
| # Shape: (batch, state_dim, num_qubits) via broadcasting | ||
| trig = bits.unsqueeze(0) * sin_vals.unsqueeze(1) + ( |
There was a problem hiding this comment.
Nit: We can save GPU VRAM and skip arithmetic steps by using torch.where(bits.bool().unsqueeze(0), sin_vals.unsqueeze(1), cos_vals.unsqueeze(1)) instead of adding and multiplying here. maybe? I think
rich7420
left a comment
There was a problem hiding this comment.
LGTM! thx for the update
Closes #1177
Related to #1227
Summary
This PR adds a pure-PyTorch reference backend to the QDP Python package to compare with our implementation of GPU kernel~
Benchmark Results
All runs: 100 batches x 64 vectors (except 18-qubit: 50 batches x 64), median of 3 trials.
Amplitude Encoding
Angle Encoding
IQP Encoding
Analysis
encodepath still pays per-batch GPU output allocation + D2H norm validation sync overhead.generate_batch_data+torch.tensor+ H2D transfer.Known Limitations
engine.encodeexpects per-sample basis indices; batch input format differs from PyTorch. Requires Rust API change to support.1 << num_qubitsas sample_size regardless of encoding method, causing a mismatch for IQP (which expectsn + n*(n-1)/2). Pre-existing Rust pipeline bug.