From ca0917fc35f59017bea63c311631809e6536126c Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Thu, 26 Mar 2026 13:54:11 +0800 Subject: [PATCH] remove CORRECT_DIVMOD_FOLDING --- test/external/fuzz_symbolic.py | 12 +++++------- test/external/fuzz_symbolic_symbolic_div.py | 5 ++--- test/null/test_symbolic_failures.py | 7 ------- test/null/test_uop_symbolic.py | 21 ++++----------------- tinygrad/helpers.py | 2 +- tinygrad/uop/divandmod.py | 11 ++++++----- 6 files changed, 18 insertions(+), 40 deletions(-) diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index 641d5eed6eb16..e97840d0cf255 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -7,7 +7,7 @@ from tinygrad import Variable, dtypes from tinygrad.uop.ops import UOp from tinygrad.uop.validate import uops_to_z3 -from tinygrad.helpers import DEBUG, Context +from tinygrad.helpers import DEBUG seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100) print(f"Seed: {seed}", flush=True) @@ -56,8 +56,7 @@ def random_bool_expr(depth=10, expr1=None): v = [u1,u2,u3] expr = random_int_expr(6) - with Context(CORRECT_DIVMOD_FOLDING=1): - simplified_expr = expr.simplify() + simplified_expr = expr.simplify() solver = z3.Solver(ctx=z3.Context()) solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat @@ -74,10 +73,9 @@ def random_bool_expr(depth=10, expr1=None): m = solver.model() n1, n2, n3 = m[v1], m[v2], m[v3] u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long()) - with Context(CORRECT_DIVMOD_FOLDING=1): - num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() - rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() - if num==rn: print("z3 found a mismatch but the expressions are equal!!") + num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() + rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify() + if num==rn: print("z3 found a mismatch but the expressions are equal!!") assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\ "Reproduce with:\n" +\ f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\ diff --git a/test/external/fuzz_symbolic_symbolic_div.py b/test/external/fuzz_symbolic_symbolic_div.py index 7a70ee34cede6..deca50fdb89c8 100644 --- a/test/external/fuzz_symbolic_symbolic_div.py +++ b/test/external/fuzz_symbolic_symbolic_div.py @@ -2,7 +2,7 @@ import z3 from tinygrad.uop.ops import UOp, Ops from tinygrad.uop.validate import uops_to_z3 -from tinygrad.helpers import DEBUG, Context, colored +from tinygrad.helpers import DEBUG, colored seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100) print(f"Seed: {seed}", flush=True) @@ -36,8 +36,7 @@ def get_random_expr(ranges, factors): variable_names += [f"r{i}" for i in range(num_ranges)] expr = get_random_expr(ranges, factors) - with Context(CORRECT_DIVMOD_FOLDING=1): - simplified_expr = expr.simplify() + simplified_expr = expr.simplify() if DEBUG>=1: print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False)) diff --git a/test/null/test_symbolic_failures.py b/test/null/test_symbolic_failures.py index 8587bf2659756..3f349da344937 100644 --- a/test/null/test_symbolic_failures.py +++ b/test/null/test_symbolic_failures.py @@ -1,15 +1,8 @@ import unittest from tinygrad import Variable -from tinygrad.helpers import Context class TestFuzzFailure(unittest.TestCase): - def setUp(self): - self.context = Context(CORRECT_DIVMOD_FOLDING=1) - self.context.__enter__() - - def tearDown(self): - self.context.__exit__(None, None, None) def test_fuzz_failure1(self): v1=Variable('v1', 0, 8) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index ddda8e615290e..6412dd21a8c05 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -3,7 +3,6 @@ import z3 from tinygrad.dtype import dtypes, ConstType, DType, Invalid -from tinygrad.helpers import Context from test.helpers import get_uops from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid @@ -441,21 +440,11 @@ def test_sum_combine_num(self): def test_sum_num_hoisted_and_factors_cancel_out(self): self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") - @unittest.expectedFailure # only correct for floordiv, not truncdiv def test_div_cancel(self): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)") + self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)") - def test_div_cancel_correct(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)") - - @unittest.expectedFailure # only correct for floordiv, not truncdiv def test_mod_cancel(self): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)") - - def test_mod_cancel_correct(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)") + self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)") def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -542,8 +531,7 @@ def test_cmod_const_evaluation(self): self.helper_test_variable((-Variable("a", 10, 10))%7, -3, -3, "-3") def test_div_numerator_negative(self): - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") + self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)") def test_nest_div_negative_factor(self): ridx0=Variable("ridx0", 0, 9) @@ -723,8 +711,7 @@ def test_gcd_with_remainder(self): def test_div_by_factor_tie_break(self): a = Variable("a", 0, 1) b = Variable("b", 0, 1) - with Context(CORRECT_DIVMOD_FOLDING=1): - self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)") + self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)") def test_div_mod_recombine_large_coeff(self): # recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e92f4d8b11596..564d00bebd83f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -186,7 +186,7 @@ def tolist(self, obj=None): RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0) -CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0) +FUSE_OPTIM = ContextVar("FUSE_OPTIM", 0) ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0) MAX_KERNEL_BUFFERS = ContextVar("MAX_KERNEL_BUFFERS", 0) EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "") diff --git a/tinygrad/uop/divandmod.py b/tinygrad/uop/divandmod.py index 086b937b8d014..fa9e7e9178e9f 100644 --- a/tinygrad/uop/divandmod.py +++ b/tinygrad/uop/divandmod.py @@ -1,11 +1,11 @@ import functools, itertools, math from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp from tinygrad.dtype import dtypes -from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap +from tinygrad.helpers import cdiv, cmod, unwrap # NOTE: this cache is only on index UOps @functools.cache -def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: +def fold_divmod_general(d: UOp) -> UOp|None: x, y = d.src # cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval @@ -47,7 +47,8 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: return (y2-y1)*(v-v.vmin) + y1 # fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c - if not (x.vmin<0 and correct_divmod_folding): + # NOTE: relies on linearity of floor division which doesn't hold for truncation, only safe when x.vmin >= 0 + if x.vmin >= 0: # when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors] for rems in itertools.product(*rem_choices): @@ -66,7 +67,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: if x.vmin >= 0: results = [] for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}: - if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0: + if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0: if d.op is Ops.IDIV: results.append((len(newxs.backward_slice), newxs // (c // div))) else: @@ -115,7 +116,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None: lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None), # ** 2. Slow Rules ** - (UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))), + (UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d)), # NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops (UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),