Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -397,32 +397,27 @@ void [[min_sm=6.10]] __builtin_VectorAccumulate(in LinAlg<c> InputVector, in RWB

// LinAlg intrinsics

// TODO: Replace all int MatrixRef with MatrixRef type
// TODO: Replace all int GroupSharedMem with groupshared memory
void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(int MatrixRef, numeric value);
void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(int MatrixRefDest, int MatrixRefSrc, bool transpose);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(int MatrixRef);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(int MatrixRef, int32_only threadLocalIndex);
numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(int MatrixRef, int32_only threadLocalIndex);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(int MatrixRef, int32_only threadLocalIndex, numeric value);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);
int32_only [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(int MatrixRefA, int MatrixRefB, int MatrixRefC);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(int MatrixRefA, int MatrixRefB, int MatrixRefC);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(int MatrixRefRHS, int MatrixRefLHS);

// TODO: Fix vector types
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(int MatrixRef);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(int MatrixRef);

void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);

// TODO: Fix vector types
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(int MatrixRef);
void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeric value);
void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix);
uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex);
numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(in LinAlgMatrix matrix, in uint threadLocalIndex);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature here doesn't seem to indicate that any matrix is "updated" during the accumulation. Should matrixA be an inout, or otherwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there is an issue here in that regard but the fix (coming in a later PR)[1] doesn't make this look any clearer. The API is weird that even when something is updated "in place" a new SSA value is produced.

At the DXIL level this function is

declare %dx.types.LinAlgMatrix<mangling> @dx.op.linAlgMatrixMultiplyAccumulate.[MatTyR].[MatTyA].[MatTyB].[MatTyC](
  immarg i32,                        ; opcode
  %dx.types.LinAlgMatrix<mangling>,  ; matrix A
  %dx.types.LinAlgMatrix<mangling>   ; matrix B
  %dx.types.LinAlgMatrix<mangling>   ; matrix C
  )

so its better to think of the function as R = A * B + C. R is just missing because it was a spec bug that was fixed after this PR was written (and its a lot easier on me to punt the fix instead of doing a bunch of complicated rebases)

1: The final signature of this function will actually be:
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the very next PR contains this fix. I suppose I could merge them into one 2-topic PR if folks prefer

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I don't mind doing whatever is easier to you

void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
Comment on lines +413 to +415
Copy link
Member

@hekota hekota Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the double Matrix in the name intentional? I think __builtin_LinAlg_MatrixMultiply would be just fine.

And shouldn't these return a Matrix value instead of having an out parameter?

Copy link
Collaborator Author

@V-FEXrt V-FEXrt Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The double is intentional yes! There is MatrixMatrixMultiply (aka Matrix x Matrix) and MatrixVectorMultiply (aka Matrix x Vector)

Nope, out parameter is necessary for overload resolution! They get translated into the correct return shape in lowering. Since we can overloading on just the return type it needs to be captured in the parameter list

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "cannot do overloading on just the return type"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);

} namespace

Expand Down