//===- ExpandStridedMetadata.cpp - Simplify this operation -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// The pass expands memref operations that modify the metadata of a memref /// (sizes, offset, strides) into a sequence of easier to analyze constructs. /// In particular, this pass transforms operations into explicit sequence of /// operations that model the effect of this operation on the different /// metadata. This pass uses affine constructs to materialize these effects. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include namespace mlir { namespace memref { #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir using namespace mlir; using namespace mlir::affine; namespace { struct StridedMetadata { Value basePtr; OpFoldResult offset; SmallVector sizes; SmallVector strides; }; /// From `subview(memref, subOffset, subSizes, subStrides))` compute /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \endverbatim /// /// and return {baseBuffer, offset, sizes, strides} static FailureOr resolveSubviewStridedMetadata(RewriterBase &rewriter, memref::SubViewOp subview) { // Build a plain extract_strided_metadata(memref) from subview(memref). Location origLoc = subview.getLoc(); Value source = subview.getSource(); auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = rewriter.create(origLoc, source); auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); // Compute the new strides and offset from the base strides and offset: // newStride#i = baseStride#i * subStride#i // offset = baseOffset + sum(subOffsets#i * newStrides#i) SmallVector strides; SmallVector subStrides = subview.getMixedStrides(); auto origStrides = newExtractStridedMetadata.getStrides(); // Hold the affine symbols and values for the computation of the offset. SmallVector values(2 * sourceRank + 1); SmallVector symbols(2 * sourceRank + 1); bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); AffineExpr expr = symbols.front(); values[0] = ShapedType::isDynamic(sourceOffset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(sourceOffset); SmallVector subOffsets = subview.getMixedOffsets(); AffineExpr s0 = rewriter.getAffineSymbolExpr(0); AffineExpr s1 = rewriter.getAffineSymbolExpr(1); for (unsigned i = 0; i < sourceRank; ++i) { // Compute the stride. OpFoldResult origStride = ShapedType::isDynamic(sourceStrides[i]) ? origStrides[i] : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); strides.push_back(makeComposedFoldedAffineApply( rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); // Build up the computation of the offset. unsigned baseIdxForDim = 1 + 2 * i; unsigned subOffsetForDim = baseIdxForDim; unsigned origStrideForDim = baseIdxForDim + 1; expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; values[subOffsetForDim] = subOffsets[i]; values[origStrideForDim] = origStride; } // Compute the offset. OpFoldResult finalOffset = makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. auto subType = cast(subview.getType()); unsigned subRank = subType.getRank(); // The sizes of the final type are defined directly by the input sizes of // the subview. // Moreover subviews can drop some dimensions, some strides and sizes may // not end up in the final value that we are // replacing. // Do the filtering here. SmallVector subSizes = subview.getMixedSizes(); llvm::SmallBitVector droppedDims = subview.getDroppedDims(); SmallVector finalSizes; finalSizes.reserve(subRank); SmallVector finalStrides; finalStrides.reserve(subRank); for (unsigned i = 0; i < sourceRank; ++i) { if (droppedDims.test(i)) continue; finalSizes.push_back(subSizes[i]); finalStrides.push_back(strides[i]); } assert(finalSizes.size() == subRank && "Should have populated all the values at this point"); return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, finalSizes, finalStrides}; } /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// dst = reinterpret_cast baseBuffer, offset, sizes, strides /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize /// on its effects on the offset, the sizes, and the strides using affine.apply. struct SubviewFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::SubViewOp subview, PatternRewriter &rewriter) const override { FailureOr stridedMetadata = resolveSubviewStridedMetadata(rewriter, subview); if (failed(stridedMetadata)) { return rewriter.notifyMatchFailure(subview, "failed to resolve subview metadata"); } rewriter.replaceOpWithNewOp( subview, subview.getType(), stridedMetadata->basePtr, stridedMetadata->offset, stridedMetadata->sizes, stridedMetadata->strides); return success(); } }; /// Pattern to replace `extract_strided_metadata(subview)` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \verbatim /// /// with `baseBuffer`, `offset`, `sizes` and `strides` being /// the replacements for the original `extract_strided_metadata`. struct ExtractStridedMetadataOpSubviewFolder : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto subviewOp = op.getSource().getDefiningOp(); if (!subviewOp) return failure(); FailureOr stridedMetadata = resolveSubviewStridedMetadata(rewriter, subviewOp); if (failed(stridedMetadata)) { return rewriter.notifyMatchFailure( op, "failed to resolve metadata in terms of source subview op"); } Location loc = subviewOp.getLoc(); SmallVector results; results.reserve(subviewOp.getType().getRank() * 2 + 2); results.push_back(stridedMetadata->basePtr); results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->offset)); results.append( getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); results.append(getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->strides)); rewriter.replaceOp(op, results); return success(); } }; /// Compute the expanded sizes of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// sizes#i = /// baseSizes#groupId / product(expandShapeSizes#j, /// for j in group excluding reassIdx#i) /// Where reassIdx#i is the reassociation index at index i in \p groupId. /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef origSizes, unsigned groupId) { SmallVector reassocGroup = expandShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); unsigned groupSize = reassocGroup.size(); SmallVector expandedSizes(groupSize); uint64_t productOfAllStaticSizes = 1; std::optional dynSizeIdx; MemRefType expandShapeType = expandShape.getResultType(); // Fill up all the statically known sizes. for (unsigned i = 0; i < groupSize; ++i) { uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); if (ShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; } productOfAllStaticSizes *= dimSize; expandedSizes[i] = builder.getIndexAttr(dimSize); } // Compute the dynamic size using the original size and all the other known // static sizes: // expandSize = origSize / productOfAllStaticSizes. if (dynSizeIdx) { AffineExpr s0 = builder.getAffineSymbolExpr(0); expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes), origSizes[groupId]); } return expandedSizes; } /// Compute the expanded strides of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// strides#i = /// origStrides#reassDim * product(expandShapeSizes#j, for j in /// reassIdx#i+1..reassIdx#i+group.size-1) /// /// Where reassIdx#i is the reassociation index for at index i in \p groupId /// and expandShapeSizes#j is either: /// - The constant size at dimension j, derived directly from the result type of /// the expand_shape op, or /// - An affine expression: baseSizes#reassDim / product of all constant sizes /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic /// element.) /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef origSizes, ArrayRef origStrides, unsigned groupId) { SmallVector reassocGroup = expandShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); unsigned groupSize = reassocGroup.size(); MemRefType expandShapeType = expandShape.getResultType(); std::optional dynSizeIdx; // Fill up the expanded strides, with the information we can deduce from the // resulting shape. uint64_t currentStride = 1; SmallVector expandedStrides(groupSize); for (int i = groupSize - 1; i >= 0; --i) { expandedStrides[i] = builder.getIndexAttr(currentStride); uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); if (ShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; } currentStride *= dimSize; } // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) ? origStrides[groupId] : builder.getIndexAttr(strides[groupId]); // Apply the original stride to all the strides. int64_t doneStrideIdx = 0; // If we saw a dynamic dimension, we need to fix-up all the strides up to // that dimension with the dynamic size. if (dynSizeIdx) { int64_t productOfAllStaticSizes = currentStride; assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && "We shouldn't be able to change dynamicity"); OpFoldResult origSize = origSizes[groupId]; AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { int64_t baseExpandedStride = cast(expandedStrides[doneStrideIdx].get()) .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, {origSize, origStride}); } } // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { int64_t baseExpandedStride = cast(expandedStrides[doneStrideIdx].get()) .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } return expandedStrides; } /// Produce an OpFoldResult object with \p builder at \p loc representing /// `prod(valueOrConstant#i, for i in {indices})`, /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, /// values[i] otherwise. /// /// \pre for all index in indices: index < values.size() /// \pre for all index in indices: index < maybeConstants.size() static OpFoldResult getProductOfValues(ArrayRef indices, OpBuilder &builder, Location loc, ArrayRef maybeConstants, ArrayRef values, llvm::function_ref isDynamic) { AffineExpr productOfValues = builder.getAffineConstantExpr(1); SmallVector inputValues; unsigned numberOfSymbols = 0; unsigned groupSize = indices.size(); for (unsigned i = 0; i < groupSize; ++i) { productOfValues = productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++); unsigned srcIdx = indices[i]; int64_t maybeConstant = maybeConstants[srcIdx]; inputValues.push_back(isDynamic(maybeConstant) ? values[srcIdx] : builder.getIndexAttr(maybeConstant)); } return makeComposedFoldedAffineApply(builder, loc, productOfValues, inputValues); } /// Compute the collapsed size of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// Conceptually this helper function computes: /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. /// /// TODO: Move this utility function directly within CollapseShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef origSizes, unsigned groupId) { SmallVector collapsedSize; MemRefType collapseShapeType = collapseShape.getResultType(); uint64_t size = collapseShapeType.getDimSize(groupId); if (!ShapedType::isDynamic(size)) { collapsedSize.push_back(builder.getIndexAttr(size)); return collapsedSize; } // We are dealing with a dynamic size. // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); auto sourceType = cast(source.getType()); SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; collapsedSize.push_back(getProductOfValues( reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), origSizes, ShapedType::isDynamic)); return collapsedSize; } /// Compute the collapsed stride of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// Conceptually this helper function returns the stride of the inner most /// dimension of that group in the original shape. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. static SmallVector getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef origSizes, ArrayRef origStrides, unsigned groupId) { SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); SmallVector groupStrides; ArrayRef srcShape = sourceType.getShape(); for (int64_t currentDim : reassocGroup) { // Skip size-of-1 dimensions, since right now their strides may be // meaningless. // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless // they are truly contiguous. When they are truly contiguous, we shouldn't // need to skip them. if (srcShape[currentDim] == 1) continue; int64_t currentStride = strides[currentDim]; groupStrides.push_back(ShapedType::isDynamic(currentStride) ? origStrides[currentDim] : builder.getIndexAttr(currentStride)); } if (groupStrides.empty()) { // We're dealing with a 1x1x...x1 shape. The stride is meaningless, // but we still have to make the type system happy. MemRefType collapsedType = collapseShape.getResultType(); auto [collapsedStrides, collapsedOffset] = getStridesAndOffset(collapsedType); int64_t finalStride = collapsedStrides[groupId]; if (ShapedType::isDynamic(finalStride)) { // Look for a dynamic stride. At this point we don't know which one is // desired, but they are all equally good/bad. for (int64_t currentDim : reassocGroup) { assert(srcShape[currentDim] == 1 && "We should be dealing with 1x1x...x1"); if (ShapedType::isDynamic(strides[currentDim])) return {origStrides[currentDim]}; } llvm_unreachable("We should have found a dynamic stride"); } return {builder.getIndexAttr(finalStride)}; } // For the general case, we just want the minimum stride // since the collapsed dimensions are contiguous. auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(), builder.getContext()); return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, groupStrides)}; } /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With /// /// \verbatim /// baseBuffer, offset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// sizes = getReshapedSizes(reshapeLike) /// strides = getReshapedStrides(reshapeLike) /// \endverbatim /// /// /// Notice that `baseBuffer` and `offset` are unchanged. /// /// In other words, get rid of the expand_shape in that expression and /// materialize its effects on the sizes and the strides using affine apply. template (*getReshapedSizes)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef /*origSizes*/, unsigned /*groupId*/), SmallVector (*getReshapedStrides)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef /*origSizes*/, ArrayRef /*origStrides*/, unsigned /*groupId*/)> struct ReshapeFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = rewriter.create(origLoc, source); // Collect statically known information. auto [strides, offset] = getStridesAndOffset(sourceType); MemRefType reshapeType = reshape.getResultType(); unsigned reshapeRank = reshapeType.getRank(); OpFoldResult offsetOfr = ShapedType::isDynamic(offset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(offset); // Get the special case of 0-D out of the way. if (sourceRank == 0) { SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); auto memrefDesc = rewriter.create( origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), offsetOfr, /*sizes=*/ones, /*strides=*/ones); rewriter.replaceOp(reshape, memrefDesc.getResult()); return success(); } SmallVector finalSizes; finalSizes.reserve(reshapeRank); SmallVector finalStrides; finalStrides.reserve(reshapeRank); // Compute the reshaped strides and sizes from the base strides and sizes. SmallVector origSizes = getAsOpFoldResult(newExtractStridedMetadata.getSizes()); SmallVector origStrides = getAsOpFoldResult(newExtractStridedMetadata.getStrides()); unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); for (; idx != endIdx; ++idx) { SmallVector reshapedSizes = getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); SmallVector reshapedStrides = getReshapedStrides( reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); unsigned groupSize = reshapedSizes.size(); for (unsigned i = 0; i < groupSize; ++i) { finalSizes.push_back(reshapedSizes[i]); finalStrides.push_back(reshapedStrides[i]); } } assert(((isa(reshape) && idx == sourceRank) || (isa(reshape) && idx == reshapeRank)) && "We should have visited all the input dimensions"); assert(finalSizes.size() == reshapeRank && "We should have populated all the values"); auto memrefDesc = rewriter.create( origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), offsetOfr, finalSizes, finalStrides); rewriter.replaceOp(reshape, memrefDesc.getResult()); return success(); } }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// /// With /// /// ``` /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref /// offset = 0 /// sizes = allocSizes /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) /// ``` /// /// The transformation only applies if the allocLikeOp has been normalized. /// In other words, the affine_map must be an identity. template struct ExtractStridedMetadataOpAllocFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto allocLikeOp = op.getSource().getDefiningOp(); if (!allocLikeOp) return failure(); auto memRefType = cast(allocLikeOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); Location loc = op.getLoc(); int rank = memRefType.getRank(); // Collect the sizes. ValueRange dynamic = allocLikeOp.getDynamicSizes(); SmallVector sizes; sizes.reserve(rank); unsigned dynamicPos = 0; for (int64_t size : memRefType.getShape()) { if (ShapedType::isDynamic(size)) sizes.push_back(dynamic[dynamicPos++]); else sizes.push_back(rewriter.getIndexAttr(size)); } // Strides (just creates identity strides). SmallVector strides(rank, rewriter.getIndexAttr(1)); AffineExpr expr = rewriter.getAffineConstantExpr(1); unsigned symbolNumber = 0; for (int i = rank - 2; i >= 0; --i) { expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++); assert(i + 1 + symbolNumber == sizes.size() && "The ArrayRef should encompass the last #symbolNumber sizes"); ArrayRef sizesInvolvedInStride(&sizes[i + 1], symbolNumber); strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, sizesInvolvedInStride); } // Put all the values together to replace the results. SmallVector results; results.reserve(rank * 2 + 2); auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (op.getBaseBuffer().use_empty()) { results.push_back(nullptr); } else { if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); else results.push_back(rewriter.create( loc, baseBufferType, allocLikeOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); } // Offset. results.push_back(rewriter.create(loc, offset)); for (OpFoldResult size : sizes) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); for (OpFoldResult stride : strides) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); rewriter.replaceOp(op, results); return success(); } }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(get_global)` /// /// With /// /// ``` /// base = reinterpret_cast get_global to a flat memref /// offset = 0 /// sizes = allocSizes /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) /// ``` /// /// It is expected that the memref.get_global op has static shapes /// and identity affine_map for the layout. struct ExtractStridedMetadataOpGetGlobalFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto getGlobalOp = op.getSource().getDefiningOp(); if (!getGlobalOp) return failure(); auto memRefType = cast(getGlobalOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure( getGlobalOp, "get-global operation result should have been normalized"); } Location loc = op.getLoc(); int rank = memRefType.getRank(); // Collect the sizes. ArrayRef sizes = memRefType.getShape(); assert(!llvm::any_of(sizes, ShapedType::isDynamic) && "unexpected dynamic shape for result of `memref.get_global` op"); // Strides (just creates identity strides). SmallVector strides = computeSuffixProduct(sizes); // Put all the values together to replace the results. SmallVector results; results.reserve(rank * 2 + 2); auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); else results.push_back(rewriter.create( loc, baseBufferType, getGlobalOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); // Offset. results.push_back(rewriter.create(loc, offset)); for (auto size : sizes) results.push_back(rewriter.create(loc, size)); for (auto stride : strides) results.push_back(rewriter.create(loc, stride)); rewriter.replaceOp(op, results); return success(); } }; /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the /// source of the ViewLikeOp. class RewriteExtractAlignedPointerAsIndexOfViewLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp(); if (!viewLikeOp) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); }); return success(); } }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata( /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = srcOffset /// sizes = srcSizes /// strides = srcStrides /// ``` /// /// In other words, consume the `reinterpret_cast` and apply its effects /// on the offset, sizes, and strides. class ExtractStridedMetadataOpReinterpretCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, PatternRewriter &rewriter) const override { auto reinterpretCastOp = extractStridedMetadataOp.getSource() .getDefiningOp(); if (!reinterpretCastOp) return failure(); Location loc = extractStridedMetadataOp.getLoc(); // Check if the source is suitable for extract_strided_metadata. SmallVector inferredReturnTypes; if (failed(extractStridedMetadataOp.inferReturnTypes( rewriter.getContext(), loc, {reinterpretCastOp.getSource()}, /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, inferredReturnTypes))) return rewriter.notifyMatchFailure( reinterpretCastOp, "reinterpret_cast source's type is incompatible"); auto memrefType = cast(reinterpretCastOp.getResult().getType()); unsigned rank = memrefType.getRank(); SmallVector results; results.resize_for_overwrite(rank * 2 + 2); auto newExtractStridedMetadata = rewriter.create( loc, reinterpretCastOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); // Register the new offset. results[1] = getValueOrCreateConstantIndexOp( rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); const unsigned sizeStartIdx = 2; const unsigned strideStartIdx = sizeStartIdx + rank; SmallVector sizes = reinterpretCastOp.getMixedSizes(); SmallVector strides = reinterpretCastOp.getMixedStrides(); for (unsigned i = 0; i < rank; ++i) { results[sizeStartIdx + i] = sizes[i]; results[strideStartIdx + i] = strides[i]; } rewriter.replaceOp(extractStridedMetadataOp, getValueOrCreateConstantIndexOp(rewriter, loc, results)); return success(); } }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata( /// cast(src) to dstTy)` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = !dstTy.srcOffset.isDynamic() /// ? dstTy.srcOffset /// : extract_strided_metadata(src).offset /// sizes = for each srcSize in dstTy.srcSizes: /// !srcSize.isDynamic() /// ? srcSize // : extract_strided_metadata(src).sizes[i] /// strides = for each srcStride in dstTy.srcStrides: /// !srcStrides.isDynamic() /// ? srcStrides /// : extract_strided_metadata(src).strides[i] /// ``` /// /// In other words, consume the `cast` and apply its effects /// on the offset, sizes, and strides or compute them directly from `src`. class ExtractStridedMetadataOpCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, PatternRewriter &rewriter) const override { Value source = extractStridedMetadataOp.getSource(); auto castOp = source.getDefiningOp(); if (!castOp) return failure(); Location loc = extractStridedMetadataOp.getLoc(); // Check if the source is suitable for extract_strided_metadata. SmallVector inferredReturnTypes; if (failed(extractStridedMetadataOp.inferReturnTypes( rewriter.getContext(), loc, {castOp.getSource()}, /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, inferredReturnTypes))) return rewriter.notifyMatchFailure(castOp, "cast source's type is incompatible"); auto memrefType = cast(source.getType()); unsigned rank = memrefType.getRank(); SmallVector results; results.resize_for_overwrite(rank * 2 + 2); auto newExtractStridedMetadata = rewriter.create(loc, castOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); auto getConstantOrValue = [&rewriter](int64_t constant, OpFoldResult ofr) -> OpFoldResult { return !ShapedType::isDynamic(constant) ? OpFoldResult(rewriter.getIndexAttr(constant)) : ofr; }; auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType); assert(sourceStrides.size() == rank && "unexpected number of strides"); // Register the new offset. results[1] = getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); const unsigned sizeStartIdx = 2; const unsigned strideStartIdx = sizeStartIdx + rank; ArrayRef sourceSizes = memrefType.getShape(); SmallVector sizes = newExtractStridedMetadata.getSizes(); SmallVector strides = newExtractStridedMetadata.getStrides(); for (unsigned i = 0; i < rank; ++i) { results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); results[strideStartIdx + i] = getConstantOrValue(sourceStrides[i], strides[i]); } rewriter.replaceOp(extractStridedMetadataOp, getValueOrCreateConstantIndexOp(rewriter, loc, results)); return success(); } }; /// Replace `base, offset = /// extract_strided_metadata(extract_strided_metadata(src)#0)` /// With /// ``` /// base, ... = extract_strided_metadata(src) /// offset = 0 /// ``` class ExtractStridedMetadataOpExtractStridedMetadataFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, PatternRewriter &rewriter) const override { auto sourceExtractStridedMetadataOp = extractStridedMetadataOp.getSource() .getDefiningOp(); if (!sourceExtractStridedMetadataOp) return failure(); Location loc = extractStridedMetadataOp.getLoc(); rewriter.replaceOp(extractStridedMetadataOp, {sourceExtractStridedMetadataOp.getBaseBuffer(), getValueOrCreateConstantIndexOp( rewriter, loc, rewriter.getIndexAttr(0))}); return success(); } }; } // namespace void memref::populateExpandStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add, ReshapeFolder, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpGetGlobalFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } void memref::populateResolveExtractStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpGetGlobalFolder, ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { struct ExpandStridedMetadataPass final : public memref::impl::ExpandStridedMetadataBase< ExpandStridedMetadataPass> { void runOnOperation() override; }; } // namespace void ExpandStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateExpandStridedMetadataPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createExpandStridedMetadataPass() { return std::make_unique(); }