diff --git a/pyproject.toml b/pyproject.toml index 3dfefff..35e0b79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "qten" -version = "0.4.1" +version = "0.4.2" description = "Torch-based tensors, lattices, symmetries, and band-structure tools for quantum lattice models." readme = "README.md" requires-python = ">=3.11" diff --git a/src/qten/geometries/ops.py b/src/qten/geometries/ops.py index 2b38068..7fefcb9 100644 --- a/src/qten/geometries/ops.py +++ b/src/qten/geometries/ops.py @@ -247,15 +247,14 @@ def get_strip_region_2d( origin: Offset[AffineSpace] | Offset[Lattice] | None = None, ) -> tuple[Offset[Lattice], ...]: r""" - Return a 2D rectangular strip region in primitive-strip lattice coordinates. + Return a 2D rectangular strip region in strip coordinates induced by `direction`. This helper is defined only for 2D lattices. Let `r0` be the supplied [`origin`][qten.geometries.spatials.AffineSpace.origin] (or the lattice origin when omitted). - Let \((d_x, d_y)\) be the supplied direction coordinates. Let - \(p = (p_x, p_y)\) be the associated primitive integer direction, and let - \(n = (-p_y, p_x)\) be the primitive integer normal. `side="lhs"` grows - toward positive \(n\) and `side="rhs"` grows toward negative \(n\). + Let \((d_x, d_y)\) be the supplied direction coordinates and let + \(n = (-d_y, d_x)\) be the rotated transverse direction. `side="lhs"` + grows toward positive \(n\) and `side="rhs"` grows toward negative \(n\). A lattice site belongs to the strip when some periodic image of that site satisfies both of the following: @@ -265,25 +264,21 @@ def get_strip_region_2d( \le d_x(r_x-r_{0x}) + d_y(r_y-r_{0y}) \le (\mathrm{length\_step}-1)(d_x^2+d_y^2)\). - Transverse bound: - \(0 \le s[-p_y(r_x-r_{0x}) + p_x(r_y-r_{0y})] - \le \mathrm{width\_step}-1\). + \(0 \le s[-d_y(r_x-r_{0x}) + d_x(r_y-r_{0y})] + \le (\mathrm{width\_step}-1)(d_x^2+d_y^2)\). where \(s = 1\) for `"lhs"` and \(s = -1\) for `"rhs"`. - For integer directions, \((d_x, d_y) = (p_x, p_y)\). For rational directions, - longitudinal shell spacing is computed from the supplied direction - `(dx, dy)`, while transverse shelling is computed from the primitive - integer direction `p`. - `width_step` counts the transverse shell thickness including the main axis - row. `trim_step` is a tail trimmer only: it advances the strip start along - the longitudinal axis without affecting the transverse width. + row in the rotated, scaled direction induced by `direction`. `trim_step` + is a tail trimmer only: it advances the strip start along the + longitudinal axis without affecting the transverse width. Parameters ---------- direction : Offset[Lattice] - Non-zero lattice translation on a 2D lattice whose primitive direction - defines the strip axis. + Non-zero lattice translation on a 2D lattice whose direction defines + both the strip axis and the scaled transverse direction. length_step : int Number of strip shells from the origin along the primitive direction. width_step : int @@ -292,8 +287,8 @@ def get_strip_region_2d( Number of longitudinal shells trimmed from the tail near the origin. side : Literal["lhs", "rhs"] Side on which transverse width shells are accumulated relative to the - strip direction. `"lhs"` uses the positive lattice normal and `"rhs"` - uses the negative lattice normal. + strip direction. `"lhs"` uses the positive rotated direction and + `"rhs"` uses the negative one. origin : Offset[AffineSpace] | Offset[Lattice] | None Anchor point for the strip coordinates. If omitted, the zero offset in the lattice space is used. When provided, it is rebased into the @@ -326,7 +321,7 @@ def get_strip_region_2d( if length_step == 0 or width_step == 0: return () - lattice, dx, dy, px, py = _strip_direction_data(direction) + lattice, dx, dy, _px, _py = _strip_direction_data(direction) if origin is None: origin = lattice.origin() if origin.dim != lattice.dim: @@ -338,13 +333,14 @@ def get_strip_region_2d( if len(all_sites) == 0: return () - normal_x = -py - normal_y = px + normal_x = -dy + normal_y = dx normal_sign = 1 if side == "lhs" else -1 - longitudinal_min = trim_step * (dx * dx + dy * dy) - longitudinal_max = (length_step - 1) * (dx * dx + dy * dy) + direction_norm_sq = dx * dx + dy * dy + longitudinal_min = trim_step * direction_norm_sq + longitudinal_max = (length_step - 1) * direction_norm_sq transverse_min = 0 - transverse_max = width_step - 1 + transverse_max = (width_step - 1) * direction_norm_sq boundary_basis = np.array(lattice.boundaries.basis.tolist(), dtype=int) image_shifts = [ diff --git a/tests/test_geometry_ops.py b/tests/test_geometry_ops.py index 9845183..1645179 100644 --- a/tests/test_geometry_ops.py +++ b/tests/test_geometry_ops.py @@ -239,7 +239,7 @@ def test_get_strip_region_2d_supports_affine_origin(): ) -def test_get_strip_region_2d_uses_primitive_direction(): +def test_get_strip_region_2d_inherits_direction_scale_transversely(): lattice = Lattice( basis=ImmutableDenseMatrix.eye(2), boundaries=PeriodicBoundary(ImmutableDenseMatrix.diag(8, 8)), @@ -256,9 +256,12 @@ def test_get_strip_region_2d_uses_primitive_direction(): assert tuple(tuple(site.rep) for site in region) == ( (0, 0), (0, 1), + (0, 2), (1, 1), (1, 2), + (1, 3), (2, 2), + (7, 1), ) @@ -277,10 +280,13 @@ def test_get_strip_region_2d_matches_diagonal_examples(): ) assert set(tuple(site.rep) for site in region) == { (0, 0), - (1, 1), - (2, 2), (0, 1), + (0, 2), + (1, 1), (1, 2), + (1, 3), + (2, 2), + (15, 1), } region = get_strip_region_2d( @@ -291,13 +297,18 @@ def test_get_strip_region_2d_matches_diagonal_examples(): ) assert set(tuple(site.rep) for site in region) == { (0, 0), - (1, 1), - (2, 2), (0, 1), - (1, 2), - (15, 1), (0, 2), + (0, 3), + (0, 4), + (1, 1), + (1, 2), (1, 3), + (2, 2), + (14, 2), + (15, 1), + (15, 2), + (15, 3), } region = get_strip_region_2d( @@ -309,10 +320,13 @@ def test_get_strip_region_2d_matches_diagonal_examples(): ) assert set(tuple(site.rep) for site in region) == { (0, 2), + (0, 3), + (0, 4), (1, 1), (2, 2), (1, 2), (1, 3), + (15, 3), } region = get_strip_region_2d( @@ -323,10 +337,13 @@ def test_get_strip_region_2d_matches_diagonal_examples(): ) assert set(tuple(site.rep) for site in region) == { (0, 0), - (1, 1), - (2, 2), (1, 0), + (1, 1), + (1, 15), + (2, 0), (2, 1), + (2, 2), + (3, 1), }