Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DXSA_Op<string mnemonic, list<Trait> traits = []> :
Op<DXSADialect, mnemonic, traits>;

//===----------------------------------------------------------------------===//
// DXSA module top-level container op for a DXBC tokenized program
// DXSA module - top-level container op for a DXBC tokenized program
//===----------------------------------------------------------------------===//

def DXSA_ProgramType_PixelShader : I32EnumAttrCase<"pixel_shader", 0>;
Expand Down Expand Up @@ -572,6 +572,28 @@ def DXSA_Instruction : DXSA_Op<"instruction"> {
let assemblyFormat = "$mnemonic $operands attr-dict";
}

def DXSA_Unknown : DXSA_Op<"unknown"> {
let summary = "raw tokens fallback for an undecodable instruction";
let description = [{
The `dxsa.unknown` operation represents one instruction whose raw
tokens could not be decoded into a structured op - unknown opcode,
undecodable payload, length mismatch, truncated tail, etc. It acts
as a disassembler-style fallback so the surrounding well-formed
instructions still appear in IR.

Example:

```mlir
dxsa.unknown <tokens = [0x030000FF, 0xDEADBEEF, 0x12345678]>
```
}];

let arguments = (ins DenseI32ArrayAttr:$tokens);
let results = (outs);
let hasVerifier = 1;
let assemblyFormat = "custom<HexTokens>($tokens) attr-dict";
}

def DXSA_InlineOperandType_Temp : I32EnumAttrCase<"temp", 0>;
def DXSA_InlineOperandType_Input : I32EnumAttrCase<"input", 1>;
def DXSA_InlineOperandType_Output : I32EnumAttrCase<"output", 2>;
Expand Down
51 changes: 51 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Format.h"

using namespace mlir;
using namespace mlir::dxsa;
Expand All @@ -35,6 +36,12 @@ void DXSADialect::initialize() {
>();
}

/// Declarations for custom-directive helpers used by the
/// TableGen-generated print/parse methods.
static ParseResult parseHexTokens(OpAsmParser &parser, DenseI32ArrayAttr &attr);
static void printHexTokens(OpAsmPrinter &printer, Operation *,
DenseI32ArrayAttr attr);

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -211,6 +218,50 @@ LogicalResult DclResourceStructured::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// UnknownOp
//===----------------------------------------------------------------------===//

LogicalResult Unknown::verify() {
if (getTokens().empty())
return emitOpError("tokens must not be empty");
return success();
}

/// Parse `<tokens = [0x..., ...]>` for the unknown op.
static ParseResult parseHexTokens(OpAsmParser &parser,
DenseI32ArrayAttr &attr) {
SmallVector<int32_t> tokens;
auto parseOneToken = [&]() -> ParseResult {
uint32_t value;
if (parser.parseInteger(value))
return failure();
tokens.push_back(static_cast<int32_t>(value));
return success();
};

if (parser.parseLess() || parser.parseKeyword("tokens") ||
parser.parseEqual() ||
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
parseOneToken) ||
parser.parseGreater())
return failure();

attr = DenseI32ArrayAttr::get(parser.getContext(), tokens);
return success();
}

/// Print the tokens array as uppercase, 8-digit, 0x-prefixed hex.
static void printHexTokens(OpAsmPrinter &printer, Operation *,
DenseI32ArrayAttr attr) {
printer << "<tokens = [";
llvm::interleaveComma(attr.asArrayRef(), printer.getStream(), [&](int32_t t) {
printer.getStream() << llvm::format_hex(static_cast<uint32_t>(t),
/*Width=*/10, /*Upper=*/true);
});
printer << "]>";
}

//===----------------------------------------------------------------------===//
// TableGen'd attribute method definitions
//===----------------------------------------------------------------------===//
Expand Down
138 changes: 108 additions & 30 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,13 +523,32 @@ class DXBuilder {
return op->getResults()[0];
}

size_t getNumOps() {
return builder.getInsertionBlock()->getOperations().size();
}

void rewindOpsTo(size_t numOps) {
auto *block = builder.getInsertionBlock();
while (block->getOperations().size() > numOps)
block->back().erase(); // reverse order: use-def stays valid
builder.setInsertionPointToEnd(block);
}

Instruction buildInstruction(StringRef name, ArrayRef<Operand> operands,
const InstructionModifier &modifier,
FileLineColLoc loc) {
return dxsa::Instruction::create(builder, loc, operands,
builder.getStringAttr(name));
}

Instruction buildUnknown(ArrayRef<uint32_t> tokens, Location loc) {
auto signedTokens = llvm::map_to_vector(
tokens, [](uint32_t token) { return static_cast<int32_t>(token); });
return dxsa::Unknown::create(
builder, loc,
DenseI32ArrayAttr::get(builder.getContext(), signedTokens));
}

