@@ -3560,7 +3560,7 @@ def ForceSymmetries(self, symmetries, irt = None, apply_sum_rule = True):
35603560 if apply_sum_rule :
35613561 self .ApplySumRule ()
35623562
3563- def DiagonalizeSupercell (self , verbose = False , lo_to_split = None , return_qmodes = False , timer = None ):
3563+ def DiagonalizeSupercell_slow (self , verbose = False , lo_to_split = None , return_qmodes = False , timer = None ):
35643564 r"""
35653565 DIAGONALIZE THE DYNAMICAL MATRIX IN THE SUPERCELL
35663566 =================================================
@@ -3854,6 +3854,243 @@ def DiagonalizeSupercell(self, verbose = False, lo_to_split = None, return_qmode
38543854 return w_array , e_pols_sc
38553855
38563856
3857+ def DiagonalizeSupercell (self , verbose = False , lo_to_split = None , return_qmodes = False , timer = None ):
3858+ r"""
3859+ DIAGONALIZE THE DYNAMICAL MATRIX IN THE SUPERCELL (FAST VERSION)
3860+ ================================================================
3861+
3862+ This is an optimized version of DiagonalizeSupercell that uses:
3863+ 1. Pre-computed phase factors for all q-points
3864+ 2. Vectorized polarization vector construction
3865+ 3. Reduced memory allocations in hot loops
3866+
3867+ The algorithm and output are identical to DiagonalizeSupercell,
3868+ but the implementation is optimized for speed.
3869+
3870+ Parameters
3871+ ----------
3872+ - lo_to_split : string or ndarray
3873+ Could be a string with random, or a ndarray indicating the direction on which the
3874+ LO-TO splitting is computed. If None it is neglected.
3875+ If LO-TO is specified but no effective charges are present, then a warning is print and it is ignored.
3876+ - return_qmodes : bool
3877+ If true, frequencies and polarizations in q space are returned.
3878+ Results
3879+ -------
3880+ - w_mu : ndarray( size = (n_modes), dtype = np.double)
3881+ Frequencies in the supercell
3882+ - e_mu : ndarray( size = (3*Nat_sc, n_modes), dtype = np.double, order = "F")
3883+ Polarization vectors in the supercell
3884+ - w_q : ndarray( size = (3*Nat, nq), dtype = np.double, order = "F")
3885+ Frequencies in the q space (only if return_qmodes is True)
3886+ - e_q : ndarray( size = (3*Nat, 3*Nat, nq), dtype = np.complex128, order = "F")
3887+ Polarization vectors in the q space (only if return_qmodes is True)
3888+ """
3889+
3890+ supercell_size = len (self .q_tot )
3891+ nat = self .structure .N_atoms
3892+
3893+ nmodes = 3 * nat * supercell_size
3894+ nat_sc = nat * supercell_size
3895+
3896+ w_array = np .zeros ( nmodes , dtype = np .double )
3897+ e_pols_sc = np .zeros ( (nmodes , nmodes ), dtype = np .double , order = "F" )
3898+
3899+ nq = len (self .q_tot )
3900+ w_q = np .zeros ((3 * nat , nq ), dtype = np .double , order = "F" )
3901+ pols_q = np .zeros ((3 * nat , 3 * nat , nq ), dtype = np .complex128 , order = "F" )
3902+
3903+ # Get the structure in the supercell
3904+ super_structure = self .structure .generate_supercell (self .GetSupercell ())
3905+
3906+ # Get the supercell correspondence vector
3907+ itau = super_structure .get_itau (self .structure ) - 1 # Fort2Py
3908+
3909+ # Get the itau in the contracted indices (3*nat_sc -> 3*nat)
3910+ itau_modes = (np .tile (np .array (itau ) * 3 , (3 ,1 )).T + np .arange (3 )).ravel ()
3911+
3912+ # Get the position in the supercell
3913+ R_vec = np .zeros ((nmodes , 3 ), dtype = np .double )
3914+ for i in range (nat_sc ):
3915+ R_vec [3 * i : 3 * i + 3 , :] = np .tile (super_structure .coords [i , :] - self .structure .coords [itau [i ], :], (3 ,1 ))
3916+
3917+ # OPTIMIZATION 1: Pre-compute unique q-points
3918+ bg = self .structure .get_reciprocal_vectors () / (2 * np .pi )
3919+
3920+ # Build a mask for q-points to process (unique ones, not related by G-q)
3921+ q_array = np .array (self .q_tot )
3922+ n_q = len (self .q_tot )
3923+
3924+ skip_mask = np .zeros (n_q , dtype = bool )
3925+
3926+ # For each q point, check if we've seen an equivalent one before
3927+ for iq in range (n_q ):
3928+ if skip_mask [iq ]:
3929+ continue
3930+ for jq in range (iq ):
3931+ if skip_mask [jq ]:
3932+ continue
3933+ # Check if q and q_prev are related by G-q operation
3934+ dist = Methods .get_min_dist_into_cell (bg , - q_array [iq ], q_array [jq ])
3935+ if dist < __EPSILON__ :
3936+ skip_mask [iq ] = True
3937+ break
3938+
3939+ # OPTIMIZATION 2: Pre-compute phase factors for all q-points at once
3940+ # This avoids computing R_vec.dot(q) repeatedly in the inner loop
3941+ phase_factors = np .exp (1j * 2 * np .pi * R_vec .dot (q_array .T )) # (nmodes, nq)
3942+
3943+ i_mu = 0
3944+ for iq , q in enumerate (self .q_tot ):
3945+ # Check if this q point should be skipped (already processed equivalent)
3946+ if skip_mask [iq ]:
3947+ # Check if we must return anyway the polarization in q space
3948+ if return_qmodes :
3949+ if timer is not None :
3950+ wq , eq = timer .execute_timed_function (self .DyagDinQ , iq )
3951+ else :
3952+ wq , eq = self .DyagDinQ (iq )
3953+
3954+ w_q [:, iq ] = wq
3955+ pols_q [:, :, iq ] = eq
3956+ continue
3957+
3958+ # Check if this q = -q + G
3959+ is_minus_q = False
3960+ if Methods .get_min_dist_into_cell (bg , q , - q ) < 1e-6 :
3961+ is_minus_q = True
3962+
3963+ # The dynamical matrix must be real
3964+ re_part = np .real (self .dynmats [iq ])
3965+
3966+ assert np .max (np .abs (np .imag (self .dynmats [iq ]))) < __EPSILON__ , "Error, at point {} (q = -q + G) the dynamical matrix is complex" .format (iq )
3967+
3968+ # Enforce reality to avoid complex polarization vectors
3969+ self .dynmats [iq ] = re_part
3970+
3971+ # Check if this is gamma (to apply the LO-TO splitting)
3972+ if Methods .get_min_dist_into_cell (bg , q , np .zeros (3 )) < 1e-16 and lo_to_split is not None :
3973+ if self .effective_charges is None :
3974+ warnings .warn ("WARNING: Requested LO-TO splitting without effective charges. LO-TO ignored." )
3975+
3976+ # Initialize the Force Constant
3977+ t2 = ForceTensor .Tensor2 (self .structure , self .structure .generate_supercell (self .GetSupercell ()), self .GetSupercell ())
3978+ t2 .SetupFromPhonons (self )
3979+
3980+ if isinstance (lo_to_split , str ):
3981+ if lo_to_split .lower () == "random" :
3982+ fc_gamma = t2 .Interpolate (np .zeros (3 ))
3983+ else :
3984+ raise ValueError ("Error, lo_to_split argument '%s' not recognized" % lo_to_split )
3985+ else :
3986+ fc_gamma = t2 .Interpolate (np .zeros (3 ), q_direct = - lo_to_split )
3987+
3988+ _m_ = np .tile (self .structure .get_masses_array (), (3 ,1 )).T .ravel ()
3989+ d_gamma = fc_gamma / np .sqrt (np .outer (_m_ , _m_ ))
3990+ wq2 , eq = np .linalg .eigh (d_gamma )
3991+
3992+ wq = np .sqrt (np .abs (wq2 )) * np .sign (wq2 )
3993+ else :
3994+ # Diagonalize the matrix in the given q point
3995+ if timer is not None :
3996+ wq , eq = timer .execute_timed_function (self .DyagDinQ , iq )
3997+ else :
3998+ wq , eq = self .DyagDinQ (iq )
3999+
4000+ # Store the frequencies and the polarization vectors
4001+ w_q [:, iq ] = wq
4002+ pols_q [:, :, iq ] = eq
4003+
4004+ # OPTIMIZATION 3: Vectorized polarization vector construction
4005+ nm_q = i_mu
4006+ t1 = time .time ()
4007+
4008+ # Get all polarization vectors for this q-point
4009+ tilde_e_qnu_all = eq # (3*nat, 3*nat)
4010+
4011+ # Extract relevant components using itau_modes
4012+ e_contracted = tilde_e_qnu_all [itau_modes , :] # (nmodes, nmodes_unit)
4013+
4014+ # Get phase factors for this q-point
4015+ phase_q = phase_factors [:, iq ] # (nmodes,)
4016+
4017+ # Broadcast and multiply: e_sc = e_contracted * phase / sqrt(N_q)
4018+ c_e_sc = e_contracted * phase_q [:, np .newaxis ] / np .sqrt (supercell_size )
4019+ c_e_sc_mq = np .conj (c_e_sc )
4020+
4021+ # Compute real and imaginary parts vectorized
4022+ evec_1_all = np .real (0.5 * (c_e_sc + c_e_sc_mq )) # (nmodes, nmodes_unit)
4023+ evec_2_all = np .real ((c_e_sc - c_e_sc_mq ) / (2 * 1j )) # (nmodes, nmodes_unit)
4024+
4025+ # Compute norms vectorized
4026+ norm1_all = np .sum (evec_1_all ** 2 , axis = 0 ) # (nmodes_unit,)
4027+ norm2_all = np .sum (evec_2_all ** 2 , axis = 0 ) # (nmodes_unit,)
4028+
4029+ # Now iterate over modes but with pre-computed values
4030+ for i_qnu , w_qnu in enumerate (wq ):
4031+ norm1 = norm1_all [i_qnu ]
4032+ norm2 = norm2_all [i_qnu ]
4033+
4034+ evec_1 = evec_1_all [:, i_qnu ]
4035+ evec_2 = evec_2_all [:, i_qnu ]
4036+
4037+ EPSILON = 1e-5
4038+ add_1 = norm1 > EPSILON
4039+ add_2 = norm2 > EPSILON
4040+
4041+ if is_minus_q :
4042+ if add_1 and add_2 :
4043+ # Check linear dependence
4044+ scalar_dot = np .dot (evec_1 , evec_2 ) / np .sqrt (norm1 * norm2 )
4045+ if np .abs (np .abs (scalar_dot ) - 1 ) > EPSILON :
4046+ raise ValueError ("Error, with q = -q + G, the two vectors should be linearly dependent" )
4047+
4048+ # Keep the one with higher norm
4049+ if norm1 > norm2 :
4050+ add_2 = False
4051+ else :
4052+ add_1 = False
4053+ else :
4054+ if not (add_1 and add_2 ):
4055+ raise ValueError ("Error, the q_point = {} {} {} should contribute also for -q, something went wrong" .format (* list (q )))
4056+
4057+ if add_1 and add_2 :
4058+ # Since both real and imaginary should match in this case
4059+ # Add only one of them
4060+ if is_minus_q :
4061+ add_2 = False
4062+
4063+ # Add the vectors
4064+ if add_1 :
4065+ w_array [i_mu ] = w_qnu
4066+ e_pols_sc [:, i_mu ] = evec_1 / np .sqrt (norm1 )
4067+ i_mu += 1
4068+ if add_2 :
4069+ w_array [i_mu ] = w_qnu
4070+ e_pols_sc [:, i_mu ] = evec_2 / np .sqrt (norm2 )
4071+ i_mu += 1
4072+
4073+ t2 = time .time ()
4074+ if timer is not None :
4075+ timer .add_timer ("Manipulate polarization vectors" , t2 - t1 )
4076+
4077+ # Print how many vectors have been extracted
4078+ if verbose :
4079+ print ("The {} / {} q point produced {} nodes" .format (iq , len (self .q_tot ), i_mu - nm_q ))
4080+
4081+ # Sort the frequencies
4082+ sort_mask = np .argsort (w_array )
4083+ w_array = w_array [sort_mask ]
4084+ e_pols_sc = e_pols_sc [:, sort_mask ]
4085+
4086+ # Get the check for the polarization vector normalization
4087+ assert np .max (np .abs (np .einsum ("ab, ab->b" , e_pols_sc , e_pols_sc ) - 1 )) < __EPSILON__
4088+
4089+ if return_qmodes :
4090+ return w_array , e_pols_sc , w_q , pols_q
4091+ return w_array , e_pols_sc
4092+
4093+
38574094 def FixQPoints (self , threshold = 1e-1 ):
38584095 """
38594096 FIX Q POINTS
0 commit comments