//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { template using ConvertFastMath = arith::AttrConvertFastMathToLLVM; template using ConvertFMFMathToLLVMPattern = VectorConvertToLLVMPattern; using AbsFOpLowering = ConvertFMFMathToLLVMPattern; using CeilOpLowering = ConvertFMFMathToLLVMPattern; using CopySignOpLowering = ConvertFMFMathToLLVMPattern; using CosOpLowering = ConvertFMFMathToLLVMPattern; using CtPopFOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = ConvertFMFMathToLLVMPattern; using ExpOpLowering = ConvertFMFMathToLLVMPattern; using FloorOpLowering = ConvertFMFMathToLLVMPattern; using FmaOpLowering = ConvertFMFMathToLLVMPattern; using Log10OpLowering = ConvertFMFMathToLLVMPattern; using Log2OpLowering = ConvertFMFMathToLLVMPattern; using LogOpLowering = ConvertFMFMathToLLVMPattern; using PowFOpLowering = ConvertFMFMathToLLVMPattern; using FPowIOpLowering = ConvertFMFMathToLLVMPattern; using RoundEvenOpLowering = ConvertFMFMathToLLVMPattern; using RoundOpLowering = ConvertFMFMathToLLVMPattern; using SinOpLowering = ConvertFMFMathToLLVMPattern; using SqrtOpLowering = ConvertFMFMathToLLVMPattern; using FTruncOpLowering = ConvertFMFMathToLLVMPattern; // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. template struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = IntOpWithFlagLowering; LogicalResult matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); if (!isa(operandType)) { rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), false); return success(); } auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { return rewriter.create(loc, llvm1DVectorTy, operands[0], false); }, rewriter); } }; using CountLeadingZerosOpLowering = IntOpWithFlagLowering; using CountTrailingZerosOpLowering = IntOpWithFlagLowering; using AbsIOpLowering = IntOpWithFlagLowering; // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath expAttrs(op); ConvertFastMath subAttrs(op); if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto exp = rewriter.create(loc, adaptor.getOperand(), expAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, operandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); } auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto exp = rewriter.create( loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); return rewriter.create( loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); }, rewriter); } }; // A `log1p` is converted into `log(1 + ...)`. struct Log1pOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return rewriter.notifyMatchFailure(op, "unsupported operand type"); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath addAttrs(op); ConvertFastMath logAttrs(op); if (!isa(operandType)) { LLVM::ConstantOp one = LLVM::isCompatibleVectorType(operandType) ? rewriter.create( loc, operandType, SplatElementsAttr::get(cast(resultType), floatOne)) : rewriter.create(loc, operandType, floatOne); auto add = rewriter.create( loc, operandType, ValueRange{one, adaptor.getOperand()}, addAttrs.getAttrs()); rewriter.replaceOpWithNewOp(op, operandType, ValueRange{add}, logAttrs.getAttrs()); return success(); } auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto add = rewriter.create(loc, llvm1DVectorTy, ValueRange{one, operands[0]}, addAttrs.getAttrs()); return rewriter.create( loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath sqrtAttrs(op); ConvertFastMath divAttrs(op); if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, adaptor.getOperand(), sqrtAttrs.getAttrs()); rewriter.replaceOpWithNewOp( op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); } auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto sqrt = rewriter.create( loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); return rewriter.create( loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); }, rewriter); } }; struct ConvertMathToLLVMPass : public impl::ConvertMathToLLVMPassBase { using Base::Base; void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p) { if (approximateLog1p) patterns.add(converter); // clang-format off patterns.add< AbsFOpLowering, AbsIOpLowering, CeilOpLowering, CopySignOpLowering, CosOpLowering, CountLeadingZerosOpLowering, CountTrailingZerosOpLowering, CtPopFOpLowering, Exp2OpLowering, ExpM1OpLowering, ExpOpLowering, FPowIOpLowering, FloorOpLowering, FmaOpLowering, Log10OpLowering, Log2OpLowering, LogOpLowering, PowFOpLowering, RoundEvenOpLowering, RoundOpLowering, RsqrtOpLowering, SinOpLowering, SqrtOpLowering, FTruncOpLowering >(converter); // clang-format on } //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert Math to LLVM. struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateMathToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::registerConvertMathToLLVMInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) { dialect->addInterfaces(); }); }