diff --git a/pyproject.toml b/pyproject.toml index 002720e..2785a60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ authors = [ { name = "eyjafjallac" }, ] dependencies = [ - "qten", + "qten>=0.4.2", ] [project.optional-dependencies] cpu = [ @@ -112,7 +112,7 @@ nb = [ "ipykernel>=6.29.0", "jupyterlab>=4.2.0", "notebook>=7.2.0", - "qten-plots>=0.1.0", + "qten-plots>=0.3.0", ] [tool.pytest.ini_options] @@ -135,7 +135,3 @@ files = ["src", "tests"] strict = true warn_unused_ignores = true disallow_any_generics = true - -[[tool.mypy.overrides]] -module = ["sympy", "sympy.*"] -ignore_missing_imports = true diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py deleted file mode 100644 index bf3dff7..0000000 --- a/src/qrg/wannier.py +++ /dev/null @@ -1,137 +0,0 @@ -import warnings -from typing import Any, cast - -from qten.geometries.fourier import fourier_transform -from qten.linalg.decompose import svd -from qten.linalg.tensors import Tensor -from qten.symbolics.hilbert_space import HilbertSpace -from qten.symbolics.state_space import MomentumSpace - - -def wannierize_k( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: - """ - Perform projective wannierization on the target bands with the seeding states in momentum space. - - Parameters - ---------- - eigenvectors : Tensor - Target bands/eigenvectors. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. - seeds : Tensor - Seed states in momentum space. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. - svd_threshold : float - Warn if the minimum singular value drops below this, indicating linearly dependent seeds - or poor overlap with target bands. - - Returns - ------- - Tensor - Wannierized states with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. - """ - if eigenvectors.rank() != 3 or seeds.rank() != 3: - raise ValueError("Both eigenvectors and seeds must be rank-3 Tensors.") - - # 1. Compute the overlap matrix for each momentum sector - # P_k = \psi_k^\dagger S_k - # Resulting shape: (MomentumSpace, IndexSpace_bands, IndexSpace_seeds) - overlap = eigenvectors.h(-2, -1) @ seeds - - # 2. Perform SVD on the overlap matrix - U, S, Vh = svd(overlap) - - # Check for linear dependence / poor projection - min_svd_val = S.data.min().item() - if min_svd_val < svd_threshold: - warnings.warn( - f"Precarious wannier projection with minimum svd value of {min_svd_val:.4g}", - UserWarning, - stacklevel=2, - ) - - # 3. Construct the unitary transformation matrix - # M_k = U_k V_k^\dagger - unitary = U @ Vh - - # 4. Rotate the target bands into the Wannier gauge - # W_k = \psi_k M_k - wannier_states = eigenvectors @ unitary - - return cast(Tensor[Any], wannier_states) - - -def wannierize_r( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: - """ - Perform projective wannierization using real-space localized seed states. - - Parameters - ---------- - eigenvectors : Tensor - Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. - seeds : Tensor - Seed states localized in real space with shape `(HilbertSpace_local, IndexSpace)`. - svd_threshold : float - SVD warning threshold. - - Returns - ------- - Tensor - Wannierized states in momentum space. - """ - if not isinstance(eigenvectors.dims[0], MomentumSpace): - raise TypeError("The first dimension of the eigenvectors must be a MomentumSpace.") - - kspace = eigenvectors.dims[0] - outspace = eigenvectors.dims[1] - inspace_local = seeds.dims[0] - if not isinstance(outspace, HilbertSpace) or not isinstance(inspace_local, HilbertSpace): - raise TypeError( - "The second dimension of eigenvectors and first dimension " - "of seeds must be HilbertSpace." - ) - - # Perform Fourier transform on local seeds to move them to momentum space - # f shape: (MomentumSpace, HilbertSpace_out, HilbertSpace_in_local) - f = fourier_transform(kspace, outspace, inspace_local, device=eigenvectors.device) - - # Map the seeds to crystal momentum seeds - # f @ local_seeds -> (MomentumSpace, HilbertSpace_out, IndexSpace) - crystal_seeds = f @ seeds - - return wannierize_k(eigenvectors, crystal_seeds, svd_threshold) - - -def projective_wannierization( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: - """ - Perform projective wannierization with automatic seed-space dispatch. - - Parameters - ---------- - eigenvectors : Tensor - Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. - seeds : Tensor - Either crystal-momentum seeds `(MomentumSpace, HilbertSpace, IndexSpace)` - or local real-space seeds `(HilbertSpace_local, IndexSpace)`. - svd_threshold : float - SVD warning threshold. - - Returns - ------- - Tensor - Wannierized states in momentum space. - """ - if seeds.rank() == 3: - if not isinstance(seeds.dims[0], MomentumSpace): - raise TypeError("Rank-3 seeds must have MomentumSpace as the first dimension.") - return wannierize_k(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) - - if seeds.rank() == 2: - if not isinstance(seeds.dims[0], HilbertSpace): - raise TypeError("Rank-2 seeds must have HilbertSpace as the first dimension.") - return wannierize_r(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) - - raise ValueError("Seeds must be rank-2 (local seeds) or rank-3 (momentum seeds).") diff --git a/src/qrg/wannier.pyi b/src/qrg/wannier.pyi deleted file mode 100644 index baff027..0000000 --- a/src/qrg/wannier.pyi +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any - -from qten.linalg.tensors import Tensor - -def wannierize_k( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: ... -def wannierize_r( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: ... -def projective_wannierization( - eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 -) -> Tensor[Any]: ... diff --git a/tests/test_smoke.py b/tests/test_smoke.py new file mode 100644 index 0000000..1a62f38 --- /dev/null +++ b/tests/test_smoke.py @@ -0,0 +1,5 @@ +import qrg + + +def test_import_qrg() -> None: + assert qrg.__version__ diff --git a/tests/test_wannier.py b/tests/test_wannier.py deleted file mode 100644 index 298f012..0000000 --- a/tests/test_wannier.py +++ /dev/null @@ -1,254 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -import pytest -import sympy as sy -import torch -from qten.geometries.boundary import PeriodicBoundary -from qten.geometries.fourier import fourier_transform -from qten.geometries.spatials import Lattice, Offset -from qten.linalg.tensors import Tensor -from qten.symbolics.hilbert_space import HilbertSpace, U1Basis -from qten.symbolics.state_space import IndexSpace, brillouin_zone -from sympy import ImmutableDenseMatrix - -from qrg.wannier import projective_wannierization, wannierize_k, wannierize_r - - -@dataclass(frozen=True) -class Orb: - name: str - - -def _state(r: Offset[Any], orb: str = "s") -> U1Basis: - return U1Basis(coef=sy.Integer(1), base=(r, Orb(orb))) - - -def _build_1d_spaces() -> tuple[Lattice, Any, HilbertSpace]: - lattice = Lattice( - basis=ImmutableDenseMatrix([[1]]), - boundaries=PeriodicBoundary(ImmutableDenseMatrix.diag(2)), - unit_cell={"r": ImmutableDenseMatrix([0])}, - ) - k_space = brillouin_zone(lattice.dual) - r0 = Offset(rep=ImmutableDenseMatrix([0]), space=lattice.affine) - r_half = Offset(rep=ImmutableDenseMatrix([sy.Rational(1, 2)]), space=lattice.affine) - bloch_space = HilbertSpace.new([_state(r0, "a"), _state(r_half, "b")]) - return lattice, k_space, bloch_space - - -def test_wannierize_r_matches_explicit_crystal_seed_pipeline() -> None: - """Test local-seed projection matches explicit crystal-seed workflow.""" - # Minimal 1D lattice and Brillouin zone, similar to the notebook flow. - _, k_space, bloch_space = _build_1d_spaces() - local_space = bloch_space - - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - # (K, B, I): one target band at each k with deterministic phase structure. - eigenvectors = Tensor( - data=torch.tensor( - [ - [[2**-0.5], [2**-0.5]], - [[2**-0.5], [-(2**-0.5)]], - ], - dtype=torch.complex128, - ), - dims=(k_space, bloch_space, band_space), - ) - - # (B_local, I): one local seed orbital. - local_seeds = Tensor( - data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), - dims=(local_space, seed_space), - ) - - # Notebook-like pathway: local seeds -> Fourier seeds -> projective wannierization. - crystal_seeds = fourier_transform(k_space, bloch_space, local_space) @ local_seeds - expected = wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds) - actual = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) - - assert actual.dims == expected.dims - assert torch.allclose(actual.data, expected.data) - - # Result should remain orthonormal within the selected band subspace. - overlap = actual.h(-2, -1) @ actual - assert torch.allclose( - overlap.data, - torch.ones((k_space.dim, 1, 1), dtype=torch.complex128), - ) - - -def test_wannierize_k_rejects_non_rank3_tensors() -> None: - """Test rank validation raises when inputs are not rank-3 tensors.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - eigenvectors_rank2 = Tensor( - data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), - dims=(bloch_space, band_space), - ) - seeds_rank3 = Tensor( - data=torch.ones((k_space.dim, bloch_space.dim, seed_space.dim), dtype=torch.complex128), - dims=(k_space, bloch_space, seed_space), - ) - with pytest.raises(ValueError, match="rank-3"): - wannierize_k(eigenvectors=eigenvectors_rank2, seeds=seeds_rank3) - - eigenvectors_rank3 = Tensor( - data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), - dims=(k_space, bloch_space, band_space), - ) - seeds_rank2 = Tensor( - data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), - dims=(bloch_space, seed_space), - ) - with pytest.raises(ValueError, match="rank-3"): - wannierize_k(eigenvectors=eigenvectors_rank3, seeds=seeds_rank2) - - -def test_wannierize_r_rejects_non_momentum_first_dimension() -> None: - """Test wannierize_r rejects non-MomentumSpace first tensor dim.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - bad_k_space = IndexSpace.linear(2) - eigenvectors = Tensor( - data=torch.ones((bad_k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), - dims=(bad_k_space, bloch_space, band_space), - ) - local_seeds = Tensor( - data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), - dims=(bloch_space, seed_space), - ) - with pytest.raises(TypeError, match="MomentumSpace"): - wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) - - good_eigenvectors = Tensor( - data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), - dims=(k_space, bloch_space, band_space), - ) - wannierize_r(eigenvectors=good_eigenvectors, seeds=local_seeds) - - -def test_wannierize_k_warns_on_poor_overlap() -> None: - """Test poor seed-band overlap emits the precarious projection warning.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - # Build nearly orthogonal eigenvector/seed overlap to trigger warning. - eigenvectors = Tensor( - data=torch.tensor( - [ - [[1.0], [0.0]], - [[1.0], [0.0]], - ], - dtype=torch.complex128, - ), - dims=(k_space, bloch_space, band_space), - ) - tiny = 1.0e-8 - seeds = Tensor( - data=torch.tensor( - [ - [[tiny], [1.0]], - [[tiny], [1.0]], - ], - dtype=torch.complex128, - ), - dims=(k_space, bloch_space, seed_space), - ) - - with pytest.warns(UserWarning, match="Precarious wannier projection"): - _ = wannierize_k( - eigenvectors=eigenvectors, - seeds=seeds, - svd_threshold=1.0e-3, - ) - - -def test_wannierize_r_projector_is_gauge_invariant() -> None: - """Test projector is invariant between equivalent seed construction routes.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - eigenvectors = Tensor( - data=torch.tensor( - [ - [[2**-0.5], [2**-0.5]], - [[2**-0.5], [-(2**-0.5)]], - ], - dtype=torch.complex128, - ), - dims=(k_space, bloch_space, band_space), - ) - local_seeds = Tensor( - data=torch.tensor([[0.0], [1.0]], dtype=torch.complex128), - dims=(bloch_space, seed_space), - ) - - w_local = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) - w_crystal = wannierize_k( - eigenvectors=eigenvectors, - seeds=fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds, - ) - - p_local = w_local @ w_local.h(-2, -1) - p_crystal = w_crystal @ w_crystal.h(-2, -1) - assert torch.allclose(p_local.data, p_crystal.data) - - -def test_projective_wannierization_matches_explicit_paths() -> None: - """Test auto-dispatch reproduces explicit k-space and local-space APIs.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - eigenvectors = Tensor( - data=torch.tensor( - [ - [[2**-0.5], [2**-0.5]], - [[2**-0.5], [-(2**-0.5)]], - ], - dtype=torch.complex128, - ), - dims=(k_space, bloch_space, band_space), - ) - local_seeds = Tensor( - data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), - dims=(bloch_space, seed_space), - ) - crystal_seeds = fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds - - assert torch.allclose( - projective_wannierization(eigenvectors=eigenvectors, seeds=local_seeds).data, - wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds).data, - ) - assert torch.allclose( - projective_wannierization(eigenvectors=eigenvectors, seeds=crystal_seeds).data, - wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds).data, - ) - - -def test_projective_wannierization_rejects_invalid_seed_rank() -> None: - """Test auto-dispatch raises when seeds rank is neither 2 nor 3.""" - _, k_space, bloch_space = _build_1d_spaces() - band_space = IndexSpace.linear(1) - seed_space = IndexSpace.linear(1) - - eigenvectors = Tensor( - data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), - dims=(k_space, bloch_space, band_space), - ) - bad_seeds = Tensor( - data=torch.ones((seed_space.dim,), dtype=torch.complex128), - dims=(seed_space,), - ) - - with pytest.raises(ValueError, match="rank-2|rank-3"): - projective_wannierization(eigenvectors=eigenvectors, seeds=bad_seeds)