diff --git a/docs/DXIL.rst b/docs/DXIL.rst index 2afd65f55e..8007a1ef48 100644 --- a/docs/DXIL.rst +++ b/docs/DXIL.rst @@ -3076,7 +3076,7 @@ ID Name Description 2147483656 RayQuery_CandidateTriangleObjectPosition returns candidate triangle vertices in object space as <9 x float> 2147483657 RayQuery_CommittedTriangleObjectPosition returns committed triangle vertices in object space as <9 x float> 2147483658 HitObject_TriangleObjectPosition returns triangle vertices in object space as <9 x float> -2147483659 ReservedD0 reserved +2147483659 LinAlgMatrixMultiplyAccumulate Returns the resulting matrix from multiplying A and B and accumulating into C 2147483660 LinAlgFillMatrix fills a matrix with a scalar value 2147483661 LinAlgCopyConvertMatrix Converts and copies the element and use type of the source matrix to the destination matrix with optional transpose 2147483662 LinAlgMatrixLoadFromDescriptor fills a matrix with data from a [RW]ByteAddressBuffer @@ -3088,7 +3088,7 @@ ID Name Description 2147483668 LinAlgMatrixStoreToDescriptor stores a matrix to a RWByteAddressBuffer 2147483669 LinAlgMatrixStoreToMemory stores a matrix to groupshared memory 2147483670 LinAlgMatrixQueryAccumulatorLayout returns comptime 0 when accumulator matrix are A layout, 1 when B layout -2147483671 LinAlgMatrixMulOp applies a multiplication op to matrix C using A and B as parameters +2147483671 LinAlgMatrixMultiply Returns the resulting matrix from multiplying A and B 2147483672 LinAlgMatrixAccumulate accumulate A or B matrix into Accumulator matrix following LHS += RHS 2147483673 LinAlgMatVecMul Multiplies a MxK dimension matrix and a K sized input vector 2147483674 LinAlgMatVecMulAdd Multiplies a MxK dimension matrix and a K sized input vector then adds a M sized bias vector diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index dfb835aa00..eb38ec6e70 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -524,7 +524,6 @@ static const OpCodeTableID TableID = OpCodeTableID::ExperimentalOps; // Enumeration for ExperimentalOps DXIL operations enum class OpCode : unsigned { // - ReservedD0 = 11, // reserved ReservedD1 = 30, // reserved ReservedD2 = 31, // reserved ReservedD3 = 32, // reserved @@ -573,8 +572,11 @@ enum class OpCode : unsigned { 14, // fills a matrix with data from a [RW]ByteAddressBuffer LinAlgMatrixLoadFromMemory = 15, // fills a matrix with data from a groupshared array - LinAlgMatrixMulOp = - 23, // applies a multiplication op to matrix C using A and B as parameters + LinAlgMatrixMultiply = + 23, // Returns the resulting matrix from multiplying A and B + LinAlgMatrixMultiplyAccumulate = + 11, // Returns the resulting matrix from multiplying A and B and + // accumulating into C LinAlgMatrixOuterProduct = 29, // Outer products an M sized vector and a N // sized vector producing an MxN matrix LinAlgMatrixQueryAccumulatorLayout = @@ -1263,8 +1265,11 @@ enum class OpCode : unsigned { EXP_OPCODE(ExperimentalOps, HitObject_TriangleObjectPosition), // returns triangle vertices in // object space as <9 x float> - // ReservedD0 = 0x8000000B, 2147483659U, -2147483637 - EXP_OPCODE(ExperimentalOps, ReservedD0), // reserved + // LinAlgMatrixMultiplyAccumulate = 0x8000000B, 2147483659U, -2147483637 + EXP_OPCODE(ExperimentalOps, + LinAlgMatrixMultiplyAccumulate), // Returns the resulting matrix + // from multiplying A and B and + // accumulating into C // LinAlgFillMatrix = 0x8000000C, 2147483660U, -2147483636 EXP_OPCODE(ExperimentalOps, LinAlgFillMatrix), // fills a matrix with a scalar value @@ -1316,10 +1321,10 @@ enum class OpCode : unsigned { LinAlgMatrixQueryAccumulatorLayout), // returns comptime 0 when // accumulator matrix are A // layout, 1 when B layout - // LinAlgMatrixMulOp = 0x80000017, 2147483671U, -2147483625 + // LinAlgMatrixMultiply = 0x80000017, 2147483671U, -2147483625 EXP_OPCODE(ExperimentalOps, - LinAlgMatrixMulOp), // applies a multiplication op to matrix C - // using A and B as parameters + LinAlgMatrixMultiply), // Returns the resulting matrix from + // multiplying A and B // LinAlgMatrixAccumulate = 0x80000018, 2147483672U, -2147483624 EXP_OPCODE(ExperimentalOps, LinAlgMatrixAccumulate), // accumulate A or B matrix into @@ -1529,7 +1534,8 @@ enum class OpCodeClass : unsigned { LinAlgMatrixLength, LinAlgMatrixLoadFromDescriptor, LinAlgMatrixLoadFromMemory, - LinAlgMatrixMulOp, + LinAlgMatrixMultiply, + LinAlgMatrixMultiplyAccumulate, LinAlgMatrixOuterProduct, LinAlgMatrixQueryAccumulatorLayout, LinAlgMatrixSetElement, @@ -1725,7 +1731,7 @@ enum class OpCodeClass : unsigned { NodeOutputIsValid, OutputComplete, - NumOpClasses = 224, // exclusive last value of enumeration + NumOpClasses = 225, // exclusive last value of enumeration }; // OPCODECLASS-ENUM:END diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index 2f388fdcd3..8c48202ce0 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -10500,6 +10500,41 @@ struct DxilInst_HitObject_TriangleObjectPosition { void set_hitObject(llvm::Value *val) { Instr->setOperand(1, val); } }; +/// This instruction Returns the resulting matrix from multiplying A and B and +/// accumulating into C +struct DxilInst_LinAlgMatrixMultiplyAccumulate { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_LinAlgMatrixMultiplyAccumulate(llvm::Instruction *pInstr) + : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::LinAlgMatrixMultiplyAccumulate); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixA = 1, + arg_matrixB = 2, + arg_matrixC = 3, + }; + // Accessors + llvm::Value *get_matrixA() const { return Instr->getOperand(1); } + void set_matrixA(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_matrixB() const { return Instr->getOperand(2); } + void set_matrixB(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_matrixC() const { return Instr->getOperand(3); } + void set_matrixC(llvm::Value *val) { Instr->setOperand(3, val); } +}; + /// This instruction fills a matrix with a scalar value struct DxilInst_LinAlgFillMatrix { llvm::Instruction *Instr; @@ -10859,15 +10894,14 @@ struct DxilInst_LinAlgMatrixQueryAccumulatorLayout { bool requiresUniformInputs() const { return false; } }; -/// This instruction applies a multiplication op to matrix C using A and B as -/// parameters -struct DxilInst_LinAlgMatrixMulOp { +/// This instruction Returns the resulting matrix from multiplying A and B +struct DxilInst_LinAlgMatrixMultiply { llvm::Instruction *Instr; // Construction and identification - DxilInst_LinAlgMatrixMulOp(llvm::Instruction *pInstr) : Instr(pInstr) {} + DxilInst_LinAlgMatrixMultiply(llvm::Instruction *pInstr) : Instr(pInstr) {} operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst(Instr, - hlsl::OP::OpCode::LinAlgMatrixMulOp); + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::LinAlgMatrixMultiply); } // Validation support bool isAllowed() const { return true; } diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 1393474b48..fa9e0fde4c 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2823,16 +2823,15 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = { {{0x2}}, {{0x0}}}, // Overloads: f - {OC::ReservedD0, - "ReservedD0", - OCC::Reserved, - "reserved", - Attribute::None, - 0, - {}, - {}}, // Overloads: v - // Linear Algebra Operations + {OC::LinAlgMatrixMultiplyAccumulate, + "LinAlgMatrixMultiplyAccumulate", + OCC::LinAlgMatrixMultiplyAccumulate, + "linAlgMatrixMultiplyAccumulate", + Attribute::None, + 4, + {{0x200}, {0x200}, {0x200}, {0x200}}, + {{0x0}, {0x0}, {0x0}, {0x0}}}, // Overloads: o,o,o,o {OC::LinAlgFillMatrix, "LinAlgFillMatrix", OCC::LinAlgFillMatrix, @@ -2921,10 +2920,10 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = { 0, {}, {}}, // Overloads: v - {OC::LinAlgMatrixMulOp, - "LinAlgMatrixMulOp", - OCC::LinAlgMatrixMulOp, - "linAlgMatrixMulOp", + {OC::LinAlgMatrixMultiply, + "LinAlgMatrixMultiply", + OCC::LinAlgMatrixMultiply, + "linAlgMatrixMultiply", Attribute::None, 3, {{0x200}, {0x200}, {0x200}}, @@ -3950,15 +3949,16 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation, minor = 10; return; } - // Instructions: LinAlgFillMatrix=2147483660, - // LinAlgCopyConvertMatrix=2147483661, LinAlgMatrixLoadFromMemory=2147483663, - // LinAlgMatrixLength=2147483664, LinAlgMatrixGetCoordinate=2147483665, - // LinAlgMatrixGetElement=2147483666, LinAlgMatrixSetElement=2147483667, + // Instructions: LinAlgMatrixMultiplyAccumulate=2147483659, + // LinAlgFillMatrix=2147483660, LinAlgCopyConvertMatrix=2147483661, + // LinAlgMatrixLoadFromMemory=2147483663, LinAlgMatrixLength=2147483664, + // LinAlgMatrixGetCoordinate=2147483665, LinAlgMatrixGetElement=2147483666, + // LinAlgMatrixSetElement=2147483667, // LinAlgMatrixStoreToDescriptor=2147483668, - // LinAlgMatrixStoreToMemory=2147483669, LinAlgMatrixMulOp=2147483671, + // LinAlgMatrixStoreToMemory=2147483669, LinAlgMatrixMultiply=2147483671, // LinAlgMatrixAccumulate=2147483672, // LinAlgMatrixAccumulateToMemory=2147483676 - if ((2147483660 <= op && op <= 2147483661) || + if ((2147483659 <= op && op <= 2147483661) || (2147483663 <= op && op <= 2147483669) || (2147483671 <= op && op <= 2147483672) || op == 2147483676) { major = 6; @@ -6557,13 +6557,14 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pHit); break; - // - case OpCode::ReservedD0: - A(pV); + // Linear Algebra Operations + case OpCode::LinAlgMatrixMultiplyAccumulate: + EXT(0); A(pI32); + EXT(1); + EXT(2); + EXT(3); break; - - // Linear Algebra Operations case OpCode::LinAlgFillMatrix: EXT(0); A(pI32); @@ -6637,7 +6638,7 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI32); A(pI32); break; - case OpCode::LinAlgMatrixMulOp: + case OpCode::LinAlgMatrixMultiply: EXT(0); A(pI32); EXT(1); @@ -7013,7 +7014,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::GetGroupWaveIndex: case OpCode::GetGroupWaveCount: case OpCode::ClusterID: - case OpCode::ReservedD0: case OpCode::LinAlgMatrixQueryAccumulatorLayout: case OpCode::ReservedD1: case OpCode::ReservedD2: @@ -7070,13 +7070,20 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { return llvm::StructType::get(Ctx, {FT->getParamType(1), FT->getParamType(2)}); + case OpCode::LinAlgMatrixMultiplyAccumulate: + if (FT->getNumParams() < 4) + return nullptr; + return llvm::StructType::get(Ctx, + {FT->getReturnType(), FT->getParamType(1), + FT->getParamType(2), FT->getParamType(3)}); + case OpCode::LinAlgMatrixSetElement: if (FT->getNumParams() < 4) return nullptr; return llvm::StructType::get( Ctx, {FT->getReturnType(), FT->getParamType(1), FT->getParamType(3)}); - case OpCode::LinAlgMatrixMulOp: + case OpCode::LinAlgMatrixMultiply: case OpCode::LinAlgMatrixAccumulate: case OpCode::LinAlgMatVecMul: case OpCode::LinAlgMatrixOuterProduct: diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 22ea3c77d0..2665c441a6 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -7680,9 +7680,9 @@ constexpr IntrinsicLower gLowerTable[] = { {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate, EmptyLower, DXIL::OpCode::LinAlgMatrixAccumulate}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiply, EmptyLower, - DXIL::OpCode::LinAlgMatrixMulOp}, + DXIL::OpCode::LinAlgMatrixMultiply}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiplyAccumulate, - EmptyLower, DXIL::OpCode::LinAlgMatrixMulOp}, + EmptyLower, DXIL::OpCode::LinAlgMatrixMultiplyAccumulate}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixQueryAccumulatorLayout, EmptyLower, DXIL::OpCode::LinAlgMatrixQueryAccumulatorLayout}, {IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor, EmptyLower, diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index 94ff9e4b33..e269a25877 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -411,7 +411,7 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix ma void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); -void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); +void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC); void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp); void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp); diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 0b98fed0a0..36884121f3 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -1174,8 +1174,8 @@ def populate_categories_and_models_ExperimentalOps(self): + "LinAlgMatrixGetCoordinate,LinAlgMatrixGetElement," + "LinAlgMatrixSetElement,LinAlgMatrixStoreToDescriptor," + "LinAlgMatrixLoadFromMemory,LinAlgMatrixStoreToMemory," - + "LinAlgMatrixAccumulateToMemory,LinAlgMatrixMulOp," - + "LinAlgMatrixAccumulate" + + "LinAlgMatrixAccumulateToMemory,LinAlgMatrixMultiply," + + "LinAlgMatrixMultiplyAccumulate,LinAlgMatrixAccumulate" ): i.category = "Linear Algebra Operations" i.shader_model = experimental_sm @@ -6341,7 +6341,19 @@ def populate_ExperimentalOps(self): ) # Linear Algebra Ops - op_table.reserve_dxil_op_range("ReservedD", 1) + add_dxil_op( + "LinAlgMatrixMultiplyAccumulate", + "LinAlgMatrixMultiplyAccumulate", + "Returns the resulting matrix from multiplying A and B and accumulating into C", + "o,o,o,o", + "", + [ + db_dxil_param(0, "$x0", "", "resulting matrix"), + db_dxil_param(2, "$x1", "matrixA", "A matrix"), + db_dxil_param(3, "$x2", "matrixB", "B matrix"), + db_dxil_param(4, "$x3", "matrixC", "C matrix"), + ], + ) add_dxil_op( "LinAlgFillMatrix", @@ -6530,9 +6542,9 @@ def populate_ExperimentalOps(self): ) add_dxil_op( - "LinAlgMatrixMulOp", - "LinAlgMatrixMulOp", - "applies a multiplication op to matrix C using A and B as parameters", + "LinAlgMatrixMultiply", + "LinAlgMatrixMultiply", + "Returns the resulting matrix from multiplying A and B", "o,o,o", "", [