Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions test/external/fuzz_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad import Variable, dtypes
from tinygrad.uop.ops import UOp
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context
from tinygrad.helpers import DEBUG

seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
Expand Down Expand Up @@ -56,8 +56,7 @@ def random_bool_expr(depth=10, expr1=None):
v = [u1,u2,u3]
expr = random_int_expr(6)

with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()

solver = z3.Solver(ctx=z3.Context())
solver.set(timeout=5000) # some expressions take very long verify, but its very unlikely they actually return sat
Expand All @@ -74,10 +73,9 @@ def random_bool_expr(depth=10, expr1=None):
m = solver.model()
n1, n2, n3 = m[v1], m[v2], m[v3]
u1_val, u2_val, u3_val = u1.const_like(n1.as_long()), u2.const_like(n2.as_long()), u3.const_like(n3.as_long())
with Context(CORRECT_DIVMOD_FOLDING=1):
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
num = expr.simplify().substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
rn = expr.substitute({u1:u1_val, u2:u2_val, u3:u3_val}).ssimplify()
if num==rn: print("z3 found a mismatch but the expressions are equal!!")
assert False, f"mismatched {expr.render()} at v1={m[v1]}; v2={m[v2]}; v3={m[v3]} = {num} != {rn}\n" +\
"Reproduce with:\n" +\
f"v1=Variable(\"{u1.arg[0]}\", {u1.arg[1]}, {u1.arg[2]})\n" +\
Expand Down
5 changes: 2 additions & 3 deletions test/external/fuzz_symbolic_symbolic_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import z3
from tinygrad.uop.ops import UOp, Ops
from tinygrad.uop.validate import uops_to_z3
from tinygrad.helpers import DEBUG, Context, colored
from tinygrad.helpers import DEBUG, colored

seed = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(0, 100)
print(f"Seed: {seed}", flush=True)
Expand Down Expand Up @@ -36,8 +36,7 @@ def get_random_expr(ranges, factors):
variable_names += [f"r{i}" for i in range(num_ranges)]
expr = get_random_expr(ranges, factors)

with Context(CORRECT_DIVMOD_FOLDING=1):
simplified_expr = expr.simplify()
simplified_expr = expr.simplify()

if DEBUG>=1:
print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False))
Expand Down
1 change: 1 addition & 0 deletions test/null/test_simplify_valid_idx.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def test_simplify6(self):
load = get_load_image_uop((128, 768, 4), valid, (alu0, alu1))
self.check(load, None, "((((idx1*24)+r3)+(r5*3))+-3)", "(((idx2*2)+r4)+-1)")

@unittest.expectedFailure # TODO: fix with correct optimizations
def test_simplify7(self):
# DEBUG=2 ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 # noqa: E501
# kernel 143
Expand Down
7 changes: 0 additions & 7 deletions test/null/test_symbolic_failures.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import unittest
from tinygrad import Variable
from tinygrad.helpers import Context


class TestFuzzFailure(unittest.TestCase):
def setUp(self):
self.context = Context(CORRECT_DIVMOD_FOLDING=1)
self.context.__enter__()

def tearDown(self):
self.context.__exit__(None, None, None)

def test_fuzz_failure1(self):
v1=Variable('v1', 0, 8)
Expand Down
21 changes: 4 additions & 17 deletions test/null/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import z3

from tinygrad.dtype import dtypes, ConstType, DType, Invalid
from tinygrad.helpers import Context
from test.helpers import get_uops
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
Expand Down Expand Up @@ -441,21 +440,11 @@ def test_sum_combine_num(self):
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(usum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")

@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_div_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(b+-1)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)")

def test_div_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40])//40, -1, 9, "(((a+(b*20))+-20)//20)")

@unittest.expectedFailure # only correct for floordiv, not truncdiv
def test_mod_cancel(self):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, 0, 20, "(a*2)")

