diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index eb7ae12f3724..e235f2a52364 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -24,7 +24,7 @@ class DXSA_Op traits = []> : Op; //===----------------------------------------------------------------------===// -// DXSA module — top-level container op for a DXBC tokenized program +// DXSA module - top-level container op for a DXBC tokenized program //===----------------------------------------------------------------------===// def DXSA_ProgramType_PixelShader : I32EnumAttrCase<"pixel_shader", 0>; @@ -572,6 +572,28 @@ def DXSA_Instruction : DXSA_Op<"instruction"> { let assemblyFormat = "$mnemonic $operands attr-dict"; } +def DXSA_Unknown : DXSA_Op<"unknown"> { + let summary = "raw tokens fallback for an undecodable instruction"; + let description = [{ + The `dxsa.unknown` operation represents one instruction whose raw + tokens could not be decoded into a structured op - unknown opcode, + undecodable payload, length mismatch, truncated tail, etc. It acts + as a disassembler-style fallback so the surrounding well-formed + instructions still appear in IR. + + Example: + + ```mlir + dxsa.unknown + ``` + }]; + + let arguments = (ins DenseI32ArrayAttr:$tokens); + let results = (outs); + let hasVerifier = 1; + let assemblyFormat = "custom($tokens) attr-dict"; +} + def DXSA_InlineOperandType_Temp : I32EnumAttrCase<"temp", 0>; def DXSA_InlineOperandType_Input : I32EnumAttrCase<"input", 1>; def DXSA_InlineOperandType_Output : I32EnumAttrCase<"output", 2>; diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index 7acec7fae4e1..889216a2902b 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Format.h" using namespace mlir; using namespace mlir::dxsa; @@ -35,6 +36,12 @@ void DXSADialect::initialize() { >(); } +/// Declarations for custom-directive helpers used by the +/// TableGen-generated print/parse methods. +static ParseResult parseHexTokens(OpAsmParser &parser, DenseI32ArrayAttr &attr); +static void printHexTokens(OpAsmPrinter &printer, Operation *, + DenseI32ArrayAttr attr); + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -211,6 +218,50 @@ LogicalResult DclResourceStructured::verify() { return success(); } +//===----------------------------------------------------------------------===// +// UnknownOp +//===----------------------------------------------------------------------===// + +LogicalResult Unknown::verify() { + if (getTokens().empty()) + return emitOpError("tokens must not be empty"); + return success(); +} + +/// Parse `` for the unknown op. +static ParseResult parseHexTokens(OpAsmParser &parser, + DenseI32ArrayAttr &attr) { + SmallVector tokens; + auto parseOneToken = [&]() -> ParseResult { + uint32_t value; + if (parser.parseInteger(value)) + return failure(); + tokens.push_back(static_cast(value)); + return success(); + }; + + if (parser.parseLess() || parser.parseKeyword("tokens") || + parser.parseEqual() || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square, + parseOneToken) || + parser.parseGreater()) + return failure(); + + attr = DenseI32ArrayAttr::get(parser.getContext(), tokens); + return success(); +} + +/// Print the tokens array as uppercase, 8-digit, 0x-prefixed hex. +static void printHexTokens(OpAsmPrinter &printer, Operation *, + DenseI32ArrayAttr attr) { + printer << "(t), + /*Width=*/10, /*Upper=*/true); + }); + printer << "]>"; +} + //===----------------------------------------------------------------------===// // TableGen'd attribute method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index b2adacfc79de..475100222993 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -523,6 +523,17 @@ class DXBuilder { return op->getResults()[0]; } + size_t getNumOps() { + return builder.getInsertionBlock()->getOperations().size(); + } + + void rewindOpsTo(size_t numOps) { + auto *block = builder.getInsertionBlock(); + while (block->getOperations().size() > numOps) + block->back().erase(); // reverse order: use-def stays valid + builder.setInsertionPointToEnd(block); + } + Instruction buildInstruction(StringRef name, ArrayRef operands, const InstructionModifier &modifier, FileLineColLoc loc) { @@ -530,6 +541,14 @@ class DXBuilder { builder.getStringAttr(name)); } + Instruction buildUnknown(ArrayRef tokens, Location loc) { + auto signedTokens = llvm::map_to_vector( + tokens, [](uint32_t token) { return static_cast(token); }); + return dxsa::Unknown::create( + builder, loc, + DenseI32ArrayAttr::get(builder.getContext(), signedTokens)); + } + Instruction buildDclGlobalFlags(dxsa::GlobalFlags flags, Location loc) { auto flagsAttr = dxsa::GlobalFlagsAttr::get(builder.getContext(), flags); return dxsa::DclGlobalFlags::create(builder, loc, flagsAttr); @@ -783,21 +802,32 @@ class Parser { /// Width of the token in the program binary stream. static constexpr size_t tokenSize = sizeof(uint32_t); + uint32_t getRemainingBytes() { return buffer.size() - currentTokenOffset; } /// Parse the current token and move the cursor to the next one. Token parseToken() { - if ((currentTokenOffset + tokenSize) > buffer.size()) { - emitError(getLocation(), "unexpected end of file"); - return failure(); + if (getRemainingBytes() < tokenSize) { + return emitError(getLocation(), "unexpected end of file"); } - uint32_t value = support::endian::read( + auto value = support::endian::read( buffer.begin() + currentTokenOffset, endianness::little); currentTokenOffset += tokenSize; return value; } + FailureOr> parseTokens(uint32_t numTokens) { + SmallVector tokens(numTokens); + for (uint32_t i = 0; i < numTokens; ++i) { + auto token = parseToken(); + if (failed(token)) + return failure(); + tokens[i] = *token; + } + return tokens; + } + /// Returns location where the last parsed token begins (at offset /// -4 from the currentTokenOffset). FileLineColLoc getLocation(int offset = -4) const { @@ -1712,39 +1742,49 @@ class Parser { return success(); } - FailureOr parseInstruction() { - size_t beginOffset = currentTokenOffset; - Token token = parseToken(); - if (failed(token)) + FailureOr parseInstruction(uint32_t &instructionLengthInTokens) { + auto beginOffset = currentTokenOffset; + instructionLengthInTokens = 1; // Min instruction length + auto opcodeToken0 = parseToken(); + if (failed(opcodeToken0)) return failure(); - FileLineColLoc loc = getLocation(); + uint32_t opcode = DECODE_D3D10_SB_OPCODE_TYPE(*opcodeToken0); - uint32_t opcode = DECODE_D3D10_SB_OPCODE_TYPE(*token); - InstructionModifier modifier; - modifier.preciseMask = DECODE_D3D11_SB_INSTRUCTION_PRECISE_VALUES(*token); - modifier.saturate = DECODE_IS_D3D10_SB_INSTRUCTION_SATURATE_ENABLED(*token); + // CUSTOMDATA carries its total token count (>= 2) in token1. + // Just set the instruction length for the unknown fallback. + if (opcode == D3D10_SB_OPCODE_CUSTOMDATA) { + auto numTokensToken = parseToken(); + FAILURE_IF_FAILED(numTokensToken); + instructionLengthInTokens = std::max(*numTokensToken, 2u); + return failure(); + } - uint32_t length = DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(*token); + instructionLengthInTokens = std::max( + DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(*opcodeToken0), 1u); + + InstructionModifier modifier; + modifier.preciseMask = + DECODE_D3D11_SB_INSTRUCTION_PRECISE_VALUES(*opcodeToken0); + modifier.saturate = + DECODE_IS_D3D10_SB_INSTRUCTION_SATURATE_ENABLED(*opcodeToken0); // TODO: extended instructions: // BOOL b51PlusShader = // BOOL bExtended = DECODE_IS_D3D10_SB_OPCODE_EXTENDED(Token) // ... - if (opcode >= D3D10_SB_NUM_OPCODES) { - emitError(getLocation(), "unknown opcode"); - return failure(); - } - - auto opcodeToken = *token; + if (opcode >= D3D10_SB_NUM_OPCODES) + return emitError(getLocation(), "unknown opcode: ") << opcode; Instruction dclInstruction; - auto parseResult = parseDclInstruction(opcodeToken, loc, dclInstruction); + auto parseResult = + parseDclInstruction(*opcodeToken0, getLocation(), dclInstruction); if (parseResult.has_value()) { if (failed(*parseResult)) return failure(); - if (failed(verifyInstructionLength(beginOffset, length))) + if (failed( + verifyInstructionLength(beginOffset, instructionLengthInTokens))) return failure(); return dclInstruction; } @@ -1759,11 +1799,48 @@ class Parser { operands.push_back(*operand); } - if (failed(verifyInstructionLength(beginOffset, length))) + if (failed(verifyInstructionLength(beginOffset, instructionLengthInTokens))) return failure(); return builder.buildInstruction(instrInfo[opcode].name, operands, modifier, - loc); + getLocation()); + } + + /// On failure, sets `instructionLengthInTokens` for the unknown fallback. + bool tryParseInstructionOrRewind(uint32_t &instructionLengthInTokens) { + auto numOpsBefore = builder.getNumOps(); + + // Scope for ScopedDiagnosticHandler + { + ScopedDiagnosticHandler suppress(name.getContext(), + [](Diagnostic &) { return success(); }); + if (succeeded(parseInstruction(instructionLengthInTokens))) + return true; + } + + builder.rewindOpsTo(numOpsBefore); + return false; + } + + LogicalResult parseUnknownTokens(uint32_t numTokens) { + auto loc = getLocation(); + numTokens = std::min(numTokens, getRemainingBytes() / tokenSize); + auto tokens = parseTokens(numTokens); + FAILURE_IF_FAILED(tokens); + builder.buildUnknown(*tokens, loc); + return success(); + } + + LogicalResult parseNextInstruction() { + auto beginOffset = currentTokenOffset; + uint32_t instructionLengthInTokens = 0; + if (tryParseInstructionOrRewind(instructionLengthInTokens)) + return success(); + + currentTokenOffset = beginOffset; + emitWarning(getLocation()) << "treating next " << instructionLengthInTokens + << " token(s) as unknown"; + return parseUnknownTokens(instructionLengthInTokens); } FailureOr parseModule() { @@ -1779,12 +1856,13 @@ class Parser { name.getContext(), (*header)->major, (*header)->minor); } auto module = builder.createModule(programType, shaderVersion, loc); - while (currentTokenOffset < buffer.size()) { - FailureOr inst = parseInstruction(); - if (failed(inst)) { + while (getRemainingBytes() >= tokenSize) { + if (failed(parseNextInstruction())) return failure(); - } } + if (auto trailingBytes = getRemainingBytes()) + emitWarning(getLocation(0)) + << "ignoring " << trailingBytes << " trailing byte(s)"; return module; } @@ -1799,7 +1877,7 @@ class Parser { /// and shader version. Otherwise return without touching the parser current /// position. FailureOr> parseProgramHeader() { - auto remainingBytes = buffer.size() - currentTokenOffset; + auto remainingBytes = getRemainingBytes(); if (remainingBytes < tokenSize) return std::optional{}; @@ -1841,7 +1919,7 @@ class Parser { } LogicalResult verifyInstructionLength(size_t beginOffset, uint32_t length) { - if (((currentTokenOffset - beginOffset) / 4) != length) { + if (((currentTokenOffset - beginOffset) / tokenSize) != length) { emitError(getLocation(), "instruction length mismatch"); return failure(); } diff --git a/mlir/test/Target/DXSA/inputs/unknown.bin b/mlir/test/Target/DXSA/inputs/unknown.bin new file mode 100644 index 000000000000..367d481a0234 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/unknown.bin differ diff --git a/mlir/test/Target/DXSA/inputs/unknown_past_eof.bin b/mlir/test/Target/DXSA/inputs/unknown_past_eof.bin new file mode 100644 index 000000000000..468a58fc674f Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/unknown_past_eof.bin differ diff --git a/mlir/test/Target/DXSA/unknown.mlir b/mlir/test/Target/DXSA/unknown.mlir new file mode 100644 index 000000000000..80728a72d2fe --- /dev/null +++ b/mlir/test/Target/DXSA/unknown.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/unknown.bin | FileCheck %s + +// CHECK: dxsa.module { +// CHECK-NEXT: dxsa.dcl_temps 1 +// CHECK-NEXT: dxsa.unknown +// CHECK-NEXT: dxsa.dcl_temps 2 +// CHECK-NEXT: dxsa.unknown +// CHECK-NEXT: dxsa.dcl_temps 3 +// CHECK-NEXT: dxsa.unknown +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/unknown_invalid.mlir b/mlir/test/Target/DXSA/unknown_invalid.mlir new file mode 100644 index 000000000000..d8ad07aae906 --- /dev/null +++ b/mlir/test/Target/DXSA/unknown_invalid.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{'dxsa.unknown' op tokens must not be empty}} +dxsa.unknown diff --git a/mlir/test/Target/DXSA/unknown_past_eof.mlir b/mlir/test/Target/DXSA/unknown_past_eof.mlir new file mode 100644 index 000000000000..94d70d628b51 --- /dev/null +++ b/mlir/test/Target/DXSA/unknown_past_eof.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/unknown_past_eof.bin | FileCheck %s + +// Opcode declares length 5, but only 3 tokens remain in the file. +// The unknown fallback clamps the span to the actual remainder, never past EOF. + +// CHECK: dxsa.module { +// CHECK-NEXT: dxsa.dcl_temps 1 +// CHECK-NEXT: dxsa.unknown +// CHECK-NEXT: }