Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions test/external/fuzz_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad import Variable, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context
from tinygrad.helpers import DEBUG

seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
Expand Down Expand Up @@ -56,8 +56,7 @@ def random_bool_expr(depth=10, expr1=None):
v = [u1,u2,u3]
expr = random_int_expr(6)

with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()

solver = z3.Solver(ctx=z3.Context())
solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat
Expand All @@ -74,10 +73,9 @@ def random_bool_expr(depth=10, expr1=None):
m = solver.model()
n1, n2, n3 = m[v1], m[v2], m[v3]
u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long())
with Context(CORRECT_DIVMOD_FOLDING=1):
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\
"Reproduce with:\n" +\
f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\
Expand Down
5 changes: 2 additions & 3 deletions test/external/fuzz_symbolic_symbolic_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import z3
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context, colored
from tinygrad.helpers import DEBUG, colored

seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
Expand Down Expand Up @@ -36,8 +36,7 @@ def get_random_expr(ranges, factors):
variable_names += [f"r{i}" for i in range(num_ranges)]
expr = get_random_expr(ranges, factors)

with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()

if DEBUG>=1:
print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False))
Expand Down
7 changes: 0 additions & 7 deletions test/null/test_symbolic_failures.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import unittest
from tinygrad import Variable
from tinygrad.helpers import Context


class TestFuzzFailure(unittest.TestCase):
def setUp(self):
self.context = Context(CORRECT_DIVMOD_FOLDING=1)
self.context.__enter__()

def tearDown(self):
self.context.__exit__(None, None, None)

def test_fuzz_failure1(self):
v1=Variable('v1', 0, 8)
Expand Down
21 changes: 4 additions & 17 deletions test/null/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import z3

from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.helpers import Context
from test.helpers import get_uops
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
Expand Down Expand Up @@ -441,21 +440,11 @@ def test_sum_combine_num(self):
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")

@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_div_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)")

def test_div_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)")

@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_mod_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")

def test_mod_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)")

def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
Expand Down Expand Up @@ -542,8 +531,7 @@ def test_cmod_const_evaluation(self):
self.helper_test_variable((-Variable("a", 10, 10))%7, -3, -3, "-3")

def test_div_numerator_negative(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")

def test_nest_div_negative_factor(self):
ridx0=Variable("ridx0", 0, 9)
Expand Down Expand Up @@ -723,8 +711,7 @@ def test_gcd_with_remainder(self):
def test_div_by_factor_tie_break(self):
a = Variable("a", 0, 1)
b = Variable("b", 0, 1)
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")

def test_div_mod_recombine_large_coeff(self):
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def tolist(self, obj=None):
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
FUSE_OPTIM = ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
MAX_KERNEL_BUFFERS = ContextVar("MAX_KERNEL_BUFFERS", 0)
EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "")
Expand Down
11 changes: 6 additions & 5 deletions tinygrad/uop/divandmod.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools, itertools, math
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.helpers import cdiv, cmod, unwrap

# NOTE: this cache is only on index UOps
@functools.cache
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
def fold_divmod_general(d: UOp) -> UOp|None:
x, y = d.src

# cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval
Expand Down Expand Up @@ -47,7 +47,8 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
return (y2-y1)*(v-v.vmin) + y1

# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
if not (x.vmin<0 and correct_divmod_folding):
# NOTE: relies on linearity of floor division which doesn't hold for truncation, only safe when x.vmin >= 0
if x.vmin >= 0:
# when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period
rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors]
for rems in itertools.product(*rem_choices):
Expand All @@ -66,7 +67,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
if x.vmin >= 0:
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
if d.op is Ops.IDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))
else:
Expand Down Expand Up @@ -115,7 +116,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),

# ** 2. Slow Rules **
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d)),

# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
Expand Down
Loading