//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include #include namespace hlfir { #define GEN_PASS_DEF_LOWERHLFIRINTRINSICS #include "flang/Optimizer/HLFIR/Passes.h.inc" } // namespace hlfir namespace { /// Base class for passes converting transformational intrinsic operations into /// runtime calls template class HlfirIntrinsicConversion : public mlir::OpRewritePattern { public: explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) : mlir::OpRewritePattern{ctx} { // required for cases where intrinsics are chained together e.g. // matmul(matmul(a, b), c) // because converting the inner operation then invalidates the // outer operation: causing the pattern to apply recursively. // // This is safe because we always progress with each iteration. Circular // applications of operations are not expressible in MLIR because we use // an SSA form and one must become first. E.g. // %a = hlfir.matmul %b %d // %b = hlfir.matmul %a %d // cannot be written. // MSVC needs the this-> this->setHasBoundedRewriteRecursion(true); } protected: struct IntrinsicArgument { mlir::Value val; // allowed to be null if the argument is absent mlir::Type desiredType; }; /// Lower the arguments to the intrinsic: adding necessary boxing and /// conversion to match the signature of the intrinsic in the runtime library. llvm::SmallVector lowerArguments(mlir::Operation *op, const llvm::ArrayRef &args, mlir::PatternRewriter &rewriter, const fir::IntrinsicArgumentLoweringRules *argLowering) const { mlir::Location loc = op->getLoc(); fir::FirOpBuilder builder{rewriter, op}; llvm::SmallVector ret; llvm::SmallVector, 2> cleanupFns; for (size_t i = 0; i < args.size(); ++i) { mlir::Value arg = args[i].val; mlir::Type desiredType = args[i].desiredType; if (!arg) { ret.emplace_back(fir::getAbsentIntrinsicArgument()); continue; } hlfir::Entity entity{arg}; fir::ArgLoweringRule argRules = fir::lowerIntrinsicArgumentAs(*argLowering, i); switch (argRules.lowerAs) { case fir::LowerIntrinsicArgAs::Value: { if (args[i].desiredType != arg.getType()) { arg = builder.createConvert(loc, desiredType, arg); entity = hlfir::Entity{arg}; } auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity); if (cleanup) cleanupFns.push_back(*cleanup); ret.emplace_back(exv); } break; case fir::LowerIntrinsicArgAs::Addr: { auto [exv, cleanup] = hlfir::convertToAddress(loc, builder, entity, desiredType); if (cleanup) cleanupFns.push_back(*cleanup); ret.emplace_back(exv); } break; case fir::LowerIntrinsicArgAs::Box: { auto [box, cleanup] = hlfir::convertToBox(loc, builder, entity, desiredType); if (cleanup) cleanupFns.push_back(*cleanup); ret.emplace_back(box); } break; case fir::LowerIntrinsicArgAs::Inquired: { if (args[i].desiredType != arg.getType()) { arg = builder.createConvert(loc, desiredType, arg); entity = hlfir::Entity{arg}; } // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities // are translated to fir::ExtendedValue without transofrmation (notably, // pointers/allocatable are not dereferenced). // TODO: once lowering to FIR retires, UBOUND and LBOUND can be // simplified since the fir.box lowered here are now guarenteed to // contain the local lower bounds thanks to the hlfir.declare (the extra // rebox can be removed). auto [exv, cleanup] = hlfir::translateToExtendedValue(loc, builder, entity); if (cleanup) cleanupFns.push_back(*cleanup); ret.emplace_back(exv); } break; } } if (cleanupFns.size()) { auto oldInsertionPoint = builder.saveInsertionPoint(); builder.setInsertionPointAfter(op); for (std::function cleanup : cleanupFns) cleanup(); builder.restoreInsertionPoint(oldInsertionPoint); } return ret; } void processReturnValue(mlir::Operation *op, const fir::ExtendedValue &resultExv, bool mustBeFreed, fir::FirOpBuilder &builder, mlir::PatternRewriter &rewriter) const { mlir::Location loc = op->getLoc(); mlir::Value firBase = fir::getBase(resultExv); mlir::Type firBaseTy = firBase.getType(); std::optional resultEntity; if (fir::isa_trivial(firBaseTy)) { // Some intrinsics return i1 when the original operation // produces fir.logical<>, so we may need to cast it. firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase); resultEntity = hlfir::EntityWithAttributes{firBase}; } else { resultEntity = hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result", fir::FortranVariableFlagsAttr{}); } if (resultEntity->isVariable()) { hlfir::AsExprOp asExpr = builder.create( loc, *resultEntity, builder.createBool(loc, mustBeFreed)); resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()}; } mlir::Value base = resultEntity->getBase(); if (!mlir::isa(base.getType())) { for (mlir::Operation *use : op->getResult(0).getUsers()) { if (mlir::isa(use)) rewriter.eraseOp(use); } } rewriter.replaceAllUsesWith(op->getResults(), {base}); rewriter.replaceOp(op, base); } }; // Given an integer or array of integer type, calculate the Kind parameter from // the width for use in runtime intrinsic calls. static unsigned getKindForType(mlir::Type ty) { mlir::Type eltty = hlfir::getFortranElementType(ty); unsigned width = eltty.cast().getWidth(); return width / 8; } template class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; using IntrinsicArgument = typename HlfirIntrinsicConversion::IntrinsicArgument; using HlfirIntrinsicConversion::lowerArguments; using HlfirIntrinsicConversion::processReturnValue; protected: auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName) const { llvm::SmallVector inArgs; inArgs.push_back({operation.getArray(), operation.getArray().getType()}); inArgs.push_back({operation.getDim(), i32}); inArgs.push_back({operation.getMask(), logicalType}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName, fir::FirOpBuilder builder) const { llvm::SmallVector inArgs; inArgs.push_back({operation.getArray(), operation.getArray().getType()}); inArgs.push_back({operation.getDim(), i32}); inArgs.push_back({operation.getMask(), logicalType}); mlir::Value kind = builder.createIntegerConstant( operation->getLoc(), i32, getKindForType(operation.getType())); inArgs.push_back({kind, i32}); inArgs.push_back({operation.getBack(), i32}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, mlir::PatternRewriter &rewriter, std::string opName) const { llvm::SmallVector inArgs; inArgs.push_back({operation.getMask(), logicalType}); inArgs.push_back({operation.getDim(), i32}); auto *argLowering = fir::getIntrinsicArgumentLowering(opName); return lowerArguments(operation, inArgs, rewriter, argLowering); }; public: mlir::LogicalResult matchAndRewrite(OP operation, mlir::PatternRewriter &rewriter) const override { std::string opName; if constexpr (std::is_same_v) { opName = "sum"; } else if constexpr (std::is_same_v) { opName = "product"; } else if constexpr (std::is_same_v) { opName = "maxval"; } else if constexpr (std::is_same_v) { opName = "minval"; } else if constexpr (std::is_same_v) { opName = "minloc"; } else if constexpr (std::is_same_v) { opName = "maxloc"; } else if constexpr (std::is_same_v) { opName = "any"; } else if constexpr (std::is_same_v) { opName = "all"; } else { return mlir::failure(); } fir::FirOpBuilder builder{rewriter, operation.getOperation()}; const mlir::Location &loc = operation->getLoc(); mlir::Type i32 = builder.getI32Type(); mlir::Type logicalType = fir::LogicalType::get( builder.getContext(), builder.getKindMap().defaultLogicalKind()); llvm::SmallVector args; if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); } else if constexpr (std::is_same_v || std::is_same_v) { args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName, builder); } else { args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); } mlir::Type scalarResultType = hlfir::getFortranElementType(operation.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; using SumOpConversion = HlfirReductionIntrinsicConversion; using ProductOpConversion = HlfirReductionIntrinsicConversion; using MaxvalOpConversion = HlfirReductionIntrinsicConversion; using MinvalOpConversion = HlfirReductionIntrinsicConversion; using MinlocOpConversion = HlfirReductionIntrinsicConversion; using MaxlocOpConversion = HlfirReductionIntrinsicConversion; using AnyOpConversion = HlfirReductionIntrinsicConversion; using AllOpConversion = HlfirReductionIntrinsicConversion; struct CountOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult matchAndRewrite(hlfir::CountOp count, mlir::PatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, count.getOperation()}; const mlir::Location &loc = count->getLoc(); mlir::Type i32 = builder.getI32Type(); mlir::Type logicalType = fir::LogicalType::get( builder.getContext(), builder.getKindMap().defaultLogicalKind()); llvm::SmallVector inArgs; inArgs.push_back({count.getMask(), logicalType}); inArgs.push_back({count.getDim(), i32}); mlir::Value kind = builder.createIntegerConstant( count->getLoc(), i32, getKindForType(count.getType())); inArgs.push_back({kind, i32}); auto *argLowering = fir::getIntrinsicArgumentLowering("count"); llvm::SmallVector args = lowerArguments(count, inArgs, rewriter, argLowering); mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args); processReturnValue(count, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; struct MatmulOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult matchAndRewrite(hlfir::MatmulOp matmul, mlir::PatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; const mlir::Location &loc = matmul->getLoc(); mlir::Value lhs = matmul.getLhs(); mlir::Value rhs = matmul.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()}); auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); llvm::SmallVector args = lowerArguments(matmul, inArgs, rewriter, argLowering); mlir::Type scalarResultType = hlfir::getFortranElementType(matmul.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args); processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; struct DotProductOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult matchAndRewrite(hlfir::DotProductOp dotProduct, mlir::PatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; const mlir::Location &loc = dotProduct->getLoc(); mlir::Value lhs = dotProduct.getLhs(); mlir::Value rhs = dotProduct.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()}); auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product"); llvm::SmallVector args = lowerArguments(dotProduct, inArgs, rewriter, argLowering); mlir::Type scalarResultType = hlfir::getFortranElementType(dotProduct.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( builder, loc, "dot_product", scalarResultType, args); processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; class TransposeOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; mlir::LogicalResult matchAndRewrite(hlfir::TransposeOp transpose, mlir::PatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; const mlir::Location &loc = transpose->getLoc(); mlir::Value arg = transpose.getArray(); llvm::SmallVector inArgs; inArgs.push_back({arg, arg.getType()}); auto *argLowering = fir::getIntrinsicArgumentLowering("transpose"); llvm::SmallVector args = lowerArguments(transpose, inArgs, rewriter, argLowering); mlir::Type scalarResultType = hlfir::getFortranElementType(transpose.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( builder, loc, "transpose", scalarResultType, args); processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; struct MatmulTransposeOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion< hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; mlir::LogicalResult matchAndRewrite(hlfir::MatmulTransposeOp multranspose, mlir::PatternRewriter &rewriter) const override { fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; const mlir::Location &loc = multranspose->getLoc(); mlir::Value lhs = multranspose.getLhs(); mlir::Value rhs = multranspose.getRhs(); llvm::SmallVector inArgs; inArgs.push_back({lhs, lhs.getType()}); inArgs.push_back({rhs, rhs.getType()}); auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); llvm::SmallVector args = lowerArguments(multranspose, inArgs, rewriter, argLowering); mlir::Type scalarResultType = hlfir::getFortranElementType(multranspose.getType()); auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( builder, loc, "matmul_transpose", scalarResultType, args); processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter); return mlir::success(); } }; class LowerHLFIRIntrinsics : public hlfir::impl::LowerHLFIRIntrinsicsBase { public: void runOnOperation() override { // TODO: make this a pass operating on FuncOp. The issue is that // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm // intrinsics calls creation. This may create race conflict if the pass is // scheduled on FuncOp. A solution could be to provide an optional mutex // when building a FirOpBuilder and locking around FuncOp and GlobalOp // creation, but this needs a bit more thinking, so at this point the pass // is scheduled on the moduleOp. mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns .insert(context); mlir::ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( mlir::applyFullConversion(module, target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in HLFIR intrinsic lowering"); signalPassFailure(); } } }; } // namespace std::unique_ptr hlfir::createLowerHLFIRIntrinsicsPass() { return std::make_unique(); }