//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; using namespace mlir::bufferization; namespace { /// Bufferization of arith.constant. Replace with memref.get_global. struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto constantOp = cast(op); Attribute memorySpace; if (options.defaultMemorySpace.has_value()) memorySpace = *options.defaultMemorySpace; else return constantOp->emitError("could not infer memory space"); // Only ranked tensors are supported. if (!isa(constantOp.getType())) return failure(); // Only constants inside a module are supported. auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) return failure(); // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = getGlobalFor(constantOp, options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; replaceOpWithNewBufferizedOp( rewriter, op, globalMemref.getType(), globalMemref.getName()); return success(); } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. assert(isa(value)); return false; } }; struct IndexCastOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getResult(0), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); auto resultTensorType = cast(castOp.getType()); FailureOr source = getBuffer(rewriter, castOp.getIn(), options); if (failed(source)) return failure(); auto sourceType = cast(source->getType()); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; if (auto rankedMemRefType = dyn_cast(sourceType)) { resultType = MemRefType::get( rankedMemRefType.getShape(), resultTensorType.getElementType(), rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); } else { auto unrankedMemrefType = cast(sourceType); resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), unrankedMemrefType.getMemorySpace()); } replaceOpWithNewBufferizedOp(rewriter, op, resultType, *source); return success(); } }; /// Bufferization of arith.select. Just replace the operands. struct SelectOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {{op->getOpResult(0) /*result*/, BufferRelation::Equivalent, /*isDefinite=*/false}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto selectOp = cast(op); Location loc = selectOp.getLoc(); // Elementwise conditions are not supported yet. To bufferize such an op, // it could be lowered to an elementwise "linalg.generic" with a new // "tensor.empty" out tensor, followed by "empty tensor elimination". Such // IR will bufferize. if (!selectOp.getCondition().getType().isInteger(1)) return op->emitOpError("only i1 condition values are supported"); // TODO: It would be more efficient to copy the result of the `select` op // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one // copy would be needed. FailureOr maybeTrueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options); FailureOr maybeFalseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options); if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer)) return failure(); Value trueBuffer = *maybeTrueBuffer; Value falseBuffer = *maybeFalseBuffer; // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { auto targetType = bufferization::getBufferType(selectOp.getResult(), options); if (failed(targetType)) return failure(); if (trueBuffer.getType() != *targetType) trueBuffer = rewriter.create(loc, *targetType, trueBuffer); if (falseBuffer.getType() != *targetType) falseBuffer = rewriter.create(loc, *targetType, falseBuffer); } replaceOpWithNewBufferizedOp( rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); return success(); } FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); auto trueType = bufferization::getBufferType(selectOp.getTrueValue(), options, invocationStack); auto falseType = bufferization::getBufferType(selectOp.getFalseValue(), options, invocationStack); if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) return *trueType; if (trueType->getMemorySpace() != falseType->getMemorySpace()) return op->emitError("inconsistent memory space on true/false operands"); // If the buffers have different types, they differ only in their layout // map. auto memrefType = llvm::cast(*trueType); return getMemRefTypeWithFullyDynamicLayout( RankedTensorType::get(memrefType.getShape(), memrefType.getElementType()), memrefType.getMemorySpace()); } }; } // namespace void mlir::arith::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) { ConstantOp::attachInterface(*ctx); IndexCastOp::attachInterface(*ctx); SelectOp::attachInterface(*ctx); }); }