From 4256f6bd20d0001514e807f2518d863e418e067d Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Fri, 12 Jun 2026 10:12:00 +0800 Subject: [PATCH 1/3] add no_grad decorator through NEGF process --- dpnegf/runner/NEGF.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/dpnegf/runner/NEGF.py b/dpnegf/runner/NEGF.py index dee3c4e..16b9132 100644 --- a/dpnegf/runner/NEGF.py +++ b/dpnegf/runner/NEGF.py @@ -35,7 +35,8 @@ # TODO : add common class to set all the dtype and precision. class NEGF(object): - def __init__(self, + @torch.no_grad() + def __init__(self, model: torch.nn.Module, structure: Union[AtomicData, ase.Atoms, str], ele_T: float, @@ -154,13 +155,12 @@ def __init__(self, unit = self.unit, results_path=self.results_path, torch_device = self.torch_device) - with torch.no_grad(): - # if useBloch is None, structure_leads_fold,bloch_sorted_indices,bloch_R_lists = None,None,None - struct_device, struct_leads,structure_leads_fold,bloch_sorted_indices,bloch_R_lists = \ - self.negf_hamiltonian.initialize(kpoints=self.kpoints, - block_tridiagnal=self.block_tridiagonal, plot_blocks=self.plot_blocks,\ - useBloch=self.useBloch,bloch_factor=self.bloch_factor, - use_saved_HS=self.use_saved_HS, saved_HS_path=self.saved_HS_path) + # if useBloch is None, structure_leads_fold,bloch_sorted_indices,bloch_R_lists = None,None,None + struct_device, struct_leads,structure_leads_fold,bloch_sorted_indices,bloch_R_lists = \ + self.negf_hamiltonian.initialize(kpoints=self.kpoints, + block_tridiagnal=self.block_tridiagonal, plot_blocks=self.plot_blocks,\ + useBloch=self.useBloch,bloch_factor=self.bloch_factor, + use_saved_HS=self.use_saved_HS, saved_HS_path=self.saved_HS_path) profiler.stop() output_path = os.path.join(self.results_path, "profile_report_ham_init.html") with open(output_path, 'w') as report_file: @@ -356,7 +356,8 @@ def generate_energy_grid(self): xu = torch.tensor(max(v_list)+8*self.kBT) self.int_grid, self.int_weight = gauss_xw(xl=xl, xu=xu, n=int(self.density_options["n_gauss"])) - def compute(self, + @torch.no_grad() + def compute(self, pcond: Optional[Interface3D]=None) -> Optional[Interface3D]: ''' compute the NEGF calculation, can also from the given Poisson From 8bf59758e7bf943b688486ea517ea9f57b786b55 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Fri, 12 Jun 2026 15:09:49 +0800 Subject: [PATCH 2/3] Refactor RGF kernel: 1) removing dead codes 2) add uniform-block path for uniform tri-diagonal matrix --- dpnegf/negf/recursive_green_cal.py | 303 +++++++++++++++++++---------- 1 file changed, 203 insertions(+), 100 deletions(-) diff --git a/dpnegf/negf/recursive_green_cal.py b/dpnegf/negf/recursive_green_cal.py index be8f472..f6dccfe 100644 --- a/dpnegf/negf/recursive_green_cal.py +++ b/dpnegf/negf/recursive_green_cal.py @@ -3,7 +3,8 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, sd, su, sl, s_in=0, s_out=0, eta=1e-5, - need_lesser=False, need_greater=False, need_gr_lc=False): + need_lesser=False, need_greater=False, need_gr_lc=False, + stacked=False): """The recursive Green's function algorithm is taken from M. P. Anantram, M. S. Lundstrom and D. E. Nikonov, Proceedings of the IEEE, 96, 1511 - 1550 (2008) DOI: 10.1109/JPROC.2008.927355 @@ -15,7 +16,6 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, In order to get the electron correlation function output, the parameters s_in has to be set. For the hole correlation function, the parameter s_out has to be set. - By default, the function would return the retarded Green's function blocks. Parameters @@ -23,11 +23,14 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, energy : torch.Tensor (dtype=torch.complex) Energy array of shape ``[B]``. mat_d_list : list of torch.Tensor (dtype=torch.complex) - List of diagonal blocks, each of shape ``[B, n_q, n_q]``. + List of diagonal blocks, each of shape ``[B, n_q, n_q]``. When + ``stacked=True``, a single ``[K, B, n, n]`` tensor with K = num blocks. mat_u_list : list of torch.Tensor (dtype=torch.complex) List of upper-diagonal blocks, each of shape ``[B, n_q, n_{q+1}]``. + When ``stacked=True``, a single ``[K-1, B, n, n]`` tensor. mat_l_list : list of torch.Tensor (dtype=torch.complex) List of lower-diagonal blocks, each of shape ``[B, n_{q+1}, n_q]``. + When ``stacked=True``, a single ``[K-1, B, n, n]`` tensor. s_in : (Default value = 0). When ``need_lesser`` is True, a list of ``[B, n_q, n_q]`` tensors. @@ -47,6 +50,12 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, gr_lc is used for lead spectral function A_L/ A_R = G^r * Gamma_L/R * G^a calculation. Although set need_gr_lc to True would not increase the computational cost of the recursive Green's function algorithm, it would increase the memory cost. If the memory cost is a concern, it is recommended to set need_gr_lc to False. + stacked : bool, optional + When True, ``mat_d_list``/``mat_u_list``/``mat_l_list`` and + ``sd``/``su``/``sl`` are single 4-D tensors of shape ``[K,B,n,n]`` + (resp. ``[K-1,B,n,n]``) instead of Python lists. The wrapper sets + this automatically when all device blocks share an `n`; the block + outputs are unbound back to lists so the caller contract is unchanged. Returns ------- @@ -57,8 +66,98 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, if need_greater: assert isinstance(s_out, list), "Greater Green's function calculation requires s_out to be a list of coupling matrices" - # energy enters as a [B]-shaped (or broadcastable) complex tensor; reshape - # so it multiplies into [B, n, n] block tensors along the batch dim. + # ------------------------------------------------------------------ + # Uniform-block path + # ------------------------------------------------------------------ + if stacked: + # mat_*/s* are 4-D tensors; do the energy shift in three fused ops + # over the leading-K dim instead of K Python iterations. + e_bcast = (energy + 1j * eta).view(1, -1, 1, 1) + mat_d = mat_d_list - e_bcast * sd + mat_l = mat_l_list - e_bcast * sl + mat_u = mat_u_list - e_bcast * su + + num_of_matrices = mat_d.shape[0] + ref = mat_d + B = ref.shape[1] + n = ref.shape[-1] + # Single identity reused across all K forward solves. + eye_bnn = torch.eye(n, dtype=ref.dtype, device=ref.device).expand(B, n, n) + + # ------------- left-connected retarded Green's function ------------ + gr_left = [None] * num_of_matrices + gr_left[0] = tLA.solve(-mat_d[0], eye_bnn) + for q in range(num_of_matrices - 1): # (B2) + gr_left[q + 1] = tLA.solve( + -mat_d[q + 1] - mat_l[q] @ gr_left[q] @ mat_u[q], + eye_bnn, + ) + + grl = [None] * (num_of_matrices - 1) + gru = [None] * (num_of_matrices - 1) + grd = [g.clone() for g in gr_left] + g_trans = gr_left[-1].clone() + gr_lc = [g_trans] if need_gr_lc else None + for q in range(num_of_matrices - 2, -1, -1): + gU = gr_left[q] @ mat_u[q] # hoisted: used 2-3x below + grl[q] = grd[q + 1] @ mat_l[q] @ gr_left[q] # (B5) + gru[q] = gU @ grd[q + 1] # (B6) + grd[q] = gr_left[q] + gU @ grl[q] # (B4) + g_trans = gU @ g_trans + if need_gr_lc: + gr_lc.append(g_trans) + if need_gr_lc: + gr_lc.reverse() + + gnd = gnl = gnu = gin_left = None + if need_lesser: + gin_left = [None] * num_of_matrices + # G^< = G^r * Sigma^< * G^a + gin_left[0] = gr_left[0] @ s_in[0] @ gr_left[0].mH + for q in range(num_of_matrices - 1): + sla2 = mat_l[q] @ gin_left[q] @ mat_u[q] + gin_left[q + 1] = gr_left[q + 1] @ (s_in[q + 1] + sla2) @ gr_left[q + 1].mH + + gnl = [None] * (num_of_matrices - 1) + gnu = [None] * (num_of_matrices - 1) + gnd = [g.clone() for g in gin_left] + for q in range(num_of_matrices - 2, -1, -1): + gLmH = mat_l[q] @ gr_left[q].mH # hoisted: used twice + gnl[q] = grd[q + 1] @ mat_l[q] @ gin_left[q] + gnd[q + 1] @ gLmH # (B10) + gnd[q] = gin_left[q] + \ + gr_left[q] @ mat_u[q] @ gnd[q + 1] @ gLmH + \ + (gin_left[q] @ mat_u[q] @ gru[q].mH) + \ + (gru[q] @ mat_l[q] @ gin_left[q]) # (B11) + gnu[q] = gnl[q].mH + + gpd = gpl = gpu = gip_left = None + if need_greater: + gip_left = [None] * num_of_matrices + gip_left[0] = gr_left[0] @ s_out[0] @ gr_left[0].conj() + for q in range(num_of_matrices - 1): + sla2 = mat_l[q] @ gip_left[q] @ mat_u[q].conj() + gip_left[q + 1] = gr_left[q + 1] @ (s_out[q + 1] + sla2) @ gr_left[q + 1].conj() + + gpl = [None] * (num_of_matrices - 1) + gpu = [None] * (num_of_matrices - 1) + gpd = [g.clone() for g in gip_left] + for q in range(num_of_matrices - 2, -1, -1): + lcgc = mat_l[q].conj() @ gr_left[q].conj() # hoisted: used twice + gpl[q] = grd[q + 1] @ mat_l[q] @ gip_left[q] + gpd[q + 1] @ lcgc + gpd[q] = gip_left[q] + \ + gr_left[q] @ mat_u[q] @ gpd[q + 1] @ lcgc + \ + (gip_left[q] @ mat_u[q].conj() @ grl[q].conj()) + \ + (gru[q] @ mat_l[q] @ gip_left[q]) + gpu[q] = gpl[q].mH + + return _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, + gnd, gnl, gnu, gin_left, + gpd, gpl, gpu, gip_left, + need_lesser, need_greater) + + # ------------------------------------------------------------------ + # Non-uniform-block path + # ------------------------------------------------------------------ e_bcast = (energy + 1j * eta).view(-1, 1, 1) for jj in range(len(mat_d_list)): @@ -73,133 +172,118 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, ref = mat_d_list[0] B = ref.shape[0] + eye_cache = {} def _batched_eye(n): - return torch.eye(n, dtype=ref.dtype, device=ref.device).expand(B, n, n) - - # ------------------------------------------------------------------- - # -------------- compute retarded Green's function ------------------ - # ------------------------------------------------------------------- - # Firstly calculate the left-connected retarded Green's function - gr_left = [None for _ in range(num_of_matrices)] + e = eye_cache.get(n) + if e is None: + e = torch.eye(n, dtype=ref.dtype, device=ref.device).expand(B, n, n) + eye_cache[n] = e + return e + + # ------------------ retarded Green's function ---------------------- + gr_left = [None] * num_of_matrices gr_left[0] = tLA.solve(-mat_d_list[0], _batched_eye(mat_shapes[0][-1])) - for q in range(num_of_matrices - 1): # Recursive algorithm (B2) + for q in range(num_of_matrices - 1): # (B2) gr_left[q + 1] = tLA.solve( -mat_d_list[q + 1] - mat_l_list[q] @ gr_left[q] @ mat_u_list[q], _batched_eye(mat_shapes[q + 1][-1]), ) - grl = [None for _ in range(num_of_matrices - 1)] - gru = [None for _ in range(num_of_matrices - 1)] + grl = [None] * (num_of_matrices - 1) + gru = [None] * (num_of_matrices - 1) grd = [i.clone() for i in gr_left] - g_trans = gr_left[len(gr_left) - 1].clone() - if need_gr_lc: - gr_lc = [gr_left[len(gr_left) - 1].clone()] - else: - gr_lc = None - for q in range(num_of_matrices - 2, -1, -1): # Recursive algorithm - grl[q] = grd[q + 1] @ mat_l_list[q] @ gr_left[q] # (B5) - gru[q] = gr_left[q] @ mat_u_list[q] @ grd[q + 1] # (B6) - grd[q] = gr_left[q] + gr_left[q] @ mat_u_list[q] @ grl[q] # (B4) - g_trans = gr_left[q] @ mat_u_list[q] @ g_trans + g_trans = gr_left[-1].clone() + gr_lc = [g_trans] if need_gr_lc else None + for q in range(num_of_matrices - 2, -1, -1): + gU = gr_left[q] @ mat_u_list[q] # hoisted + grl[q] = grd[q + 1] @ mat_l_list[q] @ gr_left[q] # (B5) + gru[q] = gU @ grd[q + 1] # (B6) + grd[q] = gr_left[q] + gU @ grl[q] # (B4) + g_trans = gU @ g_trans if need_gr_lc: gr_lc.append(g_trans) if need_gr_lc: gr_lc.reverse() - # ------------------------------------------------------------------- - # ------ compute the electron correlation function ( Lesser Green Function ) if needed -------- - # ------------------------------------------------------------------- - + gnd = gnl = gnu = gin_left = None if need_lesser: assert isinstance(s_in, list), "need_lesser=True requires s_in to be a list of coupling matrices" - gin_left = [None for _ in range(num_of_matrices)] - # Keldysh formula: G^< = G^r * Sigma^< * G^a ====> (-i * G^<) = G^r * (-i * Sigma^<) * G^a + gin_left = [None] * num_of_matrices + # G^< = G^r * Sigma^< * G^a gin_left[0] = gr_left[0] @ s_in[0] @ gr_left[0].mH for q in range(num_of_matrices - 1): sla2 = mat_l_list[q] @ gin_left[q] @ mat_u_list[q] - prom = s_in[q + 1] + sla2 - gin_left[q + 1] = gr_left[q + 1] @ prom @ gr_left[q + 1].mH + gin_left[q + 1] = gr_left[q + 1] @ (s_in[q + 1] + sla2) @ gr_left[q + 1].mH - gnl = [None for _ in range(num_of_matrices - 1)] - gnu = [None for _ in range(num_of_matrices - 1)] + gnl = [None] * (num_of_matrices - 1) + gnu = [None] * (num_of_matrices - 1) gnd = [i.clone() for i in gin_left] for q in range(num_of_matrices - 2, -1, -1): + gLmH = mat_l_list[q] @ gr_left[q].mH # hoisted gnl[q] = grd[q + 1] @ mat_l_list[q] @ gin_left[q] + \ - gnd[q + 1] @ mat_l_list[q] @ gr_left[q].mH # (B10) + gnd[q + 1] @ gLmH # (B10) gnd[q] = gin_left[q] + \ - gr_left[q] @ mat_u_list[q] @ gnd[q + 1] @ mat_l_list[q] @ gr_left[q].mH + \ - ((gin_left[q] @ mat_u_list[q] @ gru[q].mH) + - (gru[q] @ mat_l_list[q] @ gin_left[q])) # (B11) + gr_left[q] @ mat_u_list[q] @ gnd[q + 1] @ gLmH + \ + (gin_left[q] @ mat_u_list[q] @ gru[q].mH) + \ + (gru[q] @ mat_l_list[q] @ gin_left[q]) # (B11) gnu[q] = gnl[q].mH - # ------------------------------------------------------------------- - # -------- compute the hole correlation function if needed ---------- - # Only used when considering phase-breaking scattering - # ------------------------------------------------------------------- + gpd = gpl = gpu = gip_left = None if need_greater: assert isinstance(s_out, list), "need_greater=True requires s_out to be a list of coupling matrices" - gip_left = [None for _ in range(num_of_matrices)] + gip_left = [None] * num_of_matrices gip_left[0] = gr_left[0] @ s_out[0] @ gr_left[0].conj() for q in range(num_of_matrices - 1): sla2 = mat_l_list[q] @ gip_left[q] @ mat_u_list[q].conj() - prom = s_out[q + 1] + sla2 - gip_left[q + 1] = gr_left[q + 1] @ prom @ gr_left[q + 1].conj() + gip_left[q + 1] = gr_left[q + 1] @ (s_out[q + 1] + sla2) @ gr_left[q + 1].conj() - gpl = [None for _ in range(num_of_matrices - 1)] - gpu = [None for _ in range(num_of_matrices - 1)] + gpl = [None] * (num_of_matrices - 1) + gpu = [None] * (num_of_matrices - 1) gpd = [i.clone() for i in gip_left] for q in range(num_of_matrices - 2, -1, -1): + lcgc = mat_l_list[q].conj() @ gr_left[q].conj() # hoisted gpl[q] = grd[q + 1] @ mat_l_list[q] @ gip_left[q] + \ - gpd[q + 1] @ mat_l_list[q].conj() @ gr_left[q].conj() + gpd[q + 1] @ lcgc gpd[q] = gip_left[q] + \ - gr_left[q] @ mat_u_list[q] @ gpd[q + 1] @ mat_l_list[q].conj() @ gr_left[q].conj() + \ - ((gip_left[q] @ mat_u_list[q].conj() @ grl[q].conj()) + - (gru[q] @ mat_l_list[q] @ gip_left[q])) + gr_left[q] @ mat_u_list[q] @ gpd[q + 1] @ lcgc + \ + (gip_left[q] @ mat_u_list[q].conj() @ grl[q].conj()) + \ + (gru[q] @ mat_l_list[q] @ gip_left[q]) gpu[q] = gpl[q].mH - # ------------------------------------------------------------------- - # -- remove energy from the main diagonal of th Hamiltonian matrix -- - # ------------------------------------------------------------------- - - for jj in range(len(mat_d_list)): - mat_d_list[jj] = mat_d_list[jj] + e_bcast * sd[jj] - for jj in range(len(mat_l_list)): - mat_l_list[jj] = mat_l_list[jj] + e_bcast * sl[jj] - for jj in range(len(mat_u_list)): - mat_u_list[jj] = mat_u_list[jj] + e_bcast * su[jj] + return _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, + gnd, gnl, gnu, gin_left, + gpd, gpl, gpu, gip_left, + need_lesser, need_greater) - # ------------------------------------------------------------------- - # ---- choose a proper output depending on the list of arguments ---- - # ------------------------------------------------------------------- +def _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, + gnd, gnl, gnu, gin_left, + gpd, gpl, gpu, gip_left, + need_lesser, need_greater): if not need_lesser and not need_greater: return g_trans, gr_lc, \ grd, grl, gru, gr_left, \ None, None, None, None, \ None, None, None, None - - elif need_lesser and not need_greater: + if need_lesser and not need_greater: return g_trans, gr_lc, \ grd, grl, gru, gr_left, \ gnd, gnl, gnu, gin_left, \ None, None, None, None - - elif not need_lesser and need_greater: + if not need_lesser and need_greater: return g_trans, gr_lc, \ grd, grl, gru, gr_left, \ None, None, None, None, \ gpd, gpl, gpu, gip_left - - else: - return g_trans, gr_lc, \ - grd, grl, gru, gr_left, \ - gnd, gnl, gnu, gin_left, \ - gpd, gpl, gpu, gip_left + return g_trans, gr_lc, \ + grd, grl, gru, gr_left, \ + gnd, gnl, gnu, gin_left, \ + gpd, gpl, gpu, gip_left def recursive_gf(energy, hl, hd, hu, sd, su, sl, left_se, right_se, seP=None, E_ref=0.0, s_in=0, s_out=0, @@ -209,7 +293,7 @@ def recursive_gf(energy, hl, hd, hu, sd, su, sl, left_se, right_se, seP=None, E_ M. P. Anantram, M. S. Lundstrom and D. E. Nikonov, Proceedings of the IEEE, 96, 1511 - 1550 (2008) DOI: 10.1109/JPROC.2008.927355 - Obtain various Green's functions for later calculations. + Wrapper of RGF algorithm to obtain various Green's functions. Accepts either a scalar/0-d ``energy`` (legacy callers) or a 1-D ``[B]`` energy tensor. In the scalar case, all inputs are broadcast to B=1, the @@ -220,6 +304,14 @@ def recursive_gf(energy, hl, hd, hu, sd, su, sl, left_se, right_se, seP=None, E_ k-dependent blocks (``hd``, ``sd``, ``hl``, ``hu``, ``sl``, ``su``) may arrive 2-D and will be expanded to ``[B, ...]`` zero-copy. + When every diagonal block shares the same ``n``, the wrapper auto-detects + and stacks the K blocks into a single ``[K, B, n, n]`` tensor before + calling the kernel. The K-loop build step then collapses to one fused op, + the forward solve reuses one cached identity matrix, and per-step list + overhead in the backward sweeps disappears. Non-uniform geometries fall + through to the legacy list path. Outputs are unbound back to Python lists + so downstream callers see the same shape contract either way. + Parameters ---------- energy : torch.Tensor @@ -250,15 +342,16 @@ def recursive_gf(energy, hl, hd, hu, sd, su, sl, left_se, right_se, seP=None, E_ shift_energy = energy + E_ref if not torch.is_tensor(shift_energy): shift_energy = torch.as_tensor(shift_energy, dtype=torch.complex128) - # Legacy scalar callers pass either a 0-d tensor or a length-1 1-D tensor together with 2-D Hamiltonian / self-energy inputs. - # Batched callers pass a 1-D ``[B]`` energy together with 3-D ``[B, n, n]`` tensors. + # Legacy scalar callers pass either a 0-d tensor or a length-1 1-D tensor together with 2-D Hamiltonian / self-energy inputs. + # Batched callers pass a 1-D ``[B]`` energy together with 3-D ``[B, n, n]`` tensors. # Use the rank of ``left_se`` (or ``right_se``) as the disambiguator so the # wrapper can squeeze the batch dim back out for scalar callers. # se_probe is used to determine whether the self-energy inputs are 2-D (scalar energy case) or 3-D (batched energy case). se_probe = left_se if isinstance(left_se, torch.Tensor) else right_se squeezed = isinstance(se_probe, torch.Tensor) and se_probe.ndim == 2 - # if squeezed = True, the wrapper will squeeze the leading batch dim from every output tensor; + # if squeezed = True, the wrapper will squeeze the leading batch dim from every output tensor; # if False, the wrapper leaves the leading batch dim in place. + if shift_energy.ndim == 0: shift_energy = shift_energy.reshape(1) elif squeezed and shift_energy.ndim == 1 and shift_energy.shape[0] == 1: @@ -281,8 +374,6 @@ def _to_batch(t): seP_b = [_to_batch(seP[i]) if torch.is_tensor(seP[i]) else seP[i] for i in range(len(seP))] for i in range(len(temp_mat_d_list)): temp_mat_d_list[i] = temp_mat_d_list[i] + seP_b[i] - else: - seP_b = None if isinstance(left_se, torch.Tensor): left_se_b = _to_batch(left_se) @@ -290,8 +381,6 @@ def _to_batch(t): se01, se02 = left_se_b.shape[-2], left_se_b.shape[-1] idx0, idy0 = min(s01, se01), min(s02, se02) temp_mat_d_list[0][:, :idx0, :idy0] = temp_mat_d_list[0][:, :idx0, :idy0] + left_se_b[:, :idx0, :idy0] - else: - left_se_b = left_se if isinstance(right_se, torch.Tensor): right_se_b = _to_batch(right_se) @@ -299,8 +388,6 @@ def _to_batch(t): se11, se12 = right_se_b.shape[-2], right_se_b.shape[-1] idx1, idy1 = min(s11, se11), min(s12, se12) temp_mat_d_list[-1][:, -idx1:, -idy1:] = temp_mat_d_list[-1][:, -idx1:, -idy1:] + right_se_b[:, -idx1:, -idy1:] - else: - right_se_b = right_se # s_in / s_out arrive as lists when the lesser/greater paths are active. if isinstance(s_in, list): @@ -312,21 +399,37 @@ def _to_batch(t): else: s_out_b = s_out - ans = recursive_gf_cal(shift_energy, temp_mat_l_list, temp_mat_d_list, temp_mat_u_list, sd_b, su_b, sl_b, - s_in=s_in_b, s_out=s_out_b, eta=eta, - need_lesser=need_lesser, - need_greater=need_greater, - need_gr_lc=need_gr_lc) - - if isinstance(left_se, torch.Tensor): - temp_mat_d_list[0][:, :idx0, :idy0] = temp_mat_d_list[0][:, :idx0, :idy0] - left_se_b[:, :idx0, :idy0] - - if isinstance(right_se, torch.Tensor): - temp_mat_d_list[-1][:, -idx1:, -idy1:] = temp_mat_d_list[-1][:, -idx1:, -idy1:] - right_se_b[:, -idx1:, -idy1:] - - if seP is not None: - for i in range(len(temp_mat_d_list)): - temp_mat_d_list[i] = temp_mat_d_list[i] - seP_b[i] + # Auto-detect the uniform-block case: every D/L/U block must share the + # same n x n footprint. Stacking lets the kernel hit a single fused build + # step and reuse one cached identity across all K forward solves. + n0 = temp_mat_d_list[0].shape[-1] + uniform = ( + all(t.shape[-2] == n0 and t.shape[-1] == n0 for t in temp_mat_d_list) + and all(t.shape[-2] == n0 and t.shape[-1] == n0 for t in temp_mat_l_list) + and all(t.shape[-2] == n0 and t.shape[-1] == n0 for t in temp_mat_u_list) + ) + + if uniform and len(temp_mat_d_list) >= 2: + D = torch.stack(temp_mat_d_list, dim=0) # [K, B, n, n] + L = torch.stack(temp_mat_l_list, dim=0) # [K-1, B, n, n] + U = torch.stack(temp_mat_u_list, dim=0) # [K-1, B, n, n] + Sd = torch.stack(sd_b, dim=0) # [K, B, n, n] + Sl = torch.stack(sl_b, dim=0) # [K-1, B, n, n] + Su = torch.stack(su_b, dim=0) # [K-1, B, n, n] + ans = recursive_gf_cal(shift_energy, L, D, U, Sd, Su, Sl, + s_in=s_in_b, s_out=s_out_b, eta=eta, + need_lesser=need_lesser, + need_greater=need_greater, + need_gr_lc=need_gr_lc, + stacked=True) + else: + ans = recursive_gf_cal(shift_energy, temp_mat_l_list, temp_mat_d_list, temp_mat_u_list, + sd_b, su_b, sl_b, + s_in=s_in_b, s_out=s_out_b, eta=eta, + need_lesser=need_lesser, + need_greater=need_greater, + need_gr_lc=need_gr_lc, + stacked=False) if squeezed: ans = _squeeze_ans(ans) From d5f37df7feaf7a0dc821a95fefcf3d3811226de9 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Fri, 12 Jun 2026 15:10:08 +0800 Subject: [PATCH 3/3] update unit test for uniform blocks --- dpnegf/tests/test_recursive_gf_batched.py | 48 +++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/dpnegf/tests/test_recursive_gf_batched.py b/dpnegf/tests/test_recursive_gf_batched.py index 9d8cc64..8ee4ebb 100644 --- a/dpnegf/tests/test_recursive_gf_batched.py +++ b/dpnegf/tests/test_recursive_gf_batched.py @@ -149,6 +149,54 @@ def test_batched_matches_scalar_lesser(): _assert_ans_close(stacked, batched, atol=1e-10) +def test_batched_matches_scalar_retarded_uniform(): + # Uniform block sizes exercise the [K, B, n, n] stacked fast path. + B = 16 + block_sizes = [8, 8, 8, 8, 8] + hd, sd, hl, hu, sl, su, left_se, right_se, energies = _make_btd_inputs( + B, block_sizes, seed=3) + + stacked = _stack_scalar(hd, sd, hl, hu, sl, su, left_se, right_se, energies, + need_lesser=False) + + batched = recursive_gf( + energy=energies, + hl=hl, hd=hd, hu=hu, sd=sd, su=su, sl=sl, + left_se=left_se, right_se=right_se, + s_in=0, eta=1e-5, + need_lesser=False, need_gr_lc=True, + ) + + _assert_ans_close(stacked, batched, atol=1e-10) + + +def test_batched_matches_scalar_lesser_uniform(): + B = 8 + block_sizes = [6, 6, 6, 6] + hd, sd, hl, hu, sl, su, left_se, right_se, energies = _make_btd_inputs( + B, block_sizes, seed=4) + + rng = torch.Generator(device="cpu").manual_seed(5) + s_in_batched = [] + for n in block_sizes: + x = (torch.randn(B, n, n, generator=rng, dtype=torch.float64) + + 1j * torch.randn(B, n, n, generator=rng, dtype=torch.float64)).to(torch.complex128) + s_in_batched.append(0.5 * (x + x.mH)) + + stacked = _stack_scalar(hd, sd, hl, hu, sl, su, left_se, right_se, energies, + need_lesser=True, s_in_batched=s_in_batched) + + batched = recursive_gf( + energy=energies, + hl=hl, hd=hd, hu=hu, sd=sd, su=su, sl=sl, + left_se=left_se, right_se=right_se, + s_in=s_in_batched, eta=1e-5, + need_lesser=True, need_gr_lc=True, + ) + + _assert_ans_close(stacked, batched, atol=1e-10) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_batched_cuda_smoke(): B = 16