Skip to content

Commit 2a60851

Browse files
committed
Refactor acoustic mode masking in Q-space modules
Replace hardcoded acoustic_eps threshold with a unified valid_modes_q mask that combines translation-based detection at Gamma with optional frequency- based masking when ignore_small_w is True. Changes: - QSpaceLanczos.py: Build valid_modes_q mask using translations at Gamma plus small-frequency masking at all q-points when ignore_small_w=True. Pass mask to Julia extension. - QSpaceHessian.py: Replace all acoustic_eps usages with valid_modes_q mask lookups for consistent mode exclusion. - tdscha_qspace.jl: Accept valid_modes_q mask from Python and use it instead of hardcoded 1e-6 threshold for zeroing f_Y/f_psi. All 13 q-space tests pass.
1 parent e249852 commit 2a60851

3 files changed

Lines changed: 42 additions & 43 deletions

File tree

Modules/QSpaceHessian.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ def _find_degenerate_blocks(self, iq):
215215
E.g. [[3], [4, 5, 6], [7, 8, 9]] for a singlet + two triplets.
216216
"""
217217
w_qp = self.w_q[:, iq]
218-
eps = self.qlanc.acoustic_eps
219218

220219
deg_lists = CC.symmetries.get_degeneracies(w_qp)
221220

@@ -224,12 +223,12 @@ def _find_degenerate_blocks(self, iq):
224223
for nu in range(self.n_bands):
225224
if nu in seen:
226225
continue
227-
if w_qp[nu] < eps:
226+
if not self.qlanc.valid_modes_q[nu, iq]:
228227
seen.add(nu)
229228
continue
230229
block = sorted(deg_lists[nu].tolist())
231-
# Filter out any acoustic modes that crept in
232-
block = [b for b in block if w_qp[b] >= eps]
230+
# Filter out any masked modes that crept in
231+
block = [b for b in block if self.qlanc.valid_modes_q[b, iq]]
233232
for b in block:
234233
seen.add(b)
235234
if len(block) > 0:
@@ -281,38 +280,35 @@ def _static_mask(self):
281280
"""Build mask for static psi inner product.
282281
283282
Same structure as FT mask but only one W sector (not a' + b').
284-
Acoustic modes (w < eps) are masked out (set to 0) to keep
283+
Masked modes (acoustic/small-w) are set to 0 to keep
285284
the null space purely in the W sector and avoid inconsistency.
286285
"""
287286
psi_size = self._get_static_psi_size()
288287
mask = np.ones(psi_size, dtype=np.float64)
289288
nb = self.n_bands
290-
eps = self.qlanc.acoustic_eps
291289

292-
# R sector: mask acoustic modes at the perturbation q-point
293-
w_qp = self.w_q[:, self.qlanc.iq_pert]
294-
mask[:nb] = np.where(w_qp < eps, 0.0, 1.0)
290+
# R sector: mask modes at the perturbation q-point
291+
valid_pert = self.qlanc.valid_modes_q[:, self.qlanc.iq_pert]
292+
mask[:nb] = np.where(valid_pert, 1.0, 0.0)
295293

296294
for pair_idx, (iq1, iq2) in enumerate(self.qlanc.unique_pairs):
297295
offset = self.qlanc.get_block_offset(pair_idx, 'a')
298-
w1 = self.w_q[:, iq1]
299-
w2 = self.w_q[:, iq2]
300-
ac1 = w1 < eps
301-
ac2 = w2 < eps
302-
# acoustic_any[i,j] = True if either mode i or j is acoustic
303-
acoustic_any = ac1[:, None] | ac2[None, :]
296+
valid1 = self.qlanc.valid_modes_q[:, iq1]
297+
valid2 = self.qlanc.valid_modes_q[:, iq2]
298+
# invalid_any[i,j] = True if either mode i or j is masked
299+
invalid_any = ~valid1[:, None] | ~valid2[None, :]
304300

305301
if iq1 < iq2:
306-
# Full block: 0 if acoustic, 2 if not
307-
block_mask = np.where(acoustic_any, 0.0, 2.0)
302+
# Full block: 0 if masked, 2 if not
303+
block_mask = np.where(invalid_any, 0.0, 2.0)
308304
mask[offset:offset + nb * nb] = block_mask.ravel()
309305
else:
310306
# Upper triangle storage: pack diag (weight=1) + off-diag (weight=2)
311307
size = self.qlanc.get_block_size(pair_idx)
312-
# Build full matrix of weights: diag=1, off-diag=2, acoustic=0
313-
full_weights = np.where(acoustic_any, 0.0, 2.0)
308+
# Build full matrix of weights: diag=1, off-diag=2, masked=0
309+
full_weights = np.where(invalid_any, 0.0, 2.0)
314310
np.fill_diagonal(full_weights, np.where(
315-
ac1 | ac2, 0.0, 1.0))
311+
valid1 & valid2, 1.0, 0.0))
316312
# Pack upper triangle in row-major order
317313
tri = np.zeros(size, dtype=np.float64)
318314
idx = 0
@@ -350,16 +346,16 @@ def _compute_lambda_q(self, iq1, iq2):
350346
n2 = np.zeros_like(w2)
351347
if T > __EPSILON__:
352348
beta = __RyToK__ / T
353-
valid1 = w1 > self.qlanc.acoustic_eps
354-
valid2 = w2 > self.qlanc.acoustic_eps
349+
valid1 = self.qlanc.valid_modes_q[:, iq1]
350+
valid2 = self.qlanc.valid_modes_q[:, iq2]
355351
n1[valid1] = 1.0 / (np.exp(w1[valid1] * beta) - 1.0)
356352
n2[valid2] = 1.0 / (np.exp(w2[valid2] * beta) - 1.0)
357353

358354
n1_mat = np.tile(n1, (self.n_bands, 1)).T
359355
n2_mat = np.tile(n2, (self.n_bands, 1))
360356

361-
valid_mask = np.outer(w1 > self.qlanc.acoustic_eps,
362-
w2 > self.qlanc.acoustic_eps)
357+
valid_mask = np.outer(self.qlanc.valid_modes_q[:, iq1],
358+
self.qlanc.valid_modes_q[:, iq2])
363359

364360
# (n1 - n2) / (w1 - w2), regularized for w1 ≈ w2
365361
diff_n = np.zeros_like(w1_mat)
@@ -395,18 +391,18 @@ def _compute_lambda_q(self, iq1, iq2):
395391
def _compute_Y_w(self, iq):
396392
"""Compute Y_w = 2*w/(2*n+1) for all bands at q-point iq.
397393
398-
Acoustic modes (w < eps) get Y_w = 0.
394+
Masked modes (acoustic/small-w) get Y_w = 0.
399395
"""
400396
w = self.w_q[:, iq]
401397
T = self.qlanc.T
402398

403399
n = np.zeros_like(w)
404400
if T > __EPSILON__:
405-
valid = w > self.qlanc.acoustic_eps
401+
valid = self.qlanc.valid_modes_q[:, iq]
406402
n[valid] = 1.0 / (np.exp(w[valid] * __RyToK__ / T) - 1.0)
407403

408404
Y_w = np.zeros_like(w)
409-
valid = w > self.qlanc.acoustic_eps
405+
valid = self.qlanc.valid_modes_q[:, iq]
410406
Y_w[valid] = 2.0 * w[valid] / (2.0 * n[valid] + 1.0)
411407
return Y_w
412408

@@ -421,13 +417,13 @@ def _precompute_static_quantities(self):
421417
- _cached_yw_outer[pair_idx]: -2*outer(Y_w1,Y_w2) for anharmonic
422418
"""
423419
nb = self.n_bands
424-
eps = self.qlanc.acoustic_eps
425420

