From 358a4591da876dc4eded8cc371c366f5fbebae6d Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 10:05:46 +0800 Subject: [PATCH 1/8] Refactor RGF kernel to reduce memory using --- dpnegf/negf/recursive_green_cal.py | 47 +++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/dpnegf/negf/recursive_green_cal.py b/dpnegf/negf/recursive_green_cal.py index f6dccfe..452f0b5 100644 --- a/dpnegf/negf/recursive_green_cal.py +++ b/dpnegf/negf/recursive_green_cal.py @@ -4,7 +4,7 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, sd, su, sl, s_in=0, s_out=0, eta=1e-5, need_lesser=False, need_greater=False, need_gr_lc=False, - stacked=False): + stacked=False, keep_gr_left=True): """The recursive Green's function algorithm is taken from M. P. Anantram, M. S. Lundstrom and D. E. Nikonov, Proceedings of the IEEE, 96, 1511 - 1550 (2008) DOI: 10.1109/JPROC.2008.927355 @@ -92,10 +92,13 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, -mat_d[q + 1] - mat_l[q] @ gr_left[q] @ mat_u[q], eye_bnn, ) + # mat_d is dead after the forward sweep — backward sweep only reads mat_l/mat_u. + del mat_d grl = [None] * (num_of_matrices - 1) gru = [None] * (num_of_matrices - 1) - grd = [g.clone() for g in gr_left] + grd = [None] * num_of_matrices + grd[-1] = gr_left[-1].clone() g_trans = gr_left[-1].clone() gr_lc = [g_trans] if need_gr_lc else None for q in range(num_of_matrices - 2, -1, -1): @@ -120,7 +123,8 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, gnl = [None] * (num_of_matrices - 1) gnu = [None] * (num_of_matrices - 1) - gnd = [g.clone() for g in gin_left] + gnd = [None] * num_of_matrices + gnd[-1] = gin_left[-1].clone() for q in range(num_of_matrices - 2, -1, -1): gLmH = mat_l[q] @ gr_left[q].mH # hoisted: used twice gnl[q] = grd[q + 1] @ mat_l[q] @ gin_left[q] + gnd[q + 1] @ gLmH # (B10) @@ -140,7 +144,8 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, gpl = [None] * (num_of_matrices - 1) gpu = [None] * (num_of_matrices - 1) - gpd = [g.clone() for g in gip_left] + gpd = [None] * num_of_matrices + gpd[-1] = gip_left[-1].clone() for q in range(num_of_matrices - 2, -1, -1): lcgc = mat_l[q].conj() @ gr_left[q].conj() # hoisted: used twice gpl[q] = grd[q + 1] @ mat_l[q] @ gip_left[q] + gpd[q + 1] @ lcgc @@ -150,6 +155,8 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, (gru[q] @ mat_l[q] @ gip_left[q]) gpu[q] = gpl[q].mH + if not keep_gr_left: + gr_left = None return _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, gnd, gnl, gnu, gin_left, gpd, gpl, gpu, gip_left, @@ -161,7 +168,9 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, e_bcast = (energy + 1j * eta).view(-1, 1, 1) for jj in range(len(mat_d_list)): - mat_d_list[jj] = mat_d_list[jj] - e_bcast * sd[jj] + # In-place: mat_d_list is a fresh tensor (wrapper's `* 1.` copy on D), + # so we can fuse the energy shift without the e_bcast*sd transient. + mat_d_list[jj].addcmul_(sd[jj], e_bcast, value=-1) for jj in range(len(mat_l_list)): mat_l_list[jj] = mat_l_list[jj] - e_bcast * sl[jj] for jj in range(len(mat_u_list)): @@ -183,16 +192,19 @@ def _batched_eye(n): # ------------------ retarded Green's function ---------------------- gr_left = [None] * num_of_matrices gr_left[0] = tLA.solve(-mat_d_list[0], _batched_eye(mat_shapes[0][-1])) + mat_d_list[0] = None # consumed; free immediately for q in range(num_of_matrices - 1): # (B2) gr_left[q + 1] = tLA.solve( -mat_d_list[q + 1] - mat_l_list[q] @ gr_left[q] @ mat_u_list[q], _batched_eye(mat_shapes[q + 1][-1]), ) + mat_d_list[q + 1] = None # consumed; backward sweep only reads mat_l/mat_u. grl = [None] * (num_of_matrices - 1) gru = [None] * (num_of_matrices - 1) - grd = [i.clone() for i in gr_left] + grd = [None] * num_of_matrices + grd[-1] = gr_left[-1].clone() g_trans = gr_left[-1].clone() gr_lc = [g_trans] if need_gr_lc else None for q in range(num_of_matrices - 2, -1, -1): @@ -219,7 +231,8 @@ def _batched_eye(n): gnl = [None] * (num_of_matrices - 1) gnu = [None] * (num_of_matrices - 1) - gnd = [i.clone() for i in gin_left] + gnd = [None] * num_of_matrices + gnd[-1] = gin_left[-1].clone() for q in range(num_of_matrices - 2, -1, -1): gLmH = mat_l_list[q] @ gr_left[q].mH # hoisted @@ -243,7 +256,8 @@ def _batched_eye(n): gpl = [None] * (num_of_matrices - 1) gpu = [None] * (num_of_matrices - 1) - gpd = [i.clone() for i in gip_left] + gpd = [None] * num_of_matrices + gpd[-1] = gip_left[-1].clone() for q in range(num_of_matrices - 2, -1, -1): lcgc = mat_l_list[q].conj() @ gr_left[q].conj() # hoisted @@ -255,6 +269,8 @@ def _batched_eye(n): (gru[q] @ mat_l_list[q] @ gip_left[q]) gpu[q] = gpl[q].mH + if not keep_gr_left: + gr_left = None return _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, gnd, gnl, gnu, gin_left, gpd, gpl, gpu, gip_left, @@ -287,7 +303,8 @@ def _pack_ans(g_trans, gr_lc, grd, grl, gru, gr_left, def recursive_gf(energy, hl, hd, hu, sd, su, sl, left_se, right_se, seP=None, E_ref=0.0, s_in=0, s_out=0, - eta=1e-5, need_lesser=False, need_greater=False, need_gr_lc=False): + eta=1e-5, need_lesser=False, need_greater=False, need_gr_lc=False, + keep_gr_left=True): """The recursive Green's function algorithm is taken from M. P. Anantram, M. S. Lundstrom and D. E. Nikonov, Proceedings of the IEEE, 96, 1511 - 1550 (2008) @@ -364,8 +381,10 @@ def _to_batch(t): return t temp_mat_d_list = [_to_batch(hd[i]) * 1. for i in range(len(hd))] - temp_mat_l_list = [_to_batch(hl[i]) * 1. for i in range(len(hl))] - temp_mat_u_list = [_to_batch(hu[i]) * 1. for i in range(len(hu))] + # L and U are only subtracted out-of-place inside the kernel; the expanded + # view is fine, and skipping the copy saves K x B x n^2 per list. + temp_mat_l_list = [_to_batch(hl[i]) for i in range(len(hl))] + temp_mat_u_list = [_to_batch(hu[i]) for i in range(len(hu))] sd_b = [_to_batch(sd[i]) for i in range(len(sd))] sl_b = [_to_batch(sl[i]) for i in range(len(sl))] su_b = [_to_batch(su[i]) for i in range(len(su))] @@ -421,7 +440,8 @@ def _to_batch(t): need_lesser=need_lesser, need_greater=need_greater, need_gr_lc=need_gr_lc, - stacked=True) + stacked=True, + keep_gr_left=keep_gr_left) else: ans = recursive_gf_cal(shift_energy, temp_mat_l_list, temp_mat_d_list, temp_mat_u_list, sd_b, su_b, sl_b, @@ -429,7 +449,8 @@ def _to_batch(t): need_lesser=need_lesser, need_greater=need_greater, need_gr_lc=need_gr_lc, - stacked=False) + stacked=False, + keep_gr_left=keep_gr_left) if squeezed: ans = _squeeze_ans(ans) From 26955a246936e818a41786c6578c1f5ea742e900 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 14:06:47 +0800 Subject: [PATCH 2/8] set need_gr_lc default to false --- dpnegf/negf/device_property.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnegf/negf/device_property.py b/dpnegf/negf/device_property.py index b82a165..333d1d1 100644 --- a/dpnegf/negf/device_property.py +++ b/dpnegf/negf/device_property.py @@ -135,7 +135,7 @@ def set_leadLR(self, lead_L, lead_R): def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=True, Vbias=None, - HS_inmem:bool=True, need_lesser:bool=False, need_greater:bool=False, need_gr_lc:bool=True): + HS_inmem:bool=True, need_lesser:bool=False, need_greater:bool=False, need_gr_lc:bool=False): ''' computes the Green's function for a given energy and k-point in device. the tags used here to identify different Green's functions follows the NEGF theory From 2d46fc9054b3e81c7b7350e70bb0eab34b76929d Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 14:48:04 +0800 Subject: [PATCH 3/8] add 'release_greenfuncs' --- dpnegf/negf/device_property.py | 18 ++++++++++-------- dpnegf/runner/NEGF.py | 13 ++++++++----- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/dpnegf/negf/device_property.py b/dpnegf/negf/device_property.py index 333d1d1..ad1067a 100644 --- a/dpnegf/negf/device_property.py +++ b/dpnegf/negf/device_property.py @@ -290,6 +290,16 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr # self.green = update_temp_file(update_fn=fn, file_path=GFpath, ee=ee, tags=tags, info="Computing Green's Function") + def release_greenfuncs(self): + '''Drop the Green's-function dict so the underlying rgf_device storage + can be freed before the next energy chunk. H/S blocks are kept resident + (they are k,V-dependent, not energy-dependent). The runner is + responsible for restoring scalar lead.se references before calling + this, so any batched [B,n,n] copies become collectable too.''' + self.greenfuncs = 0 + if isinstance(self.rgf_device, torch.device) and self.rgf_device.type == "cuda": + torch.cuda.empty_cache() + def _cal_current_(self, espacing): '''calculate the current based on the voltage difference @@ -312,14 +322,6 @@ def _cal_current_(self, espacing): xl = min(v_L, v_R)-4*self.kBT xu = max(v_L, v_R)+4*self.kBT - def fcn(e): - self.cal_green_function() - - cc = leggauss(fcn=self._cal_tc_) - - int_grid, int_weight = gauss_xw(xl=xl, xu=xu, n=int((xu-xl)/espacing)) - - self.__CURRENT__ = simpson(y=(self.lead_L.fermi_dirac(self.ee+self.E_ref) - self.lead_R.fermi_dirac(self.ee+self.E_ref)) * self.tc, x=self.ee) diff --git a/dpnegf/runner/NEGF.py b/dpnegf/runner/NEGF.py index cad05cc..a847072 100644 --- a/dpnegf/runner/NEGF.py +++ b/dpnegf/runner/NEGF.py @@ -741,19 +741,22 @@ def negf_compute(self,scf_require=False,Vbias=None): ) if self.out_dos: - self.out.setdefault('DOS', {}).setdefault(str(k), []).append(self.compute_DOS(k).reshape(-1)) + self.out.setdefault('DOS', {}).setdefault(str(k), []).append(self.compute_DOS(k).reshape(-1).cpu()) if self.out_tc or self.out_current_nscf: - self.out.setdefault('T_k', {}).setdefault(str(k), []).append(self.compute_TC(k).reshape(-1)) + self.out.setdefault('T_k', {}).setdefault(str(k), []).append(self.compute_TC(k).reshape(-1).cpu()) if self.out_ldos: ldos_chunk = self.compute_LDOS(k) if ldos_chunk.ndim == 1: # scalar-E chunk → [na] ldos_chunk = ldos_chunk.unsqueeze(0) - self.out.setdefault('LDOS', {}).setdefault(str(k), []).append(ldos_chunk) + self.out.setdefault('LDOS', {}).setdefault(str(k), []).append(ldos_chunk.cpu()) - # Restore lead.se to scalar [n,n] so downstream scalar callers - # (density modules, lcurrent loop, future SCF re-entry) see the expected shape. + # Restore lead.se to scalar [n,n] before releasing the GF dict so + # the batched [B,n,n] GPU copies become collectable, and so any + # subsequent scalar caller (density modules, lcurrent loop, future + # SCF re-entry) sees the expected shape. self.deviceprop.lead_L.se = seL_list[-1] self.deviceprop.lead_R.se = seR_list[-1] + self.deviceprop.release_greenfuncs() # over energy loop in uni_gird # The following code is for output properties before NEGF ends From 57a70c9ce8b362ff40a49ccdd5f0f7ed7f437429 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 21:55:57 +0800 Subject: [PATCH 4/8] remove dead slots --- dpnegf/negf/device_property.py | 7 ++++++- dpnegf/negf/recursive_green_cal.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/dpnegf/negf/device_property.py b/dpnegf/negf/device_property.py index ad1067a..8146d9b 100644 --- a/dpnegf/negf/device_property.py +++ b/dpnegf/negf/device_property.py @@ -273,11 +273,16 @@ def cal_green_function(self, energy, kpoint, eta_device=0., block_tridiagonal=Tr else: s_in = 0 + # gr_left is only consumed inside the lesser/greater forward pass of the + # kernel. If neither is active, the per-block list would sit on the GPU + # unread; ask the kernel to drop it so its slots are freed mid-sweep. + keep_gr_left = bool(need_lesser or need_greater) ans = recursive_gf(energy, hl=self.hl, hd=self.hd, hu=self.hu, sd=self.sd, su=self.su, sl=self.sl, left_se=seL, right_se=seR, seP=None, s_in=s_in, s_out=None, eta=eta_device, E_ref=self.E_ref, - need_lesser=need_lesser, need_greater=need_greater, need_gr_lc=need_gr_lc) + need_lesser=need_lesser, need_greater=need_greater, + need_gr_lc=need_gr_lc, keep_gr_left=keep_gr_left) # green shape [[g_trans, grd, grl,...],[g_trans, ...]] for t in range(len(tags)): diff --git a/dpnegf/negf/recursive_green_cal.py b/dpnegf/negf/recursive_green_cal.py index 452f0b5..3cefcca 100644 --- a/dpnegf/negf/recursive_green_cal.py +++ b/dpnegf/negf/recursive_green_cal.py @@ -109,8 +109,13 @@ def recursive_gf_cal(energy, mat_l_list, mat_d_list, mat_u_list, g_trans = gU @ g_trans if need_gr_lc: gr_lc.append(g_trans) + del gU if need_gr_lc: gr_lc.reverse() + if not need_lesser and not need_greater: + # Stacked path: mat_l/mat_u are 4-D and can't be slice-freed mid-loop; + # they are dead now if no lesser/greater pass will read them. + del mat_l, mat_u gnd = gnl = gnu = gin_left = None if need_lesser: @@ -207,6 +212,14 @@ def _batched_eye(n): grd[-1] = gr_left[-1].clone() g_trans = gr_left[-1].clone() gr_lc = [g_trans] if need_gr_lc else None + # Slots that go dead at the end of iteration q: + # - mat_l_list[q], mat_u_list[q]: only re-read by the lesser/greater branches. + # - gr_left[q]: dead unless the lesser/greater branch will consume it OR + # the caller asked us to keep the list intact. + # Nulling per slot lets the caching allocator coalesce its free list inside the + # loop instead of holding a long fragmented tail until the sweep ends. + drop_lu = not need_lesser and not need_greater + drop_gl = drop_lu and not keep_gr_left for q in range(num_of_matrices - 2, -1, -1): gU = gr_left[q] @ mat_u_list[q] # hoisted grl[q] = grd[q + 1] @ mat_l_list[q] @ gr_left[q] # (B5) @@ -215,6 +228,12 @@ def _batched_eye(n): g_trans = gU @ g_trans if need_gr_lc: gr_lc.append(g_trans) + del gU + if drop_lu: + mat_l_list[q] = None + mat_u_list[q] = None + if drop_gl: + gr_left[q] = None if need_gr_lc: gr_lc.reverse() From 5b883a703690880b7fc6f56abc1c6c827daefc8f Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 21:58:19 +0800 Subject: [PATCH 5/8] add _auto_chunk_size --- dpnegf/runner/NEGF.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/dpnegf/runner/NEGF.py b/dpnegf/runner/NEGF.py index a847072..7e99440 100644 --- a/dpnegf/runner/NEGF.py +++ b/dpnegf/runner/NEGF.py @@ -585,6 +585,37 @@ def prepare_self_energy(self, scf_require: bool) -> None: + def _auto_chunk_size(self, n_grid): + """Pick a chunk size from free CUDA memory when the user didn't set + ``e_batch_size``. Returns the full grid length on CPU / when the + device geometry isn't probable yet. + + Per-energy peak (post per-slot-release, complex128) approximated as + bytes_per_E ~= C * K * n_max**2 * 16 + with C bundling the live tensors in the worst backward-sweep slot + (grd full + grl + gru full + decaying gr_left tail + gU + transients). + C=10 with a 0.7x free-memory budget; deliberately conservative because + without expandable_segments the allocator can't defragment on demand. + """ + rgf_dev = self.rgf_device + if not (isinstance(rgf_dev, torch.device) and rgf_dev.type == "cuda"): + return n_grid + try: + free_bytes, _total = torch.cuda.mem_get_info(rgf_dev) + n_max = max(int(b.shape[-1]) for b in self.deviceprop.hd) + K = len(self.deviceprop.hd) + except Exception: + return n_grid + per_e = 10 * K * (n_max ** 2) * 16 + if per_e <= 0: + return n_grid + b = max(1, min(n_grid, int(0.7 * free_bytes) // per_e)) + log.info( + f"auto e_batch_size={b} (free={free_bytes/2**30:.2f} GiB, " + f"per_E~={per_e/2**20:.1f} MiB, K={K}, n_max={n_max})" + ) + return b + def negf_compute(self,scf_require=False,Vbias=None): assert scf_require is not None, "scf_require should be set to True or False" @@ -704,7 +735,10 @@ def negf_compute(self,scf_require=False,Vbias=None): self.out.setdefault('LDOS', {}).setdefault(str(k), []).append(self.compute_LDOS(k)) else: # Non-SCF: solve a whole chunk of energies in one batched recursive_gf call. - chunk = self.e_batch_size if self.e_batch_size is not None else len(self.uni_grid) + if self.e_batch_size is not None: + chunk = self.e_batch_size + else: + chunk = self._auto_chunk_size(len(self.uni_grid)) for e_chunk in torch.split(self.uni_grid, chunk): e_batch_size = len(e_chunk) log.info( From 245ebc84437f90c398f5214a2a2f6be42d2b7fab Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 21:59:02 +0800 Subject: [PATCH 6/8] remove unnecessary self-energy terms --- dpnegf/runner/NEGF.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/dpnegf/runner/NEGF.py b/dpnegf/runner/NEGF.py index 7e99440..6ac2d07 100644 --- a/dpnegf/runner/NEGF.py +++ b/dpnegf/runner/NEGF.py @@ -784,12 +784,21 @@ def negf_compute(self,scf_require=False,Vbias=None): ldos_chunk = ldos_chunk.unsqueeze(0) self.out.setdefault('LDOS', {}).setdefault(str(k), []).append(ldos_chunk.cpu()) - # Restore lead.se to scalar [n,n] before releasing the GF dict so - # the batched [B,n,n] GPU copies become collectable, and so any - # subsequent scalar caller (density modules, lcurrent loop, future - # SCF re-entry) sees the expected shape. - self.deviceprop.lead_L.se = seL_list[-1] - self.deviceprop.lead_R.se = seR_list[-1] + # Restore lead.se to a scalar [n,n] before releasing the GF + # dict. For B>1 we clone the last per-E tensor so the new + # lead.se doesn't share storage with anything still + # referenced through seL_list/seR_list, then drop both + # lists so release_greenfuncs's empty_cache() has the per-E + # and stacked [B,n,n] copies to release. + if e_batch_size > 1: + self.deviceprop.lead_L.se = seL_list[-1].detach().clone() + self.deviceprop.lead_R.se = seR_list[-1].detach().clone() + else: + # B=1 path: lead.se already IS the per-E [n,n] tensor; + # preserve byte-identical behavior for the scalar case. + self.deviceprop.lead_L.se = seL_list[-1] + self.deviceprop.lead_R.se = seR_list[-1] + del seL_list, seR_list self.deviceprop.release_greenfuncs() # over energy loop in uni_gird From 8dfb109a135c58c706bb84029b500750e432afbb Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 21:59:21 +0800 Subject: [PATCH 7/8] add log warning for rgf_device --- dpnegf/runner/NEGF.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/dpnegf/runner/NEGF.py b/dpnegf/runner/NEGF.py index 6ac2d07..77520bd 100644 --- a/dpnegf/runner/NEGF.py +++ b/dpnegf/runner/NEGF.py @@ -69,6 +69,20 @@ def __init__(self, self.rgf_device = rgf_device self.n_cpus = n_cpus self.e_batch_size = e_batch_size + + # The RGF q-loop allocates/frees many small slabs; with the default + # cudaMalloc-backed caching allocator this fragments quickly on long + # energy grids. expandable_segments avoids that, but must be set before + # torch initializes its CUDA context — by the time we get here it's + # already live, so we can only nudge the user. + if isinstance(self.rgf_device, torch.device) and self.rgf_device.type == "cuda": + if "expandable_segments" not in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""): + log.warning( + "RGF on CUDA can fragment the caching allocator on long energy " + "grids. Consider exporting " + "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True BEFORE invoking " + "dpnegf (must be set before torch's CUDA context initializes)." + ) # get the parameters self.ele_T = ele_T From 7c1cc1876e7951c9dd0b2d74576c079e264b1032 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Mon, 15 Jun 2026 21:59:47 +0800 Subject: [PATCH 8/8] update batched rgf --- dpnegf/tests/test_recursive_gf_batched.py | 90 +++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/dpnegf/tests/test_recursive_gf_batched.py b/dpnegf/tests/test_recursive_gf_batched.py index 8ee4ebb..f561b9f 100644 --- a/dpnegf/tests/test_recursive_gf_batched.py +++ b/dpnegf/tests/test_recursive_gf_batched.py @@ -230,3 +230,93 @@ def test_batched_cuda_smoke(): for cq, gq in zip(c, g): assert gq.device.type == "cuda" assert torch.allclose(cq, gq.cpu(), atol=1e-8) + + +def _ans_equal_skipping_gr_left(a, b): + """Compare two recursive_gf outputs slot-by-slot. Position 5 (gr_left) is + expected to differ when one call uses keep_gr_left=False — skip it. + Everything else must be bit-identical.""" + for i, (x, y) in enumerate(zip(a, b)): + if i == 5: # gr_left slot + continue + if x is None: + assert y is None + continue + if torch.is_tensor(x): + assert torch.equal(x, y), f"tuple slot {i} mismatch" + else: + for q, (xq, yq) in enumerate(zip(x, y)): + if xq is None: + assert yq is None + continue + assert torch.equal(xq, yq), f"tuple slot {i}, block {q} mismatch" + + +def test_keep_gr_left_false_outputs_match_true(): + """Lever 1 + 2 release path: with need_lesser=False, need_greater=False, the + keep_gr_left=False call must produce the same tensors (modulo gr_left) as + the keep_gr_left=True call. Runs on CPU; correctness, not memory.""" + B = 8 + block_sizes = [8, 6, 8] + hd, sd, hl, hu, sl, su, left_se, right_se, energies = _make_btd_inputs( + B, block_sizes, seed=11) + + common = dict(energy=energies, hl=hl, hd=hd, hu=hu, sd=sd, su=su, sl=sl, + left_se=left_se, right_se=right_se, + s_in=0, eta=1e-5, + need_lesser=False, need_greater=False, need_gr_lc=False) + ans_keep = recursive_gf(**common, keep_gr_left=True) + ans_drop = recursive_gf(**common, keep_gr_left=False) + + _ans_equal_skipping_gr_left(ans_keep, ans_drop) + assert ans_drop[5] is None # gr_left dropped + assert isinstance(ans_keep[5], list) # gr_left populated + + +def test_keep_gr_left_false_outputs_match_true_uniform(): + """Same equivalence check but on the uniform [K,B,n,n] fast path (Lever 1b).""" + B = 8 + block_sizes = [8, 8, 8, 8] + hd, sd, hl, hu, sl, su, left_se, right_se, energies = _make_btd_inputs( + B, block_sizes, seed=12) + + common = dict(energy=energies, hl=hl, hd=hd, hu=hu, sd=sd, su=su, sl=sl, + left_se=left_se, right_se=right_se, + s_in=0, eta=1e-5, + need_lesser=False, need_greater=False, need_gr_lc=False) + ans_keep = recursive_gf(**common, keep_gr_left=True) + ans_drop = recursive_gf(**common, keep_gr_left=False) + + _ans_equal_skipping_gr_left(ans_keep, ans_drop) + assert ans_drop[5] is None + assert isinstance(ans_keep[5], list) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_peak_memory_decreases_with_keep_gr_left_false(): + """Lever 1 (per-slot release) + Lever 2 (keep_gr_left wiring) regression. + On a synthetic batched problem, keep_gr_left=False must use strictly less + CUDA peak memory than keep_gr_left=True, and outputs (excluding gr_left) + must be bit-identical.""" + B, block_sizes = 8, [16, 16, 16, 16, 16, 16, 16, 16] + hd, sd, hl, hu, sl, su, left_se, right_se, energies = _make_btd_inputs( + B, block_sizes, seed=13, device="cuda") + + common = dict(energy=energies, hl=hl, hd=hd, hu=hu, sd=sd, su=su, sl=sl, + left_se=left_se, right_se=right_se, + s_in=0, eta=1e-5, + need_lesser=False, need_greater=False, need_gr_lc=False) + + torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() + ans_keep = recursive_gf(**common, keep_gr_left=True) + peak_keep = torch.cuda.max_memory_allocated() + + torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() + ans_drop = recursive_gf(**common, keep_gr_left=False) + peak_drop = torch.cuda.max_memory_allocated() + + _ans_equal_skipping_gr_left(ans_keep, ans_drop) + assert peak_drop < peak_keep, ( + f"keep_gr_left=False peak {peak_drop} >= keep_gr_left=True peak {peak_keep}; " + "per-slot release did not reduce memory." + )