//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { struct AbsOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = op.getType(); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = rewriter.create(loc, type, adaptor.getComplex()); Value imag = rewriter.create(loc, type, adaptor.getComplex()); Value realSqr = rewriter.create(loc, real, real, fmf.getValue()); Value imagSqr = rewriter.create(loc, imag, imag, fmf.getValue()); Value sqNorm = rewriter.create(loc, realSqr, imagSqr, fmf.getValue()); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) struct Atan2OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); Type elementType = type.getElementType(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); Value rhsSquared = b.create(type, rhs, rhs); Value lhsSquared = b.create(type, lhs, lhs); Value rhsSquaredPlusLhsSquared = b.create(type, rhsSquared, lhsSquared); Value sqrtOfRhsSquaredPlusLhsSquared = b.create(type, rhsSquaredPlusLhsSquared); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value i = b.create(type, zero, one); Value iTimesLhs = b.create(i, lhs); Value rhsPlusILhs = b.create(rhs, iTimesLhs); Value divResult = b.create(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared); Value logResult = b.create(divResult); Value negativeOne = b.create( elementType, b.getFloatAttr(elementType, -1)); Value negativeI = b.create(type, zero, negativeOne); rewriter.replaceOpWithNewOp(op, negativeI, logResult); return success(); } }; template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using ResultCombiner = std::conditional_t::value, arith::AndIOp, arith::OrIOp>; LogicalResult matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()).getElementType(); Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); Value realRhs = rewriter.create(loc, type, adaptor.getRhs()); Value imagRhs = rewriter.create(loc, type, adaptor.getRhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); return success(); } }; // Default conversion which applies the BinaryStandardOp separately on the real // and imaginary parts. Can for example be used for complex::AddOp and // complex::SubOp. template struct BinaryComplexOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value realLhs = b.create(elementType, adaptor.getLhs()); Value realRhs = b.create(elementType, adaptor.getRhs()); Value resultReal = b.create(elementType, realLhs, realRhs, fmf.getValue()); Value imagLhs = b.create(elementType, adaptor.getLhs()); Value imagRhs = b.create(elementType, adaptor.getRhs()); Value resultImag = b.create(elementType, imagLhs, imagRhs, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; template struct TrigonometricOpConversion : public OpConversionPattern { using OpAdaptor = typename OpConversionPattern::OpAdaptor; using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); // Trigonometric ops use a set of common building blocks to convert to real // ops. Here we create these building blocks and call into an op-specific // implementation in the subclass to combine them. Value half = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); Value exp = rewriter.create(loc, imag); Value scaledExp = rewriter.create(loc, half, exp); Value reciprocalExp = rewriter.create(loc, half, exp); Value sin = rewriter.create(loc, real); Value cos = rewriter.create(loc, real); auto resultPair = combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter); rewriter.replaceOpWithNewOp(op, type, resultPair.first, resultPair.second); return success(); } virtual std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter) const = 0; }; struct CosOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter) const override { // Complex cosine is defined as; // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x Value sum = rewriter.create(loc, reciprocalExp, scaledExp); Value resultReal = rewriter.create(loc, sum, cos); Value diff = rewriter.create(loc, reciprocalExp, scaledExp); Value resultImag = rewriter.create(loc, diff, sin); return {resultReal, resultImag}; } }; struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); Value lhsReal = rewriter.create(loc, elementType, adaptor.getLhs()); Value lhsImag = rewriter.create(loc, elementType, adaptor.getLhs()); Value rhsReal = rewriter.create(loc, elementType, adaptor.getRhs()); Value rhsImag = rewriter.create(loc, elementType, adaptor.getRhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i) // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) / // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i)) // = ((lhsReal * rhsReal + lhsImag * rhsImag) + // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2 // // Depending on whether |rhsReal| < |rhsImag| we compute either // rhsRealImagRatio = rhsReal / rhsImag // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom // // or // // rhsImagRealRatio = rhsImag / rhsReal // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom // // See https://dl.acm.org/citation.cfm?id=368661 for more details. Value rhsRealImagRatio = rewriter.create(loc, rhsReal, rhsImag); Value rhsRealImagDenom = rewriter.create( loc, rhsImag, rewriter.create(loc, rhsRealImagRatio, rhsReal)); Value realNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsReal, rhsRealImagRatio), lhsImag); Value resultReal1 = rewriter.create(loc, realNumerator1, rhsRealImagDenom); Value imagNumerator1 = rewriter.create( loc, rewriter.create(loc, lhsImag, rhsRealImagRatio), lhsReal); Value resultImag1 = rewriter.create(loc, imagNumerator1, rhsRealImagDenom); Value rhsImagRealRatio = rewriter.create(loc, rhsImag, rhsReal); Value rhsImagRealDenom = rewriter.create( loc, rhsReal, rewriter.create(loc, rhsImagRealRatio, rhsImag)); Value realNumerator2 = rewriter.create( loc, lhsReal, rewriter.create(loc, lhsImag, rhsImagRealRatio)); Value resultReal2 = rewriter.create(loc, realNumerator2, rhsImagRealDenom); Value imagNumerator2 = rewriter.create( loc, lhsImag, rewriter.create(loc, lhsReal, rhsImagRealRatio)); Value resultImag2 = rewriter.create(loc, imagNumerator2, rhsImagRealDenom); // Consider corner cases. // Case 1. Zero denominator, numerator contains at most one NaN value. Value zero = rewriter.create( loc, elementType, rewriter.getZeroAttr(elementType)); Value rhsRealAbs = rewriter.create(loc, rhsReal); Value rhsRealIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero); Value rhsImagAbs = rewriter.create(loc, rhsImag); Value rhsImagIsZero = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero); Value lhsRealIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsReal, zero); Value lhsImagIsNotNaN = rewriter.create( loc, arith::CmpFPredicate::ORD, lhsImag, zero); Value lhsContainsNotNaNValue = rewriter.create(loc, lhsRealIsNotNaN, lhsImagIsNotNaN); Value resultIsInfinity = rewriter.create( loc, lhsContainsNotNaNValue, rewriter.create(loc, rhsRealIsZero, rhsImagIsZero)); Value inf = rewriter.create( loc, elementType, rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.getFloatSemantics()))); Value infWithSignOfRhsReal = rewriter.create(loc, inf, rhsReal); Value infinityResultReal = rewriter.create(loc, infWithSignOfRhsReal, lhsReal); Value infinityResultImag = rewriter.create(loc, infWithSignOfRhsReal, lhsImag); // Case 2. Infinite numerator, finite denominator. Value rhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf); Value rhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf); Value rhsFinite = rewriter.create(loc, rhsRealFinite, rhsImagFinite); Value lhsRealAbs = rewriter.create(loc, lhsReal); Value lhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagAbs = rewriter.create(loc, lhsImag); Value lhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsInfinite = rewriter.create(loc, lhsRealInfinite, lhsImagInfinite); Value infNumFiniteDenom = rewriter.create(loc, lhsInfinite, rhsFinite); Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value lhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsRealInfinite, one, zero), lhsReal); Value lhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, lhsImagInfinite, one, zero), lhsImag); Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsRealIsInfWithSign, rhsReal); Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsImagIsInfWithSign, rhsImag); Value resultReal3 = rewriter.create( loc, inf, rewriter.create(loc, lhsRealIsInfWithSignTimesRhsReal, lhsImagIsInfWithSignTimesRhsImag)); Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create(loc, lhsRealIsInfWithSign, rhsImag); Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create(loc, lhsImagIsInfWithSign, rhsReal); Value resultImag3 = rewriter.create( loc, inf, rewriter.create(loc, lhsImagIsInfWithSignTimesRhsReal, lhsRealIsInfWithSignTimesRhsImag)); // Case 3: Finite numerator, infinite denominator. Value lhsRealFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf); Value lhsImagFinite = rewriter.create( loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf); Value lhsFinite = rewriter.create(loc, lhsRealFinite, lhsImagFinite); Value rhsRealInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagInfinite = rewriter.create( loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsInfinite = rewriter.create(loc, rhsRealInfinite, rhsImagInfinite); Value finiteNumInfiniteDenom = rewriter.create(loc, lhsFinite, rhsInfinite); Value rhsRealIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsRealInfinite, one, zero), rhsReal); Value rhsImagIsInfWithSign = rewriter.create( loc, rewriter.create(loc, rhsImagInfinite, one, zero), rhsImag); Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsImagIsInfWithSign); Value resultReal4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsReal, rhsImagIsInfWithSignTimesLhsImag)); Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create(loc, lhsImag, rhsRealIsInfWithSign); Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create(loc, lhsReal, rhsImagIsInfWithSign); Value resultImag4 = rewriter.create( loc, zero, rewriter.create(loc, rhsRealIsInfWithSignTimesLhsImag, rhsImagIsInfWithSignTimesLhsReal)); Value realAbsSmallerThanImagAbs = rewriter.create( loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs); Value resultReal = rewriter.create( loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2); Value resultImag = rewriter.create( loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2); Value resultRealSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultReal4, resultReal); Value resultImagSpecialCase3 = rewriter.create( loc, finiteNumInfiniteDenom, resultImag4, resultImag); Value resultRealSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3); Value resultImagSpecialCase2 = rewriter.create( loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3); Value resultRealSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2); Value resultImagSpecialCase1 = rewriter.create( loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2); Value resultRealIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultReal, zero); Value resultImagIsNaN = rewriter.create( loc, arith::CmpFPredicate::UNO, resultImag, zero); Value resultIsNaN = rewriter.create(loc, resultRealIsNaN, resultImagIsNaN); Value resultRealWithSpecialCases = rewriter.create( loc, resultIsNaN, resultRealSpecialCase1, resultReal); Value resultImagWithSpecialCases = rewriter.create( loc, resultIsNaN, resultImagSpecialCase1, resultImag); rewriter.replaceOpWithNewOp( op, type, resultRealWithSpecialCases, resultImagWithSpecialCases); return success(); } }; struct ExpOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value expReal = rewriter.create(loc, real, fmf.getValue()); Value cosImag = rewriter.create(loc, imag, fmf.getValue()); Value resultReal = rewriter.create(loc, expReal, cosImag, fmf.getValue()); Value sinImag = rewriter.create(loc, imag, fmf.getValue()); Value resultImag = rewriter.create(loc, expReal, sinImag, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct Expm1OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value exp = b.create(adaptor.getComplex(), fmf.getValue()); Value real = b.create(elementType, exp); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value realMinusOne = b.create(real, one, fmf.getValue()); Value imag = b.create(elementType, exp); rewriter.replaceOpWithNewOp(op, type, realMinusOne, imag); return success(); } }; struct LogOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = b.create(elementType, adaptor.getComplex(), fmf.getValue()); Value resultReal = b.create(elementType, abs, fmf.getValue()); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value resultImag = b.create(elementType, imag, real, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct Log1pOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value half = b.create(elementType, b.getFloatAttr(elementType, 0.5)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value two = b.create(elementType, b.getFloatAttr(elementType, 2)); // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1) // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1) Value sumSq = b.create(real, real, fmf.getValue()); sumSq = b.create( sumSq, b.create(real, two, fmf.getValue()), fmf.getValue()); sumSq = b.create( sumSq, b.create(imag, imag, fmf.getValue()), fmf.getValue()); Value logSumSq = b.create(elementType, sumSq, fmf.getValue()); Value resultReal = b.create(logSumSq, half, fmf.getValue()); Value realPlusOne = b.create(real, one, fmf.getValue()); Value resultImag = b.create(elementType, imag, realPlusOne, fmf.getValue()); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct MulOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); auto fmfValue = fmf.getValue(); Value lhsReal = b.create(elementType, adaptor.getLhs()); Value lhsRealAbs = b.create(lhsReal, fmfValue); Value lhsImag = b.create(elementType, adaptor.getLhs()); Value lhsImagAbs = b.create(lhsImag, fmfValue); Value rhsReal = b.create(elementType, adaptor.getRhs()); Value rhsRealAbs = b.create(rhsReal, fmfValue); Value rhsImag = b.create(elementType, adaptor.getRhs()); Value rhsImagAbs = b.create(rhsImag, fmfValue); Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal, fmfValue); Value lhsRealTimesRhsRealAbs = b.create(lhsRealTimesRhsReal, fmfValue); Value lhsImagTimesRhsImag = b.create(lhsImag, rhsImag, fmfValue); Value lhsImagTimesRhsImagAbs = b.create(lhsImagTimesRhsImag, fmfValue); Value real = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag, fmfValue); Value lhsImagTimesRhsReal = b.create(lhsImag, rhsReal, fmfValue); Value lhsImagTimesRhsRealAbs = b.create(lhsImagTimesRhsReal, fmfValue); Value lhsRealTimesRhsImag = b.create(lhsReal, rhsImag, fmfValue); Value lhsRealTimesRhsImagAbs = b.create(lhsRealTimesRhsImag, fmfValue); Value imag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag, fmfValue); // Handle cases where the "naive" calculation results in NaN values. Value realIsNan = b.create(arith::CmpFPredicate::UNO, real, real); Value imagIsNan = b.create(arith::CmpFPredicate::UNO, imag, imag); Value isNan = b.create(realIsNan, imagIsNan); Value inf = b.create( elementType, b.getFloatAttr(elementType, APFloat::getInf(elementType.getFloatSemantics()))); // Case 1. `lhsReal` or `lhsImag` are infinite. Value lhsRealIsInf = b.create(arith::CmpFPredicate::OEQ, lhsRealAbs, inf); Value lhsImagIsInf = b.create(arith::CmpFPredicate::OEQ, lhsImagAbs, inf); Value lhsIsInf = b.create(lhsRealIsInf, lhsImagIsInf); Value rhsRealIsNan = b.create(arith::CmpFPredicate::UNO, rhsReal, rhsReal); Value rhsImagIsNan = b.create(arith::CmpFPredicate::UNO, rhsImag, rhsImag); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value lhsRealIsInfFloat = b.create(lhsRealIsInf, one, zero); lhsReal = b.create( lhsIsInf, b.create(lhsRealIsInfFloat, lhsReal), lhsReal); Value lhsImagIsInfFloat = b.create(lhsImagIsInf, one, zero); lhsImag = b.create( lhsIsInf, b.create(lhsImagIsInfFloat, lhsImag), lhsImag); Value lhsIsInfAndRhsRealIsNan = b.create(lhsIsInf, rhsRealIsNan); rhsReal = b.create( lhsIsInfAndRhsRealIsNan, b.create(zero, rhsReal), rhsReal); Value lhsIsInfAndRhsImagIsNan = b.create(lhsIsInf, rhsImagIsNan); rhsImag = b.create( lhsIsInfAndRhsImagIsNan, b.create(zero, rhsImag), rhsImag); // Case 2. `rhsReal` or `rhsImag` are infinite. Value rhsRealIsInf = b.create(arith::CmpFPredicate::OEQ, rhsRealAbs, inf); Value rhsImagIsInf = b.create(arith::CmpFPredicate::OEQ, rhsImagAbs, inf); Value rhsIsInf = b.create(rhsRealIsInf, rhsImagIsInf); Value lhsRealIsNan = b.create(arith::CmpFPredicate::UNO, lhsReal, lhsReal); Value lhsImagIsNan = b.create(arith::CmpFPredicate::UNO, lhsImag, lhsImag); Value rhsRealIsInfFloat = b.create(rhsRealIsInf, one, zero); rhsReal = b.create( rhsIsInf, b.create(rhsRealIsInfFloat, rhsReal), rhsReal); Value rhsImagIsInfFloat = b.create(rhsImagIsInf, one, zero); rhsImag = b.create( rhsIsInf, b.create(rhsImagIsInfFloat, rhsImag), rhsImag); Value rhsIsInfAndLhsRealIsNan = b.create(rhsIsInf, lhsRealIsNan); lhsReal = b.create( rhsIsInfAndLhsRealIsNan, b.create(zero, lhsReal), lhsReal); Value rhsIsInfAndLhsImagIsNan = b.create(rhsIsInf, lhsImagIsNan); lhsImag = b.create( rhsIsInfAndLhsImagIsNan, b.create(zero, lhsImag), lhsImag); Value recalc = b.create(lhsIsInf, rhsIsInf); // Case 3. One of the pairwise products of left hand side with right hand // side is infinite. Value lhsRealTimesRhsRealIsInf = b.create( arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf); Value lhsImagTimesRhsImagIsInf = b.create( arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf); Value isSpecialCase = b.create(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf); Value lhsRealTimesRhsImagIsInf = b.create( arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf); isSpecialCase = b.create(isSpecialCase, lhsRealTimesRhsImagIsInf); Value lhsImagTimesRhsRealIsInf = b.create( arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf); isSpecialCase = b.create(isSpecialCase, lhsImagTimesRhsRealIsInf); Type i1Type = b.getI1Type(); Value notRecalc = b.create( recalc, b.create(i1Type, b.getIntegerAttr(i1Type, 1))); isSpecialCase = b.create(isSpecialCase, notRecalc); Value isSpecialCaseAndLhsRealIsNan = b.create(isSpecialCase, lhsRealIsNan); lhsReal = b.create( isSpecialCaseAndLhsRealIsNan, b.create(zero, lhsReal), lhsReal); Value isSpecialCaseAndLhsImagIsNan = b.create(isSpecialCase, lhsImagIsNan); lhsImag = b.create( isSpecialCaseAndLhsImagIsNan, b.create(zero, lhsImag), lhsImag); Value isSpecialCaseAndRhsRealIsNan = b.create(isSpecialCase, rhsRealIsNan); rhsReal = b.create( isSpecialCaseAndRhsRealIsNan, b.create(zero, rhsReal), rhsReal); Value isSpecialCaseAndRhsImagIsNan = b.create(isSpecialCase, rhsImagIsNan); rhsImag = b.create( isSpecialCaseAndRhsImagIsNan, b.create(zero, rhsImag), rhsImag); recalc = b.create(recalc, isSpecialCase); recalc = b.create(isNan, recalc); // Recalculate real part. lhsRealTimesRhsReal = b.create(lhsReal, rhsReal, fmfValue); lhsImagTimesRhsImag = b.create(lhsImag, rhsImag, fmfValue); Value newReal = b.create(lhsRealTimesRhsReal, lhsImagTimesRhsImag, fmfValue); real = b.create( recalc, b.create(inf, newReal, fmfValue), real); // Recalculate imag part. lhsImagTimesRhsReal = b.create(lhsImag, rhsReal, fmfValue); lhsRealTimesRhsImag = b.create(lhsReal, rhsImag, fmfValue); Value newImag = b.create(lhsImagTimesRhsReal, lhsRealTimesRhsImag, fmfValue); imag = b.create( recalc, b.create(inf, newImag, fmfValue), imag); rewriter.replaceOpWithNewOp(op, type, real, imag); return success(); } }; struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value negReal = rewriter.create(loc, real); Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); return success(); } }; struct SinOpConversion : public TrigonometricOpConversion { using TrigonometricOpConversion::TrigonometricOpConversion; std::pair combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin, Value cos, ConversionPatternRewriter &rewriter) const override { // Complex sine is defined as; // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) // Plugging in: // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) // and defining t := exp(y) // We get: // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x Value sum = rewriter.create(loc, scaledExp, reciprocalExp); Value resultReal = rewriter.create(loc, sum, sin); Value diff = rewriter.create(loc, scaledExp, reciprocalExp); Value resultImag = rewriter.create(loc, diff, cos); return {resultReal, resultImag}; } }; // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. struct SqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto type = cast(op.getType()); Type elementType = type.getElementType(); Value arg = adaptor.getComplex(); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value absLhs = b.create(real); Value absArg = b.create(elementType, arg); Value addAbs = b.create(absLhs, absArg); Value half = b.create(elementType, b.getFloatAttr(elementType, 0.5)); Value halfAddAbs = b.create(addAbs, half); Value sqrtAddAbs = b.create(halfAddAbs); Value realIsNegative = b.create(arith::CmpFPredicate::OLT, real, zero); Value imagIsNegative = b.create(arith::CmpFPredicate::OLT, imag, zero); Value resultReal = sqrtAddAbs; Value imagDivTwoResultReal = b.create( imag, b.create(resultReal, resultReal)); Value negativeResultReal = b.create(resultReal); Value resultImag = b.create( realIsNegative, b.create(imagIsNegative, negativeResultReal, resultReal), imagDivTwoResultReal); resultReal = b.create( realIsNegative, b.create( imag, b.create(resultImag, resultImag)), resultReal); Value realIsZero = b.create(arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero); Value argIsZero = b.create(realIsZero, imagIsZero); resultReal = b.create(argIsZero, zero, resultReal); resultImag = b.create(argIsZero, zero, resultImag); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); return success(); } }; struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); Value imag = b.create(elementType, adaptor.getComplex()); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value realIsZero = b.create(arith::CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); auto abs = b.create(elementType, adaptor.getComplex()); Value realSign = b.create(real, abs); Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); rewriter.replaceOpWithNewOp(op, isZero, adaptor.getComplex(), sign); return success(); } }; struct TanOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); Value cos = rewriter.create(loc, adaptor.getComplex()); Value sin = rewriter.create(loc, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, sin, cos); return success(); } }; struct TanhOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); // The hyperbolic tangent for complex number can be calculated as follows. // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value tanhA = rewriter.create(loc, real); Value cosB = rewriter.create(loc, imag); Value sinB = rewriter.create(loc, imag); Value tanB = rewriter.create(loc, sinB, cosB); Value numerator = rewriter.create(loc, type, tanhA, tanB); Value one = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 1)); Value mul = rewriter.create(loc, tanhA, tanB); Value denominator = rewriter.create(loc, type, one, mul); rewriter.replaceOpWithNewOp(op, numerator, denominator); return success(); } }; struct ConjOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = rewriter.create(loc, elementType, adaptor.getComplex()); Value negImag = rewriter.create(loc, elementType, imag); rewriter.replaceOpWithNewOp(op, type, real, negImag); return success(); } }; /// Coverts x^y = (a+bi)^(c+di) to /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, ComplexType type, Value a, Value b, Value c, Value d) { auto elementType = cast(type.getElementType()); // Compute (a*a+b*b)^(0.5c). Value aaPbb = builder.create( builder.create(a, a), builder.create(b, b)); Value half = builder.create( elementType, builder.getFloatAttr(elementType, 0.5)); Value halfC = builder.create(half, c); Value aaPbbTohalfC = builder.create(aaPbb, halfC); // Compute exp(-d*atan2(b,a)). Value negD = builder.create(d); Value argX = builder.create(b, a); Value negDArgX = builder.create(negD, argX); Value eToNegDArgX = builder.create(negDArgX); // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)). Value coeff = builder.create(aaPbbTohalfC, eToNegDArgX); // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b). Value lnAaPbb = builder.create(aaPbb); Value halfD = builder.create(half, d); Value q = builder.create( builder.create(c, argX), builder.create(halfD, lnAaPbb)); Value cosQ = builder.create(q); Value sinQ = builder.create(q); Value zero = builder.create( elementType, builder.getFloatAttr(elementType, 0)); Value one = builder.create( elementType, builder.getFloatAttr(elementType, 1)); Value xEqZero = builder.create(arith::CmpFPredicate::OEQ, aaPbb, zero); Value yGeZero = builder.create( builder.create(arith::CmpFPredicate::OGE, c, zero), builder.create(arith::CmpFPredicate::OEQ, d, zero)); Value cEqZero = builder.create(arith::CmpFPredicate::OEQ, c, zero); Value complexZero = builder.create(type, zero, zero); Value complexOne = builder.create(type, one, zero); Value complexOther = builder.create( type, builder.create(coeff, cosQ), builder.create(coeff, sinQ)); // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see // Branch Cuts for Complex Elementary Functions or Much Ado About // Nothing's Sign Bit, W. Kahan, Section 10. return builder.create( builder.create(xEqZero, yGeZero), builder.create(cEqZero, complexOne, complexZero), complexOther); } struct PowOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast(adaptor.getLhs().getType()); auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getLhs()); Value b = builder.create(elementType, adaptor.getLhs()); Value c = builder.create(elementType, adaptor.getRhs()); Value d = builder.create(elementType, adaptor.getRhs()); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); return success(); } }; struct RsqrtOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast(adaptor.getComplex().getType()); auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getComplex()); Value b = builder.create(elementType, adaptor.getComplex()); Value c = builder.create( elementType, builder.getFloatAttr(elementType, -0.5)); Value d = builder.create( elementType, builder.getFloatAttr(elementType, 0)); rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); return success(); } }; struct AngleOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = op.getType(); Value real = rewriter.create(loc, type, adaptor.getComplex()); Value imag = rewriter.create(loc, type, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, imag, real); return success(); } }; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { // clang-format off patterns.add< AbsOpConversion, AngleOpConversion, Atan2OpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, ComparisonOpConversion, ComparisonOpConversion, ConjOpConversion, CosOpConversion, DivOpConversion, ExpOpConversion, Expm1OpConversion, Log1pOpConversion, LogOpConversion, MulOpConversion, NegOpConversion, SignOpConversion, SinOpConversion, SqrtOpConversion, TanOpConversion, TanhOpConversion, PowOpConversion, RsqrtOpConversion >(patterns.getContext()); // clang-format on } namespace { struct ConvertComplexToStandardPass : public impl::ConvertComplexToStandardBase { void runOnOperation() override; }; void ConvertComplexToStandardPass::runOnOperation() { // Convert to the Standard dialect using the converter defined above. RewritePatternSet patterns(&getContext()); populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } } // namespace std::unique_ptr mlir::createConvertComplexToStandardPass() { return std::make_unique(); }