diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSADialect.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSADialect.td index e3ddd344acb7..d8b8b594a70c 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSADialect.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSADialect.td @@ -30,6 +30,11 @@ def DXSADialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + // Defined in DXSAOperand.cpp where GET_ATTRDEF_CLASSES is instantiated. + void registerAttributes(); + }]; } #endif // DXSA_DIALECT_TD diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td new file mode 100644 index 000000000000..3f5501dbaaca --- /dev/null +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAFPArithOps.td @@ -0,0 +1,59 @@ +//===- DXSAFPArithOps.td - DXSA float arithmetic ops ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Floating-point arithmetic instructions of the DXSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_DXSA_IR_DXSAFPARITHOPS +#define MLIR_DIALECT_DXSA_IR_DXSAFPARITHOPS + +include "mlir/Dialect/DXSA/IR/DXSAOpBase.td" + +//===----------------------------------------------------------------------===// +// dxsa.add +//===----------------------------------------------------------------------===// + +def DXSA_Add : DXSA_BinaryOp<"add"> { + let summary = "component-wise floating-point add"; + let description = [{ + The `dxsa.add` operation computes the component-wise floating-point + sum `$dst = $lhs + $rhs`. + + Example: + + ```mlir + dxsa.add r<0>, r<1>, r<2> + dxsa.add r<0>, -r<1>, -|r<2>| + dxsa.add r<0, >, r<1, >, |r<2, >| + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// dxsa.add_sat +//===----------------------------------------------------------------------===// + +def DXSA_AddSat : DXSA_BinaryOp<"add_sat"> { + let summary = "component-wise floating-point add, saturated to [0, 1]"; + let description = [{ + The `dxsa.add_sat` operation computes the component-wise floating-point + sum of `$lhs` and `$rhs`, clamps each component of the result to + `[0.0, 1.0]`, and writes it to `$dst`. + + Example: + + ```mlir + dxsa.add_sat r<0>, r<1>, r<2> + dxsa.add_sat r<0>, -r<1>, -|r<2>| + dxsa.add_sat r<0, >, r<1, >, |r<2, >| + ``` + }]; +} + +#endif // MLIR_DIALECT_DXSA_IR_DXSAFPARITHOPS diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td new file mode 100644 index 000000000000..0435e4c88e25 --- /dev/null +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOpBase.td @@ -0,0 +1,41 @@ +//===- DXSAOpBase.td - DXSA op base classes --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Base classes shared by the DXSA operation families. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_DXSA_IR_DXSAOPBASE +#define MLIR_DIALECT_DXSA_IR_DXSAOPBASE + +include "mlir/Dialect/DXSA/IR/DXSADialect.td" +include "mlir/Dialect/DXSA/IR/DXSAOperand.td" + +//===----------------------------------------------------------------------===// +// DXSA op base class +//===----------------------------------------------------------------------===// + +// Base class for all operations in this dialect. +class DXSA_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// DXSA shared bases for ops with inline operands +//===----------------------------------------------------------------------===// + +class DXSA_BinaryOp : DXSA_Op { + let arguments = (ins + DXSA_DstOperandAttr:$dst, + DXSA_SrcOperandAttr:$lhs, + DXSA_SrcOperandAttr:$rhs, + OptionalAttr:$precise); + let results = (outs); + let assemblyFormat = "$dst `,` $lhs `,` $rhs attr-dict"; +} + +#endif // MLIR_DIALECT_DXSA_IR_DXSAOPBASE diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOperand.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOperand.td new file mode 100644 index 000000000000..3089518a29b8 --- /dev/null +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOperand.td @@ -0,0 +1,563 @@ +//===- DXSAOperand.td - DXSA operand attributes ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attributes describing a fully decoded dialect operand. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_DXSA_IR_DXSAOPERAND +#define MLIR_DIALECT_DXSA_IR_DXSAOPERAND + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/DXSA/IR/DXSADialect.td" + +//===----------------------------------------------------------------------===// +// DXSA operand type +//===----------------------------------------------------------------------===// + +def DXSA_OperandType_Temp : I32EnumAttrCase<"r", 0>; +def DXSA_OperandType_Input : I32EnumAttrCase<"v", 1>; +def DXSA_OperandType_Output : I32EnumAttrCase<"o", 2>; +def DXSA_OperandType_IndexableTemp : I32EnumAttrCase<"x", 3>; +def DXSA_OperandType_Immediate32 : I32EnumAttrCase<"l", 4>; +def DXSA_OperandType_Immediate64 : I32EnumAttrCase<"d", 5>; +def DXSA_OperandType_Sampler : I32EnumAttrCase<"s", 6>; +def DXSA_OperandType_Resource : I32EnumAttrCase<"t", 7>; +def DXSA_OperandType_ConstantBuffer : I32EnumAttrCase<"cb", 8>; +def DXSA_OperandType_ImmediateConstantBuffer : I32EnumAttrCase<"icb", 9>; +def DXSA_OperandType_Label : I32EnumAttrCase<"label", 10>; +def DXSA_OperandType_InputPrimitiveId : I32EnumAttrCase<"vPrim", 11>; +def DXSA_OperandType_OutputDepth : I32EnumAttrCase<"oDepth", 12>; +def DXSA_OperandType_Null : I32EnumAttrCase<"null", 13>; +def DXSA_OperandType_Rasterizer : I32EnumAttrCase<"rasterizer", 14>; +def DXSA_OperandType_OutputCoverageMask : I32EnumAttrCase<"oMask", 15>; +def DXSA_OperandType_Stream : I32EnumAttrCase<"m", 16>; +def DXSA_OperandType_FunctionBody : I32EnumAttrCase<"fb", 17>; +def DXSA_OperandType_FunctionTable : I32EnumAttrCase<"ft", 18>; +def DXSA_OperandType_Interface : I32EnumAttrCase<"fp", 19>; +def DXSA_OperandType_FunctionInput : I32EnumAttrCase<"funcInput", 20>; +def DXSA_OperandType_FunctionOutput : I32EnumAttrCase<"funcOutput", 21>; +def DXSA_OperandType_OutputControlPointId : I32EnumAttrCase<"vOutputControlPointID", 22>; +def DXSA_OperandType_InputForkInstanceId : I32EnumAttrCase<"vForkInstanceID", 23>; +def DXSA_OperandType_InputJoinInstanceId : I32EnumAttrCase<"vJoinInstanceID", 24>; +def DXSA_OperandType_InputControlPoint : I32EnumAttrCase<"vicp", 25>; +def DXSA_OperandType_OutputControlPoint : I32EnumAttrCase<"vocp", 26>; +def DXSA_OperandType_InputPatchConstant : I32EnumAttrCase<"vpc", 27>; +def DXSA_OperandType_InputDomainPoint : I32EnumAttrCase<"vDomain", 28>; +def DXSA_OperandType_ThisPointer : I32EnumAttrCase<"thisPtr", 29>; +def DXSA_OperandType_Uav : I32EnumAttrCase<"u", 30>; +def DXSA_OperandType_ThreadGroupSharedMemory : I32EnumAttrCase<"g", 31>; +def DXSA_OperandType_InputThreadId : I32EnumAttrCase<"vThreadID", 32>; +def DXSA_OperandType_InputThreadGroupId : I32EnumAttrCase<"vThreadGroupID", 33>; +def DXSA_OperandType_InputThreadIdInGroup : I32EnumAttrCase<"vThreadIDInGroup", 34>; +def DXSA_OperandType_InputCoverageMask : I32EnumAttrCase<"vCoverage", 35>; +def DXSA_OperandType_InputThreadIdInGroupFlattened : I32EnumAttrCase<"vThreadIDInGroupFlattened", 36>; +def DXSA_OperandType_InputGsInstanceId : I32EnumAttrCase<"vGSInstanceID", 37>; +def DXSA_OperandType_OutputDepthGe : I32EnumAttrCase<"oDepthGE", 38>; +def DXSA_OperandType_OutputDepthLe : I32EnumAttrCase<"oDepthLE", 39>; +def DXSA_OperandType_CycleCounter : I32EnumAttrCase<"cycleCounter", 40>; +def DXSA_OperandType_OutputStencilRef : I32EnumAttrCase<"oStencilRef", 41>; +def DXSA_OperandType_InnerCoverage : I32EnumAttrCase<"vInnerCoverage", 42>; + +def DXSA_OperandType : I32EnumAttr< + "OperandType", "operand type", [ + DXSA_OperandType_Temp, + DXSA_OperandType_Input, + DXSA_OperandType_Output, + DXSA_OperandType_IndexableTemp, + DXSA_OperandType_Immediate32, + DXSA_OperandType_Immediate64, + DXSA_OperandType_Sampler, + DXSA_OperandType_Resource, + DXSA_OperandType_ConstantBuffer, + DXSA_OperandType_ImmediateConstantBuffer, + DXSA_OperandType_Label, + DXSA_OperandType_InputPrimitiveId, + DXSA_OperandType_OutputDepth, + DXSA_OperandType_Null, + DXSA_OperandType_Rasterizer, + DXSA_OperandType_OutputCoverageMask, + DXSA_OperandType_Stream, + DXSA_OperandType_FunctionBody, + DXSA_OperandType_FunctionTable, + DXSA_OperandType_Interface, + DXSA_OperandType_FunctionInput, + DXSA_OperandType_FunctionOutput, + DXSA_OperandType_OutputControlPointId, + DXSA_OperandType_InputForkInstanceId, + DXSA_OperandType_InputJoinInstanceId, + DXSA_OperandType_InputControlPoint, + DXSA_OperandType_OutputControlPoint, + DXSA_OperandType_InputPatchConstant, + DXSA_OperandType_InputDomainPoint, + DXSA_OperandType_ThisPointer, + DXSA_OperandType_Uav, + DXSA_OperandType_ThreadGroupSharedMemory, + DXSA_OperandType_InputThreadId, + DXSA_OperandType_InputThreadGroupId, + DXSA_OperandType_InputThreadIdInGroup, + DXSA_OperandType_InputCoverageMask, + DXSA_OperandType_InputThreadIdInGroupFlattened, + DXSA_OperandType_InputGsInstanceId, + DXSA_OperandType_OutputDepthGe, + DXSA_OperandType_OutputDepthLe, + DXSA_OperandType_CycleCounter, + DXSA_OperandType_OutputStencilRef, + DXSA_OperandType_InnerCoverage + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +//===----------------------------------------------------------------------===// +// DXSA component mask (destination writemask) +//===----------------------------------------------------------------------===// + +def DXSA_ComponentMask_X : I32BitEnumAttrCaseBit<"x", 0>; +def DXSA_ComponentMask_Y : I32BitEnumAttrCaseBit<"y", 1>; +def DXSA_ComponentMask_Z : I32BitEnumAttrCaseBit<"z", 2>; +def DXSA_ComponentMask_W : I32BitEnumAttrCaseBit<"w", 3>; + +def DXSA_ComponentMask : I32BitEnumAttr< + "ComponentMask", "destination writemask (subset of x, y, z, w)", [ + DXSA_ComponentMask_X, + DXSA_ComponentMask_Y, + DXSA_ComponentMask_Z, + DXSA_ComponentMask_W + ]> { + let separator = ", "; + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_ComponentMaskAttr : + EnumAttr { + let summary = "destination writemask (subset of x, y, z, w)"; + let description = [{ + The `#dxsa.component_mask` attribute is the destination writemask + of a 4-component operand. It records which components of the + 4-component register are written by the instruction, as a + non-empty subset of `{x, y, z, w}` in canonical order. + + Each entry names one component of the register: + + - `x` enables writing the first (X) component. + - `y` enables writing the second (Y) component. + - `z` enables writing the third (Z) component. + - `w` enables writing the fourth (W) component. + + Combinations are written as a comma-separated list and always + print in canonical `x, y, z, w` order, regardless of how the + bits were set in the source binary. + + Example: + + ```mlir + // write all components + // write three components, fourth left unchanged + // scalar write into the first component + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// DXSA Swizzle attribute (ordered component selection) +//===----------------------------------------------------------------------===// + +def DXSA_SwizzleAttr : AttrDef { + let mnemonic = "swizzle"; + let summary = "source component selector (1 or 4 of x, y, z, w)"; + let description = [{ + The `#dxsa.swizzle` attribute is an ordered component selector on + a 4-component source operand. It is a list of 1 or 4 components, + each one of `x`, `y`, `z`, `w`. Entries may repeat. + + Each entry names which source component is read at that position: + + - `x` reads the first (X) source component. + - `y` reads the second (Y) source component. + - `z` reads the third (Z) source component. + - `w` reads the fourth (W) source component. + + The size of the list encodes the component-selection mode of the + source: + + - 4 entries form a full swizzle. The list defines the mapping of + destination components `x, y, z, w` onto source components, in + that order. + - 1 entry selects a single component. Inside source operand + a 1-component list is a select and a 4-component list is a swizzle. + + Example: + + ```mlir + // identity swizzle (no rearrangement) + // swap pairs + // broadcast first component to all destination components + // single-component select + ``` + }]; + let parameters = (ins ArrayRefParameter<"int64_t">:$components); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +//===----------------------------------------------------------------------===// +// DXSA source operand modifier +//===----------------------------------------------------------------------===// + +def DXSA_OperandModifier_Neg : I32EnumAttrCase<"neg", 1>; +def DXSA_OperandModifier_Abs : I32EnumAttrCase<"abs", 2>; +def DXSA_OperandModifier_AbsNeg : I32EnumAttrCase<"abs_neg", 3>; + +def DXSA_OperandModifier : I32EnumAttr< + "OperandModifier", "source operand modifier", [ + DXSA_OperandModifier_Neg, + DXSA_OperandModifier_Abs, + DXSA_OperandModifier_AbsNeg + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_OperandModifierAttr : + EnumAttr { + let summary = "source operand modifier (abs, neg, or both)"; + let description = [{ + The `#dxsa.operand_modifier` attribute is the source-operand + modifier carried by an arithmetic instruction. Each case has a + fixed spelling that wraps the source operand: + + - `|OPERAND|` (abs) yields the absolute value of the source. + For float operations the sign is forced positive, including on + INF; NaN inputs remain NaN, but the resulting NaN bit pattern + is unspecified. + - `-OPERAND` (neg) flips the sign of the source. For float + operations the sign is flipped including on INF, and NaN is + preserved. For integer operations the modifier yields the two's + complement of the source. + - `-|OPERAND|` (abs_neg) is the documented combination of the + two: absolute is applied first, then negated, so the result is + always negative. + + Example: + + ```mlir + dxsa.add r<0>, -r<1>, r<2> + dxsa.add r<0>, |r<1>|, r<2> + dxsa.add r<0>, -|r<1>|, r<2> + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// DXSA operand minimum precision +//===----------------------------------------------------------------------===// + +def DXSA_OperandMinPrecision_Float16 : I32EnumAttrCase<"min16f", 1>; +def DXSA_OperandMinPrecision_Float2_8 : I32EnumAttrCase<"min2_8f", 2>; +def DXSA_OperandMinPrecision_SInt16 : I32EnumAttrCase<"min16i", 4>; +def DXSA_OperandMinPrecision_UInt16 : I32EnumAttrCase<"min16u", 5>; + +def DXSA_OperandMinPrecision : I32EnumAttr< + "OperandMinPrecision", "minimum precision hint on an operand", [ + DXSA_OperandMinPrecision_Float16, + DXSA_OperandMinPrecision_Float2_8, + DXSA_OperandMinPrecision_SInt16, + DXSA_OperandMinPrecision_UInt16 + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_OperandMinPrecisionAttr : + EnumAttr { + let summary = "minimum precision hint on an operand"; + let description = [{ + The `#dxsa.operand_min_precision` attribute is the minimum + precision hint carried by an operand. It declares the lowest + precision at which the operation may execute; an implementation + is free to run at any equal or higher precision. + + Each case also encodes the element type (float, signed integer, + unsigned integer), so that type-neutral instructions like `mov` + are unambiguous when a size change is involved. The four cases + are: + + - `min16f` is at least 16-bit-per-component float. + - `min2_8f` is at least 10-bit-per-component float, laid out as a + 2.8 mantissa/exponent split. + - `min16i` is at least 16-bit-per-component signed integer. + - `min16u` is at least 16-bit-per-component unsigned integer. + + The default precision for the shader model is encoded by + the attribute being absent on the parent destination or source operand. + + Example: + + ```mlir + dxsa.dcl_input v<0, min16f, > + dxsa.add r<0>, v<0, min16f, >, r<2> + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// DXSA operand component count +//===----------------------------------------------------------------------===// + +def DXSA_OperandComponents_None : I32EnumAttrCase<"none", 0>; +def DXSA_OperandComponents_Scalar : I32EnumAttrCase<"scalar", 1>; +def DXSA_OperandComponents_Vector : I32EnumAttrCase<"vector", 2>; +def DXSA_OperandComponents_Reserved : I32EnumAttrCase<"reserved", 3>; + +def DXSA_OperandComponents : I32EnumAttr< + "OperandComponents", "operand component count", [ + DXSA_OperandComponents_None, + DXSA_OperandComponents_Scalar, + DXSA_OperandComponents_Vector, + DXSA_OperandComponents_Reserved + ]> { + let cppNamespace = "::mlir::dxsa"; + let genSpecializedAttr = 0; +} + +def DXSA_OperandComponentsAttr : + EnumAttr { + let summary = "component count declared on an operand"; + let description = [{ + The `#dxsa.operand_components` attribute records the + component-count declaration of an operand, i.e. how many + components the operand exposes. The four cases are: + + - `none` is a zero-component operand, used for handles and + built-in operands that carry no component data, such as + `null`, `oDepth`, or `label`. + - `scalar` is a one-component operand, used for system-value + scalars such as `vPrim` or `vCoverage`. + - `vector` is a four-component operand, the common case for + general-purpose registers (`r`, `v`, `o`, etc.). + - `reserved` is an "N-component" placeholder, documented as + unused but retained for round-trip fidelity in case a binary + uses it. + + Every operand type has a canonical component count. The attribute + is omitted on the printed form of destination or source operand + when its value matches the canonical one for that operand type. + It is shown only when the operand overrides the canonical count. + + Example: + + ```mlir + dxsa.dcl_input v<0, scalar> + dxsa.dcl_input vPrim + ``` + }]; + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// DXSA operand index slot (one entry of an operand's index list) +//===----------------------------------------------------------------------===// + +def DXSA_IndexAttr : AttrDef { + let mnemonic = "index"; + let summary = "operand index slot (immediate, relative, or both)"; + let description = [{ + The `#dxsa.index` attribute is one slot of an operand's index + list, i.e. one dimension of indexing into a register file or + resource binding array. + + There are several ways to express the value of an index: + + - A 32-bit immediate constant. + - A 64-bit immediate constant. + - A relative operand: another register (typically `r#` or a + statically-indexed `x#`) holding the value at run time. + - A 32-bit immediate plus a relative operand, added at run time. + - A 64-bit immediate plus a relative operand, added at run time. + + These five representations are stored here as a pair of + optional fields: an integer immediate (`i32` or `i64`) and a + relative nested source operand. At least one of the two is present. + + Example: + + ```mlir + 5 + 5 : i64 + r<1, > + 2 + r<1, > + 2 : i64 + r<1, > + ``` + }]; + let parameters = (ins + OptionalParameter<"IntegerAttr">:$imm, + OptionalParameter<"SrcOperandAttr">:$relative); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def DXSA_OperandIndexAttr : + ArrayOfAttr { + let summary = "ordered list of operand index slots, one per dimension"; + let description = [{ + The `#dxsa.operand_index` attribute is the ordered list of + index slots making up the full index expression of an + operand. There is one slot per index dimension. Arrays of + 1, 2, or 3 slots correspond to operands with one, two, or three + levels of indexing. + + Example: + + ```mlir + [0] + [0, r<1, >] + [r<1, >, 2, 3] + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// DXSA destination operand attribute +//===----------------------------------------------------------------------===// + +def DXSA_DstOperandAttr : AttrDef { + let mnemonic = "dst_operand"; + let summary = "decoded destination operand"; + let description = [{ + The `#dxsa.dst_operand` attribute describes an operand that is + written by an instruction. + + The textual form is the operand-type keyword, optionally + followed by an angle-bracketed body. The parameters print in + this fixed order, with absent or canonical fields omitted: + + - `$type` is the operand-type keyword (`r`, `o`, `oDepth`, + `null`, ...). + - `$index` is the list of index slots, one per dimension. + A single-entry list omits the brackets. + - `$components` is the component-count declaration. Omitted + when it equals the canonical count for the operand type. + - `$minPrecision` is the minimum precision hint on the operand. + - `$mask` is the destination writemask. 1 to 4 unique + components are accepted, duplicates are rejected. + + When every body field is absent or canonical the entire body is + dropped from the printed form, leaving just the operand-type + keyword (e.g. `oDepth`, `null`). + + Example: + + ```mlir + dxsa.dcl_output oDepth + dxsa.dcl_output o<0> + dxsa.dcl_input v<0, min16f> + dxsa.add null, r<1>, r<2> + dxsa.add r<0, >, r<1>, r<2> + ``` + }]; + let parameters = (ins + EnumParameter:$type, + OptionalParameter<"OperandIndexAttr", + "index slot list, one entry per dimension">:$index, + AttrParameter<"OperandComponentsAttr", + "component-count declaration">:$components, + OptionalParameter<"OperandMinPrecisionAttr", + "minimum precision hint">:$minPrecision, + OptionalParameter<"ComponentMaskAttr", + "destination writemask">:$mask); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +//===----------------------------------------------------------------------===// +// DXSA source operand attribute +//===----------------------------------------------------------------------===// + +def DXSA_SrcOperandAttr : AttrDef { + let mnemonic = "src_operand"; + let summary = "decoded source operand"; + let description = [{ + The `#dxsa.src_operand` attribute describes an operand that is + read by an instruction. + + The textual form is the operand-type keyword, optionally + preceded by a unary modifier and followed by an angle-bracketed + body or, for the `l` and `d` operand types, a parenthesised + immediate literal list. The parameters print in this fixed + order, with absent or canonical fields omitted: + + - `$type` is the operand-type keyword (`r`, `v`, `cb`, `l`, + `d`, ...). + - `$index` is the list of index slots, one per dimension. + A single-entry list omits the brackets. + - `$components` is the component-count declaration. Omitted + when it equals the canonical count for the operand type. + - `$minPrecision` is the minimum precision hint on the operand. + - `$nonUniform` marks a resource handle whose index is not + uniform across the lockstep execution of the draw, printed + as `nonuniform`. + - `$swizzle` is the source component selection. A 1-component + list means a single-component select; a 4-component list is an + ordered swizzle (repetition allowed). + - `$modifier`, when set, decorates the printed operand text: + `neg` adds a leading minus (`-OPERAND`), `abs` wraps in pipes + (`|OPERAND|`), `abs_neg` combines both (`-|OPERAND|`). + - `$values` carries a 32-bit immediate literal payload for the + `l` operand type, printed as `l(...)`. Each value is stored + as raw bits and printed in the shortest form that round-trips + (float, signed integer, or hex). + - `$values64` is the 64-bit variant used by the `d` operand + type, printed as `d(...)`. + + When every body field is absent or canonical the entire body is + dropped from the printed form, leaving just the operand-type + keyword (e.g. `vPrim`, `null`). + + Example: + + ```mlir + dxsa.add r<0>, vPrim, l(1.0, 2.0, 3.0, 4.0) + dxsa.add r<0>, -r<1, >, r<2, > + dxsa.add r<0>, -|r<1, min16f, nonuniform>|, d(1.0, 2.0) + ``` + }]; + let parameters = (ins + EnumParameter:$type, + OptionalParameter<"OperandIndexAttr", + "index slot list, one entry per dimension">:$index, + AttrParameter<"OperandComponentsAttr", + "component-count declaration">:$components, + OptionalParameter<"OperandMinPrecisionAttr", + "minimum precision hint">:$minPrecision, + OptionalParameter<"UnitAttr", + "nonuniform-index marker for resource handles">:$nonUniform, + OptionalParameter<"SwizzleAttr", + "source component selection: 1-component select or 4-component swizzle" + >:$swizzle, + OptionalParameter<"OperandModifierAttr", + "neg, abs or abs_neg modifier for the operand" + >:$modifier, + OptionalParameter<"DenseI32ArrayAttr", + "32-bit immediate literal payload (for the `l` operand type)" + >:$values, + OptionalParameter<"DenseI64ArrayAttr", + "64-bit immediate literal payload (for the `d` operand type)" + >:$values64); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +#endif // MLIR_DIALECT_DXSA_IR_DXSAOPERAND diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index e497b53c7abd..9212aa21a023 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -9,20 +9,13 @@ #ifndef DXSA_OPS #define DXSA_OPS -include "mlir/Dialect/DXSA/IR/DXSADialect.td" +include "mlir/Dialect/DXSA/IR/DXSAOpBase.td" include "mlir/Dialect/DXSA/IR/DXSATypes.td" +include "mlir/Dialect/DXSA/IR/DXSAFPArithOps.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" -//===----------------------------------------------------------------------===// -// DXSA op base class -//===----------------------------------------------------------------------===// - -// Base class for all operations in this dialect. -class DXSA_Op traits = []> : - Op; - //===----------------------------------------------------------------------===// // DXSA module - top-level container op for a DXBC tokenized program //===----------------------------------------------------------------------===// @@ -419,32 +412,6 @@ def DXSA_ResourceReturnTypeAttr : let assemblyFormat = "$value"; } -//===----------------------------------------------------------------------===// -// DXSA ComponentMask bit-enum (mask field of operand, normalized to bits 0..3) -//===----------------------------------------------------------------------===// - -def DXSA_ComponentMask_X : I32BitEnumAttrCaseBit<"x", 0>; -def DXSA_ComponentMask_Y : I32BitEnumAttrCaseBit<"y", 1>; -def DXSA_ComponentMask_Z : I32BitEnumAttrCaseBit<"z", 2>; -def DXSA_ComponentMask_W : I32BitEnumAttrCaseBit<"w", 3>; - -def DXSA_ComponentMask : I32BitEnumAttr< - "ComponentMask", "component mask (subset of x, y, z, w)", [ - DXSA_ComponentMask_X, - DXSA_ComponentMask_Y, - DXSA_ComponentMask_Z, - DXSA_ComponentMask_W - ]> { - let separator = ", "; - let cppNamespace = "::mlir::dxsa"; - let genSpecializedAttr = 0; -} - -def DXSA_ComponentMaskAttr : - EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def DXSA_ConstantBufferAccessPattern_ImmediateIndexed : I32EnumAttrCase<"immediateIndexed", 0>; def DXSA_ConstantBufferAccessPattern_DynamicIndexed : I32EnumAttrCase<"dynamicIndexed", 1>; @@ -513,7 +480,7 @@ def DXSA_Operand : DXSA_Op<"operand"> { OptionalAttr:$modifier, OptionalAttr:$min_precision, OptionalAttr:$non_uniform); - let results = (outs DXSA_OperandType:$operand); + let results = (outs DXSA_LegacyOperandType:$operand); let assemblyFormat = "$operands attr-dict"; } @@ -524,7 +491,7 @@ def DXSA_OperandImm : DXSA_Op<"operand.imm"> { }]; let arguments = (ins AnyAttrOf<[I32ElementsAttr, I64ElementsAttr]>:$imm); - let results = (outs DXSA_OperandType:$operand); + let results = (outs DXSA_LegacyOperandType:$operand); let assemblyFormat = "attr-dict"; } @@ -545,7 +512,7 @@ def DXSA_IndexRel : DXSA_Op<"index.rel"> { TODO }]; - let arguments = (ins DXSA_OperandType:$operand); + let arguments = (ins DXSA_LegacyOperandType:$operand); let results = (outs DXSA_IndexType:$index); let assemblyFormat = "$operand attr-dict"; } @@ -556,7 +523,7 @@ def DXSA_IndexRelImm : DXSA_Op<"index.rel.imm"> { TODO }]; - let arguments = (ins DXSA_OperandType:$operand, StrAttr:$op, I64Attr:$imm); + let arguments = (ins DXSA_LegacyOperandType:$operand, StrAttr:$op, I64Attr:$imm); let results = (outs DXSA_IndexType:$index); let assemblyFormat = "$operand attr-dict"; } @@ -567,7 +534,7 @@ def DXSA_Instruction : DXSA_Op<"instruction"> { TODO }]; - let arguments = (ins Variadic:$operands, StrAttr:$mnemonic); + let arguments = (ins Variadic:$operands, StrAttr:$mnemonic); let results = (outs); let assemblyFormat = "$mnemonic $operands attr-dict"; } @@ -868,7 +835,7 @@ def DXSA_DclInputPs : DXSA_Op<"dcl_input_ps"> { ``` }]; let arguments = (ins DXSA_InterpolationModeAttr:$mode, - DXSA_OperandType:$operand); + DXSA_LegacyOperandType:$operand); let assemblyFormat = "$mode $operand attr-dict"; } @@ -886,7 +853,7 @@ def DXSA_DclInputPsSiv : DXSA_Op<"dcl_input_ps_siv"> { ``` }]; let arguments = (ins DXSA_InterpolationModeAttr:$mode, - DXSA_OperandType:$operand, + DXSA_LegacyOperandType:$operand, DXSA_SystemValueNameAttr:$name); let assemblyFormat = "$mode $operand `,` $name attr-dict"; } @@ -903,7 +870,7 @@ def DXSA_DclInputPsSgv : DXSA_Op<"dcl_input_ps_sgv"> { dxsa.dcl_input_ps_sgv %v0, ``` }]; - let arguments = (ins DXSA_OperandType:$operand, + let arguments = (ins DXSA_LegacyOperandType:$operand, DXSA_SystemValueNameAttr:$name); let assemblyFormat = "$operand `,` $name attr-dict"; } diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSATypes.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSATypes.td index ce312dbfd71b..444455570e56 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSATypes.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSATypes.td @@ -28,8 +28,12 @@ def DXSA_IndexType : DXSA_Type<"Index", "index"> { }]; } -def DXSA_OperandType : DXSA_Type<"Operand", "operand"> { - let summary = "dxsa operand type"; +// The `Legacy` prefix frees the bare `OperandType` C++ name for the new +// `DXSA_OperandType` enum (the operand-kind classifier shared by +// `#dxsa.dst_operand` / `#dxsa.src_operand`). The textual mnemonic +// `!dxsa.operand` is unchanged so existing tests keep working. +def DXSA_LegacyOperandType : DXSA_Type<"LegacyOperand", "operand"> { + let summary = "dxsa legacy operand type"; let description = [{ TODO }]; diff --git a/mlir/lib/Dialect/DXSA/IR/CMakeLists.txt b/mlir/lib/Dialect/DXSA/IR/CMakeLists.txt index 5de37e1cc01d..2ef0b7884f4f 100644 --- a/mlir/lib/Dialect/DXSA/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/DXSA/IR/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRDXSADialect DXSA.cpp + DXSAImmediate.cpp + DXSAOperand.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/DXSA diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index 0482ba734552..834c2eda3d61 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -30,10 +30,7 @@ void DXSADialect::initialize() { #define GET_TYPEDEF_LIST #include "mlir/Dialect/DXSA/IR/DXSAOpsTypes.cpp.inc" >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "mlir/Dialect/DXSA/IR/DXSAOpsAttributes.cpp.inc" - >(); + registerAttributes(); } /// Declarations for custom-directive helpers used by the @@ -275,13 +272,6 @@ static void printHexTokens(OpAsmPrinter &printer, Operation *, printer << "]>"; } -//===----------------------------------------------------------------------===// -// TableGen'd attribute method definitions -//===----------------------------------------------------------------------===// - -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/DXSA/IR/DXSAOpsAttributes.cpp.inc" - //===----------------------------------------------------------------------===// // TableGen'd type method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.cpp b/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.cpp new file mode 100644 index 000000000000..87212d38e54a --- /dev/null +++ b/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.cpp @@ -0,0 +1,225 @@ +//===--------- DXSAImmediate.cpp - DXSA immediate-literal codec ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DXSAImmediate.h" + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/bit.h" +#include "llvm/Support/Format.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::dxsa; + +//===----------------------------------------------------------------------===// +// Immediate parsing for `l(...)` / `d(...)` payloads +//===----------------------------------------------------------------------===// + +/// Parses one literal element of an `l(...)` / `d(...)` payload. Accepts a +/// float literal, a decimal integer, or a hex integer, and reinterprets the +/// value as the raw bit pattern at the requested width. +static ParseResult parseImmValue(AsmParser &parser, bool is64, uint64_t &bits) { + auto loc = parser.getCurrentLocation(); + Attribute attr; + if (parser.parseAttribute(attr)) + return failure(); + if (auto floatAttr = llvm::dyn_cast(attr)) { + if (is64) { + bits = llvm::bit_cast(floatAttr.getValueAsDouble()); + } else { + auto fValue = static_cast(floatAttr.getValueAsDouble()); + bits = llvm::bit_cast(fValue); + } + return success(); + } + if (auto intAttr = llvm::dyn_cast(attr)) { + auto width = is64 ? 64u : 32u; + auto value = intAttr.getValue(); + if (value.getBitWidth() < width) { + value = value.sext(width); + } else if (value.getBitWidth() > width) { + auto neededBits = value.isNegative() ? value.getSignificantBits() + : value.getActiveBits(); + if (neededBits > width) + return parser.emitError(loc) + << neededBits << "-bit immediate does not fit in " << width + << "-bit literal"; + value = value.trunc(width); + } + bits = value.getZExtValue(); + return success(); + } + return parser.emitError(loc, "expected float or integer literal"); +} + +/// Parses the parenthesised comma-separated immediate list, e.g. +/// `(1.0, 2.0, 3.0, 4.0)` or `(0x3F800000)`, appending each value's raw bits to +/// `bits`. The element width follows `Int`: `int32_t` for `l`, `int64_t` for +/// `d`. +template +static ParseResult parseImmValueList(AsmParser &parser, + SmallVectorImpl &bits) { + return parser.parseCommaSeparatedList( + AsmParser::Delimiter::Paren, [&]() -> ParseResult { + uint64_t v = 0; + if (parseImmValue(parser, /*is64=*/sizeof(Int) == 8, v)) + return failure(); + bits.push_back(static_cast(v)); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// Immediate printing for `l(...)` / `d(...)` payloads +//===----------------------------------------------------------------------===// + +namespace { +struct DecodedImmValue { + uint64_t bits = 0; + unsigned byteWidth = 0; + std::optional intValue; + std::optional floatValue; + std::optional doubleValue; +}; +} // namespace + +template +static SmallString<48> floatSpellingForcingDecimalPoint(Float value) { + SmallString<48> out; + llvm::raw_svector_ostream os(out); + os << llvm::format("%g", value); + StringRef text(out.data(), out.size()); + if (text.contains('.')) + return out; + size_t exponentPos = text.find_first_of("eE"); + if (exponentPos == StringRef::npos) { + out.push_back('.'); + out.push_back('0'); + } else { + out.insert(out.begin() + exponentPos, '0'); + out.insert(out.begin() + exponentPos, '.'); + } + return out; +} + +template +static bool floatSpellingReparsesToIdenticalBits(Float value, UInt bits) { + auto text = floatSpellingForcingDecimalPoint(value); + double reparsed; + return !StringRef(text).getAsDouble(reparsed) && + llvm::bit_cast(static_cast(reparsed)) == bits; +} + +template +static DecodedImmValue decodeImmValue(Int value) { + using UInt = std::make_unsigned_t; + constexpr bool is32 = std::is_same_v; + auto bits = static_cast(value); + + DecodedImmValue imm; + imm.bits = bits; + imm.byteWidth = sizeof(UInt); + + // Positive zero is the only value spelled as a bare `0`; every other pattern, + // negative zero included, is offered to the float reading first. + if (bits == 0) { + imm.intValue = 0; + return imm; + } + + auto f = llvm::bit_cast(bits); + constexpr Float minMagnitude = is32 ? 1e-30f : 1e-300; + if (std::isfinite(f) && (f == Float(0) || std::abs(f) >= minMagnitude) && + floatSpellingReparsesToIdenticalBits(f, bits)) { + if constexpr (is32) + imm.floatValue = f; + else + imm.doubleValue = f; + } + + constexpr Int intThreshold = is32 ? 0x10000 : 0x100000; + if (value >= -intThreshold && value <= intThreshold) + imm.intValue = value; + + return imm; +} + +static void printImmValue(llvm::raw_ostream &os, const DecodedImmValue &imm) { + if (imm.floatValue || imm.doubleValue) { + os << (imm.floatValue ? floatSpellingForcingDecimalPoint(*imm.floatValue) + : floatSpellingForcingDecimalPoint(*imm.doubleValue)); + return; + } + if (imm.intValue) { + os << *imm.intValue; + return; + } + os << llvm::format_hex(imm.bits, imm.byteWidth * 2 + 2, true); +} + +//===----------------------------------------------------------------------===// +// Printing and parsing for immediate-literal operand. +//===----------------------------------------------------------------------===// + +/// Parses the body of a 32-bit immediate-literal operand. +ParseResult mlir::dxsa::parseImm32Body(AsmParser &parser, + DenseI32ArrayAttr &values) { + SmallVector bits; + if (parseImmValueList(parser, bits)) + return failure(); + values = DenseI32ArrayAttr::get(parser.getContext(), bits); + return success(); +} + +/// Parses the body of a 64-bit immediate-literal operand. +ParseResult mlir::dxsa::parseImm64Body(AsmParser &parser, + DenseI64ArrayAttr &values64) { + SmallVector bits; + if (parseImmValueList(parser, bits)) + return failure(); + values64 = DenseI64ArrayAttr::get(parser.getContext(), bits); + return success(); +} + +/// Returns the default `OperandComponents` for an immediate-literal +/// payload of the given length. 1 maps to scalar, 4 to vector; other +/// lengths fall through to `reserved` so the round-trip remains lossless. +OperandComponents mlir::dxsa::immComponentsFor(size_t count) { + switch (count) { + case 0: + return OperandComponents::none; + case 1: + return OperandComponents::scalar; + case 4: + return OperandComponents::vector; + default: + return OperandComponents::reserved; + } +} + +void mlir::dxsa::printImm32(AsmPrinter &printer, ArrayRef values) { + printer << "l("; + llvm::interleaveComma(values, printer.getStream(), [&](int32_t v) { + printImmValue(printer.getStream(), decodeImmValue(v)); + }); + printer << ")"; +} + +void mlir::dxsa::printImm64(AsmPrinter &printer, ArrayRef values64) { + printer << "d("; + llvm::interleaveComma(values64, printer.getStream(), [&](int64_t v) { + printImmValue(printer.getStream(), decodeImmValue(v)); + }); + printer << ")"; +} diff --git a/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.h b/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.h new file mode 100644 index 000000000000..35278bb2c602 --- /dev/null +++ b/mlir/lib/Dialect/DXSA/IR/DXSAImmediate.h @@ -0,0 +1,30 @@ +//===--------- DXSAImmediate.h - DXSA immediate-literal codec ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_DIALECT_DXSA_IR_DXSAIMMEDIATE_H +#define MLIR_LIB_DIALECT_DXSA_IR_DXSAIMMEDIATE_H + +#include "mlir/Dialect/DXSA/IR/DXSA.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace dxsa { + +OperandComponents immComponentsFor(size_t count); + +void printImm32(AsmPrinter &printer, ArrayRef values); +ParseResult parseImm32Body(AsmParser &parser, DenseI32ArrayAttr &values); + +void printImm64(AsmPrinter &printer, ArrayRef values64); +ParseResult parseImm64Body(AsmParser &parser, DenseI64ArrayAttr &values64); + +} // namespace dxsa +} // namespace mlir + +#endif // MLIR_LIB_DIALECT_DXSA_IR_DXSAIMMEDIATE_H diff --git a/mlir/lib/Dialect/DXSA/IR/DXSAOperand.cpp b/mlir/lib/Dialect/DXSA/IR/DXSAOperand.cpp new file mode 100644 index 000000000000..1e41df7cc39a --- /dev/null +++ b/mlir/lib/Dialect/DXSA/IR/DXSAOperand.cpp @@ -0,0 +1,831 @@ +//===--------- DXSAOperand.cpp - DXSA operand attributes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/DXSA/IR/DXSA.h" + +#include "DXSAImmediate.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#include +#include +#include + +using namespace mlir; +using namespace mlir::dxsa; + +//===----------------------------------------------------------------------===// +// TableGen'd attribute method definitions +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/DXSA/IR/DXSAOpsAttributes.cpp.inc" + +void DXSADialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/DXSA/IR/DXSAOpsAttributes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Helpers for operand attributes custom asm +//===----------------------------------------------------------------------===// + +static constexpr StringLiteral nonUniformKeyword = "nonuniform"; + +/// Returns the default components for the given operand type. +/// Printing omits the explicit keyword when the value matches this default. +static OperandComponents defaultComponentsFor(OperandType type) { + switch (type) { + case OperandType::r: + case OperandType::v: + case OperandType::o: + case OperandType::x: + case OperandType::vicp: + case OperandType::vocp: + case OperandType::vpc: + case OperandType::vDomain: + case OperandType::vThreadID: + case OperandType::vThreadGroupID: + case OperandType::vThreadIDInGroup: + case OperandType::cycleCounter: + return OperandComponents::vector; + case OperandType::l: + case OperandType::d: + case OperandType::oDepth: + case OperandType::oDepthGE: + case OperandType::oDepthLE: + case OperandType::oMask: + case OperandType::oStencilRef: + case OperandType::vCoverage: + case OperandType::vGSInstanceID: + case OperandType::vForkInstanceID: + case OperandType::vJoinInstanceID: + case OperandType::vOutputControlPointID: + case OperandType::vInnerCoverage: + case OperandType::vPrim: + case OperandType::vThreadIDInGroupFlattened: + return OperandComponents::scalar; + case OperandType::s: + case OperandType::t: + case OperandType::cb: + case OperandType::icb: + case OperandType::u: + case OperandType::g: + case OperandType::m: + case OperandType::label: + case OperandType::null: + case OperandType::rasterizer: + case OperandType::fb: + case OperandType::ft: + case OperandType::fp: + case OperandType::funcInput: + case OperandType::funcOutput: + case OperandType::thisPtr: + return OperandComponents::none; + } + llvm_unreachable("unknown operand type"); +} + +static bool isImmediateType(OperandType type) { + return type == OperandType::l || type == OperandType::d; +} + +static char charFromComponent(int64_t component) { + return component == 0 ? 'x' + : component == 1 ? 'y' + : component == 2 ? 'z' + : 'w'; +} + +static std::optional componentFromKeyword(StringRef keyword) { + if (keyword == "x") + return 0; + if (keyword == "y") + return 1; + if (keyword == "z") + return 2; + if (keyword == "w") + return 3; + return std::nullopt; +} + +/// Parses a single component keyword `x`, `y`, `z` or `w`. +static ParseResult parseComponent(AsmParser &parser, + SmallVectorImpl &components) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + auto component = componentFromKeyword(keyword); + if (!component) + return parser.emitError(parser.getCurrentLocation()) + << "unknown component: `" << keyword << "`"; + components.push_back(*component); + return success(); +} + +static OptionalParseResult +parseOptionalComponents(AsmParser &parser, + SmallVectorImpl &components) { + if (failed(parser.parseOptionalLess())) + return std::nullopt; + return failure(parser.parseCommaSeparatedList( + AsmParser::Delimiter::None, + [&] { return parseComponent(parser, components); }) || + parser.parseGreater()); +} + +static ParseResult parseComponents(AsmParser &parser, + SmallVectorImpl &components) { + OptionalParseResult result = parseOptionalComponents(parser, components); + if (!result.has_value()) + return parser.emitError(parser.getCurrentLocation()) << "expected '<'"; + return *result; +} + +//===----------------------------------------------------------------------===// +// SwizzleAttr +//===----------------------------------------------------------------------===// + +LogicalResult SwizzleAttr::verify(function_ref emitError, + ArrayRef components) { + if (components.size() != 1 && components.size() != 4) + return emitError() << "swizzle must have 1 or 4 components, got " + << components.size(); + for (int64_t component : components) { + if (component < 0 || component > 3) + return emitError() << "component must be in [0, 3], got " << component; + } + return success(); +} + +Attribute SwizzleAttr::parse(AsmParser &parser, Type) { + auto loc = parser.getCurrentLocation(); + SmallVector components; + if (parseComponents(parser, components)) + return {}; + return parser.getChecked(loc, parser.getContext(), components); +} + +void SwizzleAttr::print(AsmPrinter &printer) const { + printer << '<'; + llvm::interleaveComma(getComponents(), printer.getStream(), + [&](int64_t c) { printer << charFromComponent(c); }); + printer << '>'; +} + +//===----------------------------------------------------------------------===// +// IndexAttr +//===----------------------------------------------------------------------===// + +LogicalResult IndexAttr::verify(function_ref emitError, + IntegerAttr imm, SrcOperandAttr relative) { + if (!imm && !relative) + return emitError() << "index must be either immediate or relative"; + if (imm) { + auto type = imm.getType(); + if (!type.isInteger(32) && !type.isInteger(64)) + return emitError() << "unsupported index type: " << type; + } + return success(); +} + +/// Parses an index entry that begins with the already-parsed integer `value`, +/// e.g. `2`, `2 : i64`, `2 : i64 + r<...>`. +static IndexAttr parseIntIndexEntry(AsmParser &parser, MLIRContext *ctx, + SMLoc loc, uint64_t value) { + auto immType = IntegerType::get(ctx, 32); + if (succeeded(parser.parseOptionalColon())) + if (parser.parseType(immType)) + return {}; + auto immAttr = IntegerAttr::get(immType, value); + + SrcOperandAttr relative; + if (succeeded(parser.parseOptionalPlus())) + if (parser.parseCustomAttributeWithFallback(relative)) + return {}; + return parser.getChecked(loc, ctx, immAttr, relative); +} + +Attribute IndexAttr::parse(AsmParser &parser, Type) { + auto loc = parser.getCurrentLocation(); + auto *ctx = parser.getContext(); + + uint64_t value; + auto intResult = parser.parseOptionalInteger(value); + if (intResult.has_value()) { + if (failed(*intResult)) + return {}; + return parseIntIndexEntry(parser, ctx, loc, value); + } + + SrcOperandAttr relative; + if (parser.parseCustomAttributeWithFallback(relative)) + return {}; + return parser.getChecked(loc, ctx, IntegerAttr(), relative); +} + +void IndexAttr::print(AsmPrinter &printer) const { + if (auto imm = getImm()) { + printer << imm.getValue(); + if (imm.getType().isInteger(64)) + printer << " : i64"; + } + if (auto relative = getRelative()) { + if (getImm()) + printer << " + "; + printer.printStrippedAttrOrType(relative); + } +} + +//===----------------------------------------------------------------------===// +// Shared helpers for DstOperandAttr / SrcOperandAttr asm +//===----------------------------------------------------------------------===// + +namespace { + +struct DstOperandBody { + OperandIndexAttr index; + OperandComponentsAttr components; + ComponentMaskAttr mask; + OperandMinPrecisionAttr minPrecision; +}; + +struct SrcOperandBody { + OperandIndexAttr index; + OperandComponentsAttr components; + SwizzleAttr swizzle; + OperandMinPrecisionAttr minPrecision; + UnitAttr nonUniform; +}; + +} // namespace + +static OptionalParseResult tryParseIndexList(AsmParser &parser, SMLoc fieldLoc, + OperandIndexAttr &result) { + if (failed(parser.parseOptionalLSquare())) + return std::nullopt; + if (result) + return parser.emitError(fieldLoc) << "duplicate index list"; + + SmallVector entries; + if (failed(parser.parseOptionalRSquare())) { + if (parser.parseCommaSeparatedList( + AsmParser::Delimiter::None, + [&]() -> ParseResult { + IndexAttr entry; + if (parser.parseCustomAttributeWithFallback(entry)) + return failure(); + entries.push_back(entry); + return success(); + }) || + parser.parseRSquare()) + return failure(); + } + result = OperandIndexAttr::get(parser.getContext(), entries); + return success(); +} + +static SwizzleAttr identitySwizzle(MLIRContext *ctx) { + return SwizzleAttr::get(ctx, ArrayRef{0, 1, 2, 3}); +} + +static void applySrcOperandDefaults(MLIRContext *ctx, OperandType type, + SrcOperandBody &body) { + if (!body.components) + body.components = + OperandComponentsAttr::get(ctx, defaultComponentsFor(type)); + if (!body.swizzle && body.components.getValue() == OperandComponents::vector) + body.swizzle = identitySwizzle(ctx); +} + +static OptionalParseResult parseSrcOperandBody(AsmParser &parser, + SrcOperandBody &body); + +static FailureOr parseRelativeSrcOperand(AsmParser &parser, + OperandType type) { + auto *ctx = parser.getContext(); + SrcOperandBody body; + OptionalParseResult bodyResult = parseSrcOperandBody(parser, body); + if (bodyResult.has_value() && failed(*bodyResult)) + return failure(); + applySrcOperandDefaults(ctx, type, body); + return SrcOperandAttr::get(ctx, type, body.index, body.components, + body.minPrecision, body.nonUniform, body.swizzle, + /*modifier=*/OperandModifierAttr(), + /*values=*/DenseI32ArrayAttr(), + /*values64=*/DenseI64ArrayAttr()); +} + +static ParseResult setSingleIndexEntry(AsmParser &parser, MLIRContext *ctx, + SMLoc fieldLoc, OperandIndexAttr &index, + IndexAttr entry) { + if (index) + return parser.emitError(fieldLoc) << "duplicate index list"; + index = OperandIndexAttr::get(ctx, ArrayRef{entry}); + return success(); +} + +/// Parses a single bare index entry (`5`, `5 : i64`, `5 + r<...>`), absent when +/// the next token is not an integer. +static OptionalParseResult tryParseImmIndex(AsmParser &parser, MLIRContext *ctx, + SMLoc fieldLoc, + OperandIndexAttr &index) { + auto loc = parser.getCurrentLocation(); + uint64_t value = 0; + OptionalParseResult intResult = parser.parseOptionalInteger(value); + if (!intResult.has_value()) + return std::nullopt; + if (failed(*intResult)) + return failure(); + + IndexAttr entry = parseIntIndexEntry(parser, ctx, loc, value); + if (!entry) + return failure(); + return setSingleIndexEntry(parser, ctx, fieldLoc, index, entry); +} + +static ParseResult parseOperandBody(AsmParser &parser, + function_ref parseField) { + if (succeeded(parser.parseOptionalGreater())) + return success(); + return failure( + parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseField) || + parser.parseGreater()); +} + +/// Handles a keyword that names an operand type used as a relative index +/// (`r<...>`), absent when `keyword` is not an operand type. +static OptionalParseResult +keywordToRelativeIndex(AsmParser &parser, MLIRContext *ctx, SMLoc fieldLoc, + StringRef keyword, OperandIndexAttr &index) { + auto type = symbolizeOperandType(keyword); + if (!type) + return std::nullopt; + if (isImmediateType(*type)) + return parser.emitError(fieldLoc) + << "immediate operand `" << keyword << "` cannot be an index"; + auto relative = parseRelativeSrcOperand(parser, *type); + if (failed(relative)) + return failure(); + return setSingleIndexEntry(parser, ctx, fieldLoc, index, + IndexAttr::get(ctx, IntegerAttr(), *relative)); +} + +static OptionalParseResult +keywordToComponents(AsmParser &parser, MLIRContext *ctx, SMLoc fieldLoc, + StringRef keyword, OperandComponentsAttr &components) { + auto parsed = symbolizeOperandComponents(keyword); + if (!parsed) + return std::nullopt; + if (components) + return parser.emitError(fieldLoc) << "duplicate component count"; + components = OperandComponentsAttr::get(ctx, *parsed); + return success(); +} + +static OptionalParseResult +keywordToMinPrecision(AsmParser &parser, MLIRContext *ctx, SMLoc fieldLoc, + StringRef keyword, + OperandMinPrecisionAttr &minPrecision) { + auto parsed = symbolizeOperandMinPrecision(keyword); + if (!parsed) + return std::nullopt; + if (minPrecision) + return parser.emitError(fieldLoc) << "duplicate min precision"; + minPrecision = OperandMinPrecisionAttr::get(ctx, *parsed); + return success(); +} + +/// Parses the keyword-led operand-body fields shared by source and destination +/// operands. `bodyKind` ("source" / "destination") only shapes the diagnostic. +static ParseResult +parseKeywordOperandField(AsmParser &parser, MLIRContext *ctx, SMLoc fieldLoc, + StringRef bodyKind, OperandIndexAttr &index, + OperandComponentsAttr &components, + OperandMinPrecisionAttr &minPrecision) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + if (auto r = keywordToRelativeIndex(parser, ctx, fieldLoc, keyword, index); + r.has_value()) + return *r; + if (auto r = keywordToComponents(parser, ctx, fieldLoc, keyword, components); + r.has_value()) + return *r; + if (auto r = + keywordToMinPrecision(parser, ctx, fieldLoc, keyword, minPrecision); + r.has_value()) + return *r; + return parser.emitError(fieldLoc) << "unexpected keyword in " << bodyKind + << " operand body: `" << keyword << "`"; +} + +//===----------------------------------------------------------------------===// +// DstOperandAttr +//===----------------------------------------------------------------------===// + +/// Builds a destination writemask from a parsed component list. +static FailureOr +buildComponentMask(function_ref emitError, + MLIRContext *ctx, ArrayRef components) { + if (components.empty() || components.size() > 4) + return emitError() << "destination mask must have 1 to 4 components, got " + << components.size(); + auto bits = static_cast(0); + for (int64_t c : components) { + auto bit = static_cast(1u << c); + if ((bits & bit) == bit) + return emitError() << "duplicate component `" << charFromComponent(c) + << "` in destination mask"; + bits = bits | bit; + } + return ComponentMaskAttr::get(ctx, bits); +} + +/// Parses a destination writemask ``, absent when the next token is +/// not `<`. +static OptionalParseResult tryParseMask(AsmParser &parser, MLIRContext *ctx, + SMLoc fieldLoc, + ComponentMaskAttr &mask) { + SmallVector components; + OptionalParseResult componentList = + parseOptionalComponents(parser, components); + if (!componentList.has_value()) + return std::nullopt; + if (failed(*componentList)) + return failure(); + if (mask) + return parser.emitError(fieldLoc) << "duplicate mask"; + FailureOr built = buildComponentMask( + [&] { return parser.emitError(fieldLoc); }, ctx, components); + if (failed(built)) + return failure(); + mask = *built; + return success(); +} + +static OptionalParseResult parseDstOperandBody(AsmParser &parser, + DstOperandBody &body) { + if (failed(parser.parseOptionalLess())) + return std::nullopt; + auto *ctx = parser.getContext(); + return parseOperandBody(parser, [&]() -> ParseResult { + SMLoc fieldLoc = parser.getCurrentLocation(); + if (auto r = tryParseIndexList(parser, fieldLoc, body.index); r.has_value()) + return *r; + if (auto r = tryParseImmIndex(parser, ctx, fieldLoc, body.index); + r.has_value()) + return *r; + if (auto r = tryParseMask(parser, ctx, fieldLoc, body.mask); r.has_value()) + return *r; + return parseKeywordOperandField(parser, ctx, fieldLoc, "destination", + body.index, body.components, + body.minPrecision); + }); +} + +static OperandComponentsAttr +nonDefaultComponents(OperandComponentsAttr components, OperandType type) { + if (components && components.getValue() == defaultComponentsFor(type)) + return {}; + return components; +} + +static ComponentMaskAttr nonDefaultMask(ComponentMaskAttr mask) { + if (mask && bitEnumContainsAll(mask.getValue(), + ComponentMask::x | ComponentMask::y | + ComponentMask::z | ComponentMask::w)) + return {}; + return mask; +} + +static SwizzleAttr nonDefaultSwizzle(SwizzleAttr swizzle) { + if (swizzle && swizzle.getComponents() == ArrayRef{0, 1, 2, 3}) + return {}; + return swizzle; +} + +static void printOperandIndex(AsmPrinter &printer, OperandIndexAttr index) { + bool useBrackets = index.size() != 1; + if (useBrackets) + printer << '['; + llvm::interleaveComma(index, printer.getStream(), [&](IndexAttr entry) { + printer.printStrippedAttrOrType(entry); + }); + if (useBrackets) + printer << ']'; +} + +static void printOperandComponents(AsmPrinter &printer, + OperandComponentsAttr components) { + printer << stringifyOperandComponents(components.getValue()); +} + +static void printOperandMinPrecision(AsmPrinter &printer, + OperandMinPrecisionAttr minPrecision) { + printer << stringifyOperandMinPrecision(minPrecision.getValue()); +} + +static void printOperandMask(AsmPrinter &printer, ComponentMaskAttr mask) { + printer.printStrippedAttrOrType(mask); +} + +static void printOperandBody(AsmPrinter &printer, + ArrayRef> fields) { + if (fields.empty()) + return; + printer << '<'; + llvm::interleaveComma(fields, printer.getStream(), + [](const std::function &field) { field(); }); + printer << '>'; +} + +LogicalResult DstOperandAttr::verify( + function_ref emitError, OperandType /*type*/, + OperandIndexAttr /*index*/, OperandComponentsAttr components, + OperandMinPrecisionAttr /*minPrecision*/, ComponentMaskAttr /*mask*/) { + if (!components) + return emitError() << "component count is required"; + return success(); +} + +Attribute DstOperandAttr::parse(AsmParser &parser, Type) { + auto loc = parser.getCurrentLocation(); + auto *ctx = parser.getContext(); + + StringRef typeKeyword; + if (parser.parseKeyword(&typeKeyword)) + return {}; + auto operandType = symbolizeOperandType(typeKeyword); + if (!operandType) { + parser.emitError(loc) << "unknown operand type: `" << typeKeyword << "`"; + return {}; + } + if (isImmediateType(*operandType)) { + parser.emitError(loc) << "immediate operand `" << typeKeyword + << "` cannot be a destination"; + return {}; + } + + DstOperandBody body; + OptionalParseResult bodyResult = parseDstOperandBody(parser, body); + if (bodyResult.has_value() && failed(*bodyResult)) + return {}; + + OperandComponentsAttr components = body.components; + if (!components) + components = + OperandComponentsAttr::get(ctx, defaultComponentsFor(*operandType)); + ComponentMaskAttr mask = body.mask; + if (!mask && components.getValue() == OperandComponents::vector) + mask = ComponentMaskAttr::get(ctx, ComponentMask::x | ComponentMask::y | + ComponentMask::z | ComponentMask::w); + + return parser.getChecked(loc, ctx, *operandType, body.index, + components, body.minPrecision, mask); +} + +void DstOperandAttr::print(AsmPrinter &printer) const { + auto index = getIndex(); + auto minPrecision = getMinPrecision(); + auto components = nonDefaultComponents(getComponents(), getType()); + auto mask = nonDefaultMask(getMask()); + + printer << stringifyOperandType(getType()); + + SmallVector, 4> fields; + if (index) + fields.push_back([&] { printOperandIndex(printer, index); }); + if (components) + fields.push_back([&] { printOperandComponents(printer, components); }); + if (minPrecision) + fields.push_back([&] { printOperandMinPrecision(printer, minPrecision); }); + if (mask) + fields.push_back([&] { printOperandMask(printer, mask); }); + + printOperandBody(printer, fields); +} + +//===----------------------------------------------------------------------===// +// SrcOperandAttr custom asm +//===----------------------------------------------------------------------===// + +static Attribute parseImmSrcOperand(AsmParser &parser, MLIRContext *ctx, + SMLoc loc, OperandType type, + OperandModifierAttr modifier) { + DenseI32ArrayAttr values; + DenseI64ArrayAttr values64; + if (type == OperandType::d ? parseImm64Body(parser, values64) + : parseImm32Body(parser, values)) + return {}; + auto count = values ? values.size() : values64.size(); + auto components = OperandComponentsAttr::get(ctx, immComponentsFor(count)); + return parser.getChecked( + loc, ctx, type, /*index=*/OperandIndexAttr(), components, + /*minPrecision=*/OperandMinPrecisionAttr(), /*nonUniform=*/UnitAttr(), + /*swizzle=*/SwizzleAttr(), modifier, values, values64); +} + +static Attribute +parseNegAndAbsModifier(AsmParser &parser, MLIRContext *ctx, + function_ref parseBody) { + bool hasMinus = succeeded(parser.parseOptionalMinus()); + bool hasAbs = succeeded(parser.parseOptionalVerticalBar()); + + OperandModifierAttr modifier; + if (hasMinus && hasAbs) + modifier = OperandModifierAttr::get(ctx, OperandModifier::abs_neg); + else if (hasMinus) + modifier = OperandModifierAttr::get(ctx, OperandModifier::neg); + else if (hasAbs) + modifier = OperandModifierAttr::get(ctx, OperandModifier::abs); + + Attribute result = parseBody(modifier); + if (!result) + return {}; + if (hasAbs && parser.parseVerticalBar()) + return {}; + return result; +} + +static void printNegAndAbsModifier(AsmPrinter &printer, + OperandModifierAttr modifier, + function_ref printBody) { + bool isNeg = modifier && (modifier.getValue() == OperandModifier::neg || + modifier.getValue() == OperandModifier::abs_neg); + bool isAbs = modifier && (modifier.getValue() == OperandModifier::abs || + modifier.getValue() == OperandModifier::abs_neg); + if (isNeg) + printer << '-'; + if (isAbs) + printer << '|'; + + printBody(); + + if (isAbs) + printer << '|'; +} + +/// Parses the source component selection `<...>`, absent when the next token +/// is not `<`. The list size (1 or 4) distinguishes a single-component select +/// from a full swizzle. +static OptionalParseResult tryParseSwizzle(AsmParser &parser, MLIRContext *ctx, + SMLoc fieldLoc, + SwizzleAttr &swizzle) { + SmallVector components; + OptionalParseResult componentList = + parseOptionalComponents(parser, components); + if (!componentList.has_value()) + return std::nullopt; + if (failed(*componentList)) + return failure(); + auto parsed = parser.getChecked(fieldLoc, ctx, components); + if (!parsed) + return failure(); + if (swizzle) + return parser.emitError(fieldLoc) << "duplicate swizzle"; + swizzle = parsed; + return success(); +} + +/// Parses the bare `nonuniform` keyword, absent when it is not the next token. +static OptionalParseResult tryParseNonUniform(AsmParser &parser, + MLIRContext *ctx, SMLoc fieldLoc, + UnitAttr &nonUniform) { + if (failed(parser.parseOptionalKeyword(nonUniformKeyword))) + return std::nullopt; + if (nonUniform) + return parser.emitError(fieldLoc) << "duplicate " << nonUniformKeyword; + nonUniform = UnitAttr::get(ctx); + return success(); +} + +static OptionalParseResult parseSrcOperandBody(AsmParser &parser, + SrcOperandBody &body) { + if (failed(parser.parseOptionalLess())) + return std::nullopt; + auto *ctx = parser.getContext(); + return parseOperandBody(parser, [&]() -> ParseResult { + SMLoc fieldLoc = parser.getCurrentLocation(); + if (auto r = tryParseIndexList(parser, fieldLoc, body.index); r.has_value()) + return *r; + if (auto r = tryParseImmIndex(parser, ctx, fieldLoc, body.index); + r.has_value()) + return *r; + if (auto r = tryParseSwizzle(parser, ctx, fieldLoc, body.swizzle); + r.has_value()) + return *r; + if (auto r = tryParseNonUniform(parser, ctx, fieldLoc, body.nonUniform); + r.has_value()) + return *r; + return parseKeywordOperandField(parser, ctx, fieldLoc, "source", body.index, + body.components, body.minPrecision); + }); +} + +static void printSrcOperandBody(AsmPrinter &printer, SrcOperandAttr attr) { + auto index = attr.getIndex(); + auto minPrecision = attr.getMinPrecision(); + auto nonUniform = attr.getNonUniform(); + + auto components = nonDefaultComponents(attr.getComponents(), attr.getType()); + auto swizzle = nonDefaultSwizzle(attr.getSwizzle()); + + printer << stringifyOperandType(attr.getType()); + + SmallVector, 6> fields; + if (index) + fields.push_back([&] { printOperandIndex(printer, index); }); + if (components) + fields.push_back([&] { printOperandComponents(printer, components); }); + if (minPrecision) + fields.push_back([&] { printOperandMinPrecision(printer, minPrecision); }); + if (nonUniform) + fields.push_back([&] { printer << nonUniformKeyword; }); + if (swizzle) + fields.push_back([&] { printer.printStrippedAttrOrType(swizzle); }); + + printOperandBody(printer, fields); +} + +LogicalResult SrcOperandAttr::verify( + function_ref emitError, OperandType type, + OperandIndexAttr /*index*/, OperandComponentsAttr components, + OperandMinPrecisionAttr /*minPrecision*/, UnitAttr /*nonUniform*/, + SwizzleAttr /*swizzle*/, OperandModifierAttr modifier, + DenseI32ArrayAttr values, DenseI64ArrayAttr values64) { + bool hasValues = static_cast(values); + bool hasValues64 = static_cast(values64); + + if (!components) + return emitError() << "component count is required"; + if (hasValues && hasValues64) + return emitError() << "values and values64 are mutually exclusive"; + if (hasValues && type != OperandType::l) + return emitError() << "values is only valid on type `l`, got `" + << stringifyOperandType(type) << "`"; + if (hasValues64 && type != OperandType::d) + return emitError() << "values64 is only valid on type `d`, got `" + << stringifyOperandType(type) << "`"; + if (type == OperandType::l && !hasValues) + return emitError() << "type `l` requires a values literal"; + if (type == OperandType::d && !hasValues64) + return emitError() << "type `d` requires a values64 literal"; + if (isImmediateType(type) && modifier) + return emitError() << "immediate operands cannot have a source modifier"; + return success(); +} + +Attribute SrcOperandAttr::parse(AsmParser &parser, Type) { + auto loc = parser.getCurrentLocation(); + auto *ctx = parser.getContext(); + + return parseNegAndAbsModifier( + parser, ctx, [&](OperandModifierAttr modifier) -> Attribute { + auto typeLoc = parser.getCurrentLocation(); + StringRef typeKeyword; + if (parser.parseKeyword(&typeKeyword)) + return {}; + auto type = symbolizeOperandType(typeKeyword); + if (!type) { + parser.emitError(typeLoc) + << "unknown operand type: `" << typeKeyword << "`"; + return {}; + } + + if (isImmediateType(*type)) + return parseImmSrcOperand(parser, ctx, loc, *type, modifier); + + SrcOperandBody body; + OptionalParseResult bodyResult = parseSrcOperandBody(parser, body); + if (bodyResult.has_value() && failed(*bodyResult)) + return {}; + + applySrcOperandDefaults(ctx, *type, body); + return parser.getChecked( + loc, ctx, *type, body.index, body.components, body.minPrecision, + body.nonUniform, body.swizzle, modifier, + /*values=*/DenseI32ArrayAttr(), + /*values64=*/DenseI64ArrayAttr()); + }); +} + +void SrcOperandAttr::print(AsmPrinter &printer) const { + if (auto values = getValues()) + printImm32(printer, values.asArrayRef()); + else if (auto values64 = getValues64()) + printImm64(printer, values64.asArrayRef()); + else + printNegAndAbsModifier(printer, getModifier(), + [&] { printSrcOperandBody(printer, *this); }); +} diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index 427af2cb30b9..03889f698c9e 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -467,14 +467,14 @@ class DXBuilder { Operand buildOperandImm32(ArrayRef values, FileLineColLoc loc) { Operation *op = dxsa::OperandImm::create( - builder, loc, builder.getType(), + builder, loc, builder.getType(), builder.getI32VectorAttr(values)); return op->getResults()[0]; } Operand buildOperandImm64(ArrayRef values, FileLineColLoc loc) { Operation *op = dxsa::OperandImm::create( - builder, loc, builder.getType(), + builder, loc, builder.getType(), builder.getI64VectorAttr(values)); return op->getResults()[0]; } @@ -519,7 +519,8 @@ class DXBuilder { break; } Operation *op = dxsa::Operand::create( - builder, loc, builder.getType(), indices, attrs); + builder, loc, builder.getType(), indices, + attrs); return op->getResults()[0]; } @@ -664,6 +665,134 @@ class DXBuilder { indexAttr); } + dxsa::DstOperandAttr + buildDstOperandAttr(dxsa::OperandType operandType, + const OperandComponents &components, + ArrayRef indexEntries, + std::optional opModifier) { + auto componentsValue = components.num == 0 ? dxsa::OperandComponents::none + : components.num == 1 + ? dxsa::OperandComponents::scalar + : dxsa::OperandComponents::vector; + auto componentsAttr = + dxsa::OperandComponentsAttr::get(context, componentsValue); + + dxsa::ComponentMaskAttr maskAttr; + if (components.kind == OperandComponentsKind::Mask) + maskAttr = dxsa::ComponentMaskAttr::get( + context, decodeComponentMask(components.mask)); + + dxsa::OperandIndexAttr indexAttr; + if (!indexEntries.empty()) + indexAttr = dxsa::OperandIndexAttr::get(context, indexEntries); + + dxsa::OperandMinPrecisionAttr minPrecisionAttr; + if (opModifier && opModifier->minPrecision != 0) { + if (auto p = dxsa::symbolizeOperandMinPrecision(opModifier->minPrecision)) + minPrecisionAttr = dxsa::OperandMinPrecisionAttr::get(context, *p); + } + + return dxsa::DstOperandAttr::get(context, operandType, indexAttr, + componentsAttr, minPrecisionAttr, + maskAttr); + } + + dxsa::SrcOperandAttr + buildSrcOperandAttr(dxsa::OperandType operandType, + const OperandComponents &components, + ArrayRef indexEntries, + std::optional opModifier, + ArrayRef values, ArrayRef values64) { + auto componentsValue = components.num == 0 ? dxsa::OperandComponents::none + : components.num == 1 + ? dxsa::OperandComponents::scalar + : dxsa::OperandComponents::vector; + auto componentsAttr = + dxsa::OperandComponentsAttr::get(context, componentsValue); + + dxsa::SwizzleAttr swizzleAttr; + if (components.kind == OperandComponentsKind::Swizzle) { + SmallVector swizzleComponents; + for (unsigned int i : components.swizzle) + swizzleComponents.push_back(static_cast(i)); + swizzleAttr = dxsa::SwizzleAttr::get(context, swizzleComponents); + } else if (components.kind == OperandComponentsKind::One) { + swizzleAttr = dxsa::SwizzleAttr::get( + context, ArrayRef{static_cast(components.one)}); + } + + dxsa::OperandIndexAttr indexAttr; + if (!indexEntries.empty()) + indexAttr = dxsa::OperandIndexAttr::get(context, indexEntries); + + dxsa::OperandModifierAttr modifierAttr; + dxsa::OperandMinPrecisionAttr minPrecisionAttr; + UnitAttr nonUniformAttr; + if (opModifier) { + if (opModifier->modifier != 0) { + if (auto m = dxsa::symbolizeOperandModifier(opModifier->modifier)) + modifierAttr = dxsa::OperandModifierAttr::get(context, *m); + } + if (opModifier->minPrecision != 0) { + if (auto p = + dxsa::symbolizeOperandMinPrecision(opModifier->minPrecision)) + minPrecisionAttr = dxsa::OperandMinPrecisionAttr::get(context, *p); + } + if (opModifier->nonUniform != 0) + nonUniformAttr = UnitAttr::get(context); + } + + auto valuesAttr = values.empty() ? DenseI32ArrayAttr() + : DenseI32ArrayAttr::get(context, values); + auto values64Attr = values64.empty() + ? DenseI64ArrayAttr() + : DenseI64ArrayAttr::get(context, values64); + + return dxsa::SrcOperandAttr::get( + context, operandType, indexAttr, componentsAttr, minPrecisionAttr, + nonUniformAttr, swizzleAttr, modifierAttr, valuesAttr, values64Attr); + } + + dxsa::IndexAttr buildOperandIndexImm32(int32_t imm) { + return dxsa::IndexAttr::get(context, builder.getI32IntegerAttr(imm), + dxsa::SrcOperandAttr()); + } + + dxsa::IndexAttr buildOperandIndexImm64(int64_t imm) { + return dxsa::IndexAttr::get(context, builder.getI64IntegerAttr(imm), + dxsa::SrcOperandAttr()); + } + + dxsa::IndexAttr buildOperandIndexRelative(dxsa::SrcOperandAttr relative) { + return dxsa::IndexAttr::get(context, IntegerAttr(), relative); + } + + dxsa::IndexAttr + buildOperandIndexImm32PlusRelative(int32_t imm, + dxsa::SrcOperandAttr relative) { + return dxsa::IndexAttr::get(context, builder.getI32IntegerAttr(imm), + relative); + } + + dxsa::IndexAttr + buildOperandIndexImm64PlusRelative(int64_t imm, + dxsa::SrcOperandAttr relative) { + return dxsa::IndexAttr::get(context, builder.getI64IntegerAttr(imm), + relative); + } + + template + Instruction buildBinaryOp(dxsa::DstOperandAttr dst, dxsa::SrcOperandAttr lhs, + dxsa::SrcOperandAttr rhs, uint32_t preciseMask, + Location loc) { + auto preciseAttr = + preciseMask + ? dxsa::ComponentMaskAttr::get( + context, static_cast(preciseMask)) + : dxsa::ComponentMaskAttr(); + return OpT::create(builder, loc, dst, lhs, rhs, preciseAttr); + } + Instruction buildDclInput(dxsa::InlineOperandAttr operand, Location loc) { return dxsa::DclInput::create(builder, loc, operand); } @@ -1371,6 +1500,153 @@ class Parser { indices); } + struct OperandFields { + dxsa::OperandType type; + OperandComponents components; + SmallVector indexEntries; + std::optional modifier; + SmallVector values; + SmallVector values64; + }; + + FailureOr parseOperandIndex(uint32_t indexType) { + switch (indexType) { + case D3D10_SB_OPERAND_INDEX_IMMEDIATE32: { + auto value = parseToken(); + FAILURE_IF_FAILED(value); + return builder.buildOperandIndexImm32(static_cast(*value)); + } + case D3D10_SB_OPERAND_INDEX_IMMEDIATE64: { + auto high = parseToken(); + FAILURE_IF_FAILED(high); + auto low = parseToken(); + FAILURE_IF_FAILED(low); + return builder.buildOperandIndexImm64((((int64_t)*high) << 32) | *low); + } + case D3D10_SB_OPERAND_INDEX_RELATIVE: { + auto relative = parseSrcOperand(); + FAILURE_IF_FAILED(relative); + return builder.buildOperandIndexRelative(*relative); + } + case D3D10_SB_OPERAND_INDEX_IMMEDIATE32_PLUS_RELATIVE: { + auto value = parseToken(); + FAILURE_IF_FAILED(value); + auto relative = parseSrcOperand(); + FAILURE_IF_FAILED(relative); + return builder.buildOperandIndexImm32PlusRelative( + static_cast(*value), *relative); + } + case D3D10_SB_OPERAND_INDEX_IMMEDIATE64_PLUS_RELATIVE: { + auto high = parseToken(); + FAILURE_IF_FAILED(high); + auto low = parseToken(); + FAILURE_IF_FAILED(low); + auto relative = parseSrcOperand(); + FAILURE_IF_FAILED(relative); + return builder.buildOperandIndexImm64PlusRelative( + (((int64_t)*high) << 32) | *low, *relative); + } + default: + return emitError(getLocation(), "invalid operand index representation: ") + << indexType; + } + } + + FailureOr parseOperandFields() { + auto token = parseToken(); + FAILURE_IF_FAILED(token); + + auto loc = getLocation(); + auto rawOperandType = DECODE_D3D10_SB_OPERAND_TYPE(*token); + auto isExtended = DECODE_IS_D3D10_SB_OPERAND_EXTENDED(*token); + + auto type = dxsa::symbolizeOperandType(rawOperandType); + if (!type) + return emitError(loc, "unknown operand type: ") << rawOperandType; + + auto components = parseOperandComponents(*token); + FAILURE_IF_FAILED(components); + + auto indexTypes = parseOperandIndexTypes(*token); + FAILURE_IF_FAILED(indexTypes); + + OperandFields decoded; + decoded.type = *type; + decoded.components = *components; + + if (isExtended) { + auto extToken = parseToken(); + FAILURE_IF_FAILED(extToken); + auto opMod = parseOperandExtendedModifier(*extToken); + FAILURE_IF_FAILED(opMod); + decoded.modifier = *opMod; + } + + if (isImmOperand(*token)) { + if (rawOperandType == D3D10_SB_OPERAND_TYPE_IMMEDIATE64) { + for (unsigned i = 0; i < 2; ++i) { + auto high = parseToken(); + FAILURE_IF_FAILED(high); + auto low = parseToken(); + FAILURE_IF_FAILED(low); + decoded.values64.push_back((((int64_t)*high) << 32) | *low); + } + return decoded; + } + for (uint32_t i = 0; i < components->num; ++i) { + auto value = parseToken(); + FAILURE_IF_FAILED(value); + decoded.values.push_back(static_cast(*value)); + } + return decoded; + } + + for (uint32_t indexType : *indexTypes) { + auto entry = parseOperandIndex(indexType); + FAILURE_IF_FAILED(entry); + decoded.indexEntries.push_back(*entry); + } + return decoded; + } + + FailureOr parseDstOperand() { + auto loc = getLocation(); + auto fields = parseOperandFields(); + FAILURE_IF_FAILED(fields); + if (!fields->values.empty() || !fields->values64.empty()) + return emitError(loc, "immediate operand `") + << dxsa::stringifyOperandType(fields->type) + << "` cannot be a destination"; + return builder.buildDstOperandAttr(fields->type, fields->components, + fields->indexEntries, fields->modifier); + } + + FailureOr parseSrcOperand() { + auto fields = parseOperandFields(); + FAILURE_IF_FAILED(fields); + return builder.buildSrcOperandAttr(fields->type, fields->components, + fields->indexEntries, fields->modifier, + fields->values, fields->values64); + } + + template + FailureOr + decodeSaturableBinaryOp(size_t beginOffset, uint32_t length, bool saturate, + uint32_t preciseMask, Location loc) { + auto dst = parseDstOperand(); + FAILURE_IF_FAILED(dst); + auto lhs = parseSrcOperand(); + FAILURE_IF_FAILED(lhs); + auto rhs = parseSrcOperand(); + FAILURE_IF_FAILED(rhs); + if (failed(verifyInstructionLength(beginOffset, length))) + return failure(); + if (saturate) + return builder.buildBinaryOp(*dst, *lhs, *rhs, preciseMask, + loc); + return builder.buildBinaryOp(*dst, *lhs, *rhs, preciseMask, loc); + } + FailureOr parseDclInput(Location loc) { auto operand = parseInlineOperand(); FAILURE_IF_FAILED(operand); @@ -1828,6 +2104,13 @@ class Parser { unsigned numOperands = instrInfo[opcode].numOperands; + switch (opcode) { + case D3D10_SB_OPCODE_ADD: + return decodeSaturableBinaryOp( + beginOffset, instructionLengthInTokens, modifier.saturate, + modifier.preciseMask, getLocation()); + } + SmallVector operands; for (unsigned i = 0; i < numOperands; ++i) { FailureOr operand = parseOperand(); diff --git a/mlir/test/Target/DXSA/add.mlir b/mlir/test/Target/DXSA/add.mlir new file mode 100644 index 000000000000..63636132ab18 --- /dev/null +++ b/mlir/test/Target/DXSA/add.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/add.bin | FileCheck %s + +// CHECK: dxsa.add r<0>, r<1>, r<2> +// CHECK-NEXT: dxsa.add_sat r<0>, r<1>, r<2> +// CHECK-NEXT: dxsa.add r<0>, r<1>, l(1.0, 2.0, 3.0, 4.0) +// CHECK-NEXT: dxsa.add r<0>, l(1.0, 2.0, 3.0, 4.0), l(5.0, 6.0, 7.0, 8.0) +// CHECK-NEXT: dxsa.add r<0>, l(1.0, 2.0, 3.0, 4.0), r<2> +// CHECK-NEXT: dxsa.add r<0>, -r<1>, |r<2>| +// CHECK-NEXT: dxsa.add r<0>, r<1, >, r<2> +// CHECK-NEXT: dxsa.add r<0>, r<1>, d(1.0, 2.0) +// CHECK-NEXT: dxsa.add r<0>, v>>, r<2> +// CHECK-NEXT: dxsa.add r<0>, cb<[0, 2 + r<1, >], vector>, r<2> +// CHECK-NEXT: dxsa.add o<0>, r<1>, r<2> +// CHECK-NEXT: dxsa.add null, r<1>, r<2> +// CHECK-NEXT: dxsa.add r<0>, vPrim, r<2> +// CHECK-NEXT: dxsa.add r<0, >, r<1>, r<2> +// CHECK-NEXT: dxsa.add r<0>, r<1, >, r<2> diff --git a/mlir/test/Target/DXSA/inputs/add.bin b/mlir/test/Target/DXSA/inputs/add.bin new file mode 100644 index 000000000000..731c25912c25 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/add.bin differ diff --git a/mlir/test/Target/DXSA/operands/components.mlir b/mlir/test/Target/DXSA/operands/components.mlir new file mode 100644 index 000000000000..e4705a467741 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/components.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Canonical: `r` is implicitly `vector`, the keyword is omitted. +// CHECK: dxsa.add r<0>, r<1>, r<2> +dxsa.add r<0>, r<1>, r<2> + +// ----- + +// Override: explicitly downsize an `r` to `scalar`. +// CHECK: dxsa.add r<0>, r<1, scalar>, r<2> +dxsa.add r<0>, r<1, scalar>, r<2> + +// ----- + +// Override: explicitly request `none` on an `r` (unusual but legal). +// CHECK: dxsa.add r<0>, r<1, none>, r<2> +dxsa.add r<0>, r<1, none>, r<2> + +// ----- + +// Override: explicit `reserved` (DXBC `N_COMPONENT`) preserved for round-trip. +// CHECK: dxsa.add r<0>, r<1, reserved>, r<2> +dxsa.add r<0>, r<1, reserved>, r<2> + +// ----- + +// Canonical: `g` is implicitly `none` (a handle), the keyword is omitted. +// CHECK: dxsa.add r<0>, g<0>, r<2> +dxsa.add r<0>, g<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, g<0, vector>, r<2> +dxsa.add r<0>, g<0, vector>, r<2> + +// ----- + +// Canonical: `vPrim` is implicitly `scalar`. With no other body field set +// the operand prints as the bare type keyword without `<...>`. +// CHECK: dxsa.add r<0>, vPrim, r<2> +dxsa.add r<0>, vPrim, r<2> + +// ----- + +// Override: explicitly request `vector` on `vPrim`. +// CHECK: dxsa.add r<0>, vPrim, r<2> +dxsa.add r<0>, vPrim, r<2> diff --git a/mlir/test/Target/DXSA/operands/index.mlir b/mlir/test/Target/DXSA/operands/index.mlir new file mode 100644 index 000000000000..22f3db333402 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/index.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0>, v<0>, r<2> +dxsa.add r<0>, v<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<42>, r<2> +dxsa.add r<0>, v<42>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<42 : i64>, r<2> +dxsa.add r<0>, v<42 : i64>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v>>, r<2> +dxsa.add r<0>, v>>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, cb<[0, 2 + r<1, >]>, r<2> +dxsa.add r<0>, cb<[0, 2 + r<1, >]>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, cb<2 : i64 + r<3, >>, r<2> +dxsa.add r<0>, cb<2 : i64 + r<3, >>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, cb<[1 : i64, 2 + r<3, >, r<5, >]>, r<2> +dxsa.add r<0>, cb<[1 : i64, 2 + r<3, >, r<5, >]>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, vPrim, r<2> +dxsa.add r<0>, vPrim, r<2> diff --git a/mlir/test/Target/DXSA/operands/inline_operand.mlir b/mlir/test/Target/DXSA/operands/inline_operand.mlir new file mode 100644 index 000000000000..d192052ea261 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/inline_operand.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0>, r<1>, r<2> +dxsa.add r<0>, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0>, r<2> +dxsa.add r<0>, v<0>, r<2> + +// ----- + +// CHECK: dxsa.add o<0>, r<1>, r<2> +dxsa.add o<0>, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, x<[1, 4]>, r<2> +dxsa.add r<0>, x<[1, 4]>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, cb<[0, 5]>, r<2> +dxsa.add r<0>, cb<[0, 5]>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, icb<5>, r<2> +dxsa.add r<0>, icb<5>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, s<0>, r<2> +dxsa.add r<0>, s<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, t<0>, r<2> +dxsa.add r<0>, t<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, u<0>, r<2> +dxsa.add r<0>, u<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, g<0>, r<2> +dxsa.add r<0>, g<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, m<0>, r<2> +dxsa.add r<0>, m<0>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, l(1.0, 2.0, 3.0, 4.0), r<2> +dxsa.add r<0>, l(1.0, 2.0, 3.0, 4.0), r<2> + +// ----- + +// CHECK: dxsa.add r<0>, d(1.0, 2.0), r<2> +dxsa.add r<0>, d(1.0, 2.0), r<2> + +// ----- + +// CHECK: dxsa.add null, r<1>, r<2> +dxsa.add null, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, vPrim, r<2> +dxsa.add r<0>, vPrim, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, vCoverage, r<2> +dxsa.add r<0>, vCoverage, r<2> + +// ----- + +// CHECK: dxsa.add oDepth, r<1>, r<2> +dxsa.add oDepth, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, vDomain, r<2> +dxsa.add r<0>, vDomain, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, -|r<1, min16f, nonuniform, >|, r<2> +dxsa.add r<0>, -|r<1, min16f, nonuniform, >|, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, r<1, min16f, nonuniform, >, r<2> +dxsa.add r<0>, r<1, nonuniform, , min16f>, r<2> diff --git a/mlir/test/Target/DXSA/operands/inline_operand_invalid.mlir b/mlir/test/Target/DXSA/operands/inline_operand_invalid.mlir new file mode 100644 index 000000000000..dfc008cfe584 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/inline_operand_invalid.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{unknown operand type: `qq`}} +dxsa.add r<0>, qq<0>, r<2> + +// ----- + +// expected-error@+1 {{unknown component: `q`}} +dxsa.add r<0>, r<1, >, r<2> + +// ----- + +// expected-error@+1 {{duplicate component `x` in destination mask}} +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// expected-error@+1 {{duplicate mask}} +dxsa.add r<0, , >, r<1>, r<2> + +// ----- + +// expected-error@+1 {{duplicate swizzle}} +dxsa.add r<0>, r<1, , >, r<2> + +// ----- + +// expected-error@+1 {{duplicate component count}} +dxsa.add r<0>, r<1, scalar, vector>, r<2> + +// ----- + +// expected-error@+1 {{duplicate min precision}} +dxsa.add r<0>, r<1, min16f, min2_8f>, r<2> + +// ----- + +// expected-error@+1 {{duplicate nonuniform}} +dxsa.add r<0>, r<1, nonuniform, nonuniform>, r<2> + +// ----- + +// expected-error@+1 {{duplicate index list}} +dxsa.add r<0>, r<1, 2>, r<2> + +// ----- + +// expected-error@+1 {{swizzle must have 1 or 4 components, got 2}} +dxsa.add r<0>, r<1, >, r<2> + +// ----- + +// expected-error@+1 {{duplicate component `x` in destination mask}} +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// expected-error@+1 {{unexpected keyword in destination operand body: `nonuniform`}} +dxsa.add r<0, nonuniform>, r<1>, r<2> + +// ----- + +// expected-error@+1 {{immediate operands cannot have a source modifier}} +dxsa.add r<0>, -l(1.0), r<2> + +// ----- + +// expected-error@+1 {{immediate operands cannot have a source modifier}} +dxsa.add r<0>, |d(1.0)|, r<2> + +// ----- + +// expected-error@+1 {{37-bit immediate does not fit in 32-bit literal}} +dxsa.add r<0>, l(0x1FFFFFFFFF), r<2> + +// ----- + +// expected-error@+1 {{69-bit immediate does not fit in 64-bit literal}} +dxsa.add r<0>, d(0x1FFFFFFFFFFFFFFFFF : i128), r<2> diff --git a/mlir/test/Target/DXSA/operands/mask.mlir b/mlir/test/Target/DXSA/operands/mask.mlir new file mode 100644 index 000000000000..9957f307c743 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/mask.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> + +// ----- + +// Mask components are normalized to canonical order in the print, regardless +// of the input order. +// CHECK: dxsa.add r<0, >, r<1>, r<2> +dxsa.add r<0, >, r<1>, r<2> diff --git a/mlir/test/Target/DXSA/operands/min_precision.mlir b/mlir/test/Target/DXSA/operands/min_precision.mlir new file mode 100644 index 000000000000..c99c16cce08d --- /dev/null +++ b/mlir/test/Target/DXSA/operands/min_precision.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0>, v<0, min16f, >, r<2> +dxsa.add r<0>, v<0, min16f, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, min2_8f, >, r<2> +dxsa.add r<0>, v<0, min2_8f, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, min16i, >, r<2> +dxsa.add r<0>, v<0, min16i, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, min16u, >, r<2> +dxsa.add r<0>, v<0, min16u, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, min16f>, r<2> +dxsa.add r<0>, v<0, min16f>, r<2> diff --git a/mlir/test/Target/DXSA/operands/modifier.mlir b/mlir/test/Target/DXSA/operands/modifier.mlir new file mode 100644 index 000000000000..116d8f9f8cd9 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/modifier.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Negation. +// CHECK: dxsa.add r<0>, -r<1>, r<2> +dxsa.add r<0>, -r<1>, r<2> + +// ----- + +// Absolute value. +// CHECK: dxsa.add r<0>, |r<1>|, r<2> +dxsa.add r<0>, |r<1>|, r<2> + +// ----- + +// Negated absolute value. +// CHECK: dxsa.add r<0>, -|r<1>|, r<2> +dxsa.add r<0>, -|r<1>|, r<2> diff --git a/mlir/test/Target/DXSA/operands/nonuniform.mlir b/mlir/test/Target/DXSA/operands/nonuniform.mlir new file mode 100644 index 000000000000..3e900a889078 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/nonuniform.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0>, r<1, nonuniform>, r<2> +dxsa.add r<0>, r<1, nonuniform>, r<2> + +// ----- + +// Combined with min precision: ordering in the print is fixed. +// CHECK: dxsa.add r<0>, r<1, min16f, nonuniform>, r<2> +dxsa.add r<0>, r<1, nonuniform, min16f>, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, t<5, nonuniform>, r<2> +dxsa.add r<0>, t<5, nonuniform>, r<2> diff --git a/mlir/test/Target/DXSA/operands/select.mlir b/mlir/test/Target/DXSA/operands/select.mlir new file mode 100644 index 000000000000..afa7ab79aa11 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/select.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> diff --git a/mlir/test/Target/DXSA/operands/swizzle.mlir b/mlir/test/Target/DXSA/operands/swizzle.mlir new file mode 100644 index 000000000000..80a26f51f015 --- /dev/null +++ b/mlir/test/Target/DXSA/operands/swizzle.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Identity swizzle on a source collapses to the canonical print. +// CHECK: dxsa.add r<0>, v<0>, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// Order is preserved in the print. +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2> + +// ----- + +// CHECK: dxsa.add r<0>, v<0, >, r<2> +dxsa.add r<0>, v<0, >, r<2>