//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include using namespace mlir; using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// FailureOr mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType destType) { auto srcType = llvm::cast(value.getType()); // Element type, rank and memory space must match. if (srcType.getElementType() != destType.getElementType()) return failure(); if (srcType.getMemorySpace() != destType.getMemorySpace()) return failure(); if (srcType.getRank() != destType.getRank()) return failure(); // In case the affine maps are different, we may need to use a copy if we go // from dynamic to static offset or stride (the canonicalization cannot know // at this point that it is really cast compatible). auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { int64_t sourceOffset, targetOffset; SmallVector sourceStrides, targetStrides; if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || failed(getStridesAndOffset(target, targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); }; if (dynamicToStatic(sourceOffset, targetOffset)) return false; for (auto it : zip(sourceStrides, targetStrides)) if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) return false; return true; }; // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To // ensure that we only generate casts that always succeed at runtime, we check // a fix extra conditions in `isGuaranteedCastCompatible`. if (memref::CastOp::areCastCompatible(srcType, destType) && isGuaranteedCastCompatible(srcType, destType)) { Value casted = b.create(value.getLoc(), destType, value); return casted; } auto loc = value.getLoc(); SmallVector dynamicOperands; for (int i = 0; i < destType.getRank(); ++i) { if (destType.getShape()[i] != ShapedType::kDynamic) continue; Value size = b.create(loc, value, i); dynamicOperands.push_back(size); } // TODO: Use alloc/memcpy callback from BufferizationOptions if called via // BufferizableOpInterface impl of ToMemrefOp. Value copy = b.create(loc, destType, dynamicOperands); b.create(loc, value, copy); return copy; } /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. LogicalResult mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, ToMemrefOp toMemref) { auto memrefToTensor = toMemref.getTensor().getDefiningOp(); if (!memrefToTensor) return failure(); Type srcType = memrefToTensor.getMemref().getType(); Type destType = toMemref.getType(); // Directly rewrite if the type did not change. if (srcType == destType) { rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); return success(); } auto rankedSrcType = llvm::dyn_cast(srcType); auto rankedDestType = llvm::dyn_cast(destType); auto unrankedSrcType = llvm::dyn_cast(srcType); // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { FailureOr replacement = castOrReallocMemRefValue( rewriter, memrefToTensor.getMemref(), rankedDestType); if (failed(replacement)) return failure(); rewriter.replaceOp(toMemref, *replacement); return success(); } // Unranked memref -> Ranked memref cast: May require a copy. // TODO: Not implemented at the moment. if (unrankedSrcType && rankedDestType) return failure(); // Unranked memref -> unranked memref cast // Ranked memref -> unranked memref cast: No copy needed. assert(memref::CastOp::areCastCompatible(srcType, destType) && "expected that types are cast compatible"); rewriter.replaceOpWithNewOp(toMemref, destType, memrefToTensor.getMemref()); return success(); } void mlir::bufferization::populateDynamicDimSizes( OpBuilder &b, Location loc, Value shapedValue, SmallVector &dynamicDims) { auto shapedType = llvm::cast(shapedValue.getType()); for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { if (llvm::isa(shapedType)) { dynamicDims.push_back(b.create(loc, shapedValue, i)); } else { assert(llvm::isa(shapedType) && "expected tensor"); dynamicDims.push_back(b.create(loc, shapedValue, i)); } } } } //===----------------------------------------------------------------------===// // AllocTensorOp //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { OpBuilder::InsertionGuard g(rewriter); Location loc = getLoc(); // Nothing to do for dead AllocTensorOps. if (getOperation()->getUses().empty()) { rewriter.eraseOp(getOperation()); return success(); } // Get "copy" buffer. Value copyBuffer; if (getCopy()) { FailureOr maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); if (failed(maybeCopyBuffer)) return failure(); copyBuffer = *maybeCopyBuffer; } // Create memory allocation. auto allocType = bufferization::getBufferType(getResult(), options); if (failed(allocType)) return failure(); SmallVector dynamicDims = getDynamicSizes(); if (getCopy()) { assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); } FailureOr alloc = options.createAlloc( rewriter, loc, llvm::cast(*allocType), dynamicDims); if (failed(alloc)) return failure(); // Create memory copy (if any). if (getCopy()) { if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) return failure(); } // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); return success(); } bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state) { // AllocTensorOps do not write unless they have a `copy` value. return static_cast(getCopy()); } bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, const AnalysisState &state) { assert(opOperand.getOperandNumber() == getNumOperands() - 1 && "expected copy operand"); return true; } bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, const AnalysisState &state) { assert(opOperand.getOperandNumber() == getNumOperands() - 1 && "expected copy operand"); return false; } AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { // This is a new allocation. It does not alias with any other buffer. return {}; } FailureOr AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, SmallVector &invocationStack) { assert(value == getResult() && "invalid value"); // Compute memory space of this allocation. Attribute memorySpace; if (getMemorySpace().has_value()) { memorySpace = *getMemorySpace(); } else if (getCopy()) { auto copyBufferType = bufferization::getBufferType(getCopy(), options, invocationStack); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); } else if (options.defaultMemorySpace.has_value()) { memorySpace = *options.defaultMemorySpace; } else { return getOperation()->emitError("could not infer memory space"); } return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); } LogicalResult AllocTensorOp::verify() { if (getCopy() && !getDynamicSizes().empty()) return emitError("dynamic sizes not needed when copying a tensor"); if (!getCopy() && getType().getNumDynamicDims() != static_cast(getDynamicSizes().size())) return emitError("expected ") << getType().getNumDynamicDims() << " dynamic sizes"; if (getCopy() && getCopy().getType() != getType()) return emitError("expected that `copy` and return type match"); // For sparse tensor allocation, we require that none of its // uses escapes the function boundary directly. if (sparse_tensor::getSparseTensorEncoding(getType())) { for (auto &use : getOperation()->getUses()) if (isa( use.getOwner())) return emitError("sparse tensor allocation should not escape function"); } return success(); } void AllocTensorOp::build(OpBuilder &builder, OperationState &result, RankedTensorType type, ValueRange dynamicSizes) { build(builder, result, type, dynamicSizes, /*copy=*/Value(), /*size_hint=*/Value(), /*memory_space=*/IntegerAttr()); } void AllocTensorOp::build(OpBuilder &builder, OperationState &result, RankedTensorType type, ValueRange dynamicSizes, Value copy) { build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), /*memory_space=*/IntegerAttr()); } void AllocTensorOp::build(OpBuilder &builder, OperationState &result, TensorType type, ValueRange dynamicSizes, Value copy, IntegerAttr memorySpace) { build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(), memorySpace); } namespace { /// Change the type of the result of a `bufferization.alloc_tensor` by making /// the result type statically sized along dimension that in the original /// operation where defined as dynamic, but the size was defined using a /// `constant` op. For example: /// /// %c5 = arith.constant 5: index /// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor /// /// to /// /// %0 = bufferization.alloc_tensor(%arg0) : tensor struct ReplaceStaticShapeDims : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocTensorOp op, PatternRewriter &rewriter) const override { if (op.getCopy()) return failure(); SmallVector newShape = llvm::to_vector(op.getType().getShape()); SmallVector newDynamicSizes; unsigned int dynValCounter = 0; for (int64_t i = 0; i < op.getType().getRank(); ++i) { if (!op.isDynamicDim(i)) continue; Value value = op.getDynamicSizes()[dynValCounter++]; APInt intVal; if (matchPattern(value, m_ConstantInt(&intVal))) { int64_t dim = intVal.getSExtValue(); if (dim >= 0) newShape[i] = intVal.getSExtValue(); else newDynamicSizes.push_back(value); } else { newDynamicSizes.push_back(value); } } RankedTensorType newType = RankedTensorType::get( newShape, op.getType().getElementType(), op.getType().getEncoding()); if (newType == op.getType()) return failure(); auto newOp = rewriter.create( op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); rewriter.replaceOpWithNewOp(op, op.getType(), newOp); return success(); } }; struct FoldDimOfAllocTensorOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { std::optional maybeConstantIndex = dimOp.getConstantIndex(); auto allocTensorOp = dimOp.getSource().getDefiningOp(); if (!allocTensorOp || !maybeConstantIndex) return failure(); if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex)) return failure(); rewriter.replaceOp( dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex)); return success(); } }; } // namespace void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *ctx) { results.add(ctx); } LogicalResult AllocTensorOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { auto shapes = llvm::to_vector<4>( llvm::map_range(llvm::seq(0, getType().getRank()), [&](int64_t dim) -> OpFoldResult { if (isDynamicDim(dim)) return getDynamicSize(builder, dim); return builder.getIndexAttr(getStaticSize(dim)); })); reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector dynamicSizesOperands; if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) || parser.parseRParen()) return failure(); ParseResult copyKeyword = parser.parseOptionalKeyword("copy"); OpAsmParser::UnresolvedOperand copyOperand; if (copyKeyword.succeeded()) if (parser.parseLParen() || parser.parseOperand(copyOperand) || parser.parseRParen()) return failure(); ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint"); OpAsmParser::UnresolvedOperand sizeHintOperand; if (sizeHintKeyword.succeeded()) if (parser.parseEqual() || parser.parseOperand(sizeHintOperand)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) return failure(); TensorType type; if (parser.parseCustomTypeWithFallback(type)) return failure(); result.addTypes(type); Type indexType = parser.getBuilder().getIndexType(); if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands)) return failure(); if (copyKeyword.succeeded()) if (parser.resolveOperand(copyOperand, type, result.operands)) return failure(); if (sizeHintKeyword.succeeded()) if (parser.resolveOperand(sizeHintOperand, indexType, result.operands)) return failure(); result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(), parser.getBuilder().getDenseI32ArrayAttr( {static_cast(dynamicSizesOperands.size()), static_cast(copyKeyword.succeeded()), static_cast(sizeHintKeyword.succeeded())})); return success(); } void AllocTensorOp::print(OpAsmPrinter &p) { p << "(" << getDynamicSizes() << ")"; if (getCopy()) p << " copy(" << getCopy() << ")"; if (getSizeHint()) p << " size_hint=" << getSizeHint(); p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ AllocTensorOp::getOperandSegmentSizeAttr()}); p << " : "; auto type = getResult().getType(); if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type)) p.printStrippedAttrOrType(validType); else p << type; } Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { assert(isDynamicDim(idx) && "expected dynamic dim"); if (getCopy()) return b.create(getLoc(), getCopy(), idx); return getOperand(getIndexOfDynamicSize(idx)); } //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// OpFoldResult CloneOp::fold(FoldAdaptor adaptor) { return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); } namespace { /// Merge the clone and its source (by converting the clone to a cast) when /// possible. struct SimplifyClones : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CloneOp cloneOp, PatternRewriter &rewriter) const override { if (cloneOp.use_empty()) { rewriter.eraseOp(cloneOp); return success(); } Value source = cloneOp.getInput(); if (source.getType() != cloneOp.getType() && !memref::CastOp::areCastCompatible({source.getType()}, {cloneOp.getType()})) return failure(); // Aims to find the dealloc op for the canonical source // which otherwise could prevent removal of unnecessary allocs. Value canonicalSource = source; while (auto iface = dyn_cast_or_null( canonicalSource.getDefiningOp())) canonicalSource = iface.getViewSource(); std::optional maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); // Skip if either of them has > 1 deallocate operations. if (!maybeCloneDeallocOp.has_value()) return failure(); std::optional maybeSourceDeallocOp = memref::findDealloc(canonicalSource); if (!maybeSourceDeallocOp.has_value()) return failure(); Operation *cloneDeallocOp = *maybeCloneDeallocOp; Operation *sourceDeallocOp = *maybeSourceDeallocOp; // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. if (cloneDeallocOp && sourceDeallocOp && cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) return failure(); Block *currentBlock = cloneOp->getBlock(); Operation *redundantDealloc = nullptr; if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { redundantDealloc = cloneDeallocOp; } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { redundantDealloc = sourceDeallocOp; } if (!redundantDealloc) return failure(); // Safety check that there are no other deallocations inbetween // cloneOp and redundantDealloc, as otherwise we might deallocate an alias // of source before the uses of the clone. With alias information, we could // restrict this to only fail of the dealloc's operand is an alias // of the source. for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; pos = pos->getNextNode()) { auto effectInterface = dyn_cast(pos); if (!effectInterface) continue; if (effectInterface.hasEffect()) return failure(); } if (source.getType() != cloneOp.getType()) source = rewriter.create(cloneOp.getLoc(), cloneOp.getType(), source); rewriter.replaceOp(cloneOp, source); rewriter.eraseOp(redundantDealloc); return success(); } }; } // namespace void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // DeallocTensorOp //===----------------------------------------------------------------------===// LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { FailureOr buffer = getBuffer(rewriter, getTensor(), options); if (failed(buffer)) return failure(); rewriter.create(getLoc(), *buffer); rewriter.eraseOp(getOperation()); return success(); } //===----------------------------------------------------------------------===// // MaterializeInDestinationOp //===----------------------------------------------------------------------===// bool MaterializeInDestinationOp::bufferizesToMemoryRead( OpOperand &opOperand, const AnalysisState &state) { return opOperand == getSourceMutable(); } bool MaterializeInDestinationOp::bufferizesToMemoryWrite( OpOperand &opOperand, const AnalysisState &state) { if (opOperand == getDestMutable()) { assert(isa(getDest().getType()) && "expected tensor type"); return true; } return false; } bool MaterializeInDestinationOp::mustBufferizeInPlace( OpOperand &opOperand, const AnalysisState &state) { // The source is only read and not written, so it always bufferizes in-place // by default. The destination is written and is forced to bufferize in-place // (if it is a tensor). return true; } AliasingValueList MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { if (opOperand == getDestMutable()) { assert(isa(getDest().getType()) && "expected tensor type"); return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; } return {}; } LogicalResult MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { bool tensorDest = isa(getDest().getType()); Value buffer; if (tensorDest) { FailureOr maybeBuffer = getBuffer(rewriter, getDest(), options); if (failed(maybeBuffer)) return failure(); buffer = *maybeBuffer; } else { assert(isa(getDest().getType()) && "expected memref type"); buffer = getDest(); } auto srcBuffer = getBuffer(rewriter, getSource(), options); if (failed(srcBuffer)) return failure(); if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer))) return failure(); replaceOpWithBufferizedValues(rewriter, getOperation(), tensorDest ? ValueRange(buffer) : ValueRange()); return success(); } bool MaterializeInDestinationOp::bufferizesToElementwiseAccess( const AnalysisState &state, ArrayRef opOperands) { // As elements are copied from the "source" buffer to the "dest" buffer, // already copied elements are not read a second time. return true; } LogicalResult MaterializeInDestinationOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { if (getOperation()->getNumResults() == 1) { assert(isa(getDest().getType()) && "expected tensor type"); reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest()); } return success(); } Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, Location loc) { if (isa(getDest().getType())) { // The subset is the entire destination tensor. return getDest(); } // The "restrict" attribute is transferred from this op to the newly created // to_tensor op. If this op does not the "restrict" attribute, the subset // extraction cannot be built because there is no guarantee that there is no // pre-existing "restrict" to_tensor op with the same/an aliasing destination. if (!getRestrict()) return {}; // Build a bufferization.to_tensor op. assert(isa(getDest().getType()) && "expected memref type"); assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); return builder.create(loc, getDest(), /*restrict=*/true, getWritable()); } bool MaterializeInDestinationOp::isEquivalentSubset( Value candidate, function_ref equivalenceFn) { return equivalenceFn(getDest(), candidate); } SmallVector MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() { return {getDest()}; } OpOperand &MaterializeInDestinationOp::getSourceOperand() { return getOperation()->getOpOperand(0) /*source*/; } bool MaterializeInDestinationOp::operatesOnEquivalentSubset( SubsetOpInterface subsetOp, function_ref equivalenceFn) { return false; } bool MaterializeInDestinationOp::operatesOnDisjointSubset( SubsetOpInterface subsetOp, function_ref equivalenceFn) { return false; } LogicalResult MaterializeInDestinationOp::verify() { if (!isa(getDest().getType())) return emitOpError("'dest' must be a tensor or a memref"); if (auto destType = dyn_cast(getDest().getType())) { if (getOperation()->getNumResults() != 1) return emitOpError("tensor 'dest' implies exactly one tensor result"); if (destType != getResult().getType()) return emitOpError("result and 'dest' types must match"); } if (isa(getDest().getType()) && getOperation()->getNumResults() != 0) return emitOpError("memref 'dest' implies zero results"); if (getRestrict() && !isa(getDest().getType())) return emitOpError("'restrict' is valid only for memref destinations"); if (getWritable() != isa(getDest().getType())) return emitOpError("'writable' must be specified if and only if the " "destination is of memref type"); return success(); } void MaterializeInDestinationOp::build(OpBuilder &builder, OperationState &state, Value source, Value dest) { auto destTensorType = dyn_cast(dest.getType()); build(builder, state, /*result=*/destTensorType ? destTensorType : Type(), source, dest); } bool MaterializeInDestinationOp::isWritable(Value value, const AnalysisState &state) { return isa(getDest().getType()) ? true : getWritable(); } MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() { return getDestMutable(); } void MaterializeInDestinationOp::getEffects( SmallVectorImpl> &effects) { if (isa(getDest().getType())) effects.emplace_back(MemoryEffects::Write::get(), getDest(), SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // ToTensorOp //===----------------------------------------------------------------------===// bool ToTensorOp::isWritable(Value value, const AnalysisState &state) { return getWritable(); } OpFoldResult ToTensorOp::fold(FoldAdaptor) { if (auto toMemref = getMemref().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. if (toMemref->getBlock() == this->getOperation()->getBlock() && toMemref->getNextNode() == this->getOperation()) return toMemref.getTensor(); return {}; } namespace { struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp dimOp, PatternRewriter &rewriter) const override { auto memrefToTensorOp = dimOp.getSource().getDefiningOp(); if (!memrefToTensorOp) return failure(); rewriter.replaceOpWithNewOp( dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex()); return success(); } }; } // namespace void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // ToMemrefOp //===----------------------------------------------------------------------===// OpFoldResult ToMemrefOp::fold(FoldAdaptor) { if (auto memrefToTensor = getTensor().getDefiningOp()) if (memrefToTensor.getMemref().getType() == getType()) return memrefToTensor.getMemref(); return {}; } namespace { /// Replace tensor.cast + to_memref by to_memref + memref.cast. struct ToMemrefOfCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { auto tensorCastOperand = toMemref.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto srcTensorType = llvm::dyn_cast( tensorCastOperand.getOperand().getType()); if (!srcTensorType) return failure(); auto memrefType = MemRefType::get(srcTensorType.getShape(), srcTensorType.getElementType()); Value memref = rewriter.create(toMemref.getLoc(), memrefType, tensorCastOperand.getOperand()); rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), memref); return success(); } }; /// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a /// cast if necessary. struct ToMemrefToTensorFolding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { return foldToMemrefToTensorPair(rewriter, toMemref); } }; /// Fold a load on a to_memref operation into an tensor.extract on the /// corresponding tensor. struct LoadOfToMemref : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::LoadOp load, PatternRewriter &rewriter) const override { auto toMemref = load.getMemref().getDefiningOp(); if (!toMemref) return failure(); rewriter.replaceOpWithNewOp(load, toMemref.getTensor(), load.getIndices()); return success(); } }; /// Fold dim of a to_memref into the dim of the tensor. struct DimOfCastOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::DimOp dimOp, PatternRewriter &rewriter) const override { auto castOp = dimOp.getSource().getDefiningOp(); if (!castOp) return failure(); Value newSource = castOp.getOperand(); rewriter.replaceOpWithNewOp(dimOp, newSource, dimOp.getIndex()); return success(); } }; } // namespace void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. (void)foldToMemrefToTensorPair(rewriter, *this); // Note: The return value of `bufferize` indicates whether there was an error // or not. (And not whether the pattern matched or not.) return success(); } std::optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc) .getOperation(); } std::optional CloneOp::buildClone(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc).getResult(); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// LogicalResult DeallocOp::inferReturnTypes( MLIRContext *context, std::optional<::mlir::Location> location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { DeallocOpAdaptor adaptor(operands, attributes, properties, regions); inferredReturnTypes = SmallVector(adaptor.getRetained().size(), IntegerType::get(context, 1)); return success(); } LogicalResult DeallocOp::verify() { if (getMemrefs().size() != getConditions().size()) return emitOpError( "must have the same number of conditions as memrefs to deallocate"); if (getRetained().size() != getUpdatedConditions().size()) return emitOpError("must have the same number of updated conditions " "(results) as retained operands"); return success(); } static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter) { if (deallocOp.getMemrefs() == memrefs && deallocOp.getConditions() == conditions) return failure(); rewriter.modifyOpInPlace(deallocOp, [&]() { deallocOp.getMemrefsMutable().assign(memrefs); deallocOp.getConditionsMutable().assign(conditions); }); return success(); } namespace { /// Remove duplicate values in the list of memrefs to be deallocated. We need to /// make sure the corresponding condition value is updated accordingly since /// their two conditions might not cover the same set of cases. In that case, we /// have to combine them (by computing the disjunction of them). /// Example: /// ```mlir /// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2) /// ``` /// is canonicalized to /// ```mlir /// %0 = arith.ori %arg1, %arg2 : i1 /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0) /// ``` struct DeallocRemoveDuplicateDeallocMemrefs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { // Unique memrefs to be deallocated. DenseMap memrefToCondition; SmallVector newMemrefs, newConditions; for (auto [i, memref, cond] : llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) { if (memrefToCondition.count(memref)) { // If the dealloc conditions don't match, we need to make sure that the // dealloc happens on the union of cases. Value &newCond = newConditions[memrefToCondition[memref]]; if (newCond != cond) newCond = rewriter.create(deallocOp.getLoc(), newCond, cond); } else { memrefToCondition.insert({memref, newConditions.size()}); newMemrefs.push_back(memref); newConditions.push_back(cond); } } // Return failure if we don't change anything such that we don't run into an // infinite loop of pattern applications. return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, rewriter); } }; /// Remove duplicate values in the list of retained memrefs. We need to make /// sure the corresponding result condition value is replaced properly. /// Example: /// ```mlir /// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...) /// ``` /// is canonicalized to /// ```mlir /// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>) /// ``` struct DeallocRemoveDuplicateRetainedMemrefs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { // Unique retained values DenseMap seen; SmallVector newRetained; SmallVector resultReplacementIdx; unsigned i = 0; for (auto retained : deallocOp.getRetained()) { if (seen.count(retained)) { resultReplacementIdx.push_back(seen[retained]); continue; } seen[retained] = i; newRetained.push_back(retained); resultReplacementIdx.push_back(i++); } // Return failure if we don't change anything such that we don't run into an // infinite loop of pattern applications. if (newRetained.size() == deallocOp.getRetained().size()) return failure(); // We need to create a new op because the number of results is always the // same as the number of condition operands. auto newDeallocOp = rewriter.create(deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), newRetained); SmallVector replacements( llvm::map_range(resultReplacementIdx, [&](unsigned idx) { return newDeallocOp.getUpdatedConditions()[idx]; })); rewriter.replaceOp(deallocOp, replacements); return success(); } }; /// Erase deallocation operations where the variadic list of memrefs to /// deallocate is empty. Example: /// ```mlir /// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>) /// ``` struct EraseEmptyDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { if (deallocOp.getMemrefs().empty()) { Value constFalse = rewriter.create( deallocOp.getLoc(), rewriter.getBoolAttr(false)); rewriter.replaceOp( deallocOp, SmallVector(deallocOp.getUpdatedConditions().size(), constFalse)); return success(); } return failure(); } }; /// Removes memrefs from the deallocation list if their associated condition is /// always 'false'. /// /// Example: /// ``` /// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) /// if (%arg2, %false) /// ``` /// becomes /// ``` /// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) /// ``` struct EraseAlwaysFalseDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newMemrefs, newConditions; for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { if (!matchPattern(cond, m_Zero())) { newMemrefs.push_back(memref); newConditions.push_back(cond); } } return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, rewriter); } }; /// The `memref.extract_strided_metadata` is often inserted to get the base /// memref if the operand is not already guaranteed to be the result of a memref /// allocation operation. This canonicalization pattern removes this extraction /// operation if the operand is now produced by an allocation operation (e.g., /// due to other canonicalizations simplifying the IR). /// /// Example: /// ```mlir /// %alloc = memref.alloc() : memref<2xi32> /// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata /// %alloc : memref<2xi32> -> memref, index, index, index /// bufferization.dealloc (%base_memref : memref) if (%cond) /// ``` /// is canonicalized to /// ```mlir /// %alloc = memref.alloc() : memref<2xi32> /// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond) /// ``` struct SkipExtractMetadataOfAlloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newMemrefs( llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) { auto extractStridedOp = memref.getDefiningOp(); if (!extractStridedOp) return memref; Value allocMemref = extractStridedOp.getOperand(); auto allocOp = allocMemref.getDefiningOp(); if (!allocOp) return memref; if (allocOp.getEffectOnValue(allocMemref)) return allocMemref; return memref; })); return updateDeallocIfChanged(deallocOp, newMemrefs, deallocOp.getConditions(), rewriter); } }; /// Removes pairs of `bufferization.dealloc` and alloc operations if there is no /// other user of the allocated value and the allocating operation can be safely /// removed. If the same value is present multiple times, this pattern relies on /// other canonicalization patterns to remove the duplicate first. /// /// Example: /// ```mlir /// %alloc = memref.alloc() : memref<2xi32> /// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true) /// ``` /// is canonicalized to /// ```mlir /// bufferization.dealloc (%arg0 : ...) if (%true) /// ``` struct RemoveAllocDeallocPairWhenNoOtherUsers : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newMemrefs, newConditions; SmallVector toDelete; for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { if (auto allocOp = memref.getDefiningOp()) { // Check that it is indeed an allocate effect, that the op has no other // side effects (which would not allow us to remove the op), and that // there are no other users. if (allocOp.getEffectOnValue(memref) && hasSingleEffect(allocOp, memref) && memref.hasOneUse()) { toDelete.push_back(allocOp); continue; } } newMemrefs.push_back(memref); newConditions.push_back(cond); } if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, rewriter))) return failure(); for (Operation *op : toDelete) rewriter.eraseOp(op); return success(); } }; } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { populateDeallocOpCanonicalizationPatterns(results, context); } void bufferization::populateDeallocOpCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"