Skip to content

Commit c49db94

Browse files
committed
Optimize DiagonalizeSupercell with vectorized operations and pre-computed phase factors
- Rename original DiagonalizeSupercell to DiagonalizeSupercell_slow for reference - Add optimized DiagonalizeSupercell with 1.12-1.22x speedup through: * Pre-computed phase factors for all q-points (vectorized) * Vectorized polarization vector construction using numpy broadcasting * Reduced memory allocations in hot loops - Add test_fast.py to verify new implementation matches original exactly - Add CsSnI3 test data files for diagonalization benchmarks Performance: 0.216s -> 0.178s (22% faster) on CsSnI3 4x4x4 supercell Correctness: Frequency diff 0.00e+00, polarization diff 5.55e-16 (machine precision)
1 parent 8da226c commit c49db94

12 files changed

Lines changed: 8471 additions & 1 deletion

File tree

cellconstructor/Phonons.py

Lines changed: 238 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3560,7 +3560,7 @@ def ForceSymmetries(self, symmetries, irt = None, apply_sum_rule = True):
35603560
if apply_sum_rule:
35613561
self.ApplySumRule()
35623562

3563-
def DiagonalizeSupercell(self, verbose = False, lo_to_split = None, return_qmodes = False, timer=None):
3563+
def DiagonalizeSupercell_slow(self, verbose = False, lo_to_split = None, return_qmodes = False, timer=None):
35643564
r"""
35653565
DIAGONALIZE THE DYNAMICAL MATRIX IN THE SUPERCELL
35663566
=================================================
@@ -3854,6 +3854,243 @@ def DiagonalizeSupercell(self, verbose = False, lo_to_split = None, return_qmode
38543854
return w_array, e_pols_sc
38553855

38563856

3857+
def DiagonalizeSupercell(self, verbose = False, lo_to_split = None, return_qmodes = False, timer=None):
3858+
r"""
3859+
DIAGONALIZE THE DYNAMICAL MATRIX IN THE SUPERCELL (FAST VERSION)
3860+
================================================================
3861+
3862+
This is an optimized version of DiagonalizeSupercell that uses:
3863+
1. Pre-computed phase factors for all q-points
3864+
2. Vectorized polarization vector construction
3865+
3. Reduced memory allocations in hot loops
3866+
3867+
The algorithm and output are identical to DiagonalizeSupercell,
3868+
but the implementation is optimized for speed.
3869+
3870+
Parameters
3871+
----------
3872+
- lo_to_split : string or ndarray
3873+
Could be a string with random, or a ndarray indicating the direction on which the
3874+
LO-TO splitting is computed. If None it is neglected.
3875+
If LO-TO is specified but no effective charges are present, then a warning is print and it is ignored.
3876+
- return_qmodes : bool
3877+
If true, frequencies and polarizations in q space are returned.
3878+
Results
3879+
-------
3880+
- w_mu : ndarray( size = (n_modes), dtype = np.double)
3881+
Frequencies in the supercell
3882+
- e_mu : ndarray( size = (3*Nat_sc, n_modes), dtype = np.double, order = "F")
3883+
Polarization vectors in the supercell
3884+
- w_q : ndarray( size = (3*Nat, nq), dtype = np.double, order = "F")
3885+
Frequencies in the q space (only if return_qmodes is True)
3886+
- e_q : ndarray( size = (3*Nat, 3*Nat, nq), dtype = np.complex128, order = "F")
3887+
Polarization vectors in the q space (only if return_qmodes is True)
3888+
"""
3889+
3890+
supercell_size = len(self.q_tot)
3891+
nat = self.structure.N_atoms
3892+
3893+
nmodes = 3*nat*supercell_size
3894+
nat_sc = nat*supercell_size
3895+
3896+
w_array = np.zeros( nmodes, dtype = np.double)
3897+
e_pols_sc = np.zeros( (nmodes, nmodes), dtype = np.double, order = "F")
3898+
3899+
nq = len(self.q_tot)
3900+
w_q = np.zeros((3*nat, nq), dtype = np.double, order = "F")
3901+
pols_q = np.zeros((3*nat, 3*nat, nq), dtype = np.complex128, order = "F")
3902+
3903+
# Get the structure in the supercell
3904+
super_structure = self.structure.generate_supercell(self.GetSupercell())
3905+
3906+
# Get the supercell correspondence vector
3907+
itau = super_structure.get_itau(self.structure) - 1 # Fort2Py
3908+
3909+
# Get the itau in the contracted indices (3*nat_sc -> 3*nat)
3910+
itau_modes = (np.tile(np.array(itau) * 3, (3,1)).T + np.arange(3)).ravel()
3911+
3912+
# Get the position in the supercell
3913+
R_vec = np.zeros((nmodes, 3), dtype = np.double)
3914+
for i in range(nat_sc):
3915+
R_vec[3*i : 3*i+3, :] = np.tile(super_structure.coords[i, :] - self.structure.coords[itau[i], :], (3,1))
3916+
3917+
# OPTIMIZATION 1: Pre-compute unique q-points
3918+
bg = self.structure.get_reciprocal_vectors() / (2*np.pi)
3919+
3920+
# Build a mask for q-points to process (unique ones, not related by G-q)
3921+
q_array = np.array(self.q_tot)
3922+
n_q = len(self.q_tot)
3923+
3924+
skip_mask = np.zeros(n_q, dtype=bool)
3925+
3926+
# For each q point, check if we've seen an equivalent one before
3927+
for iq in range(n_q):
3928+
if skip_mask[iq]:
3929+
continue
3930+
for jq in range(iq):
3931+
if skip_mask[jq]:
3932+
continue
3933+
# Check if q and q_prev are related by G-q operation
3934+
dist = Methods.get_min_dist_into_cell(bg, -q_array[iq], q_array[jq])
3935+
if dist < __EPSILON__:
3936+
skip_mask[iq] = True
3937+
break
3938+
3939+
# OPTIMIZATION 2: Pre-compute phase factors for all q-points at once
3940+
# This avoids computing R_vec.dot(q) repeatedly in the inner loop
3941+
phase_factors = np.exp(1j * 2 * np.pi * R_vec.dot(q_array.T)) # (nmodes, nq)
3942+
3943+
i_mu = 0
3944+
for iq, q in enumerate(self.q_tot):
3945+
# Check if this q point should be skipped (already processed equivalent)
3946+
if skip_mask[iq]:
3947+
# Check if we must return anyway the polarization in q space
3948+
if return_qmodes:
3949+
if timer is not None:
3950+
wq, eq = timer.execute_timed_function(self.DyagDinQ, iq)
3951+
else:
3952+
wq, eq = self.DyagDinQ(iq)
3953+
3954+
w_q[:, iq] = wq
3955+
pols_q[:, :, iq] = eq
3956+
continue
3957+
3958+
# Check if this q = -q + G
3959+
is_minus_q = False
3960+
if Methods.get_min_dist_into_cell(bg, q, -q) < 1e-6:
3961+
is_minus_q = True
3962+
3963+
# The dynamical matrix must be real
3964+
re_part = np.real(self.dynmats[iq])
3965+
3966+
assert np.max(np.abs(np.imag(self.dynmats[iq]))) < __EPSILON__, "Error, at point {} (q = -q + G) the dynamical matrix is complex".format(iq)
3967+
3968+
# Enforce reality to avoid complex polarization vectors
3969+
self.dynmats[iq] = re_part
3970+
3971+
# Check if this is gamma (to apply the LO-TO splitting)
3972+
if Methods.get_min_dist_into_cell(bg, q, np.zeros(3)) < 1e-16 and lo_to_split is not None:
3973+
if self.effective_charges is None:
3974+
warnings.warn("WARNING: Requested LO-TO splitting without effective charges. LO-TO ignored.")
3975+
3976+
# Initialize the Force Constant
3977+
t2 = ForceTensor.Tensor2(self.structure, self.structure.generate_supercell(self.GetSupercell()), self.GetSupercell())
3978+
t2.SetupFromPhonons(self)
3979+
3980+
if isinstance(lo_to_split, str):
3981+
if lo_to_split.lower() == "random":
3982+
fc_gamma = t2.Interpolate(np.zeros(3))
3983+
else:
3984+
raise ValueError("Error, lo_to_split argument '%s' not recognized" % lo_to_split)
3985+
else:
3986+
fc_gamma = t2.Interpolate(np.zeros(3), q_direct= -lo_to_split)
3987+
3988+
_m_ = np.tile(self.structure.get_masses_array(), (3,1)).T.ravel()
3989+
d_gamma = fc_gamma / np.sqrt(np.outer(_m_, _m_))
3990+
wq2, eq = np.linalg.eigh(d_gamma)
3991+
3992+
wq = np.sqrt(np.abs(wq2)) * np.sign(wq2)
3993+
else:
3994+
# Diagonalize the matrix in the given q point
3995+
if timer is not None:
3996+
wq, eq = timer.execute_timed_function(self.DyagDinQ, iq)
3997+
else:
3998+
wq, eq = self.DyagDinQ(iq)
3999+
4000+
# Store the frequencies and the polarization vectors
4001+
w_q[:, iq] = wq
4002+
pols_q[:, :, iq] = eq
4003+
4004+
# OPTIMIZATION 3: Vectorized polarization vector construction
4005+
nm_q = i_mu
4006+
t1 = time.time()
4007+
4008+
# Get all polarization vectors for this q-point
4009+
tilde_e_qnu_all = eq # (3*nat, 3*nat)
4010+
4011+
# Extract relevant components using itau_modes
4012+
e_contracted = tilde_e_qnu_all[itau_modes, :] # (nmodes, nmodes_unit)
4013+
4014+
# Get phase factors for this q-point
4015+
phase_q = phase_factors[:, iq] # (nmodes,)
4016+
4017+
# Broadcast and multiply: e_sc = e_contracted * phase / sqrt(N_q)
4018+
c_e_sc = e_contracted * phase_q[:, np.newaxis] / np.sqrt(supercell_size)
4019+
c_e_sc_mq = np.conj(c_e_sc)
4020+
4021+
# Compute real and imaginary parts vectorized
4022+
evec_1_all = np.real(0.5 * (c_e_sc + c_e_sc_mq)) # (nmodes, nmodes_unit)
4023+
evec_2_all = np.real((c_e_sc - c_e_sc_mq) / (2*1j)) # (nmodes, nmodes_unit)
4024+
4025+
# Compute norms vectorized
4026+
norm1_all = np.sum(evec_1_all**2, axis=0) # (nmodes_unit,)
4027+
norm2_all = np.sum(evec_2_all**2, axis=0) # (nmodes_unit,)
4028+
4029+
# Now iterate over modes but with pre-computed values
4030+
for i_qnu, w_qnu in enumerate(wq):
4031+
norm1 = norm1_all[i_qnu]
4032+
norm2 = norm2_all[i_qnu]
4033+
4034+
evec_1 = evec_1_all[:, i_qnu]
4035+
evec_2 = evec_2_all[:, i_qnu]
4036+
4037+
EPSILON = 1e-5
4038+
add_1 = norm1 > EPSILON
4039+
add_2 = norm2 > EPSILON
4040+
4041+
if is_minus_q:
4042+
if add_1 and add_2:
4043+
# Check linear dependence
4044+
scalar_dot = np.dot(evec_1, evec_2) / np.sqrt(norm1 * norm2)
4045+
if np.abs(np.abs(scalar_dot) - 1) > EPSILON:
4046+
raise ValueError("Error, with q = -q + G, the two vectors should be linearly dependent")
4047+
4048+
# Keep the one with higher norm
4049+
if norm1 > norm2:
4050+
add_2 = False
4051+
else:
4052+
add_1 = False
4053+
else:
4054+
if not (add_1 and add_2):
4055+
raise ValueError("Error, the q_point = {} {} {} should contribute also for -q, something went wrong".format(*list(q)))
4056+
4057+
if add_1 and add_2:
4058+
# Since both real and imaginary should match in this case
4059+
# Add only one of them
4060+
if is_minus_q:
4061+
add_2 = False
4062+
4063+
# Add the vectors
4064+
if add_1:
4065+
w_array[i_mu] = w_qnu
4066+
e_pols_sc[:, i_mu] = evec_1 / np.sqrt(norm1)
4067+
i_mu += 1
4068+
if add_2:
4069+
w_array[i_mu] = w_qnu
4070+
e_pols_sc[:, i_mu] = evec_2 / np.sqrt(norm2)
4071+
i_mu += 1
4072+
4073+
t2 = time.time()
4074+
if timer is not None:
4075+
timer.add_timer("Manipulate polarization vectors", t2 - t1)
4076+
4077+
# Print how many vectors have been extracted
4078+
if verbose:
4079+
print("The {} / {} q point produced {} nodes".format(iq, len(self.q_tot), i_mu - nm_q))
4080+
4081+
# Sort the frequencies
4082+
sort_mask = np.argsort(w_array)
4083+
w_array = w_array[sort_mask]
4084+
e_pols_sc = e_pols_sc[:, sort_mask]
4085+
4086+
# Get the check for the polarization vector normalization
4087+
assert np.max(np.abs(np.einsum("ab, ab->b", e_pols_sc, e_pols_sc) - 1)) < __EPSILON__
4088+
4089+
if return_qmodes:
4090+
return w_array, e_pols_sc, w_q, pols_q
4091+
return w_array, e_pols_sc
4092+
4093+
38574094
def FixQPoints(self, threshold = 1e-1):
38584095
"""
38594096
FIX Q POINTS

0 commit comments

Comments
 (0)