Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Dec 23, 2025

Stacked PRs, do not merge.

The operations now accept both WaveTensorInRegister
(before conversion) and VectorOfAnyType (after conversion).

Updated type compatibility and verification logic to handle both tensor
and vector type combinations appropriately.

Fixes #624.

Implements elements per thread propagation for MMA operations.

Fixes iree-org#608.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Changes:
- ReadOp: Only propagate attribute to result (register), ignore memory
- WriteOp: Only validate/propagate with register operand, ignore memory

This fixes false positives where memory resharding was incorrectly
flagged as propagation errors.

Fixes iree-org#622.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
The operations now accept both WaveTensorInRegister
(before conversion) and VectorOfAnyType (after conversion).

Updated type compatibility and verification logic to handle both tensor
and vector type combinations appropriately.

Fixes iree-org#624.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
@tyb0807 tyb0807 requested a review from ftynse December 23, 2025 01:57
Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args,
Arg<Variadic<WaveTensorType>, "Captured values">:$captures
// Accept both WaveTensorType (before PropagateElementsPerThread) and AnyVectorOfAnyRank (after)
Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be WaveTensorInRegisters? That constraint already accepts the the tensors with no address space, tensors in register address space and 1D vectors. And we most likely don't want any vector of any rank here, which would include scalable, 0d and other nonsense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because WaveTensorInRegisters doesn't work with Variadic. I think because it's a TypeConstraint (or something like that) and not a Type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining that, this is something that should be fixed upstream eventually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we still don't want any vector of any rank here. We specifically want a 1D vector. It is also significantly easier to maintain if you somehow created a named tabelgen entity for it.

std::optional<int64_t> value = hyper.getSymbolValue(name);
#ifndef NDEBUG
if (!value) {
llvm::errs() << "symbol: " << name << "\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is extra output that will be printed before the assertion, not a debug output.

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tgymnich , when reviewing stacked PRs (the target branch is not main), click on the last commit and only review that one to avoid making comments on things that should be addressed in other PRs.

@tgymnich
Copy link
Contributor

@tgymnich , when reviewing stacked PRs (the target branch is not main), click on the last commit and only review that one to avoid making comments on things that should be addressed in other PRs.

@ftynse could we instead just change the base to the PR below in the stack?

@ftynse
Copy link
Contributor

ftynse commented Dec 24, 2025

I can adapt to the style folks use. It did accidentally click on "squash and merge" in a thus stacked PR before, polluted the other branch and had to do a bunch of force-pushing and PR re-opening to fix that.

Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args,
Arg<Variadic<WaveTensorType>, "Captured values">:$captures
// Accept both WaveTensorType (before PropagateElementsPerThread) and AnyVectorOfAnyRank (after)
Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we still don't want any vector of any rank here. We specifically want a 1D vector. It is also significantly easier to maintain if you somehow created a named tabelgen entity for it.

llvm::cast<wave::WaveTensorType>(rhs),
/*includeAddressSpace=*/true)
.succeeded();
// Handle both WaveTensorType and VectorType combinations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually don't comment on this, but please just take the habit (or configure your code generation assistant) to use full stop at the end of a sentence in comments.

"result #" + istr, resultTensor, allDims)))
return mlir::failure();

// Both are wave tensors - use existing shape verification logic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"existing" doesn't make sense in the standalone comment when reading the code. It only makes sense when reading the diff.

std::optional<int64_t> value = hyper.getSymbolValue(name);
#ifndef NDEBUG
if (!value) {
llvm::errs() << "symbol: " << name << "\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is extra output that will be printed before the assertion, not a debug output.

return
}

// CHECK-LABEL: @iterate_multidim_vectors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want these.

Comment on lines +435 to +447
// CHECK-LABEL: @iterate_vector_captures
func.func @iterate_vector_captures() {
%iter_arg = arith.constant dense<1.0> : vector<8xf32>
%capture = arith.constant dense<2.0> : vector<4xf16>

// CHECK: wave.iterate @I iter_args(%{{.*}}) captures(%{{.*}})
%result = wave.iterate @I iter_args(%iter_arg) captures(%capture) {
^bb0(%in_arg: vector<8xf32>, %cap: vector<4xf16>):
// CHECK: wave.yield %{{.*}} : vector<8xf32>
wave.yield %in_arg : vector<8xf32>
} : (vector<8xf32>, vector<4xf16>) -> (vector<8xf32>)
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this and what does it test?

// CHECK-LABEL: @iterate_with_vectors_after_ept
func.func @iterate_with_vectors_after_ept(%mem: !wave.tensor<[@M] of f32, <global>>)
attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, I = 4}>,
wave.constraints = [#wave.hardware_constraint<threads_per_wave = 64, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_32x32x8_f16>, vector_shapes = {M = 1, N = 1, K = 8}, max_bits_per_load = 128>]} {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all of these constraints?

// CHECK: wave.iterate @I iter_args({{.*}})
%result = wave.iterate @I iter_args(%init) {
^bb0(%arg: !wave.tensor<[@M] of f32, <register>>):
// Simple operation within the loop - should also work with vectors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At which point maybe this should // CHECK that it actually does?

wave.yield %doubled : !wave.tensor<[@M] of f32, <register>>
} : (!wave.tensor<[@M] of f32, <register>>) -> (!wave.tensor<[@M] of f32, <register>>)

// Write should also work with the vector result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know it does?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support vector types in wave.iterate and wave.yield operations

3 participants