//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::bufferization; using namespace mlir::scf; namespace mlir { namespace scf { namespace { /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { assert(isa(type) && "expected BaseMemRefType"); assert(isa(buffer.getType()) && "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; // TODO: In case `type` has a layout map that is not the fully dynamic // one, we may not be able to cast the buffer. In that case, the loop // iter_arg's layout map must be changed (see uses of `castBuffer`). assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && "scf.while op bufferization: cast incompatible"); return b.create(buffer.getLoc(), type, buffer).getResult(); } /// Helper function for loop bufferization. Return "true" if the given value /// is guaranteed to not alias with an external tensor apart from values in /// `exceptions`. A value is external if it is defined outside of the given /// region or if it is an entry block argument of the region. static bool doesNotAliasExternalValue(Value value, Region *region, ValueRange exceptions, const OneShotAnalysisState &state) { assert(region->getBlocks().size() == 1 && "expected region with single block"); bool result = true; state.applyOnAliases(value, [&](Value alias) { if (llvm::is_contained(exceptions, alias)) return; Region *aliasRegion = alias.getParentRegion(); if (isa(alias) && !region->isProperAncestor(aliasRegion)) result = false; if (isa(alias) && !region->isAncestor(aliasRegion)) result = false; }); return result; } /// Bufferization of scf.condition. struct ConditionOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Condition operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto conditionOp = cast(op); auto whileOp = cast(conditionOp->getParentOp()); SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); FailureOr resultType = bufferization::getBufferType( whileOp.getAfterArguments()[it.index()], options); if (failed(resultType)) return failure(); Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); newArgs.push_back(buffer); } else { newArgs.push_back(value); } } replaceOpWithNewBufferizedOp( rewriter, op, conditionOp.getCondition(), newArgs); return success(); } }; /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, /// return an empty op. static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { scf::YieldOp result; for (Block &block : executeRegionOp.getRegion()) { if (auto yieldOp = dyn_cast(block.getTerminator())) { if (result) return {}; result = yieldOp; } } return result; } /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< ExecuteRegionOpInterface, scf::ExecuteRegionOp> { static bool supportsUnstructuredControlFlow() { return true; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { return true; } LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { auto executeRegionOp = cast(op); // TODO: scf.execute_region with multiple yields are not supported. if (!getUniqueYieldOp(executeRegionOp)) return op->emitOpError("op without unique scf.yield is not supported"); return success(); } AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const { if (auto bbArg = dyn_cast(value)) return getAliasingBranchOpOperands(op, bbArg, state); // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value // is considered to be aliasing with the result. auto executeRegionOp = cast(op); auto it = llvm::find(op->getOpResults(), value); assert(it != op->getOpResults().end() && "invalid value"); size_t resultNum = std::distance(op->getOpResults().begin(), it); auto yieldOp = getUniqueYieldOp(executeRegionOp); // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. if (!yieldOp) return {}; return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); auto yieldOp = getUniqueYieldOp(executeRegionOp); TypeRange newResultTypes(yieldOp.getResults()); // Create new op and move over region. auto newOp = rewriter.create(op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); // Bufferize every block. for (Block &block : newOp.getRegion()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) return failure(); // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { if (isa(it.value())) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } } // Replace old op. rewriter.replaceOp(executeRegionOp, newResults); return success(); } }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else // branches are considered to be aliasing with the result. auto ifOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), value)); OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } auto bufferType = bufferization::getBufferType(result, options); if (failed(bufferType)) return failure(); newTypes.push_back(*bufferType); } // Create new op. rewriter.setInsertionPoint(ifOp); auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto ifOp = cast(op); auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); assert(value.getDefiningOp() == op && "invalid valid"); // Determine buffer types of the true/false branches. auto opResult = cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); BaseMemRefType thenBufferType, elseBufferType; if (isa(thenValue.getType())) { // True branch was already bufferized. thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(thenValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; } if (isa(elseValue.getType())) { // False branch was already bufferized. elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(elseValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; } // Best case: Both branches have the exact same buffer type. if (thenBufferType == elseBufferType) return thenBufferType; // Memory space mismatch. if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) return op->emitError("inconsistent memory space on then/else branches"); // Layout maps are different: Promote to fully dynamic layout map. return getMemRefTypeWithFullyDynamicLayout( cast(opResult.getType()), thenBufferType.getMemorySpace()); } }; /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that /// yields memrefs. struct IndexSwitchOpInterface : public BufferizableOpInterface::ExternalModel { AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const { // IndexSwitchOps do not have tensor OpOperands. The yielded value can be // any SSA. This is similar to IfOps. auto switchOp = cast(op); int64_t resultNum = cast(value).getResultNumber(); AliasingOpOperandList result; for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { auto yieldOp = cast(switchOp.getCaseBlock(i).getTerminator()); result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent, /*isDefinite=*/false)); } auto defaultYieldOp = cast(switchOp.getDefaultBlock().getTerminator()); result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), BufferRelation::Equivalent, /*isDefinite=*/false)); return result; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto switchOp = cast(op); // Compute bufferized result types. SmallVector newTypes; for (Value result : switchOp.getResults()) { if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } auto bufferType = bufferization::getBufferType(result, options); if (failed(bufferType)) return failure(); newTypes.push_back(*bufferType); } // Create new op. rewriter.setInsertionPoint(switchOp); auto newSwitchOp = rewriter.create( switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), switchOp.getCases().size()); // Move over blocks. for (auto [src, dest] : llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) rewriter.inlineRegionBefore(src, dest, dest.begin()); rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), newSwitchOp.getDefaultRegion(), newSwitchOp.getDefaultRegion().begin()); // Replace op results. replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto switchOp = cast(op); assert(value.getDefiningOp() == op && "invalid value"); int64_t resultNum = cast(value).getResultNumber(); // Helper function to get buffer type of a case. SmallVector yieldedTypes; auto getYieldedBufferType = [&](Block &b) -> FailureOr { auto yieldOp = cast(b.getTerminator()); Value yieldedValue = yieldOp->getOperand(resultNum); if (auto bufferType = dyn_cast(yieldedValue.getType())) return bufferType; auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); return maybeBufferType; }; // Compute buffer type of the default case. auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); if (failed(maybeBufferType)) return failure(); BaseMemRefType bufferType = *maybeBufferType; // Compute buffer types of all other cases. for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); if (failed(yieldedBufferType)) return failure(); // Best case: Both branches have the exact same buffer type. if (bufferType == *yieldedBufferType) continue; // Memory space mismatch. if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) return op->emitError("inconsistent memory space on switch cases"); // Layout maps are different: Promote to fully dynamic layout map. bufferType = getMemRefTypeWithFullyDynamicLayout( cast(value.getType()), bufferType.getMemorySpace()); } return bufferType; } }; /// Helper function for loop bufferization. Return the indices of all values /// that have a tensor type. static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) if (isa(it.value().getType())) result.insert(it.index()); return result; } /// Helper function for loop bufferization. Return the indices of all /// bbArg/yielded value pairs who's buffer relation is "Equivalent". DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, ValueRange yieldedValues, const AnalysisState &state) { unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { if (!isa(bbArgs[i].getType()) || !isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); } return result; } /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr> getBuffers(RewriterBase &rewriter, MutableOperandRange operands, const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) return failure(); result.push_back(*resultBuffer); } else { result.push_back(opOperand.get()); } } return result; } /// Helper function for loop bufferization. Given a list of bbArgs of the new /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into /// ToTensorOps, so that the block body can be moved over to the new op. static SmallVector getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, const DenseSet &tensorIndices) { SmallVector result; for (const auto &it : llvm::enumerate(bbArgs)) { size_t idx = it.index(); Value val = it.value(); if (tensorIndices.contains(idx)) { result.push_back( rewriter.create(val.getLoc(), val) .getResult()); } else { result.push_back(val); } } return result; } /// Compute the bufferized type of a loop iter_arg. This type must be equal to /// the bufferized type of the corresponding init_arg and the bufferized type /// of the corresponding yielded value. /// /// This function uses bufferization::getBufferType to compute the bufferized /// type of the init_arg and of the yielded value. (The computation of the /// bufferized yielded value type usually requires computing the bufferized type /// of the iter_arg again; the implementation of getBufferType traces back the /// use-def chain of the given value and computes a buffer type along the way.) /// If both buffer types are equal, no casts are needed the computed buffer type /// can be used directly. Otherwise, the buffer types can only differ in their /// layout map and a cast must be inserted. static FailureOr computeLoopRegionIterArgBufferType( Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, const BufferizationOptions &options, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. auto initArgBufferType = bufferization::getBufferType(initArg, options, invocationStack); if (failed(initArgBufferType)) return failure(); if (llvm::count(invocationStack, iterArg) >= 2) { // If the iter_arg is already twice on the invocation stack, just take the // type of the init_arg. This is to avoid infinite loops when calculating // the buffer type. This will most likely result in computing a memref type // with a fully dynamic layout map. // Note: For more precise layout map computation, a fixpoint iteration could // be done (i.e., re-computing the yielded buffer type until the bufferized // iter_arg type no longer changes). This current implementation immediately // switches to a fully dynamic layout map when a mismatch between bufferized // init_arg type and bufferized yield value type is detected. return *initArgBufferType; } // Compute the buffer type of the yielded value. BaseMemRefType yieldedValueBufferType; if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. yieldedValueBufferType = cast(yieldedValue.getType()); } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; } // If yielded type and init_arg type are the same, use that type directly. if (*initArgBufferType == yieldedValueBufferType) return yieldedValueBufferType; // If there is a mismatch between the yielded buffer type and the init_arg // buffer type, the buffer type must be promoted to a fully dynamic layout // map. auto yieldedBufferType = cast(yieldedValueBufferType); auto iterTensorType = cast(iterArg.getType()); auto initBufferType = llvm::cast(*initArgBufferType); if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) return loopOp->emitOpError( "init_arg and yielded value bufferize to inconsistent memory spaces"); #ifndef NDEBUG if (auto yieldedRankedBufferType = dyn_cast(yieldedBufferType)) { assert( llvm::all_equal({yieldedRankedBufferType.getShape(), cast(initBufferType).getShape(), cast(iterTensorType).getShape()}) && "expected same shape"); } #endif // NDEBUG return getMemRefTypeWithFullyDynamicLayout( iterTensorType, yieldedBufferType.getMemorySpace()); } /// Return `true` if the given loop may have 0 iterations. bool mayHaveZeroIterations(scf::ForOp forOp) { std::optional lb = getConstantIntValue(forOp.getLowerBound()); std::optional ub = getConstantIntValue(forOp.getUpperBound()); if (!lb.has_value() || !ub.has_value()) return true; return *ub <= *lb; } /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); // If the loop has zero iterations, the results of the op are their // corresponding init_args, meaning that the init_args bufferize to a read. if (mayHaveZeroIterations(forOp)) return true; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand)); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. return true; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); OpResult opResult = forOp.getTiedLoopResult(&opOperand); BufferRelation relation = bufferRelation(op, opResult, state); return {{opResult, relation, /*isDefinite=*/relation == BufferRelation::Equivalent}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); bool equivalentYield = state.areEquivalentBufferizedValues( bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get()); return equivalentYield ? BufferRelation::Equivalent : BufferRelation::Unknown; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. return true; } LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { auto bufferizableOp = cast(op); if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) return failure(); if (!state.getOptions().enforceAliasingInvariants) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult // may alias only with the corresponding bufferized init_arg (or with a // newly allocated buffer) and not with other buffers defined outside of the // loop. I.e., the i-th OpResult may alias with the i-th init_arg; // but not with any other OpOperand. auto forOp = cast(op); auto yieldOp = cast(forOp.getBody()->getTerminator()); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(yieldOp); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices = getTensorIndices(forOp.getInitArgs()); // For every yielded value, does it alias with something defined outside of // the loop? SmallVector yieldValues; for (const auto it : llvm::enumerate(yieldOp.getResults())) { // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this // type cannot be used in the signature of `resolveConflicts` because the // op interface is in the "IR" build unit and the `OneShotAnalysisState` // is defined in the "Transforms" build unit. if (!indices.contains(it.index()) || doesNotAliasExternalValue( it.value(), &forOp.getRegion(), /*exceptions=*/forOp.getRegionIterArg(it.index()), static_cast(state))) { yieldValues.push_back(it.value()); continue; } FailureOr alloc = allocateTensorForShapedValue( rewriter, yieldOp.getLoc(), it.value(), state.getOptions()); if (failed(alloc)) return failure(); yieldValues.push_back(*alloc); } rewriter.modifyOpInPlace( yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); assert(isa(value.getType()) && "expected tensor type"); if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); return bufferization::getBufferType(bbArg, options, invocationStack); } // Compute result/argument number. BlockArgument bbArg = cast(value); unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); // Compute the bufferized type. auto yieldOp = cast(forOp.getBody()->getTerminator()); Value yieldedValue = yieldOp.getOperand(resultNum); BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; Value initArg = forOp.getInitArgs()[resultNum]; return computeLoopRegionIterArgBufferType( op, iterArg, initArg, yieldedValue, options, invocationStack); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto forOp = cast(op); Block *oldLoopBody = forOp.getBody(); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices = getTensorIndices(forOp.getInitArgs()); // The new memref init_args of the loop. FailureOr> maybeInitArgs = getBuffers(rewriter, forOp.getInitArgsMutable(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; // Cast init_args if necessary. SmallVector castedInitArgs; for (const auto &it : llvm::enumerate(initArgs)) { Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. if (!isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } auto targetType = bufferization::getBufferType(result, options); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); } // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), castedInitArgs); newForOp->setAttrs(forOp->getAttrs()); Block *loopBody = newForOp.getBody(); // Set up new iter_args. The loop body uses tensors, so wrap the (memref) // iter_args of the new loop in ToTensorOps. rewriter.setInsertionPointToStart(loopBody); SmallVector iterArgs = getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); // Move loop body to new loop. rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); return success(); } /// Assert that yielded values of an scf.for op are equivalent to their /// corresponding bbArgs. In that case, the buffer relations of the /// corresponding OpResults are "Equivalent". /// /// If this is not the case, an allocs+copies are inserted and yielded from /// the loop. This could be a performance problem, so it must be explicitly /// activated with `alloc-return-allocs`. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { const auto &options = static_cast(state.getOptions()); if (options.allowReturnAllocsFromLoops) return success(); auto forOp = cast(op); auto yieldOp = cast(forOp.getBody()->getTerminator()); for (OpResult opResult : op->getOpResults()) { if (!isa(opResult.getType())) continue; // Note: This is overly strict. We should check for aliasing bufferized // values. But we don't have a "must-alias" analysis yet. if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) return yieldOp->emitError() << "Yield operand #" << opResult.getResultNumber() << " is not equivalent to the corresponding iter bbArg"; } return success(); } }; /// Bufferization of scf.while. Replace with a new scf.while that operates on /// memrefs. struct WhileOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::WhileOps are always considered as a read. return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Tensor iter_args of scf::WhileOps are always considered as a write. return true; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto whileOp = cast(op); unsigned int idx = opOperand.getOperandNumber(); // The OpResults and OpOperands may not match. They may not even have the // same type. The number of OpResults and OpOperands can also differ. if (idx >= op->getNumResults() || opOperand.get().getType() != op->getResult(idx).getType()) return {}; // The only aliasing OpResult may be the one at the same index. OpResult opResult = whileOp->getResult(idx); BufferRelation relation = bufferRelation(op, opResult, state); return {{opResult, relation, /*isDefinite=*/relation == BufferRelation::Equivalent}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // WhileOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent (for both the // "before" and the "after" block). unsigned int resultNumber = opResult.getResultNumber(); auto whileOp = cast(op); // The "before" region bbArgs and the OpResults may not match. if (resultNumber >= whileOp.getBeforeArguments().size()) return BufferRelation::Unknown; if (opResult.getType() != whileOp.getBeforeArguments()[resultNumber].getType()) return BufferRelation::Unknown; auto conditionOp = whileOp.getConditionOp(); BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; Value conditionOperand = conditionOp.getArgs()[resultNumber]; bool equivCondition = state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); auto yieldOp = whileOp.getYieldOp(); BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; Value yieldOperand = yieldOp.getOperand(resultNumber); bool equivYield = state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); return equivCondition && equivYield ? BufferRelation::Equivalent : BufferRelation::Unknown; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Interestingly, scf::WhileOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. return true; } LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, const AnalysisState &state) const { auto bufferizableOp = cast(op); if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) return failure(); if (!state.getOptions().enforceAliasingInvariants) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult // may alias only with the corresponding bufferized init_arg and with no // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; // but not with any other OpOperand. If a corresponding OpResult/init_arg // pair bufferizes to equivalent buffers, this aliasing requirement is // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. // (New buffer copies do not alias with any buffer.) OpBuilder::InsertionGuard g(rewriter); auto whileOp = cast(op); auto conditionOp = whileOp.getConditionOp(); // For every yielded value, is the value equivalent to its corresponding // bbArg? DenseSet equivalentYieldsBefore = getEquivalentBuffers( whileOp.getBeforeArguments(), conditionOp.getArgs(), state); DenseSet equivalentYieldsAfter = getEquivalentBuffers( whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); // Update "before" region. rewriter.setInsertionPoint(conditionOp); SmallVector beforeYieldValues; for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; if (!isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); continue; } FailureOr alloc = allocateTensorForShapedValue( rewriter, conditionOp.getLoc(), value, state.getOptions()); if (failed(alloc)) return failure(); beforeYieldValues.push_back(*alloc); } rewriter.modifyOpInPlace(conditionOp, [&]() { conditionOp.getArgsMutable().assign(beforeYieldValues); }); return success(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto whileOp = cast(op); // Indices of all bbArgs that have tensor type. These are the ones that // are bufferized. The "before" and "after" regions may have different args. DenseSet indicesBefore = getTensorIndices(whileOp.getInits()); DenseSet indicesAfter = getTensorIndices(whileOp.getAfterArguments()); // The new memref init_args of the loop. FailureOr> maybeInitArgs = getBuffers(rewriter, whileOp.getInitsMutable(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; // Cast init_args if necessary. SmallVector castedInitArgs; for (const auto &it : llvm::enumerate(initArgs)) { Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. if (!isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } auto targetType = bufferization::getBufferType(beforeArg, options); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); } // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { if (!isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return llvm::cast( *bufferization::getBufferType(bbArg, options)); })); // Construct a new scf.while op with memref instead of tensor values. ValueRange argsRangeBefore(castedInitArgs); TypeRange argsTypesBefore(argsRangeBefore); auto newWhileOp = rewriter.create( whileOp.getLoc(), argsTypesAfter, castedInitArgs); // Add before/after regions to the new op. SmallVector bbArgLocsBefore(castedInitArgs.size(), whileOp.getLoc()); SmallVector bbArgLocsAfter(argsTypesAfter.size(), whileOp.getLoc()); Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); // Set up new iter_args and move the loop condition block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newBeforeBody); SmallVector newBeforeArgs = getBbArgReplacements( rewriter, newWhileOp.getBeforeArguments(), indicesBefore); rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs); // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. rewriter.setInsertionPointToStart(newAfterBody); SmallVector newAfterArgs = getBbArgReplacements( rewriter, newWhileOp.getAfterArguments(), indicesAfter); rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); assert(isa(value.getType()) && "expected tensor type"); // Case 1: Block argument of the "before" region. if (auto bbArg = dyn_cast(value)) { if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; auto yieldOp = whileOp.getYieldOp(); Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); return computeLoopRegionIterArgBufferType( op, bbArg, initArg, yieldedValue, options, invocationStack); } } // Case 2: OpResult of the loop or block argument of the "after" region. // The bufferized "after" bbArg type can be directly computed from the // bufferized "before" bbArg type. unsigned resultNum; if (auto opResult = dyn_cast(value)) { resultNum = opResult.getResultNumber(); } else if (cast(value).getOwner()->getParent() == &whileOp.getAfter()) { resultNum = cast(value).getArgNumber(); } else { llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, invocationStack); } /// Assert that yielded values of an scf.while op are equivalent to their /// corresponding bbArgs. In that case, the buffer relations of the /// corresponding OpResults are "Equivalent". /// /// If this is not the case, allocs+copies are inserted and yielded from /// the loop. This could be a performance problem, so it must be explicitly /// activated with `allow-return-allocs`. /// /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the /// equivalence condition must be checked for both. LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const { auto whileOp = cast(op); const auto &options = static_cast(state.getOptions()); if (options.allowReturnAllocsFromLoops) return success(); auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Block *block = conditionOp->getBlock(); if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), block->getArgument(it.index()))) return conditionOp->emitError() << "Condition arg #" << it.index() << " is not equivalent to the corresponding iter bbArg"; } auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Block *block = yieldOp->getBlock(); if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), block->getArgument(it.index()))) return yieldOp->emitError() << "Yield operand #" << it.index() << " is not equivalent to the corresponding iter bbArg"; } return success(); } }; /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (auto ifOp = dyn_cast(op->getParentOp())) { return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), BufferRelation::Equivalent, /*isDefinite=*/false}}; } if (isa(op->getParentOp())) return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), BufferRelation::Equivalent}}; return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Yield operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto yieldOp = cast(op); if (!isa(yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; // We may have to cast the value before yielding it. if (isa( yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( yieldOp->getParentOp()->getResult(it.index()), options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } else if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( whileOp.getBeforeArguments()[it.index()], options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } newResults.push_back(buffer); } else { newResults.push_back(value); } } replaceOpWithNewBufferizedOp(rewriter, op, newResults); return success(); } }; /// Return `true` if the given loop may have 0 iterations. bool mayHaveZeroIterations(scf::ForallOp forallOp) { for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound())) { std::optional lbConst = getConstantIntValue(lb); std::optional ubConst = getConstantIntValue(ub); if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst) return true; } return false; } /// Bufferization of ForallOp. This also bufferizes the terminator of the /// region. There are op interfaces for the terminators (InParallelOp /// and ParallelInsertSliceOp), but these are only used during analysis. Not /// for bufferization. struct ForallOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forallOp = cast(op); // If the loop has zero iterations, the results of the op are their // corresponding shared_outs, meaning that the shared_outs bufferize to a // read. if (mayHaveZeroIterations(forallOp)) return true; // scf::ForallOp alone doesn't bufferize to a memory read, one of the // uses of its matching bbArg may. return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand)); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Outputs of scf::ForallOps are always considered as a write. return true; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forallOp = cast(op); return { {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard guard(rewriter); auto forallOp = cast(op); int64_t rank = forallOp.getRank(); // Get buffers for all output operands. SmallVector buffers; for (Value out : forallOp.getOutputs()) { FailureOr buffer = getBuffer(rewriter, out, options); if (failed(buffer)) return failure(); buffers.push_back(*buffer); } // Use buffers instead of block arguments. rewriter.setInsertionPointToStart(forallOp.getBody()); for (const auto &it : llvm::zip( forallOp.getBody()->getArguments().drop_front(rank), buffers)) { BlockArgument bbArg = std::get<0>(it); Value buffer = std::get<1>(it); Value bufferAsTensor = rewriter.create(forallOp.getLoc(), buffer); bbArg.replaceAllUsesWith(bufferAsTensor); } // Create new ForallOp without any results and drop the automatically // introduced terminator. rewriter.setInsertionPoint(forallOp); ForallOp newForallOp; newForallOp = rewriter.create( forallOp.getLoc(), forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), /*outputs=*/ValueRange(), forallOp.getMapping()); rewriter.eraseOp(newForallOp.getBody()->getTerminator()); // Move over block contents of the old op. SmallVector replacementBbArgs; replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), newForallOp.getBody()->getArguments().end()); replacementBbArgs.append(forallOp.getOutputs().size(), Value()); rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), replacementBbArgs); // Remove the old op and replace all of its uses. replaceOpWithBufferizedValues(rewriter, op, buffers); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto forallOp = cast(op); if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, invocationStack); } bool isRepetitiveRegion(Operation *op, unsigned index) const { auto forallOp = cast(op); // This op is repetitive if it has 1 or more steps. // If the control variables are dynamic, it is also considered so. for (auto [lb, ub, step] : llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep())) { std::optional lbConstant = getConstantIntValue(lb); if (!lbConstant) return true; std::optional ubConstant = getConstantIntValue(ub); if (!ubConstant) return true; std::optional stepConstant = getConstantIntValue(step); if (!stepConstant) return true; if (*lbConstant + *stepConstant < *ubConstant) return true; } return false; } bool isParallelRegion(Operation *op, unsigned index) const { return isRepetitiveRegion(op, index); } }; /// Nothing to do for InParallelOp. struct InParallelOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &b, const BufferizationOptions &options) const { llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } }; } // namespace } // namespace scf } // namespace mlir void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { ConditionOp::attachInterface(*ctx); ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); IndexSwitchOp::attachInterface(*ctx); ForallOp::attachInterface(*ctx); InParallelOp::attachInterface(*ctx); WhileOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); }