//===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Functor to resolve the function name corresponding to the given complex // result type. struct ComplexTypeResolver { std::optional operator()(Type type) const { auto complexType = cast(type); auto elementType = complexType.getElementType(); if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; } }; // Functor to resolve the function name corresponding to the given float result // type. struct FloatTypeResolver { std::optional operator()(Type type) const { auto elementType = cast(type); if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; } }; // Pattern to convert scalar complex operations to calls to libm functions. // Additionally the libm function signatures are declared. // TypeResolver is a functor returning the libm function name according to the // expected type double or float. template struct ScalarOpToLibmCall : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc, StringRef doubleFunc, PatternBenefit benefit) : OpRewritePattern(context, benefit), floatFunc(floatFunc), doubleFunc(doubleFunc){}; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; private: std::string floatFunc, doubleFunc; }; } // namespace template LogicalResult ScalarOpToLibmCall::matchAndRewrite( Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto isDouble = TypeResolver()(op.getType()); if (!isDouble.has_value()) return failure(); auto name = *isDouble ? doubleFunc : floatFunc; auto opFunc = dyn_cast_or_null( SymbolTable::lookupSymbolIn(module, name)); // Forward declare function if it hasn't already been if (!opFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&module->getRegion(0).front()); auto opFunctionTy = FunctionType::get( rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); opFunc = rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } assert(isa(SymbolTable::lookupSymbolIn(module, name))); rewriter.replaceOpWithNewOp(op, name, op.getType(), op->getOperands()); return success(); } void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add>(patterns.getContext(), "cpowf", "cpow", benefit); patterns.add>(patterns.getContext(), "csqrtf", "csqrt", benefit); patterns.add>(patterns.getContext(), "ctanhf", "ctanh", benefit); patterns.add>(patterns.getContext(), "ccosf", "ccos", benefit); patterns.add>(patterns.getContext(), "csinf", "csin", benefit); patterns.add>(patterns.getContext(), "conjf", "conj", benefit); patterns.add>(patterns.getContext(), "clogf", "clog", benefit); patterns.add>( patterns.getContext(), "cabsf", "cabs", benefit); patterns.add>( patterns.getContext(), "cargf", "carg", benefit); patterns.add>(patterns.getContext(), "ctanf", "ctan", benefit); } namespace { struct ConvertComplexToLibmPass : public impl::ConvertComplexToLibmBase { void runOnOperation() override; }; } // namespace void ConvertComplexToLibmPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertComplexToLibmPass() { return std::make_unique(); }