diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8213b1cc9a012..8c33161217eda 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -742,7 +742,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, cpu, opencl, lvp] + backend: [llvm, cpu, opencl, lvp, x86] name: Linux (${{ matrix.backend }}) runs-on: ubuntu-22.04 @@ -759,7 +759,7 @@ jobs: llvm: ${{ matrix.backend == 'llvm' || matrix.backend == 'lvp' }} mesa: ${{ matrix.backend == 'lvp' && 'true' }} - name: Set env - run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'CL=1' || matrix.backend == 'lvp' && 'CPU=1\nCPU_LVP=1' }}" >> $GITHUB_ENV + run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'CL=1' || matrix.backend == 'lvp' && 'CPU=1\nCPU_LVP=1' || matrix.backend == 'x86' && 'CPU=1\nCPU_X86=1' }}" >> $GITHUB_ENV - name: Check Device.DEFAULT and print some source run: | python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT" @@ -910,7 +910,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, cpu, webgpu] + backend: [llvm, cpu, webgpu, x86] name: Windows (${{ matrix.backend }}) runs-on: windows-latest @@ -926,7 +926,7 @@ jobs: pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }} - name: Set env shell: bash - run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV + run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1' || matrix.backend == 'x86' && 'CPU=1\nCPU_X86=1' }}" >> $GITHUB_ENV - name: Run unit tests if: matrix.backend=='llvm' # test_newton_schulz hits RecursionError @@ -938,7 +938,7 @@ jobs: - name: Run pytest (${{ matrix.backend }}) shell: bash run: | - python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT" + python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU', 'X86':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT" python -m pytest -n=auto test/test_tiny.py test/backend/test_ops.py --durations=20 # ****** Compile-only Tests ****** diff --git a/test/backend/test_linearizer.py b/test/backend/test_linearizer.py index 4e7e0e108de16..fe08613cb2623 100644 --- a/test/backend/test_linearizer.py +++ b/test/backend/test_linearizer.py @@ -12,6 +12,7 @@ from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.cstyle import CUDARenderer +from tinygrad.renderer.isa.x86 import X86Renderer MOCKGPU = getenv("MOCKGPU") from tinygrad.uop.ops import print_uops # noqa: F401 # pylint: disable=unused-import @@ -376,6 +377,7 @@ def test_assign_fold(self): np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) @unittest.skipIf(MOCKGPU and isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, CUDARenderer)), "PTX indexes differently. might be ok?") + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "this will work once cast to bool becomes cmpne 0") def test_where_fold(self): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink(((1, 2), None)).pad(((1, 2), None)) diff --git a/test/backend/test_tensor_variable.py b/test/backend/test_tensor_variable.py index b05529c71c55e..528a2676d54d5 100644 --- a/test/backend/test_tensor_variable.py +++ b/test/backend/test_tensor_variable.py @@ -1,6 +1,7 @@ import unittest import numpy as np -from tinygrad import Tensor, Variable +from tinygrad import Tensor, Variable, Device +from tinygrad.renderer.isa.x86 import X86Renderer class TestTensorVariable(unittest.TestCase): def test_add_tvar(self): diff --git a/test/backend/test_uops.py b/test/backend/test_uops.py index ee95d371969c0..272ec0f40623f 100644 --- a/test/backend/test_uops.py +++ b/test/backend/test_uops.py @@ -12,6 +12,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.codegen.opt import Opt, OptOps from tinygrad.renderer.ptx import PTXRenderer +from tinygrad.renderer.isa.x86 import X86Renderer from test.helpers import to_uops_list from dataclasses import replace @@ -267,6 +268,7 @@ def test_use_cmpeq(self): self.assertNotIn(Ops.CMPNE, ops) class TestZeroRange(unittest.TestCase): + @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "range check is done at the end so 1 iter always happens, skip for now") def test_reduce_variable(self): for i in range(3,-1,-1): v = UOp.variable("i", 0, 5).bind(i) diff --git a/test/null/test_opts.py b/test/null/test_opts.py index 359441cbf1d69..fe0e96f6e7b12 100644 --- a/test/null/test_opts.py +++ b/test/null/test_opts.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, Device -from tinygrad.helpers import CPU_LLVM, CPU_LVP +from tinygrad.helpers import CPU_LLVM, CPU_LVP, CPU_X86 from tinygrad.codegen.opt import Opt, OptOps from tinygrad.engine.realize import get_program @@ -12,7 +12,7 @@ def test_opt_upcast(self): out = (a+b).contiguous(arg=opts) s = out.schedule() self.assertEqual(s[-1].ast.arg.opts_to_apply, opts) - if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP: + if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP and not CPU_X86: prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer) self.assertIn('float4', prg.src) diff --git a/test/unit/test_encodings.py b/test/unit/test_encodings.py new file mode 100644 index 0000000000000..b82e2c87c156b --- /dev/null +++ b/test/unit/test_encodings.py @@ -0,0 +1,147 @@ +import unittest +from tinygrad.uop.ops import UOp, Ops +from tinygrad.dtype import dtypes +from tinygrad.helpers import SPEC +from tinygrad.renderer.isa.x86 import X86Ops, X86Renderer, RBP, RDI, RSP, RSI, RAX, RDX, XMM, GPR, imm, def_reg + +@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec") +class TestEncodingsX86(unittest.TestCase): + # NOTE: x86 supports a single displacement as memory address and index without base memory address + # these have no use cases so they aren't supported + def encode(self, u:UOp): return X86Renderer().render([u], lower=False) + + # displacement of 0 isn't emitted + def test_base_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RDI) + # mov edi, dword ptr [rdi] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 3F")) + + # rsp/r12 require a sib byte when used as base memory address + def test_rsp_base_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RSP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RSP) + # mov esp, dword ptr [rsp] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 24 24")) + + # rbp/r13 require a displacement when used as base memory address + def test_rbp_base_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RBP), UOp(Ops.NOOP), imm(dtypes.int8, 0)), RBP) + # mov ebp, dword ptr [rbp + 0] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 6D 00")) + + # test [base + index*scale] + def test_base_index_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RDX), imm(dtypes.int8, 0)), RAX) + # mov eax, dword ptr [rax + rdx*4] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 04 90")) + + # rsp as index means no index + def test_rsp_index_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, RSP), imm(dtypes.int8, 0)), RAX) + # mov eax, dword ptr [rax] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 00")) + + # however r12 is a valid index + def test_r12_index_address(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RAX), def_reg(dtypes.int32, GPR[12]), imm(dtypes.int8, 0)), RAX) + # mov eax, dword ptr [rax + r12*4] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("42 8B 04 A0")) + + # test [base + index*scale + 8bit disp] + def test_complex_address_8bit_disp(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10)), RDI) + # mov edi, dword ptr [rdi + rsi*4 + 0xa] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B 7C B7 0A")) + + # test [base + index*scale + 32bit disp] + def test_complex_address_32bit_disp(self): + load = UOp(X86Ops.MOV, dtypes.int32, (def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10000)), RDI) + # mov edi, dword ptr [rdi + rsi*4 + 0x2710] + self.assertEqual(bytes.fromhex(self.encode(load)), bytes.fromhex("8B BC B7 10 27 00 00")) + + # 8bit variants of legacy instructions subtract 1 from opcode + def test_8bit_legacy_encoding(self): + cast = UOp(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDX),), RAX) + # movsx eax, dl + self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("0F BE C2")) + + # accessing lower 8 bits of rsp, rbp, rsi, rdi requires rex prefix + def test_lower_8bits_reg(self): + cast = UOp(X86Ops.MOVSX, dtypes.int32, (def_reg(dtypes.int8, RDI),), RAX) + # movsx eax, dil + self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("40 0F BE C7")) + + # test 16 bit variant of legacy instruction + def test_16bit_legacy_encoding(self): + cast = UOp(X86Ops.MOVSX, dtypes.int16, (def_reg(dtypes.int8, RDX),), RAX) + # movsx ax, dl + self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("66 0F BE C2")) + + # test 64 bit variant of legacy instruction + def test_64bit_legacy_encoding(self): + cast = UOp(X86Ops.MOVSX, dtypes.int64, (def_reg(dtypes.int8, RDX),), RAX) + # movsx rax, dl + self.assertEqual(bytes.fromhex(self.encode(cast)), bytes.fromhex("48 0F BE C2")) + + # test compact vex encoding + def test_compact_vex_encoding(self): + xmm0, xmm1 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]) + add = UOp(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm1), XMM[0]) + # vaddss xmm0, xmm0, xmm1 + self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FA 58 C1")) + + # test long vex encoding + def test_long_vex_encoding(self): + xmm0, xmm8 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[8]) + add = UOp(X86Ops.VADDSS, dtypes.float32, (xmm0, xmm8), XMM[0]) + # vaddss xmm0, xmm0, xmm8 + self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C4 C1 7A 58 C0")) + + # test ymm encoding + def test_ymm_encoding(self): + xmm0, xmm1 = def_reg(dtypes.float32.vec(8), XMM[0]), def_reg(dtypes.float32.vec(8), XMM[1]) + add = UOp(X86Ops.VADDPS, dtypes.float32.vec(8), (xmm0, xmm1), XMM[0]) + # vaddps ymm0, ymm0, ymm1 + self.assertEqual(bytes.fromhex(self.encode(add)), bytes.fromhex("C5 FC 58 C1")) + + # test encoding where register is in the immediate field + def test_reg_in_imm_field(self): + xmm0, xmm1, xmm2 = def_reg(dtypes.float32, XMM[0]), def_reg(dtypes.float32, XMM[1]), def_reg(dtypes.float32, XMM[2]) + blend = UOp(X86Ops.VBLENDVPS, dtypes.float32, (xmm0, xmm1, xmm2), XMM[0]) + # vblendvps xmm0, xmm0, xmm1, xmm2 + self.assertEqual(bytes.fromhex(self.encode(blend)), bytes.fromhex("C4 E3 79 4A C1 20")) + + # when writting to mem the uop takes the store form where dtype is void and there's no definition + def test_write_mem(self): + base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10) + xmm0 = def_reg(dtypes.float32, XMM[0]) + extr = UOp(X86Ops.VPEXTRD, dtypes.void, (base, index, disp, xmm0, imm(dtypes.uint8, 0))) + # vpextrd dword ptr [rdi + rsi*4 + 0xa], xmm0, 0 + self.assertEqual(bytes.fromhex(self.encode(extr)), bytes.fromhex("C4 E3 79 16 44 B7 0A 00")) + + # test two address instruction with fused load works + def test_two_address_load(self): + base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int8, 10) + cmove = UOp(X86Ops.CMOVE, dtypes.int32, (base, index, disp), RAX) + # cmove eax, dword ptr [rdi + rsi*4 + 0xa] + self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 44 B7 0A")) + + # test instruction where displacement and imm have the same value + def test_disp_imm_same_value(self): + base, index, disp = def_reg(dtypes.int8.ptr(), RDI), def_reg(dtypes.int8, RSI), imm(dtypes.int8, 10) + mov = UOp(X86Ops.MOVi, dtypes.void, (base, index, disp, disp)) + # mov byte ptr [rdi + rsi + 0xa], 0xa + self.assertEqual(bytes.fromhex(self.encode(mov)), bytes.fromhex("40 C6 44 37 0A 0A")) + + base, index, disp = def_reg(dtypes.int32.ptr(), RDI), def_reg(dtypes.int32, RSI), imm(dtypes.int32, 10) + imul = UOp(X86Ops.IMULi, dtypes.int32, (base, index, disp) + (imm(dtypes.int32, 10),), RDI) + # imul edi, dword ptr [rdi + rsi*4 + 0xa], 0xa + self.assertEqual(bytes.fromhex(self.encode(imul)), bytes.fromhex("69 BC B7 0A 00 00 00 0A 00 00 00")) + + # cmoves have the cmp as the last src even though it is not explicitly used, the cmp doesn't define a reg and is ignored in the encoding + def test_cmove_ignore_cmp(self): + cmove = UOp(X86Ops.CMOVE, dtypes.int32, (def_reg(dtypes.int32, RAX), UOp(X86Ops.CMP)), RDX) + # cmove edx, eax + self.assertEqual(bytes.fromhex(self.encode(cmove)), bytes.fromhex("0F 44 D0")) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/unit/test_isel.py b/test/unit/test_isel.py new file mode 100644 index 0000000000000..457bac50f802e --- /dev/null +++ b/test/unit/test_isel.py @@ -0,0 +1,113 @@ +import unittest +from tinygrad.uop import Ops +from tinygrad.uop.ops import UOp, dtypes, graph_rewrite +from tinygrad.renderer.isa import X86Ops +from tinygrad.renderer.isa.x86 import X86Renderer +from tinygrad.renderer.isa.isa import IselContext, Register +from tinygrad.helpers import SPEC + +@unittest.skipIf(SPEC > 1, "x86 spec not supported in full_spec") +class TestIselX86(unittest.TestCase): + def isel_rewrite(self, x:UOp): return graph_rewrite(x, X86Renderer().isel_matcher, IselContext(x), bottom_up=True) + + def test_cmove(self): + a = UOp.variable("a", 0, 0, dtypes.int32) + b = UOp.variable("b", 0, 0, dtypes.int32) + c = (a < b).where(a, b) + d = (a != b).where(a, b) + f = c + d + n = self.isel_rewrite(f) + self.assertTrue(n.src[0].op is X86Ops.CMOVL and n.src[1].op is X86Ops.CMOVNE) + # both comparisons become the same instruction + self.assertTrue(n.src[0].src[2] == n.src[1].src[2] and n.src[0].src[2].op is X86Ops.CMP) + + def test_cmove_and_blend_with_float_cmp(self): + a = UOp.variable("a", 0, 0, dtypes.float32) + b = UOp.variable("b", 0, 0, dtypes.float32) + c = a < b + d = c.where(a.cast(dtypes.int32), b.cast(dtypes.int32)) + e = c.where(a, b) + f = d + e + n = self.isel_rewrite(f) + # the comparison instruction depends on the user, int cmove uses flag while float cmove uses mask + # so both flag producing and mask producing comparisons must be present + self.assertTrue(n.src[0].op is X86Ops.CMOVB and n.src[0].src[2].op is X86Ops.VUCOMISS) + self.assertTrue(n.src[1].op is X86Ops.VBLENDVPS and n.src[1].src[2].op is X86Ops.VCMPSS and n.src[1].src[2].src[2].arg == 1) + + # lower 2 32 bits must come from the same register and upper 2 32 bits must come from the same register + def test_vshufps(self): + a = UOp.variable("a", 0, 0, dtypes.float32.vec(4)) + b = UOp.variable("b", 0, 0, dtypes.float32.vec(4)) + c = UOp.variable("c", 0, 0, dtypes.float32) + d = UOp.variable("d", 0, 0, dtypes.float32) + # shuffle between 2 vectors + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), b.gep(2), b.gep(3)))) + self.assertTrue(n.op is X86Ops.VSHUFPS) + # shuffle between 2 scalars + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (c, c, d, d))) + self.assertTrue(n.op is X86Ops.VSHUFPS) + # shuffle between vector and scalar + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), a.gep(1), c, c))) + self.assertTrue(n.op is X86Ops.VSHUFPS) + # shuffle between 1 vector + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(1), a.gep(2), a.gep(3), a.gep(0)))) + self.assertTrue(n.op is X86Ops.VSHUFPS and n.src[0] is n.src[1]) + # a shuffle between 1 scalar is just a broadcast and matches X86Ops.VBROADCASTSS to allow for load fusion + + # this is the fallback slow VECTORIZE, 1 vinsertps per src in VECTORIZE + def test_vinsertps(self): + a = UOp.variable("a", 0, 0, dtypes.float32.vec(4)) + b = UOp.variable("b", 0, 0, dtypes.float32.vec(4)) + c = UOp.variable("c", 0, 0, dtypes.float32.vec(4)) + d = UOp.variable("e", 0, 0, dtypes.float32) + # pack 1 from vector and 1 from scalar, moving 0th element to position 0 does nothing so only 1 vinsertps is generated + n = self.isel_rewrite(UOp(Ops.VECTORIZE, dtypes.float32.vec(2), (a.gep(0), d))) + self.assertTrue(n.op is X86Ops.VINSERTPS and n.src[0].op is X86Ops.DEFINE_REG) + # interleaved shuffle between 2 vectors + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(0), b.gep(1), a.gep(2), b.gep(3)))) + self.assertTrue(n.op is X86Ops.VINSERTPS) + # shuffle between 4 sources + n = self.isel_rewrite(UOp(Ops.VECTORIZE, a.dtype, (a.gep(3), b.gep(2), c.gep(1), d))) + self.assertTrue(n.op is X86Ops.VINSERTPS) + + # complex address is [base + index*scale + displacement] + def test_complex_address(self): + a = UOp.variable("a", 0, 0, dtypes.int32) + load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(a + 1, ptr=True).load() + n = self.isel_rewrite(load) + # base is PARAM, index is "a" + self.assertTrue(n.src[0].op is X86Ops.DEFINE_REG and n.src[1].op is X86Ops.DEFINE_REG) + # displacement is the constant in "a" scaled to the buffer element size, dtype is int8 when the value fits otherwise int32 + self.assertTrue(n.src[2].op is X86Ops.IMM and n.src[2].dtype is dtypes.int8 and n.src[2].arg == 4) + + def test_fuse_load(self): + load1 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() + load2 = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 1), ptr=True).load() + n = self.isel_rewrite(load1 + load2) + self.assertTrue(len(n.src) == 4) + + # don't fuse when used multiple times + def test_dont_fuse_load_diff_users(self): + load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() + add = load + 1 + n = self.isel_rewrite(add + load) + self.assertTrue(len(n.src) == 2) + + def test_dont_fuse_load_same_user(self): + load = UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0).index(UOp.const(dtypes.int32, 0), ptr=True).load() + n = self.isel_rewrite(load * load) + self.assertTrue(len(n.src) == 2) + + # test noop has same reg as src, this is because noops aren't instructions but still need to be part of the graph + # as they may have different dtype from src and the correct dtype is required to encode the correct instruction + # by giving them the same reg as src we ensure they share the same live range + @unittest.skip("hmmm") + def test_noop(self): + noop = UOp(Ops.NOOP, dtypes.int32, (UOp(Ops.PARAM, dtypes.int32.ptr(), arg=0),)) + n = self.isel_rewrite(noop) + self.assertTrue(isinstance(n.arg, Register) and n.arg == n.src[0].arg) + + # TODO: might want to check that load isn't part of another range when fusing + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/unit/test_x86op_values.py b/test/unit/test_x86op_values.py new file mode 100644 index 0000000000000..15a8486c2d079 --- /dev/null +++ b/test/unit/test_x86op_values.py @@ -0,0 +1,25 @@ +import unittest +from tinygrad.uop import Ops, GroupOp +from tinygrad.renderer.isa import X86Ops, X86GroupOp + +class TestX86OpValues(unittest.TestCase): + def test_values(self): + # ADD is added in X86Ops + assert X86Ops.ADD != Ops.ADD + assert X86Ops.ADD is not Ops.ADD + assert Ops.ADD not in X86GroupOp.All + assert X86Ops.ADD not in GroupOp.All + assert X86Ops.ADD in X86GroupOp.All + assert not isinstance(Ops.ADD, X86Ops) + assert isinstance(X86Ops.ADD, Ops) + assert isinstance(X86Ops.ADD, X86Ops) + # SINK is not added in X86Ops, this is now possible but behavior doesn't change + assert X86Ops.SINK == Ops.SINK + assert X86Ops.SINK is Ops.SINK + assert X86Ops.SINK in GroupOp.All + assert X86Ops.SINK not in X86GroupOp.All + assert not isinstance(X86Ops.SINK, X86Ops) + assert max(op.value for op in GroupOp.All) + 1 == min(op.value for op in X86GroupOp.All) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index d1496b793e7f5..e7aa66cbed1c4 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -120,12 +120,12 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - ]) # requires lst be toposorted. like graph rewrite, but for lines -def line_rewrite(lst:list[UOp], pm:PatternMatcher) -> list[UOp]: +def line_rewrite(lst:list[UOp], pm:PatternMatcher, ctx=None) -> list[UOp]: newlst = [] replaced: dict[UOp, UOp] = {} for u in lst: - nu = u.replace(src=tuple([replaced[x] for x in u.src])) - ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu)) or (nu, [nu]) + nu = u.replace(src=tuple([replaced.get(x, x) for x in u.src])) + ret: tuple[UOp, list[UOp]] = cast(tuple[UOp, list[UOp]]|None, pm.rewrite(nu, ctx)) or (nu, [nu]) replaced[u] = ret[0] newlst.extend(ret[1]) return newlst diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index dd1baa7e6a60d..c1e54601a2f80 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -5,7 +5,7 @@ from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate -from tinygrad.helpers import getenv, flatten, AMX, prod, ceildiv, IMAGE +from tinygrad.helpers import getenv, flatten, AMX, CPU_X86, prod, ceildiv, IMAGE from tinygrad.renderer import Renderer # ***** image load valid simplification ***** @@ -156,6 +156,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): pass elif isinstance(buf.dtype, ImageDType): lengths = [4] + elif ctx is not None and CPU_X86: + lengths = [4,2] if buf.dtype.base == dtypes.float32 else [] + #must_divide = False elif ctx is not None and ctx.supports_float4: # TODO: a better way to get this than ctx lengths = [8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]) diff --git a/tinygrad/codegen/late/regalloc.py b/tinygrad/codegen/late/regalloc.py new file mode 100644 index 0000000000000..e18c0325fef8d --- /dev/null +++ b/tinygrad/codegen/late/regalloc.py @@ -0,0 +1,146 @@ +from __future__ import annotations +import itertools +from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat +from tinygrad.renderer.isa import X86GroupOp +from tinygrad.dtype import dtypes, DType, PtrDType +from dataclasses import dataclass, field + +@dataclass(frozen=True) +class Register: + name: str + index: int + cons: tuple[Register, ...] = field(default_factory=tuple) + + def __str__(self): return self.name + +# loosely based on: https://bernsteinbear.com/assets/img/register-spilling-range-splitting-ssa.pdf +class RegallocContext: + def __init__(self, uops:list[UOp], isel:PatternMatcher, stack_ptr:UOp, stack_size:int=0): + self.live_range: dict[Register, list[int]] = {} + self.live: dict[Register, Register] = {} + self.spills: dict[Register, UOp] = {} + self.rewrite_to_vreg: dict[UOp, Register] = {} + self.vreg_to_rewrite: dict[Register, UOp] = {} + self.live_ins: list[dict[Register, Register]] = [] + self.idx = itertools.count() + self.isel = isel + self.stack_ptr = stack_ptr + self.stack_size = stack_size + # compute live ranges + lr, ranges = self.live_range, [] + for i,u in enumerate(reversed(uops)): + if u.op in (Ops.NOOP, Ops.AFTER): continue + for v in {s.arg for s in (u,) + u.src if isinstance(s.arg, Register)}: lr.setdefault(v, []).insert(0, len(uops) - 1 - i) + # a var defined before a range and used inside it is needed for the whole range + if u.arg in lr and (n:=max((lr[rng][-1] for rng in ranges if lr[rng][0] < lr[u.arg][-1] < lr[rng][-1]), default=None)): lr[u.arg].append(n) + if u.op is Ops.RANGE: ranges.append(u.arg) + +# TODO: rm pointers +# nasty hacks to deal with pointers +def assign(ctx:RegallocContext, x:UOp, reg:Register): + dt = dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype + ret = ctx.isel.rewrite(UOp(Ops.ASSIGN, dt, (x,), reg)) + assert ret is not None + return ret.replace(dtype=x.dtype) +def load(ctx:RegallocContext, dt:DType, disp:UOp, reg:Register): + ndt = dtypes.uint64 if isinstance(dt, PtrDType) else dt + ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).load(dtype=ndt, arg=reg)) + assert ret is not None + return ret.replace(dtype=dt) +def store(ctx:RegallocContext, disp:UOp, x:UOp): + nx = x.replace(dtype=dtypes.uint64 if isinstance(x.dtype, PtrDType) else x.dtype) + ret = ctx.isel.rewrite(ctx.stack_ptr.index(disp).store(nx)) + assert ret is not None + return ret.replace(src=(s if s is not nx else x for s in ret.src)) + +def alloc(ctx:RegallocContext, cons:tuple[Register, ...], i:int) -> Register: + live_inv = {v:k for k,v in ctx.live.items()} + # allocate the best register. Registers not in live or not used again are free and have priority, + # otherwise pick the one with the furthest next use. Regs that appear first in cons have priority in case of a tie + reg,vreg = max(((r,live_inv.get(r)) for r in cons), + key=lambda rv: next((j-i for j in ([] if rv[1] is None else ctx.live_range[rv[1]]) if j >= i), float('inf'))) + if vreg is not None and vreg not in ctx.spills and ctx.live_range[vreg][-1] >= i: + sz = ctx.vreg_to_rewrite[vreg].dtype.itemsize if not isinstance(ctx.vreg_to_rewrite[vreg].dtype, PtrDType) else 8 + assert sz > 0 + offset = ctx.stack_size + (sz - ctx.stack_size % sz) % sz + ctx.spills[vreg] = UOp.const(dtypes.int32, offset) + ctx.stack_size = offset + sz + return ctx.live.pop(vreg) if vreg is not None else reg + +def regalloc(ctx:RegallocContext, x:UOp, i:int) -> tuple[UOp, list[UOp]]: + nsrc, loads = [], [] + for s in x.src: + # allocate srcs, if src was spilled it's replaced by a load, if it's live the load was already emited otherwise alloc and emit one + if isinstance(s.arg, Register) and (v:=ctx.rewrite_to_vreg[s]) in ctx.spills: + # TODO: the constraints only apply to the definition, you need to insert moves in the graph to "cleanse" the constraint + # then those moves are removed after regalloc if they move to the same register. I think this is the llvm approach + # alternatively you could beef up the register class to include constraints on the srcs, then you check those here + if v not in ctx.live: + ctx.live[v] = alloc(ctx, v.cons or (v,), i) + s = load(ctx, s.dtype, ctx.spills[v], ctx.live[v]) + loads.append(s) + else: s = load(ctx, s.dtype, ctx.spills[v], ctx.live[v]) + nsrc.append(s) + # allocate destination + if isinstance(v:=x.arg, Register) and v not in ctx.live: + # if no cons it's a real register, so it can only be assigned to itself + cons = v.cons or (v,) + # two address instructions (src is used in dest) can only coalesce reused src. reused src goes first to get priority in case of a tiebreak + # TODO: make this backend independent + if x.op in X86GroupOp.TwoAddress1st: + ins = tuple(ctx.live.get(ctx.rewrite_to_vreg[s]) for s in x.src) + cons = ((ins[0],) if ins[0] in cons else ()) + tuple(r for r in cons if r not in ins) + assert cons + ctx.live[v] = alloc(ctx, cons, i+1) + + nx = x.replace(src=tuple(nsrc), arg=ctx.live.get(v, v)) + # TODO: this check exists because of a hack in x86, rm once multiple outputs are supported + if nx not in ctx.rewrite_to_vreg: ctx.rewrite_to_vreg[nx] = v + if v not in ctx.vreg_to_rewrite: ctx.vreg_to_rewrite[v] = nx + return nx, loads + [nx] + +# move uops to registers before the loop to avoid loading inside the loop +def loop_prologue(ctx:RegallocContext, x:UOp, i:int): + assert isinstance(x.arg, Register) + nx, lst = regalloc(ctx, x, i) + # we move to register vars used in the loop sorted by next use, vars not used in the loop will not be reloaded in the epilogue + used_in_loop = [v for v in ctx.live.keys() | ctx.spills.keys() if any(i <= l < ctx.live_range[x.arg][-1] for l in ctx.live_range[v])] + sorted_uses = sorted(used_in_loop, key=lambda k: next(l-i for l in ctx.live_range[k] if l >= i)) + live_in: dict[Register, Register] = {} + loads = [] + for v in sorted_uses: + # if all the possible registers are already in live_in there's no space for this var + if set(v.cons or (v,)).issubset(live_in.values()): continue + if v not in ctx.live: + ctx.live[v] = alloc(ctx, v.cons or (v,), i) + s = ctx.vreg_to_rewrite[v] + loads.append(load(ctx, s.dtype, ctx.spills[v], ctx.live[v])) + assert ctx.live[v] not in live_in.values() + live_in |= {v: ctx.live[v]} + ctx.live_ins.append(live_in) + return nx, loads + lst + +# reload registers that were live at loop entry +def loop_epilogue(ctx:RegallocContext, x:UOp, i:int): + # TODO: if a uop is in a different reg in live out vs live in move between registers instead of loading + # TODO: don't reload if first use in loop is a load + loads = [] + for k,v in ctx.live_ins.pop().items(): + if k not in ctx.live or ctx.live[k] != v: + ctx.live[k] = alloc(ctx, (v,), i) + s = ctx.vreg_to_rewrite[k] + loads.append(load(ctx, s.dtype, ctx.spills[k], ctx.live[k])) + return x, loads + [x] + +pm_regalloc = PatternMatcher([ + (UPat(Ops.RANGE, name="x"), lambda ctx,x: loop_prologue(ctx, x, next(ctx.idx))), + (UPat(Ops.END, name="x"), lambda ctx,x: loop_epilogue(ctx, x, next(ctx.idx))), + (UPat(X86GroupOp.All | {Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER}, name="x"), lambda ctx,x: regalloc(ctx, x, next(ctx.idx))), +]) + +# annoying that this is another pm +pm_insert_spills = PatternMatcher([ + # insert spill after definition + (UPat(X86GroupOp.All | {Ops.RANGE}, name="x"), lambda ctx,x: + (x, [x, store(ctx, y, x)]) if (y:=ctx.spills.get(ctx.rewrite_to_vreg.get(x))) is not None else None), +]) diff --git a/tinygrad/device.py b/tinygrad/device.py index 88238cfaf243d..0066d285c5f4f 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar -from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK +from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, CPU_X86, NV_PTX, CUDA_PTX, NV_NAK from tinygrad.helpers import EMULATED_DTYPES, TracingKey from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -349,7 +349,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: if device == "METAL": return not CI if device == "CUDA": return not CI and not CUDA_PTX if device == "NV": return not CI and not NV_PTX and not NV_NAK - if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP + if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not CPU_LVP and not CPU_X86 return device in {"AMD", "CL", "PYTHON", "NULL"} if dtype in dtypes.fp8s: if device == "CUDA": return not CI and not CUDA_PTX diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e7d5673613c9f..7ca84fc39f705 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -187,7 +187,7 @@ def tolist(self, obj=None): CAPTURE_PROCESS_REPLAY = ContextVar("CAPTURE_PROCESS_REPLAY", 0) CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1))) # Compilers -CPU_CC, CPU_LLVM, CPU_LVP = ContextVar("CPU_CC", ""), ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0) +CPU_CC, CPU_LLVM, CPU_LVP, CPU_X86 = ContextVar("CPU_CC", ""), ContextVar("CPU_LLVM", 0), ContextVar("CPU_LVP", 0), ContextVar("CPU_X86", 0) NV_CC, NV_PTX, NV_NAK, NV_NVCC = ContextVar("NV_CC", ""), ContextVar("NV_PTX", 0), ContextVar("NV_NAK", 0), ContextVar("NV_NVCC", 0) CUDA_CC, CUDA_PTX, CUDA_NVCC = ContextVar("CUDA_CC", ""), ContextVar("CUDA_PTX", 0), ContextVar("CUDA_NVCC", 0) NULL_IR3, NULL_NAK, NULL_ALLOW_COPYOUT = ContextVar("NULL_IR3", 0), ContextVar("NULL_NAK", 0), ContextVar("NULL_ALLOW_COPYOUT", 0) diff --git a/tinygrad/renderer/isa/__init__.py b/tinygrad/renderer/isa/__init__.py new file mode 100644 index 0000000000000..1de9607c40154 --- /dev/null +++ b/tinygrad/renderer/isa/__init__.py @@ -0,0 +1,125 @@ +# flake8: noqa: E702 +# allow semicolons to put multiple ops on one line +from tinygrad.uop.ops import Ops, auto + +# ***** X86 ***** + +# NOTE: mypy doesn't allow extending enums even with our wrapper, it also doesn't allow overriding i.e. Ops.ADD to X86Ops.ADD +# we ignore it in both cases +class X86Ops(Ops): # type: ignore[misc] + # NOTE: X86Ops with i suffix are variants that take an immediate, m suffix are variants that can write to memory instead of read from + # register, not an instruction. FRAME_INDEX is used when the function arg is on the stack and is rewritten to IMM when stack size is known + DEFINE_REG = auto(); FRAME_INDEX = auto() # type: ignore[misc] + # const + IMM = auto() + # index + LEA = auto() + # register / memory / immediate moves + MOV = auto(); MOVm = auto(); MOVi = auto(); MOVABS = auto() + VMOVSS = auto(); VMOVSD = auto(); VMOVUPS = auto() + VMOVSSm = auto(); VMOVSDm = auto(); VMOVUPSm = auto() + # casts + MOVZX = auto(); MOVSX = auto(); MOVSXD = auto() + VPMOVZXBW = auto(); VPMOVZXBD = auto(); VPMOVZXBQ = auto() + VPMOVZXWD = auto(); VPMOVZXWQ = auto(); VPMOVZXDQ = auto() + VPMOVSXBW = auto(); VPMOVSXBD = auto(); VPMOVSXBQ = auto() + VPMOVSXWD = auto(); VPMOVSXWQ = auto(); VPMOVSXDQ = auto() + VCVTDQ2PS = auto(); VCVTDQ2PD = auto(); VCVTTPS2DQ = auto(); VCVTTPD2DQ = auto() + VCVTPH2PS = auto(); VCVTPS2PH = auto(); VCVTPS2PD = auto(); VCVTPD2PS = auto() + VCVTSS2SD = auto(); VCVTSD2SS = auto(); VCVTSI2SS = auto(); VCVTSI2SD = auto() + VCVTTSS2SI = auto(); VCVTTSD2SI = auto() + # bitcasts + VMOVD = auto(); VMOVQ = auto(); VMOVDm = auto(); VMOVQm = auto() + # comparisons + VUCOMISS = auto(); VUCOMISD = auto() + VCMPSS = auto(); VCMPSD = auto(); VCMPPS = auto(); VCMPPD = auto() + VPCMPGTB = auto(); VPCMPGTW = auto(); VPCMPGTD = auto(); VPCMPGTQ = auto() + VPCMPEQB = auto(); VPCMPEQW = auto(); VPCMPEQD = auto(); VPCMPEQQ = auto() + SETNE = auto(); SETE = auto(); SETL = auto(); SETB = auto() + # where + CMOVNE = auto(); CMOVE = auto(); CMOVL = auto(); CMOVB = auto() + VPBLENDVB = auto(); VBLENDVPS = auto(); VBLENDVPD = auto() + # jumps + JNE = auto(); JE = auto(); JL = auto(); JB = auto() + # vectorize / gep + VSHUFPS = auto(); VINSERTPS = auto() + VPEXTRB = auto(); VPEXTRW = auto(); VPEXTRD = auto(); VPEXTRQ = auto() + VPINSRB = auto(); VPINSRW = auto(); VPINSRD = auto(); VPINSRQ = auto() + VPBROADCASTB = auto(); VPBROADCASTW = auto(); VPBROADCASTD = auto(); VPBROADCASTQ = auto() + VBROADCASTSS = auto() # TODO: VBROADCASTSD is ymm only, add once they are supported + # int binary + IDIV = auto(); DIV = auto() # type: ignore[misc] + ADD = auto(); ADDi = auto(); SUB = auto(); SUBi = auto(); IMUL = auto(); IMULi = auto() # type: ignore[misc] + AND = auto(); ANDi = auto(); XOR = auto(); XORi = auto(); OR = auto(); ORi = auto() # type: ignore[misc] + SHL = auto(); SHLi = auto(); SHR = auto(); SHRi = auto(); SAR = auto(); SARi = auto(); CMP = auto(); CMPi = auto() # type: ignore[misc] + # float unary (sometimes not unary) + VROUNDSS = auto(); VROUNDSD = auto(); VROUNDPS = auto(); VROUNDPD = auto() + VSQRTSS = auto(); VSQRTSD = auto(); VSQRTPS = auto(); VSQRTPD = auto() + # float scalar / vector binary + VADDSS = auto(); VADDSD = auto(); VADDPS = auto(); VADDPD = auto() + VSUBSS = auto(); VSUBSD = auto(); VSUBPS = auto(); VSUBPD = auto() + VMULSS = auto(); VMULSD = auto(); VMULPS = auto(); VMULPD = auto() + VDIVSS = auto(); VDIVSD = auto(); VDIVPS = auto(); VDIVPD = auto() + VMAXSS = auto(); VMAXSD = auto(); VMAXPS = auto(); VMAXPD = auto() + VMINSS = auto(); VMINSD = auto(); VMINPS = auto(); VMINPD = auto() + # int vector binary + VPADDB = auto(); VPADDW = auto(); VPADDD = auto(); VPADDQ = auto() + VPSUBB = auto(); VPSUBW = auto(); VPSUBD = auto(); VPSUBQ = auto() + VPMULLW = auto(); VPMULLD = auto() + # packed bitwise TODO: might also want vandp cause of different execution ports + VPAND = auto(); VPOR = auto(); VPXOR = auto() + # packed variable shifts + VPSLLVD = auto(); VPSLLVQ = auto(); VPSRLVD = auto(); VPSRLVQ = auto(); VPSRAVD = auto() + # fused multiply add TODO: add other variants to fuse more loads + VFMADD213SS = auto(); VFMADD213SD = auto(); VFMADD213PS = auto(); VFMADD213PD = auto() + # return + RET = auto() + +# TODO: add commutative groupop to fuse more loads +class X86GroupOp: + # X86Ops whose first src is also the destination + TwoAddress1st = {X86Ops.ADD, X86Ops.ADDi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, X86Ops.OR, X86Ops.ORi, X86Ops.IMUL, + X86Ops.SUB, X86Ops.SUBi, X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, + X86Ops.IDIV, X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD, + X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB} + + # X86Ops whose first src can read from memory + ReadMem1st = {X86Ops.MOV, X86Ops.VMOVSS, X86Ops.VMOVSD, X86Ops.VMOVUPS, X86Ops.MOVZX, X86Ops.MOVSX, X86Ops.MOVSXD, X86Ops.VMOVD, X86Ops.VMOVQ, + X86Ops.VPMOVZXBW, X86Ops.VPMOVZXBD, X86Ops.VPMOVZXBQ, X86Ops.VPMOVZXWD, X86Ops.VPMOVZXWQ, X86Ops.VPMOVZXDQ, + X86Ops.VPMOVSXBW, X86Ops.VPMOVSXBD, X86Ops.VPMOVSXBQ, X86Ops.VPMOVSXWD, X86Ops.VPMOVSXWQ, X86Ops.VPMOVSXDQ, + X86Ops.VCVTDQ2PS, X86Ops.VCVTDQ2PD, X86Ops.VCVTTPS2DQ, X86Ops.VCVTTPD2DQ, X86Ops.VCVTTSS2SI, X86Ops.VCVTTSD2SI, + X86Ops.VCVTPH2PS, X86Ops.VCVTPS2PD, X86Ops.VCVTPD2PS, X86Ops.VROUNDPS, X86Ops.VROUNDPD, X86Ops.VSQRTPS, X86Ops.VSQRTPD, + X86Ops.VPBROADCASTB, X86Ops.VPBROADCASTW, X86Ops.VPBROADCASTD, X86Ops.VPBROADCASTQ, X86Ops.VBROADCASTSS, + X86Ops.CMPi, X86Ops.IMULi, X86Ops.DIV, X86Ops.LEA} + + # X86Ops whose second src can read from memory NOTE: some of these are TwoAddress1st so the second src is actually the first + ReadMem2nd = {X86Ops.ADD, X86Ops.SUB, X86Ops.AND, X86Ops.OR, X86Ops.XOR, X86Ops.SHL, X86Ops.SHR, X86Ops.SAR, X86Ops.IMUL, X86Ops.CMP, + X86Ops.VADDSS, X86Ops.VADDSD, X86Ops.VADDPS, X86Ops.VADDPD, X86Ops.VSUBSS, X86Ops.VSUBSD, X86Ops.VSUBPS, X86Ops.VSUBPD, + X86Ops.VMULSS, X86Ops.VMULSD, X86Ops.VMULPS, X86Ops.VMULPD, X86Ops.VDIVSS, X86Ops.VDIVSD, X86Ops.VDIVPS, X86Ops.VDIVPD, + X86Ops.VPADDB, X86Ops.VPADDW, X86Ops.VPADDD, X86Ops.VPADDQ, X86Ops.VPSUBB, X86Ops.VPSUBW, X86Ops.VPSUBD, X86Ops.VPSUBQ, + X86Ops.VPCMPEQB, X86Ops.VPCMPEQW, X86Ops.VPCMPEQD, X86Ops.VPCMPEQQ, X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD, + X86Ops.VPCMPGTB, X86Ops.VPCMPGTW, X86Ops.VPCMPGTD, X86Ops.VPCMPGTQ, X86Ops.VCMPSS, X86Ops.VCMPSD, X86Ops.VCMPPS, X86Ops.VCMPPD, + X86Ops.VPMULLW, X86Ops.VPMULLD, X86Ops.VROUNDSS, X86Ops.VROUNDSD, X86Ops.VSQRTSS, X86Ops.VSQRTSD, X86Ops.VSHUFPS, X86Ops.VINSERTPS, + X86Ops.VPINSRB, X86Ops.VPINSRW, X86Ops.VPINSRD, X86Ops.VPINSRQ, X86Ops.VPAND, X86Ops.VPOR, X86Ops.VPXOR, X86Ops.VPSLLVD, + X86Ops.VPSLLVQ, X86Ops.VPSRLVD, X86Ops.VPSRLVQ, X86Ops.VPSRAVD, X86Ops.CMOVNE, X86Ops.CMOVE, X86Ops.CMOVL, X86Ops.CMOVB, + X86Ops.VMAXSS, X86Ops.VMAXSD, X86Ops.VMAXPS, X86Ops.VMAXPD, X86Ops.VMINSS, X86Ops.VMINSD, X86Ops.VMINPS, X86Ops.VMINPD, + X86Ops.VCVTSI2SS, X86Ops.VCVTSI2SD, X86Ops.VCVTSS2SD, X86Ops.VCVTSD2SS, X86Ops.VUCOMISS, X86Ops.VUCOMISD, X86Ops.IDIV} + + # X86Ops whose third src can read from memory NOTE: these are TwoAddress1st so the third src is actually the second + ReadMem3rd = {X86Ops.VFMADD213SS, X86Ops.VFMADD213SD, X86Ops.VFMADD213PS, X86Ops.VFMADD213PD} + + # X86Ops that can write to memory + WriteMem = {X86Ops.MOVm, X86Ops.MOVi, X86Ops.VMOVSSm, X86Ops.VMOVSDm, X86Ops.VMOVUPSm, X86Ops.VMOVDm, X86Ops.VMOVQm, + X86Ops.ADDi, X86Ops.SUBi, X86Ops.ANDi, X86Ops.ORi, X86Ops.XORi, X86Ops.SHLi, X86Ops.SHRi, X86Ops.SARi, X86Ops.SETNE, + X86Ops.SETE, X86Ops.SETL, X86Ops.SETB, X86Ops.VCVTPS2PH, X86Ops.VPEXTRB, X86Ops.VPEXTRW, X86Ops.VPEXTRD, X86Ops.VPEXTRQ} + + # X86Ops that read flags + ReadFlags = {X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE, X86Ops.SETB, X86Ops.SETL, X86Ops.SETE, X86Ops.SETNE, X86Ops.JB, X86Ops.JL, + X86Ops.JE, X86Ops.JNE} + + # X86Ops that write flags or can modify flags to undefined values + WriteFlags = {X86Ops.CMP, X86Ops.CMPi, X86Ops.ADD, X86Ops.ADDi, X86Ops.SUB, X86Ops.SUBi, X86Ops.IMUL, X86Ops.IMULi, X86Ops.IDIV, X86Ops.DIV, + X86Ops.SHL, X86Ops.SHLi, X86Ops.SHR, X86Ops.SHRi, X86Ops.SAR, X86Ops.SARi, X86Ops.AND, X86Ops.ANDi, X86Ops.XOR, X86Ops.XORi, + X86Ops.OR, X86Ops.ORi, X86Ops.VUCOMISS, X86Ops.VUCOMISD} + + All = set(X86Ops) diff --git a/tinygrad/renderer/isa/isa.py b/tinygrad/renderer/isa/isa.py new file mode 100644 index 0000000000000..7911708c7d1c5 --- /dev/null +++ b/tinygrad/renderer/isa/isa.py @@ -0,0 +1,131 @@ +import itertools, heapq +from typing import Any +from collections import defaultdict +from tinygrad.renderer import Renderer +from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, UPat, Ops +from tinygrad.codegen import line_rewrite +from tinygrad.codegen.late.regalloc import RegallocContext, pm_regalloc, pm_insert_spills, Register +from tinygrad.uop.spec import type_verify +from tinygrad.helpers import SPEC, DEBUG, prod + +def print_uop_asm(uops:list[UOp]): + for i,u in enumerate(uops): + formatted_srcs = [f"{x.arg}" for x in u.src if x.arg is not None] + print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):40s} " f"{str(u.arg):32s} {str(formatted_srcs)}") + +class IselContext: + def __init__(self, sink:UOp): + self.uses = sink.get_consumer_map() + self.reg_n = itertools.count() + self.stack_size = 0 + arg_order = {Ops.PARAM: 0, Ops.DEFINE_VAR: 1, Ops.SPECIAL: 2} + self.func_args = sorted([u for u in self.uses if u.op in arg_order], key=lambda k: (arg_order[k.op], k.arg)) + + def inc_stack(self, amt:int): + ret = self.stack_size + self.stack_size += amt + return ret + + def vreg(self, cons:tuple[Register, ...]|Register|None=None): + return Register(f"v{next(self.reg_n)}", 0, cons=cons if isinstance(cons, tuple) else (cons,) if cons is not None else ()) + +isel_fixup = PatternMatcher([ + # NOOP / AFTER have the same register as first src + (UPat((Ops.NOOP, Ops.AFTER), name="x"), lambda x: x.replace(arg=x.src[0].arg) if x.src and x.arg is None else None), +]) + +# TODO: this will eventually be a proper scheduler +def isa_linearize(sink:UOp) -> list[UOp]: + from tinygrad.renderer.isa.x86 import RSP, X86Ops, X86GroupOp + # this is a toposort with priority + lst = list(sink.toposort()) + out_degree:defaultdict[UOp, int] = defaultdict(int) + priorities:dict[UOp, tuple[int, int]] = {} + + # get consumers and assign priorities + # NOTE: this requires the lst be locally toposorted + for u in reversed(lst): + for s in u.src: out_degree[s] += 1 + + # we place UOps with higher run_counts later + run_count = prod([int(r.vmax)+1 for r in u.ranges]) + + # simple priority override. this is all bottom up now, smaller numbers will be closer to the top + match u.op: + case Ops.RANGE: priority = 5 # placing RANGE is good + case Ops.END: priority = -5 # placing END is bad + # stack pointer needs to be scheduled at the top of the kernel + case X86Ops.DEFINE_REG: priority = -21 if u.arg == RSP else -20 + case X86Ops.IMM: priority = -10 + case _: priority = 0 # everything else has priority 0 + priorities[u] = (run_count, priority) + + # number the uops in "ideal" order + nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: priorities[x]))} + + # then force them to be toposorted in as close to the ideal order as possible + heap = [(-nkey[sink], sink)] + newlst = [] + lock: UOp|None = None + stupid: int = 0 + clobbers: set[UOp] = set() + while heap or clobbers: + # if heap is empty we have a cycle and the flag producer must be rematerialized + # we schedule the flag producer and free the clobbers + if not heap: + assert lock is not None and clobbers + newlst.append(lock) + for c in clobbers: heapq.heappush(heap, (-nkey[c],c)) + clobbers.clear() + lock, stupid = None, 0 + + u = heapq.heappop(heap)[1] + + # flags introduce state that must be dealt with, can't overwrite the flag until all its users and producer are scheduled + if lock is not None: + # if this is the flag producer we free the flag clobbers and release the lock + if lock is u: + for c in clobbers: heapq.heappush(heap, (-nkey[c],c)) + clobbers.clear() + lock, stupid = None, 0 + # if this is the user of or is another flag producer it can't be scheduled + # if this is a loop boundry or has a lower run count than the flag user that introduced the lock we also don't schedule + # loop boundries do clobber but we also don't want to insert stuff from outside the loop into the loop + # if there's no loop we also don't want to add IMM and DEFINE_REG in the middle of the kernel + elif u.op in X86GroupOp.ReadFlags and lock is not u.src[-1] or u.op in X86GroupOp.WriteFlags or \ + u.op in {Ops.RANGE, Ops.END, X86Ops.IMM, X86Ops.DEFINE_REG} or priorities[u][0] < stupid: + clobbers.add(u) + continue + # if there's no lock and this is a flag user its flag producer becomes the lock + elif u.op in X86GroupOp.ReadFlags: lock, stupid = u.src[-1], priorities[u][0] + + newlst.append(u) + + for v in u.src: + out_degree[v] -= 1 + if out_degree[v] == 0: heapq.heappush(heap, (-nkey[v],v)) + return newlst[::-1] + +class ISARenderer(Renderer): + isa_spec: PatternMatcher + pre_isel_matcher: PatternMatcher + isel_matcher: PatternMatcher + post_regalloc_matcher: PatternMatcher + + def stack_pointer(self) -> UOp: raise NotImplementedError("arch specific") + # TODO: these should go with the other rewrites after we know what to do with ProgramSpec and Estimates + def lower(self, sink:UOp): + sink = graph_rewrite(sink, self.pre_isel_matcher, name="pre instruction selection", bottom_up=True) + isel_ctx = IselContext(sink) + sink = graph_rewrite(sink, self.isel_matcher, ctx=isel_ctx, name="instruction selection", bottom_up=True) + # TODO: remove, annoying needed for noops + sink = graph_rewrite(sink, isel_fixup, name="instruction selection fixup") + lst = isa_linearize(sink) + if DEBUG >= 8: print_uop_asm(lst) + regalloc_ctx = RegallocContext(lst, self.isel_matcher, self.stack_pointer(), isel_ctx.stack_size) + lst = line_rewrite(lst, pm_regalloc, regalloc_ctx) + lst = line_rewrite(lst, pm_insert_spills, regalloc_ctx) + lst = line_rewrite(lst, self.post_regalloc_matcher, regalloc_ctx) + if DEBUG >= 7: print_uop_asm(lst) + if SPEC: type_verify(lst, self.isa_spec) + return lst diff --git a/tinygrad/renderer/isa/x86.py b/tinygrad/renderer/isa/x86.py new file mode 100644 index 0000000000000..892025bbd6ef3 --- /dev/null +++ b/tinygrad/renderer/isa/x86.py @@ -0,0 +1,672 @@ +import sys, struct, functools +from typing import cast +from tinygrad.dtype import dtypes, PtrDType, DType, truncate +from tinygrad.uop.ops import Ops, GroupOp, UOp, UPat, PatternMatcher +from tinygrad.renderer.isa import X86Ops, X86GroupOp +from tinygrad.renderer.isa.isa import ISARenderer, IselContext +from tinygrad.codegen.late.regalloc import Register, assign +from tinygrad.helpers import getenv, CPU_COUNT + +# ***** X86 legalization ***** + +extra_matcher = PatternMatcher([ + # bool CMPNE is XOR, bool CMPEQ is XOR+XOR, bool CMPLT is XOR+AND + # TODO: how does this work for vector dtypes? + (UPat.var('x', dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), + (UPat.var('x', dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True), + (UPat.var('x', dtypes.bool)> 1).cast(dtypes.int64).cast(x.dtype) * 2 + (y & 1).cast(dtypes.int64).cast(x.dtype)), + # TODO: these should be removed once max is canonicalized + # no max for scalar ints + (UPat(Ops.MAX, dtypes.ints, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0]) if m.dtype.count == 1 else None), + # even with Ops.MAX in decompositions this pattern still hits + ((UPat.var("a") < UPat.var("b")).where(UPat.var("b", dtypes.floats), UPat.var("a")), lambda a,b: UOp(Ops.MAX, b.dtype, (a, b))), + # no int8 mul or cmove, cast to int16 + (UPat.var("a", dtypes.int8s) * UPat.var("b"), lambda a,b: (a.cast(dtypes.int16) * b.cast(dtypes.int16)).cast(a.dtype)), + (UPat.var("m").where(UPat.var("a", (dtypes.bool,)+dtypes.int8s), UPat.var("b")), + lambda m,a,b: m.where(a.cast(dtypes.int16), b.cast(dtypes.int16)).cast(a.dtype) if a.dtype.count == 1 else None), + # float16 alus are done in float32 + (UPat(GroupOp.ALU, dtypes.float16, name="x"), lambda x: UOp(x.op, dtypes.float.vec(x.dtype.count), + tuple(s.cast(dtypes.float) if s.dtype != dtypes.bool else s for s in x.src)).cast(x.dtype)), + (UPat(GroupOp.Comparison, src=(UPat.var("a", dtypes.float16), UPat.var("b")), name="x"), + lambda x,a,b: UOp(x.op, x.dtype, (a.cast(dtypes.float32), b.cast(dtypes.float32))).cast(x.dtype)), + # no cmpne for packed ints, y != x => !(y==x) + (UPat(Ops.CMPNE, src=(UPat.var("y", dtypes.ints), UPat.var("x")), name="cmp"), + lambda y,x,cmp: UOp(Ops.CMPEQ, cmp.dtype, (y,x))^True if y.dtype.count > 1 else None), + # float where expects a mask TODO: handle float64 cmp to float32 where + (UPat.var("m", dtypes.bool).where(UPat.var("a", dtypes.floats), UPat.var("b")), + lambda m,a,b: m.cast(a.dtype).ne(0).where(a, b) if m.src[0].dtype not in dtypes.floats else None), + # TODO: do we want this? Kinda not needed if DEVECTORIZE=0. If yes make it general + (UPat(Ops.VECTORIZE, dtypes.float16, name="x"), lambda x: x.replace(dtype=dtypes.float32.vec(x.dtype.count), + src=tuple(s.src[0] for s in x.src)).cast(x.dtype) if all(s.op is Ops.CAST for s in x.src) else None), +]) + +# ***** X86 pre instruction selection ***** + +# these must be done in a separate matcher because they violate the spec +pre_isel_matcher = PatternMatcher([ + # cast from pointer is a noop + (UPat.var("y").cast(name="x"), lambda y,x: x.replace(op=Ops.NOOP) if isinstance(y.dtype, PtrDType) else None), + # zero extending scalar 32bit int is a noop + (UPat.var("y", dtypes.uint32).cast(dtypes.int64s, name="x"), lambda y,x: x.replace(op=Ops.NOOP) if y.dtype.count == 1 else None), + # cast between signed and unsigned int is a noop + (UPat.var("y", dtypes.ints+(dtypes.bool,)).cast(dtypes.ints, name="x"), + lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize == y.dtype.itemsize else None), + # cast to < scalar int is a noop + (UPat.var("y", dtypes.ints).cast(dtypes.ints, name="x"), + lambda y,x: x.replace(op=Ops.NOOP) if x.dtype.itemsize < y.dtype.itemsize and y.dtype.count == 1 else None), + # bitcasts between scalar floats and ints are real, rest are noops + (UPat.var("y").bitcast().named("x"), lambda y,x: None if y.dtype in dtypes.floats and x.dtype in dtypes.ints or \ + y.dtype in dtypes.ints and x.dtype in dtypes.floats else x.replace(op=Ops.NOOP)), + # noop of a noop is removed + (UPat(Ops.NOOP, src=(UPat(Ops.NOOP),), name="x"), lambda x: x.replace(src=x.src[0].src)), + # moving elements of a single register to another without shuffling is a noop + (UPat(Ops.VECTORIZE, src=(UPat.var("y"),), allow_any_len=True, name="x"), + lambda y,x: UOp(Ops.NOOP, x.dtype, y.src) if all(s.op is Ops.GEP and s.src == y.src and s.arg[0] == i for i,s in enumerate(x.src)) else None), + # gated index becomes a conditional move on the index, the load/store are unconditional + (UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).load(UPat.var("alt"), name="x"), lambda base,idx,gate,alt,x: + gate.where(base.index(idx, ptr=True), (l:=UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(x.dtype.count), arg=0)).after(l.store(alt)) + .index(UOp.const(dtypes.int32, 0), ptr=True)).load(dtype=x.dtype)), + (UPat.var("base").index(UPat.var("idx"), UPat.var("gate")).store(UPat.var("val")), lambda base,idx,gate,val: + gate.where(base.index(idx, ptr=True), UOp(Ops.DEFINE_LOCAL, base.dtype.base.ptr(val.dtype.count), arg=0) + .index(UOp.const(dtypes.int32, 0), ptr=True)).store(val)), + # TODO: remove this once we allow all flag producing ops in cmove + # if gate in scalar int cmove is not a comparison need to add one to set the flag + (UPat.var("m", dtypes.bool).where(UPat.var("a"), UPat.var("b")), + lambda m,a,b: m.ne(0).where(a,b) if m.op not in GroupOp.Comparison and a.dtype.count == 1 else None), +]) + +# ***** X86 registers ***** + +RAX = Register("rax", 0) # {4:"eax", 2:"ax", 1:"al"} +RCX = Register("rcx", 1) # {4:"ecx", 2:"cx", 1:"cl"} +RDX = Register("rdx", 2) # {4:"edx", 2:"dx", 1:"dl"} +RBX = Register("rbx", 3) # {4:"ebx", 2:"bx", 1:"bl"} +RSP = Register("rsp", 4) # {4:"esp", 2:"sp", 1:"spl"} +RBP = Register("rbp", 5) # {4:"ebp", 2:"bp", 1:"bpl"} +RSI = Register("rsi", 6) # {4:"esi", 2:"si", 1:"sil"} +RDI = Register("rdi", 7) # {4:"edi", 2:"di", 1:"dil"} +GPR = (RAX, RCX, RDX, RBX, RSP, RBP, RSI, RDI) + tuple(Register(f"r{i}", i) for i in range(8, 16)) +XMM = tuple(Register(f"ymm{i}", i) for i in range(16)) +# gprs you can write to +WGPR = tuple(r for r in GPR if r != RSP) + +# ***** X86 instruction selection ***** + +def def_reg(dt:DType, reg:Register|None=None) -> UOp: return UOp(X86Ops.DEFINE_REG, dt, arg=reg) +def imm(dt:DType, v:int|float) -> UOp: return UOp(X86Ops.IMM, dt, arg=v) +def to_imm(c:UOp) -> UOp|None: + if c.op is not Ops.CONST: return None + if c.dtype is dtypes.int64: return imm(dtypes.int32, c.arg) if not c.overflows(dtypes.int32) else None + if c.dtype is dtypes.uint64: return imm(dtypes.uint32, c.arg) if not c.overflows(dtypes.uint32) else None + if c.dtype in dtypes.ints+(dtypes.bool,): return imm(c.dtype, c.arg) + return None +def cmp(x:UOp) -> UOp: + if x.src[0].dtype is dtypes.float32: return UOp(X86Ops.VUCOMISS, src=x.src) + if x.src[0].dtype is dtypes.float64: return UOp(X86Ops.VUCOMISD, src=x.src) + return UOp(X86Ops.CMP, src=x.src) if (i:=to_imm(x.src[1])) is None else UOp(X86Ops.CMPi, src=(x.src[0], i)) + +# vshufps xmm2, xmm0, xmm1, imm +# xmm2 selects its lower 2 32 bits from xmm0 and its upper 2 32 bits from xmm1 according to imm +def vshufps(x:UOp) -> UOp|None: + def _in(i:int) -> UOp: return s.src[0] if (s:=x.src[i]).op is Ops.GEP else s + if len(x.src) != 4 or _in(0) is not _in(1) or _in(2) is not _in(3): return None + return UOp(X86Ops.VSHUFPS, x.dtype, (_in(0), _in(2), + imm(dtypes.uint8, sum(s.arg[0] << 2*i if s.op is Ops.GEP else 0 for i,s in enumerate(x.src))))) + +# vinsertps xmm2, xmm0, xmm1, imm +# inserts any 32 bit element in xmm1 into any position in xmm0 according to immm, result is written to xmm2 +# this is the fallback slow case for when you can't match more a powerful shuffle +def vinsertps(x:UOp) -> UOp: + def _insert(ret:UOp, i:int) -> UOp: + s, v = x.src[i], 0 + if s.op is Ops.GEP: s, v = s.src[0], s.arg[0] + # moving the 0th element into the 0th position does nothing + return s if i == v == 0 else UOp(X86Ops.VINSERTPS, x.dtype, (ret, s, imm(dtypes.uint8, v << 6 | i << 4))) + return functools.reduce(_insert, range(len(x.src)), def_reg(x.dtype)) + +# vpinsq xmm2, xmm0, rax, imm +# inserts element in rax into any position in xmm0, result is written to xmm2 according to imm +def vpins(x:UOp) -> UOp: + op = {1: X86Ops.VPINSRB, 2: X86Ops.VPINSRW, 4: X86Ops.VPINSRD, 8: X86Ops.VPINSRQ}[x.dtype.scalar().itemsize] + return functools.reduce(lambda ret,i: UOp(op, x.dtype, (ret, x.src[i], imm(dtypes.uint8, i))), range(len(x.src)), def_reg(x.dtype)) + +def div(ctx:IselContext, x:UOp): + # zero extend or move src[0] to x + move1 = UOp(X86Ops.MOV, x.dtype, (x.src[0],), ctx.vreg(RAX)) + zero = UOp(X86Ops.MOVi, x.dtype, (imm(min(dtypes.uint32, x.dtype), 0),), ctx.vreg(RDX)) + move2 = UOp(X86Ops.MOV, x.dtype, (x.src[1],), ctx.vreg(tuple(r for r in WGPR if r not in (RAX, RDX)))) + div = UOp(X86Ops.DIV, x.dtype, (move2, zero, move1), ctx.vreg(RAX)) + return UOp(X86Ops.MOV, x.dtype, (div,)) + +# TODO: you don't want to call ctx.vreg here because it can duplicate instructions, you instead assign the tuple of valid registers +# for the instruction and a rewrite will add the vreg, this ensures a duplicate isn't created. +# However vreg(RDX) is assigned here because IDIV also writes to RDX and regalloc isn't aware of that, +# the correct fix is to model IDIV as multi output (RAX, RDX) so regalloc is aware of RDX being overwritten and rm vreg from here +def idiv(ctx:IselContext, x:UOp): + ext = UOp(X86Ops.MOVSX, dtypes.int16, (x.src[0],), ctx.vreg(RAX)) if x.dtype is dtypes.int8 else \ + UOp(X86Ops.SARi, x.dtype, (x.src[0], imm(dtypes.uint8, x.dtype.itemsize * 8 - 1)), ctx.vreg(RDX)) + move = UOp(X86Ops.MOV, x.dtype, (x.src[1],), tuple(r for r in WGPR if r not in (RAX, RDX))) + idiv = UOp(X86Ops.IDIV, x.dtype, (x.src[0], move, ext), (RAX,)) + # this move "cleanses" the register constraint (rax) of idiv, this is because the constraint only applies on definition and not on the uses of idiv + return UOp(X86Ops.MOV, x.dtype, (idiv,)) + +def fuse_address(x:UOp) -> tuple[UOp, ...]: + def _disp(v:int) -> UOp: return imm(dtypes.int32 if abs(v) > dtypes.max(dtypes.int8) else dtypes.int8, v) + def _cast(v:UOp) -> UOp: return v.cast(dtypes.int64) if v.vmin < 0 else v + if x.op is not Ops.INDEX: return (x, UOp(Ops.NOOP), _disp(0)) + base, idx = x.src + disp_scale = base.dtype.itemsize if isinstance(base.dtype, PtrDType) else 1 + if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: return (base, _cast(idx.src[0]), _disp(idx.src[1].arg * disp_scale)) + if idx.op is Ops.CONST: return (base, UOp(Ops.NOOP), _disp(idx.arg * disp_scale)) + return (base, _cast(idx), _disp(0)) + +def fuse_load(ctx:IselContext, x:UOp, i:int) -> UOp|None: + # if the load is used multiple times we don't fuse + return x.replace(src=x.src[:i] + fuse_address(x.src[i].src[0]) + x.src[i+1:]) if len(ctx.uses[x.src[i]]) == x.src.count(x.src[i]) == 1 else None + +def abi(ctx:IselContext, x:UOp): + i = ctx.func_args.index(x) + def _stack_arg(disp:int): + return UOp(X86Ops.MOV, x.dtype, (def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), UOp(X86Ops.FRAME_INDEX, dtypes.int32, arg=disp))) + if sys.platform == "win32": return def_reg(x.dtype, (RCX, RDX, GPR[8], GPR[9])[i]) if i < 4 else _stack_arg((i-3)*8+32) + return def_reg(x.dtype, (RDI, RSI, RDX, RCX, GPR[8], GPR[9])[i]) if i < 6 else _stack_arg((i-5)*8) + +dts = dtypes.ints + (dtypes.bool, dtypes.float16, dtypes.float32, dtypes.float64) +dt_16bit = tuple(dt.vec(l) for dt in dts for l in [2,1] if dt.vec(l).itemsize == 2 and dt.vec(l) not in dtypes.int16s) +dt_32bit = tuple(dt.vec(l) for dt in dts for l in [4,2,1] if dt.vec(l).itemsize == 4 and dt.vec(l) not in dtypes.int32s) +dt_64bit = tuple(dt.vec(l) for dt in dts for l in [8,4,2,1] if dt.vec(l).itemsize == 8 and dt.vec(l) not in dtypes.int64s) +dt_128bit = tuple(dt.vec(l) for dt in dts for l in [16,8,4,2,1] if dt.vec(l).itemsize == 16) + +isel_matcher = PatternMatcher([ + # **** Op -> Op **** + # rewrite -x -> 0 - x, this is done here because NEG is useful for MIN + (UPat(Ops.NEG, name="x"), lambda x: UOp(Ops.SUB, x.dtype, (x.const_like(0),) + x.src)), + # TODO: RANGE and END is tricky. Both linearizer and regalloc need them so they stay as Ops and get rewritten post regalloc + # control flow ops need a refactor in general + (UPat(Ops.RANGE, src=(UPat.cvar("c"),), allow_any_len=True, name="x"), lambda c,x: x.replace(src=(imm(c.dtype, c.arg),) + x.src[1:])), + (UPat(Ops.RANGE, name="x"), lambda ctx,x: x.replace(arg=ctx.vreg(WGPR)) if not isinstance(x.arg, Register) else None), + # **** Op -> X86Op **** + # add callee saved registers to the RET, these will be scheduled at the top of the kernel and will be saved/restored if they are used in regalloc + # so regalloc builds the prologue/epilogue naturally + (UPat(Ops.SINK, name="x"), lambda x: x.replace(op=X86Ops.RET, src=x.src + tuple(def_reg(dtypes.uint64, r) for r in [RSP, RBP]))), + # function abi constraints + (UPat((Ops.PARAM, Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), abi), + # these are treated the same for now + (UPat((Ops.DEFINE_REG, Ops.DEFINE_LOCAL), name="x"), lambda ctx,x: + x.replace(op=X86Ops.LEA, src=(def_reg(dtypes.uint64, RSP), UOp(Ops.NOOP), imm(dtypes.int32, ctx.inc_stack(x.dtype.nbytes()))), arg=None)), + # constants that can't be immediates, move them to registers + (UPat(Ops.CONST, dtypes.float16, name="x"), lambda x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), UOp(X86Ops.MOVi, dtypes.int16, (imm(x.dtype, x.arg),)), imm(dtypes.uint8, 0)))), # noqa: E501 + (UPat(Ops.CONST, dtypes.float32, name="x"), lambda x: UOp(X86Ops.VMOVD, x.dtype, (UOp(X86Ops.MOVi, dtypes.int32, (imm(x.dtype, x.arg),)),))), + (UPat(Ops.CONST, dtypes.float64, name="x"), lambda x: UOp(X86Ops.VMOVQ, x.dtype, (UOp(X86Ops.MOVABS, dtypes.int64, (imm(x.dtype, x.arg),)),))), + (UPat(Ops.CONST, dtypes.int64s, name="x"), lambda x: UOp(X86Ops.MOVABS, x.dtype, (imm(x.dtype, x.arg),))), + (UPat(Ops.CONST, dtypes.ints+(dtypes.bool,), name="x"), lambda x: UOp(X86Ops.MOVi, x.dtype, (imm(x.dtype, x.arg),))), + # conditional moves that use masks NOTE: these currently assume a mask producing cmp exists + (UPat.var("m").where(UPat.var("a", dtypes.ints), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VPBLENDVB, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype))) if a.dtype.count > 1 else None), # noqa: E501 + (UPat.var("m").where(UPat.var("a", dtypes.float32), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPS, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501 + (UPat.var("m").where(UPat.var("a", dtypes.float64), UPat.var("b")), lambda m,a,b: UOp(X86Ops.VBLENDVPD, a.dtype, (b, a, m.replace(dtype=m.src[0].dtype)))), # noqa: E501 + # in this case we have a mask producing comparison whose user expects a bool, so we convert to bool + (UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int32).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501 + (UPat(GroupOp.Comparison, dtypes.bool, (UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(dtype=x.src[0].dtype).bitcast(dtypes.int64).bitwise_and(1).f(Ops.NOOP, dtype=dtypes.bool)), # noqa: E501 + # conditional moves that use flags + (UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.sints), UPat()), name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVL, a.dtype, src=(b, a, cmp(m)))), # noqa: E501 + (UPat(Ops.CMPLT, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVB, a.dtype, src=(b, a, cmp(m)))), + (UPat(Ops.CMPEQ, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVE, a.dtype, src=(b, a, cmp(m)))), + (UPat(Ops.CMPNE, name="m").where(UPat.var("a"), UPat.var("b")), lambda m,a,b: UOp(X86Ops.CMOVNE, a.dtype, src=(b, a, cmp(m)))), + # jumps, use flags + (UPat(Ops.IF, src=(UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="y"),), name="x"), lambda y,x: UOp(X86Ops.JB, x.dtype, (cmp(y),))), # noqa: E501 + (UPat(Ops.IF, src=(UPat(Ops.CMPLT, name="y"),)), lambda y: UOp(X86Ops.JL, src=(cmp(y),))), + (UPat(Ops.IF, src=(UPat(Ops.CMPEQ, name="y"),)), lambda y: UOp(X86Ops.JE, src=(cmp(y),))), + (UPat(Ops.IF, src=(UPat(Ops.CMPNE, name="y"),)), lambda y: UOp(X86Ops.JNE, src=(cmp(y),))), + # comparisons whose user doesn't use the flag, move flag result to register + (UPat(Ops.CMPLT, dtypes.bool, (UPat(dtype=dtypes.uints), UPat()), name="x"), lambda x: UOp(X86Ops.SETB, x.dtype, (cmp(x),))), + (UPat(Ops.CMPLT, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETL, x.dtype, (cmp(x),))), + (UPat(Ops.CMPEQ, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETE, x.dtype, (cmp(x),))), + (UPat(Ops.CMPNE, dtypes.bool, name="x"), lambda x: UOp(X86Ops.SETNE, x.dtype, (cmp(x),))), + # comparisons that produce masks (these aren't bool dtype) + (UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501 + (UPat(Ops.CMPLT, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 1),))), # noqa: E501 + (UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501 + (UPat(Ops.CMPNE, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 4),))), # noqa: E501 + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float32), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSS if x.dtype.count == 1 else X86Ops.VCMPPS, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501 + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.float64), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VCMPSD if x.dtype.count == 1 else X86Ops.VCMPPD, src=x.src + (imm(dtypes.uint8, 0),))), # noqa: E501 + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int8s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQB)), + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int16s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQW)), + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int32s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQD)), + (UPat(Ops.CMPEQ, src=(UPat(dtype=dtypes.int64s), UPat()), name="x"), lambda x: x.replace(op=X86Ops.VPCMPEQQ)), + (UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int8s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTB, src=(b, a))), + (UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int16s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTW, src=(b, a))), + (UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int32s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTD, src=(b, a))), + (UPat(Ops.CMPLT, src=(UPat.var("a", dtypes.int64s), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.VPCMPGTQ, src=(b, a))), + # float unary + (UPat.var("y", dtypes.float32).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSS, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPS)), # noqa: E501 + (UPat.var("y", dtypes.float64).sqrt().named("x"), lambda y,x: UOp(X86Ops.VSQRTSD, x.dtype, (y, y)) if x.dtype.count == 1 else x.replace(op=X86Ops.VSQRTPD)), # noqa: E501 + (UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSS, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501 + (UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDSD, x.dtype, (y, y, imm(dtypes.uint8, 3))) if x.dtype.count == 1 else None), # noqa: E501 + (UPat.var("y", dtypes.float32).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPS, x.dtype, (y, imm(dtypes.uint8, 3)))), + (UPat.var("y", dtypes.float64).trunc().named("x"), lambda y,x: UOp(X86Ops.VROUNDPD, x.dtype, (y, imm(dtypes.uint8, 3)))), + # broadcasts TODO: not quite right, what about load fusion? Also, bitcast should be x86op and reg is xmm? + (UPat.var("y", dtypes.int8s+(dtypes.bool,)).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTB, x.dtype, (y.bitcast(dtypes.float32),))), + (UPat.var("y", dtypes.int16s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTW, x.dtype, (y.bitcast(dtypes.float32),))), + (UPat.var("y", dtypes.int32s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTD, x.dtype, (y.bitcast(dtypes.float32),))), + (UPat.var("y", dtypes.int64s).broadcast(name="x"), lambda y,x: UOp(X86Ops.VPBROADCASTQ, x.dtype, (y.bitcast(dtypes.float64),))), + (UPat.var("y", dtypes.float32).broadcast(name="x"), lambda y,x: UOp(X86Ops.VBROADCASTSS, x.dtype, (y,))), + # shufles + (UPat.var("y", dtypes.int16s).bitcast(dtypes.float16).named("x"), lambda y,x: UOp(X86Ops.VPINSRW, x.dtype, (def_reg(x.dtype), y, imm(dtypes.uint8, 0)))), # noqa: E501 + (UPat(Ops.VECTORIZE, dtypes.ints+(dtypes.bool,), name="x"), vpins), + (UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vshufps), + (UPat(Ops.VECTORIZE, dtypes.float32, name="x"), vinsertps), + (UPat.var("y", dtypes.float32).gep(name="x"), lambda y,x: UOp(X86Ops.VINSERTPS, x.dtype, (y, y, imm(dtypes.uint8, x.arg[0] << 6)))), + # extract + (UPat.var("y", dtypes.float16).bitcast(dtypes.int16s).named("x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, 0)))), + (UPat.var("y", dtypes.int8s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRB, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))), + (UPat.var("y", dtypes.int16s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRW, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))), + (UPat.var("y", dtypes.int32s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRD, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))), + (UPat.var("y", dtypes.int64s).gep(name="x"), lambda y,x: UOp(X86Ops.VPEXTRQ, x.dtype, (y, imm(dtypes.uint8, x.arg[0])))), + # fused multiply add TODO: don't fuse if mul used several times + (UPat.var('a', dtypes.float32) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SS if a.dtype.count == 1 else X86Ops.VFMADD213PS, b, c)), # noqa: E501 + (UPat.var('a', dtypes.float64) * UPat.var('b') + UPat.var('c'), lambda a,b,c: a.alu(X86Ops.VFMADD213SD if a.dtype.count == 1 else X86Ops.VFMADD213PD, b, c)), # noqa: E501 + # packed bitwise + ((UPat() & UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPAND) if x.dtype.count > 1 else None), + ((UPat() | UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPOR) if x.dtype.count > 1 else None), + ((UPat() ^ UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPXOR) if x.dtype.count > 1 else None), + # packed int binary + ((UPat(dtype=dtypes.int32s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVD) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int64s) << UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSLLVQ) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.uint32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVD) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.uint64) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRLVQ) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int32) >> UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPSRAVD) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int8s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDB) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int16s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDW) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int32s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDD) if x.dtype.count > 1 else None), + ((UPat(dtype=dtypes.int64s) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VPADDQ) if x.dtype.count > 1 else None), + (UPat(Ops.SUB, dtypes.int8s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBB) if x.dtype.count > 1 else None), + (UPat(Ops.SUB, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBW) if x.dtype.count > 1 else None), + (UPat(Ops.SUB, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBD) if x.dtype.count > 1 else None), + (UPat(Ops.SUB, dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPSUBQ) if x.dtype.count > 1 else None), + (UPat(Ops.MUL, dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLW) if x.dtype.count > 1 else None), + (UPat(Ops.MUL, dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMULLD) if x.dtype.count > 1 else None), + # scalar int binary + ((UPat(dtype=dtypes.uints) // UPat()).named("x"), div), + ((UPat(dtype=dtypes.sints) // UPat()).named("x"), idiv), + ((UPat.var("a", dtypes.ints) << UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHLi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHL)), # noqa: E501 + ((UPat.var("a", dtypes.uints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SHRi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501 + ((UPat.var("a", dtypes.sints) >> UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.SARi, src=(a, imm(dtypes.uint8, b.arg))) if b.op is Ops.CONST else x.replace(op=X86Ops.SHR)), # noqa: E501 + ((UPat.var("a", dtypes.ints+(dtypes.bool,)) & UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.AND) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ANDi, src=(a, i))), # noqa: E501 + ((UPat.var("a", dtypes.ints+(dtypes.bool,)) | UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.OR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ORi, src=(a, i))), # noqa: E501 + ((UPat.var("a", dtypes.ints+(dtypes.bool,)) ^ UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.XOR) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.XORi, src=(a, i))), # noqa: E501 + ((UPat.var("a", dtypes.ints) * UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.IMUL) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.IMULi, src=(a, i))), # noqa: E501 + ((UPat.var("a", dtypes.ints) + UPat.var("b")).named("x"), lambda a,b,x: x.replace(op=X86Ops.ADD) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.ADDi, src=(a, i))), # noqa: E501 + (UPat(Ops.SUB, dtypes.ints, (UPat.var("a"), UPat.var("b")), name="x"), lambda a,b,x: x.replace(op=X86Ops.SUB) if (i:=to_imm(b)) is None else x.replace(op=X86Ops.SUBi, src=(a, i))), # noqa: E501 + # float binary + ((UPat(dtype=dtypes.float32) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSS if x.dtype.count == 1 else X86Ops.VADDPS)), + ((UPat(dtype=dtypes.float64) + UPat()).named("x"), lambda x: x.replace(op=X86Ops.VADDSD if x.dtype.count == 1 else X86Ops.VADDPD)), + ((UPat(dtype=dtypes.float32) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSS if x.dtype.count == 1 else X86Ops.VMULPS)), + ((UPat(dtype=dtypes.float64) * UPat()).named("x"), lambda x: x.replace(op=X86Ops.VMULSD if x.dtype.count == 1 else X86Ops.VMULPD)), + (UPat(Ops.SUB, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VSUBSS if x.dtype.count == 1 else X86Ops.VSUBPS)), + (UPat(Ops.SUB, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VSUBSD if x.dtype.count == 1 else X86Ops.VSUBPD)), + (UPat(Ops.FDIV, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VDIVSS if x.dtype.count == 1 else X86Ops.VDIVPS)), + (UPat(Ops.FDIV, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VDIVSD if x.dtype.count == 1 else X86Ops.VDIVPD)), + # TODO: these should use a.maximum(b) / a.minimum(b) + (UPat(Ops.MAX, dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VMAXSS if x.dtype.count == 1 else X86Ops.VMAXPS)), + (UPat(Ops.MAX, dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VMAXSD if x.dtype.count == 1 else X86Ops.VMAXPD)), + ((UPat.var("a", dtypes.float32) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSS if a.dtype.count == 1 else X86Ops.VMINPS, a.dtype, (a, b))), # noqa: E501 + ((UPat.var("a", dtypes.float64) < UPat.var("b")).where(UPat.var("a"), UPat.var("b")), lambda a,b: UOp(X86Ops.VMINSD if a.dtype.count == 1 else X86Ops.VMINPD, a.dtype, (a, b))), # noqa: E501 + # casts + (UPat(dtype=dtypes.int32).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PS) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTDQ2PD) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.float32).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPS2DQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.float64).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTPD2DQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.float32).cast(dtypes.float64, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PD) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.float64).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPD2PS) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.float32).cast(dtypes.float16, name="x"), lambda x: x.replace(op=X86Ops.VCVTPS2PH, src=x.src + (imm(dtypes.uint8, 4),))), + (UPat(dtype=dtypes.float16).cast(dtypes.float32, name="x"), lambda x: x.replace(op=X86Ops.VCVTPH2PS)), + (UPat(dtype=dtypes.float32).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSS2SI)), + (UPat(dtype=dtypes.float64).cast(dtypes.int32s+dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VCVTTSD2SI)), + (UPat.var("y", dtypes.float32).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSS2SD, src=(y, y))), + (UPat.var("y", dtypes.float64).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSD2SS, src=(y, y))), + (UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float32, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SS, src=(def_reg(x.dtype), y))), # noqa: E501 + (UPat.var("y", (dtypes.int32, dtypes.int64)).cast(dtypes.float64, name="x"), lambda y,x: x.replace(op=X86Ops.VCVTSI2SD, src=(def_reg(x.dtype), y))), # noqa: E501 + (UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBW) if x.dtype.count > 1 else None), + (UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBD) if x.dtype.count > 1 else None), + (UPat(dtype=(dtypes.uint8, dtypes.bool)).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXBQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.uint16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWD) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.uint16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXWQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.uint32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVZXDQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int8).cast(dtypes.int16s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBW) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int8).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBD) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int8).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXBQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int16).cast(dtypes.int32s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWD) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int16).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXWQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.VPMOVSXDQ) if x.dtype.count > 1 else None), + (UPat(dtype=dtypes.uints+(dtypes.bool,)).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVZX)), + (UPat(dtype=dtypes.int32).cast(dtypes.int64s, name="x"), lambda x: x.replace(op=X86Ops.MOVSXD)), + (UPat(dtype=dtypes.sints).cast(dtypes.ints, name="x"), lambda x: x.replace(op=X86Ops.MOVSX)), + # bitcasts + (UPat(dtype=dtypes.int32s).bitcast(dtypes.float32).named("x"), lambda x: x.replace(op=X86Ops.VMOVD)), + (UPat(dtype=dtypes.int64s).bitcast(dtypes.float64).named("x"), lambda x: x.replace(op=X86Ops.VMOVQ)), + (UPat(dtype=dtypes.float32).bitcast(dtypes.int32s).named("x"), lambda x: x.replace(op=X86Ops.VMOVDm)), + (UPat(dtype=dtypes.float64).bitcast(dtypes.int64s).named("x"), lambda x: x.replace(op=X86Ops.VMOVQm)), + # index + (UPat(Ops.INDEX, name="x"), lambda x: x.replace(op=X86Ops.LEA, src=fuse_address(x))), + # TODO: fuse stores, very few cases -- store cmp becomes setcc, store gep int becomes vpextr, store bitcast to int becomes vmovd/q + # assign, load, store + # NOTE: assign here violates the spec, it only happens in register allocation when a reg to reg move needs to be inserted + (UPat(Ops.ASSIGN, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS)), + (UPat(Ops.ASSIGN, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD)), + (UPat(Ops.ASSIGN, dt_32bit+dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS)), + (UPat(Ops.ASSIGN, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV)), + (UPat(Ops.LOAD, dt_128bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVUPS, src=fuse_address(x.src[0]))), + (UPat(Ops.LOAD, dt_64bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSD, src=fuse_address(x.src[0]))), + (UPat(Ops.LOAD, dt_32bit, name="x"), lambda x: x.replace(op=X86Ops.VMOVSS, src=fuse_address(x.src[0]))), + (UPat(Ops.LOAD, dt_16bit, name="x"), lambda x: x.replace(op=X86Ops.VPINSRW, src=(def_reg(x.dtype, x.arg),) + fuse_address(x.src[0]) + (imm(dtypes.uint8, 0),))), # noqa: E501 + (UPat(Ops.LOAD, dtypes.ints+(dtypes.bool,), name="x"), lambda x: x.replace(op=X86Ops.MOV, src=fuse_address(x.src[0]))), + (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_128bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVUPSm, src=fuse_address(x.src[0]) + (x.src[1],))), # noqa: E501 + (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_64bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSDm, src=fuse_address(x.src[0]) + (x.src[1],))), + (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_32bit)), name="x"), lambda x: x.replace(op=X86Ops.VMOVSSm, src=fuse_address(x.src[0]) + (x.src[1],))), + (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dt_16bit)), name="x"), lambda x: x.replace(op=X86Ops.VPEXTRW, src=fuse_address(x.src[0]) + (x.src[1], imm(dtypes.uint8, 0)))), # noqa: E501 + (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.ints+(dtypes.bool,))), name="x"), lambda x: + x.replace(op=X86Ops.MOVm, src=fuse_address(x.src[0]) + (x.src[1],)) if (i:=to_imm(x.src[1])) is None else x.replace(op=X86Ops.MOVi, src=fuse_address(x.src[0]) + (i,))), # noqa: E501 + # **** X86Op -> X86Op **** + # fuse loads into X86Ops that allow it, if beneficial + (UPat(X86GroupOp.ReadMem1st, src=(UPat(Ops.LOAD),), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 0)), + (UPat(X86GroupOp.ReadMem2nd, src=(UPat(), UPat(Ops.LOAD)), allow_any_len=True, name="x"), lambda ctx,x: fuse_load(ctx, x, 1)), + (UPat(X86GroupOp.ReadMem3rd, src=(UPat(), UPat(), UPat(Ops.LOAD)), name="x"), lambda ctx,x: fuse_load(ctx, x, 2)), + # allocate virtual register to X86Op with special constaints + (UPat(X86GroupOp.All, dtypes.ints+dtypes.floats+(dtypes.bool,), name="x"), lambda ctx,x: + x.replace(arg=ctx.vreg(x.arg)) if isinstance(x.arg, tuple) else None), + # allocate virtual register to X86Op without special constraints + (UPat(X86GroupOp.All, name="x"), lambda ctx,x: + x.replace(arg=ctx.vreg(XMM if x.dtype in dtypes.floats or x.dtype.count > 1 else WGPR)) if x.arg is None and x.dtype != dtypes.void else None), +]) + +# ***** post register allocation ***** + +# TODO: rm after,group +# final rewrite to match the isa spec +post_regalloc_matcher = PatternMatcher([ + # alloc stack space + (UPat(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP, name="x"), lambda ctx,x: + (x, [x, UOp(X86Ops.SUBi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP)]) if ctx.stack_size > 0 else None), + # dealloc stack space + (UPat(X86Ops.RET, name="x"), lambda ctx,x: + (x, [UOp(X86Ops.ADDi, dtypes.uint64, (imm(dtypes.uint32, ctx.stack_size),), RSP), x]) if ctx.stack_size > 0 else None), + # rewrite FRAME_INDEX to IMM now that the stack size is known + (UPat(X86Ops.FRAME_INDEX, name="x"), lambda ctx,x: (nx:=x.replace(op=X86Ops.IMM, arg=ctx.stack_size + x.arg), [nx])), + # rewrite RANGE to MOV reg, 0. Terrible HACK to pass the CONST to the END + (UPat(Ops.RANGE, name="x"), lambda x: (nx:=x.replace(op=X86Ops.MOVi, src=(imm(x.dtype, 0),), tag=x.src[0].arg), [nx])), + # rewrite END to ADD 1 -> CMPLT -> JUMP + (UPat(Ops.END, name="x"), lambda x: + (jl:=x.replace(op=X86Ops.JL, src=(x.src[1], cmp:=UOp(X86Ops.CMPi if isinstance(x.src[1].tag, int) else X86Ops.CMP, + src=(add:=UOp(X86Ops.ADDi, x.src[1].dtype, (imm(x.src[1].dtype, 1),), x.src[1].arg), + imm(x.src[1].dtype, x.src[1].tag) if isinstance(x.src[1].tag, int) else def_reg(x.src[1].dtype, x.src[1].tag))))), [add, cmp, jl])), + # TODO: need a generic way to model clobbers, idiv and flags should be handled the same way, maybe add clobber field to Register? + # fixup div, zero rdx again because scheduling constraint isn't being respected + (UPat(X86Ops.DIV, name="x"), lambda x: + (nx:=x.replace(src=x.src[:1]), [UOp(X86Ops.MOVi, x.dtype, (imm(min(dtypes.uint32, x.dtype), 0),), RDX), nx])), + # rewrite two address instructions to two address form, if reused src wasn't coalesced insert a move + (UPat(X86GroupOp.TwoAddress1st, name="x"), lambda ctx,x: + (nx:=x.replace(src=x.src[1:]), [assign(ctx, x.src[0], x.arg), nx] if x.arg != x.src[0].arg else [nx])), +]) + +# ***** X86 spec ***** +# TODO: do we even want this? +isa_spec = PatternMatcher([ + # these are the only non X86Ops allowed + (UPat((Ops.NOOP, Ops.GROUP, Ops.AFTER, Ops.BARRIER)), lambda: True), + # vblends take a mask which is float or int dtype + (UPat((X86Ops.VPBLENDVB, X86Ops.VBLENDVPS, X86Ops.VBLENDVPD), src=(UPat.var("a"), UPat.var("b"), UPat.var("m")), name="x"), + lambda a,b,m,x: x.dtype == a.dtype == b.dtype and x.dtype.itemsize == m.dtype.itemsize), + # cmoves take a flag producing instruction + (UPat((X86Ops.CMOVB, X86Ops.CMOVL, X86Ops.CMOVE, X86Ops.CMOVNE), dtypes.bool, (UPat(), UPat(), UPat(X86GroupOp.WriteFlags))), lambda: True), + (UPat(X86GroupOp.All), lambda: True), +]) + +# ***** X86 instruction encoding ***** + +def to_bytes(dt:DType, v:int|float): + v = truncate[dt](v) + if dt in dtypes.floats: return struct.pack({dtypes.float16: " 3: address, rest = x.src[:3], x.src[3:] + else: address, rest = (x, None, None), x.src + + elif x.op in X86GroupOp.ReadMem1st or x.op in X86GroupOp.ReadMem2nd and x.op in X86GroupOp.TwoAddress1st: + if len(x.src) > 2: address, rest = x.src[:3], (x,) + x.src[3:] + else: address, rest = (x.src[0], None, None), (x,) + x.src[1:] + + elif x.op in X86GroupOp.ReadMem2nd or x.op in X86GroupOp.ReadMem3rd and x.op in X86GroupOp.TwoAddress1st: + if len(x.src) > 3: address, rest = x.src[1:4], x.src[:1] + x.src[4:] + else: address, rest = (x.src[1], None, None), x.src[:1] + x.src[2:] + if x.dtype is not dtypes.void: rest = (x,) + rest + + else: return None + + # get the encoding values of the different fields + reg_sz = (rest[0].dtype.itemsize if not isinstance(rest[0].dtype, PtrDType) else 8) if reg is None else 0 + reg = cast(Register, rest[0].arg).index if reg is None else reg + vvvv = rest[1].arg.index if len(rest) > 1 and isinstance(rest[1].arg, Register) else 0 + rm = cast(Register, address[0].arg).index + idx = cast(Register, address[1].arg).index if address[1] is not None and address[1].arg is not None else 4 + disp_uop = address[2] + imm_uop = rest[-1] if rest[-1].op is X86Ops.IMM or len(rest) == 3 else None + # TODO: another reason to get rid of ptrs, if we access memory the size should be in scale uop otherwise size is in rm + rm_sz = 8 if isinstance(address[0].dtype, PtrDType) and disp_uop is None else address[0].dtype.itemsize + + # encode instruction + inst = bytes([]) + # PREFIX byte + # there's other uses for this like atomic operations but setting 16bit variant of legacy op is currently the only one + if sel == 0 and (reg_sz == 2 if reg_sz != 0 else rm_sz == 2): inst += bytes([0x66]) + # VEX bytes + assert 0 <= reg <= 15 and 0 <= idx <= 15 and 0 <= rm <= 15 + # r extends reg field, x extends index field, b extends rm or base field + r, _x, b = reg >> 3, idx >> 3, rm >> 3 + if sel: + l = (max(reg_sz, rm_sz) > 16) & 0b1 + if sel == 1 and _x == b == we == 0: inst += bytes([0xC5, (~r & 0b1) << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp]) + else: inst += bytes([0xC4, (~r & 0b1) << 7 | (~_x & 0b1) << 6 | (~b & 0b1) << 5 | sel, we << 7 | (~vvvv & 0b1111) << 3 | l << 2 | pp]) + # REX byte + else: + # bit signaling 64 bit variant of instruction + w = reg_sz == 8 if reg_sz != 0 else rm_sz == 8 + # rex prefix is required when an extended reg is used (index 8 - 15) or lower 8 bits of (rsp, rbp, rsi, rdi) are accessed + if w | r | _x | b | (reg_sz == 1 & reg >> 2) | (rm_sz == 1 & rm >> 2): inst += bytes([0b0100 << 4 | w << 3 | r << 2 | _x << 1 | b]) + # OPCODE byte + # legacy 8bit opcodes are 1 less than 16-64bit versions, with these exceptions + real_opc = opc-1 if (rm_sz == 1 or reg_sz == 1) and x.op not in {X86Ops.SETB, X86Ops.SETE, X86Ops.SETL, X86Ops.SETNE, X86Ops.LEA} else opc + inst += real_opc.to_bytes((real_opc.bit_length() + 7) // 8, 'big') + # MODRM byte + # now we only care about the lower 3 bits + idx, rm, reg = idx & 0b111, rm & 0b111, reg & 0b111 + # 0b00 -- signals memory access with no displacement + # 0b01 -- signals memory access with 8bit displacement + # 0b10 -- signals memory access with 32bit displacement + # 0b11 -- signals no memory access + if disp_uop is not None: + assert disp_uop.dtype in (dtypes.int8, dtypes.int32), "displacement can only be 1 or 4 byte signed int" + # rbp/r13 always require a displacement + if disp_uop.arg != 0 or rm == 0b101: mod = 0b01 if disp_uop.dtype.itemsize == 1 else 0b10 + else: mod = 0b00 + else: mod = 0b11 + # x 0b0 and idx 0b100 means rsp which means no index exists + # rm 0b100 (rsp/r12) signals a sib byte is required, rm then is encoded in the base field of SIB + _rm = rm if idx == 0b100 and _x == 0b0 else 0b100 + inst += bytes([mod << 6 | reg << 3 | _rm]) + # SIB byte + if _rm == 0b100 and mod != 0b11: + scale = {1: 0b00, 2: 0b01, 4: 0b10, 8: 0b11}[1 if idx == 0b100 and _x == 0b0 else rm_sz] + inst += bytes([scale << 6 | idx << 3 | rm]) + # DISP byte + if mod == 0b01 or mod == 0b10: + assert disp_uop is not None + inst += disp_uop.arg.to_bytes(disp_uop.dtype.itemsize, 'little', signed=True) + # IMM byte + if imm_uop is not None: + if isinstance(imm_uop.arg, Register): inst += bytes([(imm_uop.arg.index & 0b1111) << 4 | 0b0000]) + else: inst += to_bytes(imm_uop.dtype, imm_uop.arg) + return inst + +# https://www.felixcloutier.com/x86/ +# NOTE: LEGACY prefix == VEX prefix +# pp field: None == 0, 66 == 1, F3 == 2, F2 == 3 +# map select: 0F == 1, 0F38 == 2, 0F3A == 3 +encodings = PatternMatcher([ + # moves + (UPat(X86Ops.MOVABS, name="x"), lambda x: + bytes([0b0100 << 4 | 0b1 << 3 | 0b00 << 2 | x.arg.index >> 3, 0xB8 + (x.arg.index & 0b111)]) + to_bytes(x.src[0].dtype, x.src[0].arg)), + (UPat(X86Ops.MOV, name="x"), lambda x: encode(x, 0x8B)), (UPat(X86Ops.MOVi, name="x"), lambda x: encode(x, 0xC7, reg=0)), + (UPat(X86Ops.MOVm, name="x"), lambda x: encode(x, 0x89)), (UPat(X86Ops.LEA, name="x"), lambda x: encode(x, 0x8D)), + (UPat(X86Ops.VMOVSS, name="x"), lambda x: encode(x, 0x10, pp=2, sel=1)), (UPat(X86Ops.VMOVSSm, name="x"), lambda x: encode(x, 0x11, pp=2, sel=1)), + (UPat(X86Ops.VMOVSD, name="x"), lambda x: encode(x, 0x10, pp=3, sel=1)), (UPat(X86Ops.VMOVSDm, name="x"), lambda x: encode(x, 0x11, pp=3, sel=1)), + (UPat(X86Ops.VMOVUPS, name="x"), lambda x: encode(x, 0x10, pp=0, sel=1)), (UPat(X86Ops.VMOVUPSm, name="x"), lambda x: encode(x, 0x11, pp=0, sel=1)), # noqa: E501 + (UPat(X86Ops.VMOVD, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1)), (UPat(X86Ops.VMOVQ, name="x"), lambda x: encode(x, 0x6E, pp=1, sel=1, we=1)), # noqa: E501 + (UPat(X86Ops.VMOVDm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1)), (UPat(X86Ops.VMOVQm, name="x"), lambda x: encode(x, 0x7E, pp=1, sel=1, we=1)), # noqa: E501 + # casts + (UPat(X86Ops.MOVZX, name="x"), lambda x: encode(x, 0x0FB7)), + (UPat(X86Ops.MOVSX, name="x"), lambda x: encode(x, 0x0FBF)), (UPat(X86Ops.MOVSXD, name="x"), lambda x: encode(x, 0x63)), + (UPat(X86Ops.VPMOVZXBW, name="x"), lambda x: encode(x, 0x30, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXBD, name="x"), lambda x: encode(x, 0x31, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMOVZXBQ, name="x"), lambda x: encode(x, 0x32, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXWD, name="x"), lambda x: encode(x, 0x33, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMOVZXWQ, name="x"), lambda x: encode(x, 0x34, pp=1, sel=2)), (UPat(X86Ops.VPMOVZXDQ, name="x"), lambda x: encode(x, 0x35, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMOVSXBW, name="x"), lambda x: encode(x, 0x20, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXBD, name="x"), lambda x: encode(x, 0x21, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMOVSXBQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXWD, name="x"), lambda x: encode(x, 0x23, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMOVSXWQ, name="x"), lambda x: encode(x, 0x24, pp=1, sel=2)), (UPat(X86Ops.VPMOVSXDQ, name="x"), lambda x: encode(x, 0x25, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VCVTSS2SD, name="x"), lambda x: encode(x, 0x5A, pp=2, sel=1)), (UPat(X86Ops.VCVTSD2SS, name="x"), lambda x: encode(x, 0x5A, pp=3, sel=1)), # noqa: E501 + (UPat(X86Ops.VCVTPH2PS, name="x"), lambda x: encode(x, 0x13, pp=1, sel=2)), (UPat(X86Ops.VCVTPS2PH, name="x"), lambda x: encode(x, 0x1D, pp=1, sel=3)), # noqa: E501 + (UPat(X86Ops.VCVTDQ2PS, name="x"), lambda x: encode(x, 0x5B, pp=0, sel=1)), (UPat(X86Ops.VCVTDQ2PD, name="x"), lambda x: encode(x, 0xE6, pp=2, sel=1)), # noqa: E501 + (UPat(X86Ops.VCVTPS2PD, name="x"), lambda x: encode(x, 0x5A, pp=0, sel=1)), (UPat(X86Ops.VCVTPD2PS, name="x"), lambda x: encode(x, 0x5A, pp=1, sel=1)), # noqa: E501 + (UPat(X86Ops.VCVTTPS2DQ, name="x"), lambda x: encode(x, 0x5B, pp=2, sel=1)), (UPat(X86Ops.VCVTTPD2DQ, name="x"), lambda x: encode(x, 0xE6, pp=1, sel=1)), # noqa: E501 + (UPat(X86Ops.VCVTSI2SS, name="x"), lambda x: encode(x, 0x2A, pp=2, sel=1, we=x.src[1].dtype.base is dtypes.int64)), + (UPat(X86Ops.VCVTSI2SD, name="x"), lambda x: encode(x, 0x2A, pp=3, sel=1, we=x.src[1].dtype.base is dtypes.int64)), + (UPat(X86Ops.VCVTTSS2SI, name="x"), lambda x: encode(x, 0x2C, pp=2, sel=1, we=x.dtype in dtypes.int64s)), + (UPat(X86Ops.VCVTTSD2SI, name="x"), lambda x: encode(x, 0x2C, pp=3, sel=1, we=x.dtype in dtypes.int64s)), + # int division + (UPat(X86Ops.IDIV, name="x"), lambda x: encode(x, 0xF7, reg=7)), (UPat(X86Ops.DIV, name="x"), lambda x: encode(x, 0xF7, reg=6)), + # scalar int binary + (UPat(X86Ops.SHLi, name="x"), lambda x: encode(x, 0xC1, reg=4)), + (UPat(X86Ops.SHRi, name="x"), lambda x: encode(x, 0xC1, reg=5)), (UPat(X86Ops.SARi, name="x"), lambda x: encode(x, 0xC1, reg=7)), + (UPat(X86Ops.ADD, name="x"), lambda x: encode(x, 0x03)), (UPat(X86Ops.ADDi, name="x"), lambda x: encode(x, 0x81, reg=0)), + (UPat(X86Ops.SUB, name="x"), lambda x: encode(x, 0x2B)), (UPat(X86Ops.SUBi, name="x"), lambda x: encode(x, 0x81, reg=5)), + (UPat(X86Ops.AND, name="x"), lambda x: encode(x, 0x23)), (UPat(X86Ops.ANDi, name="x"), lambda x: encode(x, 0x81, reg=4)), + (UPat(X86Ops.XOR, name="x"), lambda x: encode(x, 0x33)), (UPat(X86Ops.XORi, name="x"), lambda x: encode(x, 0x81, reg=6)), + (UPat(X86Ops.OR, name="x"), lambda x: encode(x, 0x0B)), (UPat(X86Ops.ORi, name="x"), lambda x: encode(x, 0x81, reg=1)), + (UPat(X86Ops.CMP, name="x"), lambda x: encode(x, 0x3B)), (UPat(X86Ops.CMPi, name="x"), lambda x: encode(x, 0x81, reg=7)), + (UPat(X86Ops.IMUL, name="x"), lambda x: encode(x, 0x0FAF)), (UPat(X86Ops.IMULi, name="x"), lambda x: encode(x, 0x69)), + (UPat(X86Ops.SETB, name="x"), lambda x: encode(x, 0x0F92, reg=0)), (UPat(X86Ops.SETL, name="x"), lambda x: encode(x, 0x0F9C, reg=0)), + (UPat(X86Ops.SETE, name="x"), lambda x: encode(x, 0x0F94, reg=0)), (UPat(X86Ops.SETNE, name="x"), lambda x: encode(x, 0x0F95, reg=0)), + # packed bitwise NOTE: only bitwise and packed + (UPat(X86Ops.VPAND, name="x"), lambda x: encode(x, 0xDB, pp=1, sel=1)), (UPat(X86Ops.VPXOR, name="x"), lambda x: encode(x, 0xEF, pp=1, sel=1)), + (UPat(X86Ops.VPOR, name="x"), lambda x: encode(x, 0xEB, pp=1, sel=1)), + # unary + (UPat(X86Ops.VSQRTSS, name="x"), lambda x: encode(x, 0x51, pp=2, sel=1)), (UPat(X86Ops.VSQRTPS, name="x"), lambda x: encode(x, 0x51, pp=0, sel=1)), + (UPat(X86Ops.VSQRTSD, name="x"), lambda x: encode(x, 0x51, pp=3, sel=1)), (UPat(X86Ops.VSQRTPD, name="x"), lambda x: encode(x, 0x51, pp=1, sel=1)), + (UPat(X86Ops.VROUNDSS, name="x"), lambda x: encode(x, 0x0A, pp=1, sel=3)), (UPat(X86Ops.VROUNDPS, name="x"), lambda x: encode(x, 0x08, pp=1, sel=3)), # noqa: E501 + (UPat(X86Ops.VROUNDSD, name="x"), lambda x: encode(x, 0x0B, pp=1, sel=3)), (UPat(X86Ops.VROUNDPD, name="x"), lambda x: encode(x, 0x09, pp=1, sel=3)), # noqa: E501 + # packed int binary + (UPat(X86Ops.VPSLLVD, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2)), (UPat(X86Ops.VPSLLVQ, name="x"), lambda x: encode(x, 0x47, pp=1, sel=2, we=1)), # noqa: E501 + (UPat(X86Ops.VPSRLVD, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2)), (UPat(X86Ops.VPSRLVQ, name="x"), lambda x: encode(x, 0x45, pp=1, sel=2, we=1)), # noqa: E501 + (UPat(X86Ops.VPCMPGTB, name="x"), lambda x: encode(x, 0x64, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTW, name="x"), lambda x: encode(x, 0x65, pp=1, sel=1)), # noqa: E501 + (UPat(X86Ops.VPCMPGTD, name="x"), lambda x: encode(x, 0x66, pp=1, sel=1)), (UPat(X86Ops.VPCMPGTQ, name="x"), lambda x: encode(x, 0x37, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPCMPEQB, name="x"), lambda x: encode(x, 0x74, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQW, name="x"), lambda x: encode(x, 0x75, pp=1, sel=1)), # noqa: E501 + (UPat(X86Ops.VPCMPEQD, name="x"), lambda x: encode(x, 0x76, pp=1, sel=1)), (UPat(X86Ops.VPCMPEQQ, name="x"), lambda x: encode(x, 0x29, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPMULLW, name="x"), lambda x: encode(x, 0xD5, pp=1, sel=1)), (UPat(X86Ops.VPMULLD, name="x"), lambda x: encode(x, 0x40, pp=1, sel=2)), + (UPat(X86Ops.VPADDB, name="x"), lambda x: encode(x, 0xFC, pp=1, sel=1)), (UPat(X86Ops.VPADDW, name="x"), lambda x: encode(x, 0xFD, pp=1, sel=1)), + (UPat(X86Ops.VPADDD, name="x"), lambda x: encode(x, 0xFE, pp=1, sel=1)), (UPat(X86Ops.VPADDQ, name="x"), lambda x: encode(x, 0xD4, pp=1, sel=1)), + (UPat(X86Ops.VPSUBB, name="x"), lambda x: encode(x, 0xF8, pp=1, sel=1)), (UPat(X86Ops.VPSUBW, name="x"), lambda x: encode(x, 0xF9, pp=1, sel=1)), + (UPat(X86Ops.VPSUBD, name="x"), lambda x: encode(x, 0xFA, pp=1, sel=1)), (UPat(X86Ops.VPSUBQ, name="x"), lambda x: encode(x, 0xFB, pp=1, sel=1)), + (UPat(X86Ops.VPSRAVD, name="x"), lambda x: encode(x, 0x46, pp=1, sel=2)), + # float cmp + (UPat(X86Ops.VUCOMISS, name="x"), lambda x: encode(x, 0x2E, pp=0, sel=1)), (UPat(X86Ops.VUCOMISD, name="x"), lambda x: encode(x, 0x2E, pp=1, sel=1)), # noqa: E501 + # scalar / packed float binary + (UPat(X86Ops.VADDSS, name="x"), lambda x: encode(x, 0x58, pp=2, sel=1)), (UPat(X86Ops.VADDPS, name="x"), lambda x: encode(x, 0x58, pp=0, sel=1)), + (UPat(X86Ops.VADDSD, name="x"), lambda x: encode(x, 0x58, pp=3, sel=1)), (UPat(X86Ops.VADDPD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=1)), + (UPat(X86Ops.VSUBSS, name="x"), lambda x: encode(x, 0x5C, pp=2, sel=1)), (UPat(X86Ops.VSUBPS, name="x"), lambda x: encode(x, 0x5C, pp=0, sel=1)), + (UPat(X86Ops.VSUBSD, name="x"), lambda x: encode(x, 0x5C, pp=3, sel=1)), (UPat(X86Ops.VSUBPD, name="x"), lambda x: encode(x, 0x5C, pp=1, sel=1)), + (UPat(X86Ops.VMULSS, name="x"), lambda x: encode(x, 0x59, pp=2, sel=1)), (UPat(X86Ops.VMULPS, name="x"), lambda x: encode(x, 0x59, pp=0, sel=1)), + (UPat(X86Ops.VMULSD, name="x"), lambda x: encode(x, 0x59, pp=3, sel=1)), (UPat(X86Ops.VMULPD, name="x"), lambda x: encode(x, 0x59, pp=1, sel=1)), + (UPat(X86Ops.VDIVSS, name="x"), lambda x: encode(x, 0x5E, pp=2, sel=1)), (UPat(X86Ops.VDIVPS, name="x"), lambda x: encode(x, 0x5E, pp=0, sel=1)), + (UPat(X86Ops.VDIVSD, name="x"), lambda x: encode(x, 0x5E, pp=3, sel=1)), (UPat(X86Ops.VDIVPD, name="x"), lambda x: encode(x, 0x5E, pp=1, sel=1)), + (UPat(X86Ops.VCMPSS, name="x"), lambda x: encode(x, 0xC2, pp=2, sel=1)), (UPat(X86Ops.VCMPPS, name="x"), lambda x: encode(x, 0xC2, pp=0, sel=1)), + (UPat(X86Ops.VCMPSD, name="x"), lambda x: encode(x, 0xC2, pp=3, sel=1)), (UPat(X86Ops.VCMPPD, name="x"), lambda x: encode(x, 0xC2, pp=1, sel=1)), + (UPat(X86Ops.VMAXSS, name="x"), lambda x: encode(x, 0x5F, pp=2, sel=1)), (UPat(X86Ops.VMAXPS, name="x"), lambda x: encode(x, 0x5F, pp=0, sel=1)), + (UPat(X86Ops.VMAXSD, name="x"), lambda x: encode(x, 0x5F, pp=3, sel=1)), (UPat(X86Ops.VMAXPD, name="x"), lambda x: encode(x, 0x5F, pp=1, sel=1)), + (UPat(X86Ops.VMINSS, name="x"), lambda x: encode(x, 0x5D, pp=2, sel=1)), (UPat(X86Ops.VMINPS, name="x"), lambda x: encode(x, 0x5D, pp=0, sel=1)), + (UPat(X86Ops.VMINSD, name="x"), lambda x: encode(x, 0x5D, pp=3, sel=1)), (UPat(X86Ops.VMINPD, name="x"), lambda x: encode(x, 0x5D, pp=1, sel=1)), + # ternary + (UPat(X86Ops.CMOVB, name="x"), lambda x: encode(x, 0x0F42)), (UPat(X86Ops.CMOVL, name="x"), lambda x: encode(x, 0x0F4C)), + (UPat(X86Ops.CMOVE, name="x"), lambda x: encode(x, 0x0F44)), (UPat(X86Ops.CMOVNE, name="x"), lambda x: encode(x, 0x0F45)), + (UPat(X86Ops.VFMADD213SS, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2)), (UPat(X86Ops.VFMADD213SD, name="x"), lambda x: encode(x, 0xA9, pp=1, sel=2, we=1)), # noqa: E501 + (UPat(X86Ops.VFMADD213PS, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2)), (UPat(X86Ops.VFMADD213PD, name="x"), lambda x: encode(x, 0xA8, pp=1, sel=2, we=1)), # noqa: E501 + (UPat(X86Ops.VBLENDVPS, name="x"), lambda x: encode(x, 0x4A, pp=1, sel=3)), (UPat(X86Ops.VBLENDVPD, name="x"), lambda x: encode(x, 0x4B, pp=1, sel=3)), # noqa: E501 + (UPat(X86Ops.VPBLENDVB, name="x"), lambda x: encode(x, 0x4C, pp=1, sel=3)), + # shuffles + (UPat(X86Ops.VPBROADCASTB, name="x"), lambda x: encode(x, 0x78, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTW, name="x"), lambda x: encode(x, 0x79, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VPBROADCASTD, name="x"), lambda x: encode(x, 0x58, pp=1, sel=2)), (UPat(X86Ops.VPBROADCASTQ, name="x"), lambda x: encode(x, 0x59, pp=1, sel=2)), # noqa: E501 + (UPat(X86Ops.VBROADCASTSS, name="x"), lambda x: encode(x, 0x18, pp=1, sel=2)), + (UPat(X86Ops.VPINSRB, name="x"), lambda x: encode(x, 0x20, pp=1, sel=3)), (UPat(X86Ops.VPINSRW, name="x"), lambda x: encode(x, 0xC4, pp=1, sel=1)), + (UPat(X86Ops.VPINSRD, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3)), (UPat(X86Ops.VPINSRQ, name="x"), lambda x: encode(x, 0x22, pp=1, sel=3, we=1)), # noqa: E501 + (UPat(X86Ops.VSHUFPS, name="x"), lambda x: encode(x, 0xC6, pp=0, sel=1)), (UPat(X86Ops.VINSERTPS, name="x"), lambda x: encode(x, 0x21, pp=1, sel=3)), # noqa: E501 + # extract + (UPat(X86Ops.VPEXTRB, name="x"), lambda x: encode(x, 0x14, pp=1, sel=3)), (UPat(X86Ops.VPEXTRW, name="x"), lambda x: encode(x, 0x15, pp=1, sel=3)), + (UPat(X86Ops.VPEXTRD, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3)), (UPat(X86Ops.VPEXTRQ, name="x"), lambda x: encode(x, 0x16, pp=1, sel=3, we=1)), # noqa: E501 + # jumps are encoded with a placeholder which gets patched later once the real offset is known + (UPat(X86Ops.JE), lambda: bytes([0x0F, 0x84]) + int(0).to_bytes(4, 'little', signed=True)), + (UPat(X86Ops.JNE), lambda: bytes([0x0F, 0x85]) + int(0).to_bytes(4, 'little', signed=True)), + (UPat(X86Ops.JL), lambda: bytes([0x0F, 0x8C]) + int(0).to_bytes(4, 'little', signed=True)), + (UPat(X86Ops.JB), lambda: bytes([0x0F, 0x82]) + int(0).to_bytes(4, 'little', signed=True)), + # return + (UPat(X86Ops.RET), lambda: bytes([0xC3])), +]) + +class X86Renderer(ISARenderer): + device = "CPU" + has_local = False + has_threads = bool(getenv("THREADS", 1)) + global_max = (CPU_COUNT.value, 0, 0) + extra_matcher = extra_matcher + pre_isel_matcher = pre_isel_matcher + isel_matcher = isel_matcher + post_regalloc_matcher = post_regalloc_matcher + isa_spec = isa_spec + code_for_op = {x: lambda: None for x in (Ops.SQRT, Ops.AND, Ops.OR, Ops.SHL, Ops.SHR, Ops.NEG, Ops.SUB, Ops.FDIV, Ops.CMPLT, Ops.CMPEQ, Ops.MAX)} + def __init__(self): + from tinygrad.runtime.support.compiler_cpu import X86Compiler + self.compiler = X86Compiler() + def stack_pointer(self) -> UOp: return UOp(X86Ops.DEFINE_REG, dtypes.uint64, arg=RSP) + def render(self, uops:list[UOp], lower:bool=True) -> str: + if lower: uops = self.lower(uops[-1]) + targets: set[UOp] = set() + target_loc: list[int] = [] + binary = bytearray() + for u in uops: + if u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): targets.add(u.src[0]) + for u in uops: + if u.op in (Ops.GROUP, Ops.NOOP, Ops.AFTER, Ops.BARRIER): continue + if u.op in (X86Ops.IMM, X86Ops.DEFINE_REG): continue + if (l:=cast(bytes|None, encodings.rewrite(u))) is None: + raise RuntimeError(f"failed to encode {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") + binary.extend(l) + if u in targets: target_loc.append(len(binary)) + elif u.op in (X86Ops.JL, X86Ops.JB, X86Ops.JE, X86Ops.JNE): + binary[-4:] = (target_loc.pop() - len(binary)).to_bytes(4, 'little', signed=True) + return binary.hex() \ No newline at end of file diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 682266e858633..d6e9576ffe7d8 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,13 +1,14 @@ from __future__ import annotations import platform, sys, ctypes, functools, time, mmap, threading, queue from tinygrad.helpers import to_mv, OSX, WIN, mv_address, wait_cond, suppress_finalizing, unwrap, data64_le -from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM +from tinygrad.helpers import CPU_CC, CPU_LVP, CPU_LLVM, CPU_X86 from tinygrad.device import BufferSpec, DMACPURef, CompilerSet from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface from tinygrad.runtime.support.hcq import CLikeArgsState from tinygrad.renderer.cstyle import ClangJITRenderer from tinygrad.renderer.llvmir import CPULLVMRenderer from tinygrad.renderer.nir import LVPRenderer +from tinygrad.renderer.isa.x86 import X86Renderer from tinygrad.runtime.support.elf import jit_loader from tinygrad.uop.ops import sint @@ -133,5 +134,5 @@ class CPUDevice(HCQCompiled): def __init__(self, device:str=""): self.tasks:queue.Queue = queue.Queue() CPUWorker(self, self.tasks, thread_id=0).start() - compilers = CompilerSet([(ClangJITRenderer, None), (CPULLVMRenderer, CPU_LLVM), (LVPRenderer, CPU_LVP)], ctrl_var=CPU_CC) + compilers = CompilerSet([(ClangJITRenderer, None), (CPULLVMRenderer, CPU_LLVM), (LVPRenderer, CPU_LVP), (X86Renderer, CPU_X86)], ctrl_var=CPU_CC) super().__init__(device, CPUAllocator(self), compilers, functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue) diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index 5dd8450e46ca1..53c5958e135a8 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -91,3 +91,8 @@ def __init__(self, cache_key=None): # +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures()) super().__init__(cpu.decode(), feats.decode(), cache_key) + +class X86Compiler(Compiler): + def __init__(self): super().__init__(None) + def compile(self, src:str) -> bytes: return bytes.fromhex(src) + def disassemble(self, lib:bytes): return capstone_flatdump(lib) \ No newline at end of file diff --git a/tinygrad/uop/__init__.py b/tinygrad/uop/__init__.py index 77722880737db..e2b3cad907e5f 100644 --- a/tinygrad/uop/__init__.py +++ b/tinygrad/uop/__init__.py @@ -1,9 +1,14 @@ # flake8: noqa: E702 # allow semicolons to put multiple ops on one line -from enum import auto, IntEnum, Enum +from enum import auto, IntEnum, Enum, EnumType + +# wrapper around EnumType to allow extending enums with members +class ExtensibleEnumType(EnumType): + @classmethod + def _check_for_existing_members_(mcls, class_name, bases): return # wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses -class FastEnum(IntEnum): +class FastEnum(IntEnum, metaclass=ExtensibleEnumType): def __str__(self): return Enum.__str__(self) def __repr__(x): return str(x) @staticmethod diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 19b6b75c4b47d..66c314395fd27 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -295,6 +295,9 @@ def _shape(self) -> tuple[sint, ...]|None: if not all_same(input_shapes): raise RuntimeError(f"shape mismatch at {self.op}: {input_shapes}") return input_shapes[0] + # backend ops don't have a shape + if self.op not in GroupOp.All: return None + # all Ops must be explicitly handled raise NotImplementedError(f"no shape handling for {self.op} with {self.dtype}") @@ -894,11 +897,11 @@ def get_location() -> tuple[str, int]: class UPat(OpMixin): __slots__ = ("op", "match_dtype", "arg", "name", "src", "is_any") - def __init__(self, op:Ops|tuple[Ops, ...]|set[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None, + def __init__(self, op:Ops|Iterable[Ops]|None=None, dtype:DType|tuple[DType, ...]|set[DType]|None=None, src:tuple[UPat, ...]|list[UPat]|UPat|None=None, arg:Any=None, name:str|None=None, allow_any_len:bool=False, custom_early_reject:set[Ops]|None=None, location=None, is_any:bool=False): assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops" - self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op) + self.op: tuple[Ops, ...]|None = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, Iterable) else op) self.match_dtype: tuple[DType, ...]|None = (dtype,) if isinstance(dtype, DType) else (tuple(dtype) if isinstance(dtype, set) else dtype) self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None