//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include namespace mlir { #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = VectorConvertToLLVMPattern; using ExtFOpLowering = VectorConvertToLLVMPattern; using ExtSIOpLowering = VectorConvertToLLVMPattern; using ExtUIOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; using MaximumFOpLowering = VectorConvertToLLVMPattern; using MaxNumFOpLowering = VectorConvertToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; using MinimumFOpLowering = VectorConvertToLLVMPattern; using MinNumFOpLowering = VectorConvertToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; using ShLIOpLowering = VectorConvertToLLVMPattern; using ShRSIOpLowering = VectorConvertToLLVMPattern; using ShRUIOpLowering = VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; using TruncIOpLowering = VectorConvertToLLVMPattern; using UIToFPOpLowering = VectorConvertToLLVMPattern; using XOrIOpLowering = VectorConvertToLLVMPattern; //===----------------------------------------------------------------------===// // Op Lowering Patterns //===----------------------------------------------------------------------===// /// Directly lower to LLVM op. struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// The lowering of index_cast becomes an integer conversion since index /// becomes an integer. If the bit width of the source and target integer /// types is the same, just erase the cast. If the target type is wider, /// sign-extend the value, otherwise truncate it. template struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; using IndexCastOpSILowering = IndexCastOpLowering; using IndexCastOpUILowering = IndexCastOpLowering; struct AddUIExtendedOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; template struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; using MulSIExtendedOpLowering = MulIExtendedOpLowering; using MulUIExtendedOpLowering = MulIExtendedOpLowering; struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // ConstantOpLowering //===----------------------------------------------------------------------===// LogicalResult ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), *getTypeConverter(), rewriter); } //===----------------------------------------------------------------------===// // IndexCastOpLowering //===----------------------------------------------------------------------===// template LogicalResult IndexCastOpLowering::matchAndRewrite( OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Type resultType = op.getResult().getType(); Type targetElementType = this->typeConverter->convertType(getElementTypeOrSelf(resultType)); Type sourceElementType = this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); if (targetBits == sourceBits) { rewriter.replaceOp(op, adaptor.getIn()); return success(); } // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); if (!isa(operandType)) { Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); else rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); return success(); } if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { return rewriter.create(op.getLoc(), llvm1DVectorTy, adaptor.getIn()); } return rewriter.create(op.getLoc(), llvm1DVectorTy, adaptor.getIn()); }, rewriter); } //===----------------------------------------------------------------------===// // AddUIExtendedOpLowering //===----------------------------------------------------------------------===// LogicalResult AddUIExtendedOpLowering::matchAndRewrite( arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type sumResultType = op.getSum().getType(); Type overflowResultType = op.getOverflow().getType(); if (!LLVM::isCompatibleType(operandType)) return failure(); MLIRContext *ctx = rewriter.getContext(); Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. if (!isa(operandType)) { Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); Value addOverflow = rewriter.create( loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = rewriter.create(loc, addOverflow, 0); Value overflowExtracted = rewriter.create(loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } if (!isa(sumResultType)) return rewriter.notifyMatchFailure(loc, "expected vector result types"); return rewriter.notifyMatchFailure(loc, "ND vector types are not supported yet"); } //===----------------------------------------------------------------------===// // MulIExtendedOpLowering //===----------------------------------------------------------------------===// template LogicalResult MulIExtendedOpLowering::matchAndRewrite( ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Type resultType = adaptor.getLhs().getType(); if (!LLVM::isCompatibleType(resultType)) return failure(); Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. Because LLVM does not have a // matching extended multiplication intrinsic, perform regular multiplication // on operands zero-extended to i(2*N) bits, and truncate the results back to // iN types. if (!isa(resultType)) { // Shift amount necessary to extract the high bits from widened result. TypedAttr shiftValAttr; if (auto intTy = dyn_cast(resultType)) { unsigned resultBitwidth = intTy.getWidth(); auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); } else { auto vecTy = cast(resultType); unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); auto attrTy = VectorType::get( vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); shiftValAttr = SplatElementsAttr::get( attrTy, APInt(resultBitwidth * 2, resultBitwidth)); } Type wideType = shiftValAttr.getType(); assert(LLVM::isCompatibleType(wideType) && "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t; Value lhsExt = rewriter.create(loc, wideType, adaptor.getLhs()); Value rhsExt = rewriter.create(loc, wideType, adaptor.getRhs()); Value mulExt = rewriter.create(loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. Value low = rewriter.create(loc, resultType, mulExt); Value shiftVal = rewriter.create(loc, shiftValAttr); Value highExt = rewriter.create(loc, mulExt, shiftVal); Value high = rewriter.create(loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); } if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return rewriter.notifyMatchFailure(op, "ND vector types are not supported yet"); } //===----------------------------------------------------------------------===// // CmpIOpLowering //===----------------------------------------------------------------------===// // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums // share numerical values so just cast. template static LLVMPredType convertCmpPredicate(PredType pred) { return static_cast(pred); } LogicalResult CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); return success(); } if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } //===----------------------------------------------------------------------===// // CmpFOpLowering //===----------------------------------------------------------------------===// LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); LLVM::FastmathFlags fmf = arith::convertArithFastMathFlagsToLLVM(op.getFastmath()); // Handle the scalar and 1D vector cases. if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); return success(); } if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); }, rewriter); } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ArithToLLVMConversionPass : public impl::ArithToLLVMConversionPassBase { using Base::Base; void runOnOperation() override { LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(&getContext()); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert MemRef to LLVM. struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::arith::registerConvertArithToLLVMInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { dialect->addInterfaces(); }); } //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::arith::populateArithToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AddFOpLowering, AddIOpLowering, AndIOpLowering, AddUIExtendedOpLowering, BitcastOpLowering, ConstantOpLowering, CmpFOpLowering, CmpIOpLowering, DivFOpLowering, DivSIOpLowering, DivUIOpLowering, ExtFOpLowering, ExtSIOpLowering, ExtUIOpLowering, FPToSIOpLowering, FPToUIOpLowering, IndexCastOpSILowering, IndexCastOpUILowering, MaximumFOpLowering, MaxNumFOpLowering, MaxSIOpLowering, MaxUIOpLowering, MinimumFOpLowering, MinNumFOpLowering, MinSIOpLowering, MinUIOpLowering, MulFOpLowering, MulIOpLowering, MulSIExtendedOpLowering, MulUIExtendedOpLowering, NegFOpLowering, OrIOpLowering, RemFOpLowering, RemSIOpLowering, RemUIOpLowering, SelectOpLowering, ShLIOpLowering, ShRSIOpLowering, ShRUIOpLowering, SIToFPOpLowering, SubFOpLowering, SubIOpLowering, TruncFOpLowering, TruncIOpLowering, UIToFPOpLowering, XOrIOpLowering >(converter); // clang-format on }