//===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace index; namespace { //===----------------------------------------------------------------------===// // ConvertIndexCeilDivS //===----------------------------------------------------------------------===// /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); Value zero = rewriter.create(loc, n.getType(), 0); Value posOne = rewriter.create(loc, n.getType(), 1); Value negOne = rewriter.create(loc, n.getType(), -1); // Compute `x`. Value mPos = rewriter.create(loc, LLVM::ICmpPredicate::sgt, m, zero); Value x = rewriter.create(loc, mPos, negOne, posOne); // Compute the positive result. Value nPlusX = rewriter.create(loc, n, x); Value nPlusXDivM = rewriter.create(loc, nPlusX, m); Value posRes = rewriter.create(loc, nPlusXDivM, posOne); // Compute the negative result. Value negN = rewriter.create(loc, zero, n); Value negNDivM = rewriter.create(loc, negN, m); Value negRes = rewriter.create(loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. Value nPos = rewriter.create(loc, LLVM::ICmpPredicate::sgt, n, zero); Value sameSign = rewriter.create(loc, LLVM::ICmpPredicate::eq, nPos, mPos); Value nNonZero = rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); Value cmp = rewriter.create(loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexCeilDivU //===----------------------------------------------------------------------===// /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); Value zero = rewriter.create(loc, n.getType(), 0); Value one = rewriter.create(loc, n.getType(), 1); // Compute the non-zero result. Value minusOne = rewriter.create(loc, n, one); Value quotient = rewriter.create(loc, minusOne, m); Value plusOne = rewriter.create(loc, quotient, one); // Pick the result. Value cmp = rewriter.create(loc, LLVM::ICmpPredicate::eq, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexFloorDivS //===----------------------------------------------------------------------===// /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then /// `n*m < 0 ? -1 - (x-n)/m : n/m`. struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Value m = adaptor.getRhs(); Value zero = rewriter.create(loc, n.getType(), 0); Value posOne = rewriter.create(loc, n.getType(), 1); Value negOne = rewriter.create(loc, n.getType(), -1); // Compute `x`. Value mNeg = rewriter.create(loc, LLVM::ICmpPredicate::slt, m, zero); Value x = rewriter.create(loc, mNeg, posOne, negOne); // Compute the negative result. Value xMinusN = rewriter.create(loc, x, n); Value xMinusNDivM = rewriter.create(loc, xMinusN, m); Value negRes = rewriter.create(loc, negOne, xMinusNDivM); // Compute the positive result. Value posRes = rewriter.create(loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. Value nNeg = rewriter.create(loc, LLVM::ICmpPredicate::slt, n, zero); Value diffSign = rewriter.create(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); Value nNonZero = rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); Value cmp = rewriter.create(loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, negRes, posRes); return success(); } }; //===----------------------------------------------------------------------===// // CovnertIndexCast //===----------------------------------------------------------------------===// /// Convert a cast op. If the materialized index type is the same as the other /// type, fold away the op. Otherwise, truncate or extend the op as appropriate. /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts /// zero extend when the result bitwidth is larger. template struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern { using mlir::ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type in = adaptor.getInput().getType(); Type out = this->getTypeConverter()->convertType(op.getType()); if (in == out) rewriter.replaceOp(op, adaptor.getInput()); else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth()) rewriter.replaceOpWithNewOp(op, out, adaptor.getInput()); else rewriter.replaceOpWithNewOp(op, out, adaptor.getInput()); return success(); } }; using ConvertIndexCastS = ConvertIndexCast; using ConvertIndexCastU = ConvertIndexCast; //===----------------------------------------------------------------------===// // ConvertIndexCmp //===----------------------------------------------------------------------===// /// Assert that the LLVM comparison enum lines up with index's enum. static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs, IndexCmpPredicate rhs) { return static_cast(lhs) == static_cast(rhs); } static_assert( LLVM::getMaxEnumValForICmpPredicate() == getMaxEnumValForIndexCmpPredicate() && checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) && checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) && checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) && checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) && checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) && checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) && checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) && checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) && checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) && checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT), "LLVM ICmpPredicate mismatches IndexCmpPredicate"); struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // The LLVM enum has the same values as the index predicate enums. rewriter.replaceOpWithNewOp( op, *LLVM::symbolizeICmpPredicate(static_cast(op.getPred())), adaptor.getLhs(), adaptor.getRhs()); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexSizeOf //===----------------------------------------------------------------------===// /// Lower `index.sizeof` to a constant with the value of the index bitwidth. struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, getTypeConverter()->getIndexType(), getTypeConverter()->getIndexTypeBitwidth()); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexConstant //===----------------------------------------------------------------------===// /// Convert an index constant. Truncate the value as appropriate. struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type type = getTypeConverter()->getIndexType(); APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth()); rewriter.replaceOpWithNewOp( op, type, IntegerAttr::get(type, value)); return success(); } }; //===----------------------------------------------------------------------===// // Trivial Conversions //===----------------------------------------------------------------------===// using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexDivS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexDivU = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexRemS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexRemU = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMaxS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMaxU = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMinS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexMinU = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexShrS = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexShrU = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexAnd = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexOr = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexXor = mlir::OneToOneConvertToLLVMPattern; using ConvertIndexBoolConstant = mlir::OneToOneConvertToLLVMPattern; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void index::populateIndexToLLVMConversionPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.insert< // clang-format off ConvertIndexAdd, ConvertIndexSub, ConvertIndexMul, ConvertIndexDivS, ConvertIndexDivU, ConvertIndexRemS, ConvertIndexRemU, ConvertIndexMaxS, ConvertIndexMaxU, ConvertIndexMinS, ConvertIndexMinU, ConvertIndexShl, ConvertIndexShrS, ConvertIndexShrU, ConvertIndexAnd, ConvertIndexOr, ConvertIndexXor, ConvertIndexCeilDivS, ConvertIndexCeilDivU, ConvertIndexFloorDivS, ConvertIndexCastS, ConvertIndexCastU, ConvertIndexCmp, ConvertIndexSizeOf, ConvertIndexConstant, ConvertIndexBoolConstant // clang-format on >(typeConverter); } //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// namespace mlir { #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ConvertIndexToLLVMPass : public impl::ConvertIndexToLLVMPassBase { using Base::Base; void runOnOperation() override; }; } // namespace void ConvertIndexToLLVMPass::runOnOperation() { // Configure dialect conversion. ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); // Set LLVM lowering options. LowerToLLVMOptions options(&getContext()); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter typeConverter(&getContext(), options); // Populate patterns and run the conversion. RewritePatternSet patterns(&getContext()); populateIndexToLLVMConversionPatterns(typeConverter, patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert Index to LLVM. struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateIndexToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::index::registerConvertIndexToLLVMInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) { dialect->addInterfaces(); }); }