510 lines
20 KiB
C++
510 lines
20 KiB
C++
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
|
#include "flang/Optimizer/Builder/HLFIRTools.h"
|
|
#include "flang/Optimizer/Builder/IntrinsicCall.h"
|
|
#include "flang/Optimizer/Builder/Todo.h"
|
|
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
|
#include "flang/Optimizer/Dialect/FIRType.h"
|
|
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
|
|
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
|
|
#include "flang/Optimizer/HLFIR/HLFIROps.h"
|
|
#include "flang/Optimizer/HLFIR/Passes.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <mlir/IR/MLIRContext.h>
|
|
#include <optional>
|
|
|
|
namespace hlfir {
|
|
#define GEN_PASS_DEF_LOWERHLFIRINTRINSICS
|
|
#include "flang/Optimizer/HLFIR/Passes.h.inc"
|
|
} // namespace hlfir
|
|
|
|
namespace {
|
|
|
|
/// Base class for passes converting transformational intrinsic operations into
|
|
/// runtime calls
|
|
template <class OP>
|
|
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
|
|
public:
|
|
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
|
|
: mlir::OpRewritePattern<OP>{ctx} {
|
|
// required for cases where intrinsics are chained together e.g.
|
|
// matmul(matmul(a, b), c)
|
|
// because converting the inner operation then invalidates the
|
|
// outer operation: causing the pattern to apply recursively.
|
|
//
|
|
// This is safe because we always progress with each iteration. Circular
|
|
// applications of operations are not expressible in MLIR because we use
|
|
// an SSA form and one must become first. E.g.
|
|
// %a = hlfir.matmul %b %d
|
|
// %b = hlfir.matmul %a %d
|
|
// cannot be written.
|
|
// MSVC needs the this->
|
|
this->setHasBoundedRewriteRecursion(true);
|
|
}
|
|
|
|
protected:
|
|
struct IntrinsicArgument {
|
|
mlir::Value val; // allowed to be null if the argument is absent
|
|
mlir::Type desiredType;
|
|
};
|
|
|
|
/// Lower the arguments to the intrinsic: adding necessary boxing and
|
|
/// conversion to match the signature of the intrinsic in the runtime library.
|
|
llvm::SmallVector<fir::ExtendedValue, 3>
|
|
lowerArguments(mlir::Operation *op,
|
|
const llvm::ArrayRef<IntrinsicArgument> &args,
|
|
mlir::PatternRewriter &rewriter,
|
|
const fir::IntrinsicArgumentLoweringRules *argLowering) const {
|
|
mlir::Location loc = op->getLoc();
|
|
fir::FirOpBuilder builder{rewriter, op};
|
|
|
|
llvm::SmallVector<fir::ExtendedValue, 3> ret;
|
|
llvm::SmallVector<std::function<void()>, 2> cleanupFns;
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
mlir::Value arg = args[i].val;
|
|
mlir::Type desiredType = args[i].desiredType;
|
|
if (!arg) {
|
|
ret.emplace_back(fir::getAbsentIntrinsicArgument());
|
|
continue;
|
|
}
|
|
hlfir::Entity entity{arg};
|
|
|
|
fir::ArgLoweringRule argRules =
|
|
fir::lowerIntrinsicArgumentAs(*argLowering, i);
|
|
switch (argRules.lowerAs) {
|
|
case fir::LowerIntrinsicArgAs::Value: {
|
|
if (args[i].desiredType != arg.getType()) {
|
|
arg = builder.createConvert(loc, desiredType, arg);
|
|
entity = hlfir::Entity{arg};
|
|
}
|
|
auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity);
|
|
if (cleanup)
|
|
cleanupFns.push_back(*cleanup);
|
|
ret.emplace_back(exv);
|
|
} break;
|
|
case fir::LowerIntrinsicArgAs::Addr: {
|
|
auto [exv, cleanup] =
|
|
hlfir::convertToAddress(loc, builder, entity, desiredType);
|
|
if (cleanup)
|
|
cleanupFns.push_back(*cleanup);
|
|
ret.emplace_back(exv);
|
|
} break;
|
|
case fir::LowerIntrinsicArgAs::Box: {
|
|
auto [box, cleanup] =
|
|
hlfir::convertToBox(loc, builder, entity, desiredType);
|
|
if (cleanup)
|
|
cleanupFns.push_back(*cleanup);
|
|
ret.emplace_back(box);
|
|
} break;
|
|
case fir::LowerIntrinsicArgAs::Inquired: {
|
|
if (args[i].desiredType != arg.getType()) {
|
|
arg = builder.createConvert(loc, desiredType, arg);
|
|
entity = hlfir::Entity{arg};
|
|
}
|
|
// Place hlfir.expr in memory, and unbox fir.boxchar. Other entities
|
|
// are translated to fir::ExtendedValue without transofrmation (notably,
|
|
// pointers/allocatable are not dereferenced).
|
|
// TODO: once lowering to FIR retires, UBOUND and LBOUND can be
|
|
// simplified since the fir.box lowered here are now guarenteed to
|
|
// contain the local lower bounds thanks to the hlfir.declare (the extra
|
|
// rebox can be removed).
|
|
auto [exv, cleanup] =
|
|
hlfir::translateToExtendedValue(loc, builder, entity);
|
|
if (cleanup)
|
|
cleanupFns.push_back(*cleanup);
|
|
ret.emplace_back(exv);
|
|
} break;
|
|
}
|
|
}
|
|
|
|
if (cleanupFns.size()) {
|
|
auto oldInsertionPoint = builder.saveInsertionPoint();
|
|
builder.setInsertionPointAfter(op);
|
|
for (std::function<void()> cleanup : cleanupFns)
|
|
cleanup();
|
|
builder.restoreInsertionPoint(oldInsertionPoint);
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
void processReturnValue(mlir::Operation *op,
|
|
const fir::ExtendedValue &resultExv, bool mustBeFreed,
|
|
fir::FirOpBuilder &builder,
|
|
mlir::PatternRewriter &rewriter) const {
|
|
mlir::Location loc = op->getLoc();
|
|
|
|
mlir::Value firBase = fir::getBase(resultExv);
|
|
mlir::Type firBaseTy = firBase.getType();
|
|
|
|
std::optional<hlfir::EntityWithAttributes> resultEntity;
|
|
if (fir::isa_trivial(firBaseTy)) {
|
|
// Some intrinsics return i1 when the original operation
|
|
// produces fir.logical<>, so we may need to cast it.
|
|
firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase);
|
|
resultEntity = hlfir::EntityWithAttributes{firBase};
|
|
} else {
|
|
resultEntity =
|
|
hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result",
|
|
fir::FortranVariableFlagsAttr{});
|
|
}
|
|
|
|
if (resultEntity->isVariable()) {
|
|
hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>(
|
|
loc, *resultEntity, builder.createBool(loc, mustBeFreed));
|
|
resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()};
|
|
}
|
|
|
|
mlir::Value base = resultEntity->getBase();
|
|
if (!mlir::isa<hlfir::ExprType>(base.getType())) {
|
|
for (mlir::Operation *use : op->getResult(0).getUsers()) {
|
|
if (mlir::isa<hlfir::DestroyOp>(use))
|
|
rewriter.eraseOp(use);
|
|
}
|
|
}
|
|
rewriter.replaceAllUsesWith(op->getResults(), {base});
|
|
rewriter.replaceOp(op, base);
|
|
}
|
|
};
|
|
|
|
// Given an integer or array of integer type, calculate the Kind parameter from
|
|
// the width for use in runtime intrinsic calls.
|
|
static unsigned getKindForType(mlir::Type ty) {
|
|
mlir::Type eltty = hlfir::getFortranElementType(ty);
|
|
unsigned width = eltty.cast<mlir::IntegerType>().getWidth();
|
|
return width / 8;
|
|
}
|
|
|
|
template <class OP>
|
|
class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
|
|
using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion;
|
|
using IntrinsicArgument =
|
|
typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
|
|
using HlfirIntrinsicConversion<OP>::lowerArguments;
|
|
using HlfirIntrinsicConversion<OP>::processReturnValue;
|
|
|
|
protected:
|
|
auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
|
|
mlir::PatternRewriter &rewriter,
|
|
std::string opName) const {
|
|
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
|
|
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
|
|
inArgs.push_back({operation.getDim(), i32});
|
|
inArgs.push_back({operation.getMask(), logicalType});
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
|
|
return lowerArguments(operation, inArgs, rewriter, argLowering);
|
|
};
|
|
|
|
auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
|
|
mlir::PatternRewriter &rewriter, std::string opName,
|
|
fir::FirOpBuilder builder) const {
|
|
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
|
|
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
|
|
inArgs.push_back({operation.getDim(), i32});
|
|
inArgs.push_back({operation.getMask(), logicalType});
|
|
mlir::Value kind = builder.createIntegerConstant(
|
|
operation->getLoc(), i32, getKindForType(operation.getType()));
|
|
inArgs.push_back({kind, i32});
|
|
inArgs.push_back({operation.getBack(), i32});
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
|
|
return lowerArguments(operation, inArgs, rewriter, argLowering);
|
|
};
|
|
|
|
auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
|
|
mlir::PatternRewriter &rewriter,
|
|
std::string opName) const {
|
|
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
|
|
inArgs.push_back({operation.getMask(), logicalType});
|
|
inArgs.push_back({operation.getDim(), i32});
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
|
|
return lowerArguments(operation, inArgs, rewriter, argLowering);
|
|
};
|
|
|
|
public:
|
|
mlir::LogicalResult
|
|
matchAndRewrite(OP operation,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
std::string opName;
|
|
if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
|
|
opName = "sum";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) {
|
|
opName = "product";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) {
|
|
opName = "maxval";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) {
|
|
opName = "minval";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) {
|
|
opName = "minloc";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) {
|
|
opName = "maxloc";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) {
|
|
opName = "any";
|
|
} else if constexpr (std::is_same_v<OP, hlfir::AllOp>) {
|
|
opName = "all";
|
|
} else {
|
|
return mlir::failure();
|
|
}
|
|
|
|
fir::FirOpBuilder builder{rewriter, operation.getOperation()};
|
|
const mlir::Location &loc = operation->getLoc();
|
|
|
|
mlir::Type i32 = builder.getI32Type();
|
|
mlir::Type logicalType = fir::LogicalType::get(
|
|
builder.getContext(), builder.getKindMap().defaultLogicalKind());
|
|
|
|
llvm::SmallVector<fir::ExtendedValue, 0> args;
|
|
|
|
if constexpr (std::is_same_v<OP, hlfir::SumOp> ||
|
|
std::is_same_v<OP, hlfir::ProductOp> ||
|
|
std::is_same_v<OP, hlfir::MaxvalOp> ||
|
|
std::is_same_v<OP, hlfir::MinvalOp>) {
|
|
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
|
|
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
|
|
std::is_same_v<OP, hlfir::MaxlocOp>) {
|
|
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
|
|
builder);
|
|
} else {
|
|
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
|
|
}
|
|
|
|
mlir::Type scalarResultType =
|
|
hlfir::getFortranElementType(operation.getType());
|
|
|
|
auto [resultExv, mustBeFreed] =
|
|
fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args);
|
|
|
|
processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>;
|
|
|
|
using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>;
|
|
|
|
using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>;
|
|
|
|
using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>;
|
|
|
|
using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>;
|
|
|
|
using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>;
|
|
|
|
using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>;
|
|
|
|
using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>;
|
|
|
|
struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
|
|
using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(hlfir::CountOp count,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
fir::FirOpBuilder builder{rewriter, count.getOperation()};
|
|
const mlir::Location &loc = count->getLoc();
|
|
|
|
mlir::Type i32 = builder.getI32Type();
|
|
mlir::Type logicalType = fir::LogicalType::get(
|
|
builder.getContext(), builder.getKindMap().defaultLogicalKind());
|
|
|
|
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
|
|
inArgs.push_back({count.getMask(), logicalType});
|
|
inArgs.push_back({count.getDim(), i32});
|
|
mlir::Value kind = builder.createIntegerConstant(
|
|
count->getLoc(), i32, getKindForType(count.getType()));
|
|
inArgs.push_back({kind, i32});
|
|
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering("count");
|
|
llvm::SmallVector<fir::ExtendedValue, 3> args =
|
|
lowerArguments(count, inArgs, rewriter, argLowering);
|
|
|
|
mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType());
|
|
|
|
auto [resultExv, mustBeFreed] =
|
|
fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args);
|
|
|
|
processReturnValue(count, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
|
|
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(hlfir::MatmulOp matmul,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
|
|
const mlir::Location &loc = matmul->getLoc();
|
|
|
|
mlir::Value lhs = matmul.getLhs();
|
|
mlir::Value rhs = matmul.getRhs();
|
|
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
|
|
inArgs.push_back({lhs, lhs.getType()});
|
|
inArgs.push_back({rhs, rhs.getType()});
|
|
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
|
|
llvm::SmallVector<fir::ExtendedValue, 2> args =
|
|
lowerArguments(matmul, inArgs, rewriter, argLowering);
|
|
|
|
mlir::Type scalarResultType =
|
|
hlfir::getFortranElementType(matmul.getType());
|
|
|
|
auto [resultExv, mustBeFreed] =
|
|
fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args);
|
|
|
|
processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct DotProductOpConversion
|
|
: public HlfirIntrinsicConversion<hlfir::DotProductOp> {
|
|
using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(hlfir::DotProductOp dotProduct,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
|
|
const mlir::Location &loc = dotProduct->getLoc();
|
|
|
|
mlir::Value lhs = dotProduct.getLhs();
|
|
mlir::Value rhs = dotProduct.getRhs();
|
|
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
|
|
inArgs.push_back({lhs, lhs.getType()});
|
|
inArgs.push_back({rhs, rhs.getType()});
|
|
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
|
|
llvm::SmallVector<fir::ExtendedValue, 2> args =
|
|
lowerArguments(dotProduct, inArgs, rewriter, argLowering);
|
|
|
|
mlir::Type scalarResultType =
|
|
hlfir::getFortranElementType(dotProduct.getType());
|
|
|
|
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
|
|
builder, loc, "dot_product", scalarResultType, args);
|
|
|
|
processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
class TransposeOpConversion
|
|
: public HlfirIntrinsicConversion<hlfir::TransposeOp> {
|
|
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(hlfir::TransposeOp transpose,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
|
|
const mlir::Location &loc = transpose->getLoc();
|
|
|
|
mlir::Value arg = transpose.getArray();
|
|
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
|
|
inArgs.push_back({arg, arg.getType()});
|
|
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering("transpose");
|
|
llvm::SmallVector<fir::ExtendedValue, 1> args =
|
|
lowerArguments(transpose, inArgs, rewriter, argLowering);
|
|
|
|
mlir::Type scalarResultType =
|
|
hlfir::getFortranElementType(transpose.getType());
|
|
|
|
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
|
|
builder, loc, "transpose", scalarResultType, args);
|
|
|
|
processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct MatmulTransposeOpConversion
|
|
: public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
|
|
using HlfirIntrinsicConversion<
|
|
hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
|
|
const mlir::Location &loc = multranspose->getLoc();
|
|
|
|
mlir::Value lhs = multranspose.getLhs();
|
|
mlir::Value rhs = multranspose.getRhs();
|
|
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
|
|
inArgs.push_back({lhs, lhs.getType()});
|
|
inArgs.push_back({rhs, rhs.getType()});
|
|
|
|
auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
|
|
llvm::SmallVector<fir::ExtendedValue, 2> args =
|
|
lowerArguments(multranspose, inArgs, rewriter, argLowering);
|
|
|
|
mlir::Type scalarResultType =
|
|
hlfir::getFortranElementType(multranspose.getType());
|
|
|
|
auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
|
|
builder, loc, "matmul_transpose", scalarResultType, args);
|
|
|
|
processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
class LowerHLFIRIntrinsics
|
|
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
|
|
public:
|
|
void runOnOperation() override {
|
|
// TODO: make this a pass operating on FuncOp. The issue is that
|
|
// FirOpBuilder helpers may generate new FuncOp because of runtime/llvm
|
|
// intrinsics calls creation. This may create race conflict if the pass is
|
|
// scheduled on FuncOp. A solution could be to provide an optional mutex
|
|
// when building a FirOpBuilder and locking around FuncOp and GlobalOp
|
|
// creation, but this needs a bit more thinking, so at this point the pass
|
|
// is scheduled on the moduleOp.
|
|
mlir::ModuleOp module = this->getOperation();
|
|
mlir::MLIRContext *context = &getContext();
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns
|
|
.insert<MatmulOpConversion, MatmulTransposeOpConversion,
|
|
AllOpConversion, AnyOpConversion, SumOpConversion,
|
|
ProductOpConversion, TransposeOpConversion, CountOpConversion,
|
|
DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
|
|
MinlocOpConversion, MaxlocOpConversion>(context);
|
|
mlir::ConversionTarget target(*context);
|
|
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
|
|
mlir::func::FuncDialect, fir::FIROpsDialect,
|
|
hlfir::hlfirDialect>();
|
|
target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
|
|
hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
|
|
hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
|
|
hlfir::MaxvalOp, hlfir::MinvalOp, hlfir::MinlocOp,
|
|
hlfir::MaxlocOp>();
|
|
target.markUnknownOpDynamicallyLegal(
|
|
[](mlir::Operation *) { return true; });
|
|
if (mlir::failed(
|
|
mlir::applyFullConversion(module, target, std::move(patterns)))) {
|
|
mlir::emitError(mlir::UnknownLoc::get(context),
|
|
"failure in HLFIR intrinsic lowering");
|
|
signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<mlir::Pass> hlfir::createLowerHLFIRIntrinsicsPass() {
|
|
return std::make_unique<LowerHLFIRIntrinsics>();
|
|
}
|