-
Notifications
You must be signed in to change notification settings - Fork 830
[SM6.10] Update HL LinAlg builtin signatures #8173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -397,32 +397,27 @@ void [[min_sm=6.10]] __builtin_VectorAccumulate(in LinAlg<c> InputVector, in RWB | |
|
|
||
| // LinAlg intrinsics | ||
|
|
||
| // TODO: Replace all int MatrixRef with MatrixRef type | ||
| // TODO: Replace all int GroupSharedMem with groupshared memory | ||
| void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(int MatrixRef, numeric value); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(int MatrixRefDest, int MatrixRefSrc, bool transpose); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(int MatrixRef); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(int MatrixRef, int32_only threadLocalIndex); | ||
| numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(int MatrixRef, int32_only threadLocalIndex); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(int MatrixRef, int32_only threadLocalIndex, numeric value); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout); | ||
| int32_only [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(int MatrixRefA, int MatrixRefB, int MatrixRefC); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(int MatrixRefA, int MatrixRefB, int MatrixRefC); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(int MatrixRefRHS, int MatrixRefLHS); | ||
|
|
||
| // TODO: Fix vector types | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(int MatrixRef); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(int MatrixRef); | ||
|
|
||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout); | ||
|
|
||
| // TODO: Fix vector types | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(int MatrixRef); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeric value); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); | ||
| uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix); | ||
| uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex); | ||
| numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(in LinAlgMatrix matrix, in uint threadLocalIndex); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); | ||
| uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout(); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS); | ||
|
Comment on lines
+413
to
+415
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the double Matrix in the name intentional? I think And shouldn't these return a Matrix value instead of having an out parameter?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The double is intentional yes! There is MatrixMatrixMultiply (aka Matrix x Matrix) and MatrixVectorMultiply (aka Matrix x Vector) Nope, out parameter is necessary for overload resolution! They get translated into the correct return shape in lowering. Since we can overloading on just the return type it needs to be captured in the parameter list
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean "cannot do overloading on just the return type"?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep! |
||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout); | ||
| void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB); | ||
|
|
||
| } namespace | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature here doesn't seem to indicate that any matrix is "updated" during the accumulation. Should matrixA be an
inout, or otherwise?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, there is an issue here in that regard but the fix (coming in a later PR)[1] doesn't make this look any clearer. The API is weird that even when something is updated "in place" a new SSA value is produced.
At the DXIL level this function is
so its better to think of the function as
R = A * B + C. R is just missing because it was a spec bug that was fixed after this PR was written (and its a lot easier on me to punt the fix instead of doing a bunch of complicated rebases)1: The final signature of this function will actually be:
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the very next PR contains this fix. I suppose I could merge them into one 2-topic PR if folks prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I don't mind doing whatever is easier to you