def test_mod_cancel_correct(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)")
self.helper_test_variable(usum([uconst(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) % 40, -38, 38, "((((a+(b*20))+-20)%20)*2)")

def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
Expand Down Expand Up @@ -542,8 +531,7 @@ def test_cmod_const_evaluation(self):
self.helper_test_variable((-Variable("a", 10, 10))%7, -3, -3, "-3")

def test_div_numerator_negative(self):
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -8, 0, "(((idx*10)//11)*-1)")

def test_nest_div_negative_factor(self):
ridx0=Variable("ridx0", 0, 9)
Expand Down Expand Up @@ -723,8 +711,7 @@ def test_gcd_with_remainder(self):
def test_div_by_factor_tie_break(self):
a = Variable("a", 0, 1)
b = Variable("b", 0, 1)
with Context(CORRECT_DIVMOD_FOLDING=1):
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")
self.helper_test_variable((a*2+b*3+2)//6, 0, 1, "((a+b+1)//3)")

def test_div_mod_recombine_large_coeff(self):
# recombine must work even when coeff > divisor: both mod and div reduce the coeff the same way
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def tolist(self, obj=None):
RING, ALL2ALL = ContextVar("RING", 1), ContextVar("ALL2ALL", 0)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
FUSE_OPTIM = ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
MAX_KERNEL_BUFFERS = ContextVar("MAX_KERNEL_BUFFERS", 0)
EMULATE, EMULATED_DTYPES = ContextVar("EMULATE", ""), ContextVar("EMULATED_DTYPES", "")
Expand Down
11 changes: 6 additions & 5 deletions tinygrad/uop/divandmod.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools, itertools, math
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap
from tinygrad.helpers import cdiv, cmod, unwrap

# NOTE: this cache is only on index UOps
@functools.cache
def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
def fold_divmod_general(d: UOp) -> UOp|None:
x, y = d.src

# cancel_divmod: simple cancel div/mod case when the range of the numerator lies within a single denominator interval
Expand Down Expand Up @@ -47,7 +47,8 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
return (y2-y1)*(v-v.vmin) + y1

# fold_divmod_congruence: fold if a is congruent to an expression whose range is between 0 and c
if not (x.vmin<0 and correct_divmod_folding):
# NOTE: uses floordiv linearity which only holds for non-negative numerators
if x.vmin >= 0:
# when f%c == c//2, abs(r) == abs(r-c) is a tie, try both signs since either may fit in one period
rem_choices = [(r, r-c) if (r:=f%c)*2 == c else (min(r, r-c, key=abs),) for f in factors]
for rems in itertools.product(*rem_choices):
Expand All @@ -66,7 +67,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
if x.vmin >= 0:
results = []
for div in {abs(f) for u, f in zip(uops_no_const, factors) if u.op not in (Ops.CONST, Ops.VCONST) and 1 < abs(f) < c and (c%f)==0}:
if (newxs := fold_divmod_general(x//div, correct_divmod_folding)) is not None and newxs.vmin >= 0:
if (newxs := fold_divmod_general(x//div)) is not None and newxs.vmin >= 0:
if d.op is Ops.IDIV:
results.append((len(newxs.backward_slice), newxs // (c // div)))
else:
Expand Down Expand Up @@ -115,7 +116,7 @@ def fold_divmod_general(d: UOp, correct_divmod_folding: bool) -> UOp|None:
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),

# ** 2. Slow Rules **
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d, bool(CORRECT_DIVMOD_FOLDING))),
(UPat((Ops.IDIV, Ops.MOD), dtypes.weakint, name="d"), lambda d: fold_divmod_general(d)),

# NOTE: these have to go at the bottom or TestSymbolicOps.test_var loops
(UPat.var("x", dtypes.weakint) % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <= 0 else None),
Expand Down
18 changes: 16 additions & 2 deletions tinygrad/uop/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# all of symbolic lives here now
import math, operator, struct, functools
import math, operator, struct, functools, itertools
from collections import defaultdict
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType, can_lossless_cast, Invalid
Expand Down Expand Up @@ -316,6 +316,7 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:

# simplify uop given that valid is True
all_candidates = []
simplex_groups = []
for i,(expr,v) in enumerate(bounds.items()):
v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1])
# try checking the whole clause
Expand All @@ -326,7 +327,9 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
candidates = [[all_candidates[-1]]]
if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable(f"fake{i}", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)])
simplex = [(Xi, UOp.variable(f"fake{i}", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]
candidates.append(simplex)
simplex_groups.append(simplex)

for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
Expand All @@ -338,6 +341,17 @@ def uop_given_valid(valid:UOp, uop:UOp, try_simplex=True) -> UOp:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))

# cross-product of simplex branches: when individual constraints fail, try all combinations together
if try_simplex and len(simplex_groups) > 1:
cross_results = []
for combo in itertools.product(*simplex_groups):
sub_dict = dict(combo)
newuop = uop.substitute(sub_dict)
if newuop is uop: break
cross_results.append(newuop.simplify().substitute({v:k for k,v in sub_dict.items()}).simplify())
else:
if cross_results and all_same(cross_results): uop = cross_results[0]

# try all the valids together (but only the whole expressions)
if (s_uop:=uop.substitute(sub_dict:=dict(all_candidates))) is not uop:
uop = s_uop.simplify().substitute({newX:X for X,newX in sub_dict.items()}).simplify()
Expand Down
Loading