Instruction buildDclGlobalFlags(dxsa::GlobalFlags flags, Location loc) {
auto flagsAttr = dxsa::GlobalFlagsAttr::get(builder.getContext(), flags);
return dxsa::DclGlobalFlags::create(builder, loc, flagsAttr);
Expand Down Expand Up @@ -783,21 +802,32 @@ class Parser {

/// Width of the token in the program binary stream.
static constexpr size_t tokenSize = sizeof(uint32_t);
uint32_t getRemainingBytes() { return buffer.size() - currentTokenOffset; }

/// Parse the current token and move the cursor to the next one.
Token parseToken() {
if ((currentTokenOffset + tokenSize) > buffer.size()) {
emitError(getLocation(), "unexpected end of file");
return failure();
if (getRemainingBytes() < tokenSize) {
return emitError(getLocation(), "unexpected end of file");
}

uint32_t value = support::endian::read<uint32_t>(
auto value = support::endian::read<uint32_t>(
buffer.begin() + currentTokenOffset, endianness::little);
currentTokenOffset += tokenSize;

return value;
}

FailureOr<SmallVector<uint32_t>> parseTokens(uint32_t numTokens) {
SmallVector<uint32_t> tokens(numTokens);
for (uint32_t i = 0; i < numTokens; ++i) {
auto token = parseToken();
if (failed(token))
return failure();
tokens[i] = *token;
}
return tokens;
}

/// Returns location where the last parsed token begins (at offset
/// -4 from the currentTokenOffset).
FileLineColLoc getLocation(int offset = -4) const {
Expand Down Expand Up @@ -1712,39 +1742,49 @@ class Parser {
return success();
}

FailureOr<Instruction> parseInstruction() {
size_t beginOffset = currentTokenOffset;
Token token = parseToken();
if (failed(token))
FailureOr<Instruction> parseInstruction(uint32_t &instructionLengthInTokens) {
auto beginOffset = currentTokenOffset;
instructionLengthInTokens = 1; // Min instruction length
auto opcodeToken0 = parseToken();
if (failed(opcodeToken0))
return failure();

FileLineColLoc loc = getLocation();
uint32_t opcode = DECODE_D3D10_SB_OPCODE_TYPE(*opcodeToken0);

uint32_t opcode = DECODE_D3D10_SB_OPCODE_TYPE(*token);
InstructionModifier modifier;
modifier.preciseMask = DECODE_D3D11_SB_INSTRUCTION_PRECISE_VALUES(*token);
modifier.saturate = DECODE_IS_D3D10_SB_INSTRUCTION_SATURATE_ENABLED(*token);
// CUSTOMDATA carries its total token count (>= 2) in token1.
// Just set the instruction length for the unknown fallback.
if (opcode == D3D10_SB_OPCODE_CUSTOMDATA) {
auto numTokensToken = parseToken();
FAILURE_IF_FAILED(numTokensToken);
instructionLengthInTokens = std::max(*numTokensToken, 2u);
return failure();
}

uint32_t length = DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(*token);
instructionLengthInTokens = std::max(
DECODE_D3D10_SB_TOKENIZED_INSTRUCTION_LENGTH(*opcodeToken0), 1u);

InstructionModifier modifier;
modifier.preciseMask =
DECODE_D3D11_SB_INSTRUCTION_PRECISE_VALUES(*opcodeToken0);
modifier.saturate =
DECODE_IS_D3D10_SB_INSTRUCTION_SATURATE_ENABLED(*opcodeToken0);

// TODO: extended instructions:
// BOOL b51PlusShader =
// BOOL bExtended = DECODE_IS_D3D10_SB_OPCODE_EXTENDED(Token)
// ...

if (opcode >= D3D10_SB_NUM_OPCODES) {
emitError(getLocation(), "unknown opcode");
return failure();
}

auto opcodeToken = *token;
if (opcode >= D3D10_SB_NUM_OPCODES)
return emitError(getLocation(), "unknown opcode: ") << opcode;

Instruction dclInstruction;
auto parseResult = parseDclInstruction(opcodeToken, loc, dclInstruction);
auto parseResult =
parseDclInstruction(*opcodeToken0, getLocation(), dclInstruction);
if (parseResult.has_value()) {
if (failed(*parseResult))
return failure();
if (failed(verifyInstructionLength(beginOffset, length)))
if (failed(
verifyInstructionLength(beginOffset, instructionLengthInTokens)))
return failure();
return dclInstruction;
}
Expand All @@ -1759,11 +1799,48 @@ class Parser {
operands.push_back(*operand);
}

if (failed(verifyInstructionLength(beginOffset, length)))
if (failed(verifyInstructionLength(beginOffset, instructionLengthInTokens)))
return failure();

return builder.buildInstruction(instrInfo[opcode].name, operands, modifier,
loc);
getLocation());
}

/// On failure, sets `instructionLengthInTokens` for the unknown fallback.
bool tryParseInstructionOrRewind(uint32_t &instructionLengthInTokens) {
auto numOpsBefore = builder.getNumOps();

// Scope for ScopedDiagnosticHandler
{
ScopedDiagnosticHandler suppress(name.getContext(),
[](Diagnostic &) { return success(); });
if (succeeded(parseInstruction(instructionLengthInTokens)))
return true;
}

builder.rewindOpsTo(numOpsBefore);
return false;
}

LogicalResult parseUnknownTokens(uint32_t numTokens) {
auto loc = getLocation();
numTokens = std::min<uint32_t>(numTokens, getRemainingBytes() / tokenSize);
auto tokens = parseTokens(numTokens);
FAILURE_IF_FAILED(tokens);
builder.buildUnknown(*tokens, loc);
return success();
}

LogicalResult parseNextInstruction() {
auto beginOffset = currentTokenOffset;
uint32_t instructionLengthInTokens = 0;
if (tryParseInstructionOrRewind(instructionLengthInTokens))
return success();

currentTokenOffset = beginOffset;
emitWarning(getLocation()) << "treating next " << instructionLengthInTokens
<< " token(s) as unknown";
return parseUnknownTokens(instructionLengthInTokens);
}

FailureOr<Module> parseModule() {
Expand All @@ -1779,12 +1856,13 @@ class Parser {
name.getContext(), (*header)->major, (*header)->minor);
}
auto module = builder.createModule(programType, shaderVersion, loc);
while (currentTokenOffset < buffer.size()) {
FailureOr<Instruction> inst = parseInstruction();
if (failed(inst)) {
while (getRemainingBytes() >= tokenSize) {
if (failed(parseNextInstruction()))
return failure();
}
}
if (auto trailingBytes = getRemainingBytes())
emitWarning(getLocation(0))
<< "ignoring " << trailingBytes << " trailing byte(s)";
return module;
}

Expand All @@ -1799,7 +1877,7 @@ class Parser {
/// and shader version. Otherwise return without touching the parser current
/// position.
FailureOr<std::optional<ProgramHeader>> parseProgramHeader() {
auto remainingBytes = buffer.size() - currentTokenOffset;
auto remainingBytes = getRemainingBytes();
if (remainingBytes < tokenSize)
return std::optional<ProgramHeader>{};

Expand Down Expand Up @@ -1841,7 +1919,7 @@ class Parser {
}

LogicalResult verifyInstructionLength(size_t beginOffset, uint32_t length) {
if (((currentTokenOffset - beginOffset) / 4) != length) {
if (((currentTokenOffset - beginOffset) / tokenSize) != length) {
emitError(getLocation(), "instruction length mismatch");
return failure();
}
Expand Down
Binary file added mlir/test/Target/DXSA/inputs/unknown.bin
Binary file not shown.
Binary file added mlir/test/Target/DXSA/inputs/unknown_past_eof.bin
Binary file not shown.
10 changes: 10 additions & 0 deletions mlir/test/Target/DXSA/unknown.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/unknown.bin | FileCheck %s

// CHECK: dxsa.module {
// CHECK-NEXT: dxsa.dcl_temps 1
// CHECK-NEXT: dxsa.unknown <tokens = [0x030007FF, 0xDEADBEEF, 0x12345678]>
// CHECK-NEXT: dxsa.dcl_temps 2
// CHECK-NEXT: dxsa.unknown <tokens = [0x00000035, 0x00000004, 0x11111111, 0x22222222]>
// CHECK-NEXT: dxsa.dcl_temps 3
// CHECK-NEXT: dxsa.unknown <tokens = [0x03000068, 0x00000005, 0xCAFEBABE]>
// CHECK-NEXT: }
4 changes: 4 additions & 0 deletions mlir/test/Target/DXSA/unknown_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{'dxsa.unknown' op tokens must not be empty}}
dxsa.unknown <tokens = []>
9 changes: 9 additions & 0 deletions mlir/test/Target/DXSA/unknown_past_eof.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/unknown_past_eof.bin | FileCheck %s

// Opcode declares length 5, but only 3 tokens remain in the file.
// The unknown fallback clamps the span to the actual remainder, never past EOF.

// CHECK: dxsa.module {
// CHECK-NEXT: dxsa.dcl_temps 1
// CHECK-NEXT: dxsa.unknown <tokens = [0x050007FF, 0xAAAAAAAA, 0xBBBBBBBB]>
// CHECK-NEXT: }