//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::LLVM; using namespace mlir::arith; //===----------------------------------------------------------------------===// // ComplexStructBuilder implementation. //===----------------------------------------------------------------------===// static constexpr unsigned kRealPosInComplexNumberStruct = 0; static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, Location loc, Type type) { Value val = builder.create(loc, type); return ComplexStructBuilder(val); } void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, Value real) { setPtr(builder, loc, kRealPosInComplexNumberStruct, real); } Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRealPosInComplexNumberStruct); } void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, Value imaginary) { setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); } Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } //===----------------------------------------------------------------------===// // Conversion patterns. //===----------------------------------------------------------------------===// namespace { struct AbsOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ComplexStructBuilder complexStruct(adaptor.getComplex()); Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value sqNorm = rewriter.create( loc, rewriter.create(loc, real, real, fmf), rewriter.create(loc, imag, imag, fmf), fmf); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); } }; struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), *getTypeConverter(), rewriter); } }; struct CreateOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Pack real and imaginary part in a complex number struct. auto loc = complexOp.getLoc(); auto structType = typeConverter->convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, adaptor.getReal()); complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary()); rewriter.replaceOp(complexOp, {complexStruct}); return success(); } }; struct ReOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::ReOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(adaptor.getComplex()); Value real = complexStruct.real(rewriter, op.getLoc()); rewriter.replaceOp(op, real); return success(); } }; struct ImOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::ImOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(adaptor.getComplex()); Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); rewriter.replaceOp(op, imaginary); return success(); } }; struct BinaryComplexOperands { std::complex lhs; std::complex rhs; }; template BinaryComplexOperands unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) { auto loc = op.getLoc(); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; ComplexStructBuilder lhs(adaptor.getLhs()); unpacked.lhs.real(lhs.real(rewriter, loc)); unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); ComplexStructBuilder rhs(adaptor.getRhs()); unpacked.rhs.real(rhs.real(rewriter, loc)); unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); return unpacked; } struct AddOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct DivOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); Value rhsSqNorm = rewriter.create( loc, rewriter.create(loc, rhsRe, rhsRe, fmf), rewriter.create(loc, rhsIm, rhsIm, fmf), fmf); Value resultReal = rewriter.create( loc, rewriter.create(loc, lhsRe, rhsRe, fmf), rewriter.create(loc, lhsIm, rhsIm, fmf), fmf); Value resultImag = rewriter.create( loc, rewriter.create(loc, lhsIm, rhsRe, fmf), rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); result.setReal( rewriter, loc, rewriter.create(loc, resultReal, rhsSqNorm, fmf)); result.setImaginary( rewriter, loc, rewriter.create(loc, resultImag, rhsSqNorm, fmf)); rewriter.replaceOp(op, {result}); return success(); } }; struct MulOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); Value lhsIm = arg.lhs.imag(); Value real = rewriter.create( loc, rewriter.create(loc, rhsRe, lhsRe, fmf), rewriter.create(loc, rhsIm, lhsIm, fmf), fmf); Value imag = rewriter.create( loc, rewriter.create(loc, lhsIm, rhsRe, fmf), rewriter.create(loc, lhsRe, rhsIm, fmf), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct SubOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(complex::SubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, adaptor, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( op.getContext(), convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag(), fmf); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; } // namespace void mlir::populateComplexToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AbsOpConversion, AddOpConversion, ConstantOpLowering, CreateOpConversion, DivOpConversion, ImOpConversion, MulOpConversion, ReOpConversion, SubOpConversion >(converter); // clang-format on } namespace { struct ConvertComplexToLLVMPass : public impl::ConvertComplexToLLVMPassBase { using Base::Base; void runOnOperation() override; }; } // namespace void ConvertComplexToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect using the converter defined above. RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); populateComplexToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert MemRef to LLVM. struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateComplexToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::registerConvertComplexToLLVMInterface(DialectRegistry ®istry) { registry.addExtension( +[](MLIRContext *ctx, complex::ComplexDialect *dialect) { dialect->addInterfaces(); }); }