Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4138,36 +4138,23 @@ LogicalResult SubsetOp::inferReturnTypes(

// Derive valid shape from parent valid dims when possible.
SmallVector<int64_t> validShape;
constexpr int64_t kDynamicValidDim = -1;
ArrayRef<int64_t> parentValid = sourceType.getValidShape();
for (size_t i = 0, e = resultShape.size(); i < e; ++i) {
int64_t sizeDim = resultShape[i];
int64_t vdim = sizeDim;

if (parentValid.size() == resultShape.size()) {
int64_t pv = parentValid[i];
if (pv < 0) {
vdim = kDynamicValidDim;
} else {
int64_t off = 0;
// operands: [source, offsets...]
if (operands.size() > 1 + i) {
auto offOpt = getConstIndexValue(operands[1 + i]);
if (!offOpt) {
vdim = kDynamicValidDim;
validShape.push_back(vdim);
continue;
}
off = *offOpt;
// In current subset usage, valid dims are treated as static.
// Only refine when both parent valid and offset are compile-time constants.
if (pv >= 0 && operands.size() > 1 + i) {
auto offOpt = getConstIndexValue(operands[1 + i]);
if (offOpt) {
int64_t off = *offOpt;
// Interpret parent valid dims as a per-tile "period" when the parent
// buffer is wider than the valid region (e.g. ping/pong workspace).
// This avoids inferring a zero valid dim when taking a view at an
// offset equal to the parent valid dim.
//
// Example:
// parent: shape 32x64, valid 32x32
// subset: offset [0,32], sizes [32,32]
// should infer v_col=32 (not 0).
int64_t diff = 0;
if (pv > 0) {
int64_t offMod = off % pv;
Expand All @@ -4178,8 +4165,6 @@ LogicalResult SubsetOp::inferReturnTypes(
if (diff < 0)
diff = 0;
vdim = std::min<int64_t>(sizeDim, diff);
} else {
vdim = kDynamicValidDim;
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2246,8 +2246,13 @@ struct SubviewToEmitCPattern : public OpConversionPattern<memref::SubViewOp> {
}
if (auto ot = dyn_cast<emitc::OpaqueType>(tileCandidate.getType())) {
auto tyStr = ot.getValue();
if (tyStr.find("Tile<") != std::string::npos ||
tyStr.find("ConvTile<") != std::string::npos) {
const bool isPtrLike = tyStr.ends_with("*");
bool isTileLike = tyStr.find("Tile<") != std::string::npos ||
tyStr.find("ConvTile<") != std::string::npos;
if (!isTileLike && !isPtrLike && tyStr.find("Tile") != std::string::npos)
isTileLike = true;

if (isTileLike && !isPtrLike) {
std::string elemTok = elemTypeToString(srcType.getElementType());
std::string qualifier = "__gm__";
if (auto asAttr =
Expand Down
81 changes: 48 additions & 33 deletions lib/PTO/Transforms/PTOViewToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,36 +295,22 @@ static Value computeSubsetValidDim(IRRewriter &rewriter, Location loc,
int64_t pvConst = 0, offConst = 0;
if (getConstIndexValue(parentValid, pvConst) &&
getConstIndexValue(offset, offConst)) {
int64_t diff = 0;
if (pvConst > 0) {
int64_t offMod = offConst % pvConst;
if (offMod < 0)
offMod += pvConst;
diff = pvConst - offMod; // in [1, pvConst] when pvConst>0
if (pvConst >= 0) {
int64_t diff = 0;
if (pvConst > 0) {
int64_t offMod = offConst % pvConst;
if (offMod < 0)
offMod += pvConst;
diff = pvConst - offMod; // in [1, pvConst] when pvConst>0
}
if (diff < 0)
diff = 0;
int64_t clipped = std::min<int64_t>(size, diff);
return rewriter.create<arith::ConstantIndexOp>(loc, clipped);
}
if (diff < 0)
diff = 0;
int64_t clipped = std::min<int64_t>(size, diff);
return rewriter.create<arith::ConstantIndexOp>(loc, clipped);
}

Value pv = ensureIndex(rewriter, loc, parentValid, anchorOp);
Value off = ensureIndex(rewriter, loc, offset, anchorOp);

// Use the same "periodic valid dims" rule as SubsetOp::inferReturnTypes:
// diff = pv - (off % pv), so offsets that land on the next tile (off == pv)
// still produce a full valid dim (diff == pv), instead of 0.
Type i64Ty = rewriter.getI64Type();
Value pvI64 = rewriter.create<arith::IndexCastOp>(loc, i64Ty, pv);
Value offI64 = rewriter.create<arith::IndexCastOp>(loc, i64Ty, off);
Value remI64 = rewriter.create<arith::RemUIOp>(loc, offI64, pvI64);
Value diffI64 = rewriter.create<arith::SubIOp>(loc, pvI64, remI64);
Value diff = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
diffI64);

Value lt = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, diff,
sizeVal);
return rewriter.create<arith::SelectOp>(loc, lt, diff, sizeVal);
// Keep static valid dims when runtime values are not constant.
return sizeVal;
}

static void dumpPretty(Operation *op, llvm::raw_ostream &os) {
Expand Down Expand Up @@ -812,7 +798,12 @@ struct PTOViewToMemrefPass

// 1. Source must be memref already
Value src = op->getOperand(0);
auto srcMrTy = dyn_cast<MemRefType>(src.getType());
// If the source is a bound tile, subview the underlying memref to avoid
// materializing a tile->pointer cast in later lowering.
Value subviewSrc = src;
if (auto bind = src.getDefiningOp<pto::BindTileOp>())
subviewSrc = bind.getSource();
auto srcMrTy = dyn_cast<MemRefType>(subviewSrc.getType());
if (!srcMrTy) {
op.emitError("pto.subset source must be lowered to memref first");
signalPassFailure();
Expand All @@ -833,6 +824,8 @@ struct PTOViewToMemrefPass

// 3. Offsets (mixed)
SmallVector<OpFoldResult> mixedOffsets;
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(op.getOffsets().size());
for (Value o : op.getOffsets()) {
IntegerAttr constAttr;
bool isStatic = false;
Expand All @@ -843,10 +836,13 @@ struct PTOViewToMemrefPass
constAttr = rewriter.getIndexAttr(cInt.value());
isStatic = true;
}
if (isStatic)
if (isStatic) {
mixedOffsets.push_back(constAttr);
else
staticOffsets.push_back(constAttr.getInt());
} else {
mixedOffsets.push_back(ensureIndex(rewriter, loc, o, op));
staticOffsets.push_back(ShapedType::kDynamic);
}
}

// 3.1 Layout-aware checks for boxed tiles (SLayout != NoneBox)
Expand Down Expand Up @@ -948,7 +944,26 @@ struct PTOViewToMemrefPass
}
(void)srcOffset;

auto resultLayout = StridedLayoutAttr::get(ctx, ShapedType::kDynamic, srcStrides);
// If source offset/strides and subset offsets are all static, preserve
// a static offset in the result type to satisfy memref.subview verifier.
int64_t resultOffset = ShapedType::kDynamic;
bool allOffsetsStatic = (srcOffset != ShapedType::kDynamic);
if (allOffsetsStatic) {
int64_t totalOffset = srcOffset;
for (size_t i = 0; i < staticSizes.size(); ++i) {
if (i >= static_cast<size_t>(srcStrides.size()) ||
srcStrides[i] == ShapedType::kDynamic ||
staticOffsets[i] == ShapedType::kDynamic) {
allOffsetsStatic = false;
break;
}
totalOffset += staticOffsets[i] * srcStrides[i];
}
if (allOffsetsStatic)
resultOffset = totalOffset;
}

auto resultLayout = StridedLayoutAttr::get(ctx, resultOffset, srcStrides);
auto resultMemRefType =
MemRefType::get(staticSizes, srcMrTy.getElementType(), resultLayout,
srcMrTy.getMemorySpace());
Expand All @@ -960,7 +975,7 @@ struct PTOViewToMemrefPass
mixedStrides.push_back(rewriter.getIndexAttr(1));

auto sv = rewriter.create<memref::SubViewOp>(
loc, resultMemRefType, src, mixedOffsets, mixedSizes, mixedStrides);
loc, resultMemRefType, subviewSrc, mixedOffsets, mixedSizes, mixedStrides);

// 6. Re-bind tile metadata (config + valid dims)
Value parentVRow;
Expand Down
Loading
Loading