bolt/deps/llvm-18.1.8/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
2025-02-14 19:21:04 +01:00

883 lines
34 KiB
C++

//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
template <typename Op>
struct VecOpToScalarOp : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
};
// Callback type for getting pre-generated FuncOp implementing
// an operation of the given type.
using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
// Pattern to convert scalar IPowIOp into a call of outlined
// software implementation.
class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
public:
IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
: OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
/// Convert IPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of IPowI are linearized.
LogicalResult matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const final;
private:
GetFuncCallbackTy getFuncOpCallback;
};
// Pattern to convert scalar FPowIOp into a call of outlined
// software implementation.
class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
public:
FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
: OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
/// Convert FPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of FPowI are linearized.
LogicalResult matchAndRewrite(math::FPowIOp op,
PatternRewriter &rewriter) const final;
private:
GetFuncCallbackTy getFuncOpCallback;
};
// Pattern to convert scalar ctlz into a call of outlined software
// implementation.
class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
public:
CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
: OpRewritePattern<math::CountLeadingZerosOp>(context),
getFuncOpCallback(cb) {}
/// Convert ctlz into a call to a local function implementing
/// the count leading zeros operation.
LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) const final;
private:
GetFuncCallbackTy getFuncOpCallback;
};
} // namespace
template <typename Op>
LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Type opType = op.getType();
Location loc = op.getLoc();
auto vecType = dyn_cast<VectorType>(opType);
if (!vecType)
return rewriter.notifyMatchFailure(op, "not a vector operation");
if (!vecType.hasRank())
return rewriter.notifyMatchFailure(op, "unknown vector rank");
ArrayRef<int64_t> shape = vecType.getShape();
int64_t numElements = vecType.getNumElements();
Type resultElementType = vecType.getElementType();
Attribute initValueAttr;
if (isa<FloatType>(resultElementType))
initValueAttr = FloatAttr::get(resultElementType, 0.0);
else
initValueAttr = IntegerAttr::get(resultElementType, 0);
Value result = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(
rewriter.create<vector::ExtractOp>(loc, input, positions));
Value scalarOp =
rewriter.create<Op>(loc, vecType.getElementType(), operands);
result =
rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, result);
return success();
}
static FunctionType getElementalFuncTypeForOp(Operation *op) {
SmallVector<Type, 1> resultTys(op->getNumResults());
SmallVector<Type, 2> inputTys(op->getNumOperands());
std::transform(op->result_type_begin(), op->result_type_end(),
resultTys.begin(),
[](Type ty) { return getElementTypeOrSelf(ty); });
std::transform(op->operand_type_begin(), op->operand_type_end(),
inputTys.begin(),
[](Type ty) { return getElementTypeOrSelf(ty); });
return FunctionType::get(op->getContext(), inputTys, resultTys);
}
/// Create linkonce_odr function to implement the power function with
/// the given \p elementType type inside \p module. The \p elementType
/// must be IntegerType, an the created function has
/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
///
/// template <typename T>
/// T __mlir_math_ipowi_*(T b, T p) {
/// if (p == T(0))
/// return T(1);
/// if (p < T(0)) {
/// if (b == T(0))
/// return T(1) / T(0); // trigger div-by-zero
/// if (b == T(1))
/// return T(1);
/// if (b == T(-1)) {
/// if (p & T(1))
/// return T(-1);
/// return T(1);
/// }
/// return T(0);
/// }
/// T result = T(1);
/// while (true) {
/// if (p & T(1))
/// result *= b;
/// p >>= T(1);
/// if (p == T(0))
/// return result;
/// b *= b;
/// }
/// }
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
assert(isa<IntegerType>(elementType) &&
"non-integer element type for IPowIOp");
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
std::string funcName("__mlir_math_ipowi");
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << elementType;
FunctionType funcType = FunctionType::get(
builder.getContext(), {elementType, elementType}, elementType);
auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
funcOp->setAttr("llvm.linkage", linkage);
funcOp.setPrivate();
Block *entryBlock = funcOp.addEntryBlock();
Region *funcBody = entryBlock->getParent();
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
Value zeroValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, 0));
Value oneValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, 1));
Value minusOneValue = builder.create<arith::ConstantOp>(
elementType,
builder.getIntegerAttr(elementType,
APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
/*isSigned=*/true)));
// if (p == T(0))
// return T(1);
auto pIsZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
Block *thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(oneValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(pIsZero->getBlock());
builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
// if (p < T(0)) {
builder.setInsertionPointToEnd(fallthroughBlock);
auto pIsNeg =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
// if (b == T(0))
builder.createBlock(funcBody);
auto bIsZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
// return T(1) / T(0);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(
builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(0)).
builder.setInsertionPointToEnd(bIsZero->getBlock());
builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
// if (b == T(1))
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsOne =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
// return T(1);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(1)).
builder.setInsertionPointToEnd(bIsOne->getBlock());
builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
// if (b == T(-1)) {
builder.setInsertionPointToEnd(fallthroughBlock);
auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
bArg, minusOneValue);
// if (p & T(1))
builder.createBlock(funcBody);
auto pIsOdd = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
zeroValue);
// return T(-1);
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(minusOneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(pIsOdd->getBlock());
builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
// return T(1);
// } // b == T(-1)
builder.setInsertionPointToEnd(fallthroughBlock);
builder.create<func::ReturnOp>(oneValue);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (b == T(-1)).
builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
fallthroughBlock);
// return T(0);
// } // (p < T(0))
builder.setInsertionPointToEnd(fallthroughBlock);
builder.create<func::ReturnOp>(zeroValue);
Block *loopHeader = builder.createBlock(
funcBody, funcBody->end(), {elementType, elementType, elementType},
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set up conditional branch for (p < T(0)).
builder.setInsertionPointToEnd(pIsNeg->getBlock());
// Set initial values of 'result', 'b' and 'p' for the loop.
builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
ValueRange{oneValue, bArg, pArg});
// T result = T(1);
// while (true) {
// if (p & T(1))
// result *= b;
// p >>= T(1);
// if (p == T(0))
// return result;
// b *= b;
// }
Value resultTmp = loopHeader->getArgument(0);
Value baseTmp = loopHeader->getArgument(1);
Value powerTmp = loopHeader->getArgument(2);
builder.setInsertionPointToEnd(loopHeader);
// if (p & T(1))
auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::ne,
builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & T(1)).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= T(1);
builder.setInsertionPointToEnd(fallthroughBlock);
Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
// if (p == T(0))
auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
newPowerTmp, zeroValue);
// return result;
thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(newResultTmp);
fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == T(0)).
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
builder.create<cf::BranchOp>(
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
return funcOp;
}
/// Convert IPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of IPowI are linearized.
LogicalResult
IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const {
auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
if (!baseType)
return rewriter.notifyMatchFailure(op, "non-integer base operand");
// The outlined software implementation must have been already
// generated.
func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, "missing software implementation");
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
return success();
}
/// Create linkonce_odr function to implement the power function with
/// the given \p funcType type inside \p module. The \p funcType must be
/// 'FloatType (*)(FloatType, IntegerType)' function type.
///
/// template <typename T>
/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
/// if (p == Tp{0})
/// return Tb{1};
/// bool isNegativePower{p < Tp{0}}
/// bool isMin{p == std::numeric_limits<Tp>::min()};
/// if (isMin) {
/// p = std::numeric_limits<Tp>::max();
/// } else if (isNegativePower) {
/// p = -p;
/// }
/// Tb result = Tb{1};
/// Tb origBase = Tb{b};
/// while (true) {
/// if (p & Tp{1})
/// result *= b;
/// p >>= Tp{1};
/// if (p == Tp{0})
/// break;
/// b *= b;
/// }
/// if (isMin) {
/// result *= origBase;
/// }
/// if (isNegativePower) {
/// result = Tb{1} / result;
/// }
/// return result;
/// }
static func::FuncOp createElementFPowIFunc(ModuleOp *module,
FunctionType funcType) {
auto baseType = cast<FloatType>(funcType.getInput(0));
auto powType = cast<IntegerType>(funcType.getInput(1));
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
std::string funcName("__mlir_math_fpowi");
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << baseType;
nameOS << '_' << powType;
auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
funcOp->setAttr("llvm.linkage", linkage);
funcOp.setPrivate();
Block *entryBlock = funcOp.addEntryBlock();
Region *funcBody = entryBlock->getParent();
Value bArg = funcOp.getArgument(0);
Value pArg = funcOp.getArgument(1);
builder.setInsertionPointToEnd(entryBlock);
Value oneBValue = builder.create<arith::ConstantOp>(
baseType, builder.getFloatAttr(baseType, 1.0));
Value zeroPValue = builder.create<arith::ConstantOp>(
powType, builder.getIntegerAttr(powType, 0));
Value onePValue = builder.create<arith::ConstantOp>(
powType, builder.getIntegerAttr(powType, 1));
Value minPValue = builder.create<arith::ConstantOp>(
powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
powType.getWidth())));
Value maxPValue = builder.create<arith::ConstantOp>(
powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
powType.getWidth())));
// if (p == Tp{0})
// return Tb{1};
auto pIsZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
Block *thenBlock = builder.createBlock(funcBody);
builder.create<func::ReturnOp>(oneBValue);
Block *fallthroughBlock = builder.createBlock(funcBody);
// Set up conditional branch for (p == Tp{0}).
builder.setInsertionPointToEnd(pIsZero->getBlock());
builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
builder.setInsertionPointToEnd(fallthroughBlock);
// bool isNegativePower{p < Tp{0}}
auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
zeroPValue);
// bool isMin{p == std::numeric_limits<Tp>::min()};
auto pIsMin =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
// if (isMin) {
// p = std::numeric_limits<Tp>::max();
// } else if (isNegativePower) {
// p = -p;
// }
Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
// Tb result = Tb{1};
// Tb origBase = Tb{b};
// while (true) {
// if (p & Tp{1})
// result *= b;
// p >>= Tp{1};
// if (p == Tp{0})
// break;
// b *= b;
// }
Block *loopHeader = builder.createBlock(
funcBody, funcBody->end(), {baseType, baseType, powType},
{builder.getLoc(), builder.getLoc(), builder.getLoc()});
// Set initial values of 'result', 'b' and 'p' for the loop.
builder.setInsertionPointToEnd(pInit->getBlock());
builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
// Create loop body.
Value resultTmp = loopHeader->getArgument(0);
Value baseTmp = loopHeader->getArgument(1);
Value powerTmp = loopHeader->getArgument(2);
builder.setInsertionPointToEnd(loopHeader);
// if (p & Tp{1})
auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::ne,
builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
thenBlock = builder.createBlock(funcBody);
// result *= b;
Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(thenBlock);
builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
// Set up conditional branch for (p & Tp{1}).
builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
resultTmp);
// Merged 'result'.
newResultTmp = fallthroughBlock->getArgument(0);
// p >>= Tp{1};
builder.setInsertionPointToEnd(fallthroughBlock);
Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
// if (p == Tp{0})
auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
newPowerTmp, zeroPValue);
// break;
//
// The conditional branch is finalized below with a jump to
// the loop exit block.
fallthroughBlock = builder.createBlock(funcBody);
// b *= b;
// }
builder.setInsertionPointToEnd(fallthroughBlock);
Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
// Pass new values for 'result', 'b' and 'p' to the loop header.
builder.create<cf::BranchOp>(
ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
// Set up conditional branch for early loop exit:
// if (p == Tp{0})
// break;
Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
fallthroughBlock, ValueRange{});
// if (isMin) {
// result *= origBase;
// }
newResultTmp = loopExit->getArgument(0);
thenBlock = builder.createBlock(funcBody);
fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(loopExit);
builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
/// if (isNegativePower) {
/// result = Tb{1} / result;
/// }
newResultTmp = fallthroughBlock->getArgument(0);
thenBlock = builder.createBlock(funcBody);
Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
builder.getLoc());
builder.setInsertionPointToEnd(fallthroughBlock);
builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
newResultTmp);
builder.setInsertionPointToEnd(thenBlock);
newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
builder.create<cf::BranchOp>(newResultTmp, returnBlock);
// return result;
builder.setInsertionPointToEnd(returnBlock);
builder.create<func::ReturnOp>(returnBlock->getArgument(0));
return funcOp;
}
/// Convert FPowI into a call to a local function implementing
/// the power operation. The local function computes a scalar result,
/// so vector forms of FPowI are linearized.
LogicalResult
FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
PatternRewriter &rewriter) const {
if (dyn_cast<VectorType>(op.getType()))
return rewriter.notifyMatchFailure(op, "non-scalar operation");
FunctionType funcType = getElementalFuncTypeForOp(op);
// The outlined software implementation must have been already
// generated.
func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, "missing software implementation");
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
return success();
}
/// Create function to implement the ctlz function the given \p elementType type
/// inside \p module. The \p elementType must be IntegerType, an the created
/// function has 'IntegerType (*)(IntegerType)' function type.
///
/// template <typename T>
/// T __mlir_math_ctlz_*(T x) {
/// bits = sizeof(x) * 8;
/// if (x == 0)
/// return bits;
///
/// uint32_t n = 0;
/// for (int i = 1; i < bits; ++i) {
/// if (x < 0) continue;
/// n++;
/// x <<= 1;
/// }
/// return n;
/// }
///
/// Converts to (for i32):
///
/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
/// %c_32 = arith.constant 32 : index
/// %c_0 = arith.constant 0 : i32
/// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
/// %out = scf.if %arg_eq_zero {
/// scf.yield %c_32 : i32
/// } else {
/// %c_1index = arith.constant 1 : index
/// %c_1i32 = arith.constant 1 : i32
/// %n = arith.constant 0 : i32
/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
/// %yield_val = scf.if %cond {
/// scf.yield %arg_iter, %n_iter : i32, i32
/// } else {
/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
/// %n_next = arith.addi %n_iter, %c_1i32 : i32
/// scf.yield %arg_next, %n_next : i32, i32
/// }
/// scf.yield %yield_val: i32, i32
/// }
/// scf.yield %n_out : i32
/// }
/// return %out: i32
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
LLVM_DEBUG({
DBGS() << "non-integer element type for CtlzFunc; type was: ";
elementType.print(llvm::dbgs());
});
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
Location loc = module->getLoc();
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
std::string funcName("__mlir_math_ctlz");
llvm::raw_string_ostream nameOS(funcName);
nameOS << '_' << elementType;
FunctionType funcType =
FunctionType::get(builder.getContext(), {elementType}, elementType);
auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
// LinkonceODR ensures that there is only one implementation of this function
// across all math.ctlz functions that are lowered in this way.
LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
Attribute linkage =
LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
funcOp->setAttr("llvm.linkage", linkage);
funcOp.setPrivate();
// set the insertion point to the start of the function
Block *funcBody = funcOp.addEntryBlock();
builder.setInsertionPointToStart(funcBody);
Value arg = funcOp.getArgument(0);
Type indexType = builder.getIndexType();
Value bitWidthValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, bitWidth));
Value zeroValue = builder.create<arith::ConstantOp>(
elementType, builder.getIntegerAttr(elementType, 0));
Value inputEqZero =
builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
// if input == 0, return bit width, else enter loop.
scf::IfOp ifOp = builder.create<scf::IfOp>(
elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
Value oneIndex = elseBuilder.create<arith::ConstantOp>(
indexType, elseBuilder.getIndexAttr(1));
Value oneValue = elseBuilder.create<arith::ConstantOp>(
elementType, elseBuilder.getIntegerAttr(elementType, 1));
Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
indexType, elseBuilder.getIndexAttr(bitWidth));
Value nValue = elseBuilder.create<arith::ConstantOp>(
elementType, elseBuilder.getIntegerAttr(elementType, 0));
auto loop = elseBuilder.create<scf::ForOp>(
oneIndex, bitWidthIndex, oneIndex,
// Initial values for two loop induction variables, the arg which is being
// shifted left in each iteration, and the n value which tracks the count
// of leading zeros.
ValueRange{arg, nValue},
// Callback to build the body of the for loop
// if (arg < 0) {
// continue;
// } else {
// n++;
// arg <<= 1;
// }
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
Value argIter = args[0];
Value nIter = args[1];
Value argIsNonNegative = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, argIter, zeroValue);
scf::IfOp ifOp = b.create<scf::IfOp>(
loc, argIsNonNegative,
[&](OpBuilder &b, Location loc) {
// If arg is negative, continue (effectively, break)
b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
},
[&](OpBuilder &b, Location loc) {
// Otherwise, increment n and shift arg left.
Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
});
b.create<scf::YieldOp>(loc, ifOp.getResults());
});
elseBuilder.create<scf::YieldOp>(loop.getResult(1));
builder.create<func::ReturnOp>(ifOp.getResult(0));
return funcOp;
}
/// Convert ctlz into a call to a local function implementing the ctlz
/// operation.
LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) const {
if (dyn_cast<VectorType>(op.getType()))
return rewriter.notifyMatchFailure(op, "non-scalar operation");
Type type = getElementTypeOrSelf(op.getResult().getType());
func::FuncOp elementFunc = getFuncOpCallback(op, type);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
diag << "Missing software implementation for op " << op->getName()
<< " and type " << type;
});
rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
return success();
}
namespace {
struct ConvertMathToFuncsPass
: public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
ConvertMathToFuncsPass() = default;
ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
: impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
void runOnOperation() override;
private:
// Return true, if this FPowI operation must be converted
// because the width of its exponent's type is greater than
// or equal to minWidthOfFPowIExponent option value.
bool isFPowIConvertible(math::FPowIOp op);
// Generate outlined implementations for power operations
// and store them in funcImpls map.
void generateOpImplementations();
// A map between pairs of (operation, type) deduced from operations that this
// pass will convert, and the corresponding outlined software implementations
// of these operations for the given type.
DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
};
} // namespace
bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
auto expTy =
dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
}
void ConvertMathToFuncsPass::generateOpImplementations() {
ModuleOp module = getOperation();
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
if (!convertCtlz)
return;
Type resultType = getElementTypeOrSelf(op.getResult().getType());
// Generate the software implementation of this operation,
// if it has not been generated yet.
auto key = std::pair(op->getName(), resultType);
auto entry = funcImpls.try_emplace(key, func::FuncOp{});
if (entry.second)
entry.first->second = createCtlzFunc(&module, resultType);
})
.Case<math::IPowIOp>([&](math::IPowIOp op) {
Type resultType = getElementTypeOrSelf(op.getResult().getType());
// Generate the software implementation of this operation,
// if it has not been generated yet.
auto key = std::pair(op->getName(), resultType);
auto entry = funcImpls.try_emplace(key, func::FuncOp{});
if (entry.second)
entry.first->second = createElementIPowIFunc(&module, resultType);
})
.Case<math::FPowIOp>([&](math::FPowIOp op) {
if (!isFPowIConvertible(op))
return;
FunctionType funcType = getElementalFuncTypeForOp(op);
// Generate the software implementation of this operation,
// if it has not been generated yet.
// FPowI implementations are mapped via the FunctionType
// created from the operation's result and operands.
auto key = std::pair(op->getName(), funcType);
auto entry = funcImpls.try_emplace(key, func::FuncOp{});
if (entry.second)
entry.first->second = createElementFPowIFunc(&module, funcType);
});
});
}
void ConvertMathToFuncsPass::runOnOperation() {
ModuleOp module = getOperation();
// Create outlined implementations for power operations.
generateOpImplementations();
RewritePatternSet patterns(&getContext());
patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
VecOpToScalarOp<math::CountLeadingZerosOp>>(
patterns.getContext());
// For the given Type Returns FuncOp stored in funcImpls map.
auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
auto it = funcImpls.find(std::pair(op->getName(), type));
if (it == funcImpls.end())
return {};
return it->second;
};
patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
getFuncOpByType);
if (convertCtlz)
patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
func::FuncDialect, scf::SCFDialect,
vector::VectorDialect>();
target.addIllegalOp<math::IPowIOp>();
if (convertCtlz)
target.addIllegalOp<math::CountLeadingZerosOp>();
target.addDynamicallyLegalOp<math::FPowIOp>(
[this](math::FPowIOp op) { return !isFPowIConvertible(op); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}