//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLIBM #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Pattern to convert vector operations to scalar operations. This is needed as // libm calls require scalars. template struct VecOpToScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; // Pattern to promote an op of a smaller floating point type to F32. template struct PromoteOpToF32 : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; // Pattern to convert scalar math operations to calls to libm functions. // Additionally the libm function signatures are declared. template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, StringRef doubleFunc) : OpRewritePattern(context), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; private: std::string floatFunc, doubleFunc; }; template void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx, StringRef floatFunc, StringRef doubleFunc) { patterns.add, PromoteOpToF32>(ctx); patterns.add>(ctx, floatFunc, doubleFunc); } } // namespace template LogicalResult VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); auto loc = op.getLoc(); auto vecType = dyn_cast(opType); if (!vecType) return failure(); if (!vecType.hasRank()) return failure(); auto shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector strides = computeStrides(shape); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(linearIndex, strides); SmallVector operands; for (auto input : op->getOperands()) operands.push_back( rewriter.create(loc, input, positions)); Value scalarOp = rewriter.create(loc, vecType.getElementType(), operands); result = rewriter.create(loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); } template LogicalResult PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); if (!isa(opType)) return failure(); auto loc = op.getLoc(); auto f32 = rewriter.getF32Type(); auto extendedOperands = llvm::to_vector( llvm::map_range(op->getOperands(), [&](Value operand) -> Value { return rewriter.create(loc, f32, operand); })); auto newOp = rewriter.create(loc, f32, extendedOperands); rewriter.replaceOpWithNewOp(op, opType, newOp); return success(); } template LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); if (!isa(type)) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); // By definition Math dialect operations imply LLVM's "readnone" // function attribute, so we can set it here to provide more // optimization opportunities (e.g. LICM) for backends targeting LLVM IR. // This will have to be changed, when strict FP behavior is supported // by Math dialect. opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), UnitAttr::get(rewriter.getContext())); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); rewriter.replaceOpWithNewOp(op, name, op.getType(), op->getOperands()); return success(); } void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); populatePatternsForOp(patterns, ctx, "acosf", "acos"); populatePatternsForOp(patterns, ctx, "acoshf", "acosh"); populatePatternsForOp(patterns, ctx, "asinf", "asin"); populatePatternsForOp(patterns, ctx, "asinhf", "asinh"); populatePatternsForOp(patterns, ctx, "atan2f", "atan2"); populatePatternsForOp(patterns, ctx, "atanf", "atan"); populatePatternsForOp(patterns, ctx, "atanhf", "atanh"); populatePatternsForOp(patterns, ctx, "cbrtf", "cbrt"); populatePatternsForOp(patterns, ctx, "ceilf", "ceil"); populatePatternsForOp(patterns, ctx, "cosf", "cos"); populatePatternsForOp(patterns, ctx, "coshf", "cosh"); populatePatternsForOp(patterns, ctx, "erff", "erf"); populatePatternsForOp(patterns, ctx, "expm1f", "expm1"); populatePatternsForOp(patterns, ctx, "floorf", "floor"); populatePatternsForOp(patterns, ctx, "log1pf", "log1p"); populatePatternsForOp(patterns, ctx, "roundevenf", "roundeven"); populatePatternsForOp(patterns, ctx, "roundf", "round"); populatePatternsForOp(patterns, ctx, "sinf", "sin"); populatePatternsForOp(patterns, ctx, "sinhf", "sinh"); populatePatternsForOp(patterns, ctx, "tanf", "tan"); populatePatternsForOp(patterns, ctx, "tanhf", "tanh"); populatePatternsForOp(patterns, ctx, "truncf", "trunc"); } namespace { struct ConvertMathToLibmPass : public impl::ConvertMathToLibmBase { void runOnOperation() override; }; } // namespace void ConvertMathToLibmPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); populateMathToLibmConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertMathToLibmPass() { return std::make_unique(); }