426421
# R sector
427422
w_qp = self.w_q[:, self.qlanc.iq_pert]
428423
self._cached_w_qp_sq = w_qp ** 2
424+
valid_pert = self.qlanc.valid_modes_q[:, self.qlanc.iq_pert]
429425
self._cached_inv_w_qp_sq = np.where(
430-
w_qp > eps, 1.0 / np.where(w_qp > eps, w_qp, 1.0) ** 2, 0.0)
426+
valid_pert, 1.0 / np.where(valid_pert, w_qp, 1.0) ** 2, 0.0)
431427

432428
# W sector: per-pair Lambda and Y_w caches
433429
n_pairs = len(self.qlanc.unique_pairs)
@@ -643,7 +639,7 @@ def apply_M_tilde(x_tilde):
643639
# 5. Identify non-acoustic bands and degenerate blocks
644640
w_qp = self.w_q[:, iq]
645641
non_acoustic = [nu for nu in range(nb)
646-
if w_qp[nu] > self.qlanc.acoustic_eps]
642+
if self.qlanc.valid_modes_q[nu, iq]]
647643

648644
# Build solve schedule: list of (band_to_solve, block_members)
649645
if use_mode_symmetry:

Modules/QSpaceLanczos.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(self, ensemble, **kwargs):
142142
# -- Add the q-space attributes --
143143
qspace_attrs = [
144144
'q_points', 'n_q', 'n_bands', 'w_q', 'pols_q',
145-
'acoustic_eps', 'valid_modes_q', 'X_q', 'Y_q',
145+
'valid_modes_q', 'X_q', 'Y_q',
146146
'iq_pert', 'q_pair_map', 'unique_pairs',
147147
'_psi_size', '_block_offsets_a', '_block_offsets_b', '_block_sizes',
148148
'_qspace_sym_data', '_qspace_sym_q_map', 'n_syms_qspace',
@@ -161,17 +161,19 @@ def __init__(self, ensemble, **kwargs):
161161
# The masses needs to be restricted to the primitive cell only
162162
self.m = self.m[:self.n_bands]
163163

164-
# Small frequency threshold for acoustic mode masking
165-
self.acoustic_eps = 1e-6
166-
167-
# Build translation-based valid mask for acoustic modes
168-
# At Gamma (iq=0), identify translations from polarization vectors
169-
# At q != 0, all modes are valid (acoustic modes only have zero freq at Gamma)
164+
# Build valid_modes_q mask for acoustic mode exclusion
165+
# Always apply translation-based mask at Gamma (iq=0)
170166
masses_uc = self.dyn.structure.get_masses_array()
171167
self.valid_modes_q = np.ones((self.n_bands, self.n_q), dtype=bool)
172168
trans_mask = CC.Methods.get_translations(np.real(self.pols_q[:, :, 0]), masses_uc)
173169
self.valid_modes_q[:, 0] = ~trans_mask
174170

171+
# If ignore_small_w is True, also mask small frequencies at ALL q-points
172+
if ensemble.ignore_small_w:
173+
for iq in range(self.n_q):
174+
small_freq_mask = np.abs(self.w_q[:, iq]) < CC.Phonons.__EPSILON_W__
175+
self.valid_modes_q[:, iq] &= ~small_freq_mask
176+
175177
# == 2. Bloch transform ensemble data ==
176178
self._bloch_transform_ensemble()
177179

@@ -780,7 +782,8 @@ def get_combined(start_end):
780782
np.int64(self.iq_pert + 1),
781783
self.q_pair_map + 1, # 1-indexed
782784
unique_pairs_arr,
783-
int(start_end[0]), int(start_end[1])
785+
int(start_end[0]), int(start_end[1]),
786+
self.valid_modes_q # Pass mask to Julia
784787
)
785788

786789
combined = Parallel.GoParallel(get_combined, indices, "+")

Modules/tdscha_qspace.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,8 @@ function get_perturb_averages_qspace(
649649
q_pair_map::Vector{Int32},
650650
unique_pairs::Matrix{Int32},
651651
start_index::Int64,
652-
end_index::Int64
652+
end_index::Int64,
653+
valid_modes_q::Matrix{Bool} # Mask from Python: false for acoustic/small-w modes
653654
)
654655
n_q = size(X_q, 1)
655656
n_bands = size(X_q, 3)
@@ -662,19 +663,18 @@ function get_perturb_averages_qspace(
662663
end
663664

664665
# Precompute occupation numbers and scaling factors
665-
# Acoustic modes (w < threshold) get f_Y=0, f_psi=0 to avoid NaN/Inf
666+
# Masked modes (valid_modes_q == false) get f_Y=0, f_psi=0 to avoid NaN/Inf
666667
f_Y = zeros(Float64, n_bands, n_q)
667668
f_psi = zeros(Float64, n_bands, n_q)
668-
acoustic_eps = 1e-6
669669

670670
for iq in 1:n_q
671671
for nu in 1:n_bands
672-
w = w_q[nu, iq]
673-
if w < acoustic_eps
672+
if !valid_modes_q[nu, iq] # Masked mode -> zero out
674673
f_Y[nu, iq] = 0.0
675674
f_psi[nu, iq] = 0.0
676675
continue
677676
end
677+
w = w_q[nu, iq]
678678
if temperature > 0
679679
nw = 1.0 / (exp(w * RY_TO_K_Q / temperature) - 1.0)
680680
else

0 commit comments

Comments
 (0)