//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace index; namespace { //===----------------------------------------------------------------------===// // Trivial Conversions //===----------------------------------------------------------------------===// using ConvertIndexAdd = spirv::ElementwiseOpPattern; using ConvertIndexSub = spirv::ElementwiseOpPattern; using ConvertIndexMul = spirv::ElementwiseOpPattern; using ConvertIndexDivS = spirv::ElementwiseOpPattern; using ConvertIndexDivU = spirv::ElementwiseOpPattern; using ConvertIndexRemS = spirv::ElementwiseOpPattern; using ConvertIndexRemU = spirv::ElementwiseOpPattern; using ConvertIndexMaxS = spirv::ElementwiseOpPattern; using ConvertIndexMaxU = spirv::ElementwiseOpPattern; using ConvertIndexMinS = spirv::ElementwiseOpPattern; using ConvertIndexMinU = spirv::ElementwiseOpPattern; using ConvertIndexShl = spirv::ElementwiseOpPattern; using ConvertIndexShrS = spirv::ElementwiseOpPattern; using ConvertIndexShrU = spirv::ElementwiseOpPattern; /// It is the case that when we convert bitwise operations to SPIR-V operations /// we must take into account the special pattern in SPIR-V that if the /// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise, /// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However, /// index.add is never a boolean operation so we can directly convert it to the /// Bitwise[And|Or]Op. using ConvertIndexAnd = spirv::ElementwiseOpPattern; using ConvertIndexOr = spirv::ElementwiseOpPattern; using ConvertIndexXor = spirv::ElementwiseOpPattern; //===----------------------------------------------------------------------===// // ConvertConstantBool //===----------------------------------------------------------------------===// // Converts index.bool.constant operation to spirv.Constant. struct ConvertIndexConstantBoolOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getValueAttr()); return success(); } }; //===----------------------------------------------------------------------===// // ConvertConstant //===----------------------------------------------------------------------===// // Converts index.constant op to spirv.Constant. Will truncate from i64 to i32 // when required. struct ConvertIndexConstantOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *typeConverter = this->template getTypeConverter(); Type indexType = typeConverter->getIndexType(); APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth()); rewriter.replaceOpWithNewOp( op, indexType, IntegerAttr::get(indexType, value)); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexCeilDivS //===----------------------------------------------------------------------===// /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent /// conversion in IndexToLLVM. struct ConvertIndexCeilDivSPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Type n_type = n.getType(); Value m = adaptor.getRhs(); // Define the constants Value zero = rewriter.create( loc, n_type, IntegerAttr::get(n_type, 0)); Value posOne = rewriter.create( loc, n_type, IntegerAttr::get(n_type, 1)); Value negOne = rewriter.create( loc, n_type, IntegerAttr::get(n_type, -1)); // Compute `x`. Value mPos = rewriter.create(loc, m, zero); Value x = rewriter.create(loc, mPos, negOne, posOne); // Compute the positive result. Value nPlusX = rewriter.create(loc, n, x); Value nPlusXDivM = rewriter.create(loc, nPlusX, m); Value posRes = rewriter.create(loc, nPlusXDivM, posOne); // Compute the negative result. Value negN = rewriter.create(loc, zero, n); Value negNDivM = rewriter.create(loc, negN, m); Value negRes = rewriter.create(loc, zero, negNDivM); // Pick the positive result if `n` and `m` have the same sign and `n` is // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. Value nPos = rewriter.create(loc, n, zero); Value sameSign = rewriter.create(loc, nPos, mPos); Value nNonZero = rewriter.create(loc, n, zero); Value cmp = rewriter.create(loc, sameSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexCeilDivU //===----------------------------------------------------------------------===// /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken /// from the equivalent conversion in IndexToLLVM. struct ConvertIndexCeilDivUPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Type n_type = n.getType(); Value m = adaptor.getRhs(); // Define the constants Value zero = rewriter.create( loc, n_type, IntegerAttr::get(n_type, 0)); Value one = rewriter.create(loc, n_type, IntegerAttr::get(n_type, 1)); // Compute the non-zero result. Value minusOne = rewriter.create(loc, n, one); Value quotient = rewriter.create(loc, minusOne, m); Value plusOne = rewriter.create(loc, quotient, one); // Pick the result Value cmp = rewriter.create(loc, n, zero); rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexFloorDivS //===----------------------------------------------------------------------===// /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then /// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion /// in IndexToLLVM. struct ConvertIndexFloorDivSPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value n = adaptor.getLhs(); Type n_type = n.getType(); Value m = adaptor.getRhs(); // Define the constants Value zero = rewriter.create( loc, n_type, IntegerAttr::get(n_type, 0)); Value posOne = rewriter.create( loc, n_type, IntegerAttr::get(n_type, 1)); Value negOne = rewriter.create( loc, n_type, IntegerAttr::get(n_type, -1)); // Compute `x`. Value mNeg = rewriter.create(loc, m, zero); Value x = rewriter.create(loc, mNeg, posOne, negOne); // Compute the negative result Value xMinusN = rewriter.create(loc, x, n); Value xMinusNDivM = rewriter.create(loc, xMinusN, m); Value negRes = rewriter.create(loc, negOne, xMinusNDivM); // Compute the positive result. Value posRes = rewriter.create(loc, n, m); // Pick the negative result if `n` and `m` have different signs and `n` is // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. Value nNeg = rewriter.create(loc, n, zero); Value diffSign = rewriter.create(loc, nNeg, mNeg); Value nNonZero = rewriter.create(loc, n, zero); Value cmp = rewriter.create(loc, diffSign, nNonZero); rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); return success(); } }; //===----------------------------------------------------------------------===// // ConvertIndexCast //===----------------------------------------------------------------------===// /// Convert a cast op. If the materialized index type is the same as the other /// type, fold away the op. Otherwise, use the Convert SPIR-V operation. /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts /// zero extend when the result bitwidth is larger. template struct ConvertIndexCast final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *typeConverter = this->template getTypeConverter(); Type indexType = typeConverter->getIndexType(); Type srcType = adaptor.getInput().getType(); Type dstType = op.getType(); if (isa(srcType)) { srcType = indexType; } if (isa(dstType)) { dstType = indexType; } if (srcType == dstType) { rewriter.replaceOp(op, adaptor.getInput()); } else { rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); } return success(); } }; using ConvertIndexCastS = ConvertIndexCast; using ConvertIndexCastU = ConvertIndexCast; //===----------------------------------------------------------------------===// // ConvertIndexCmp //===----------------------------------------------------------------------===// // Helper template to replace the operation template static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); return success(); } struct ConvertIndexCmpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We must convert the predicates to the corresponding int comparions. switch (op.getPred()) { case IndexCmpPredicate::EQ: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::NE: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::SGE: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::SGT: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::SLE: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::SLT: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::UGE: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::UGT: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::ULE: return rewriteCmpOp(op, adaptor, rewriter); case IndexCmpPredicate::ULT: return rewriteCmpOp(op, adaptor, rewriter); } } }; //===----------------------------------------------------------------------===// // ConvertIndexSizeOf //===----------------------------------------------------------------------===// /// Lower `index.sizeof` to a constant with the value of the index bitwidth. struct ConvertIndexSizeOf final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *typeConverter = this->template getTypeConverter(); Type indexType = typeConverter->getIndexType(); unsigned bitwidth = typeConverter->getIndexTypeBitwidth(); rewriter.replaceOpWithNewOp( op, indexType, IntegerAttr::get(indexType, bitwidth)); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< // clang-format off ConvertIndexAdd, ConvertIndexSub, ConvertIndexMul, ConvertIndexDivS, ConvertIndexDivU, ConvertIndexRemS, ConvertIndexRemU, ConvertIndexMaxS, ConvertIndexMaxU, ConvertIndexMinS, ConvertIndexMinU, ConvertIndexShl, ConvertIndexShrS, ConvertIndexShrU, ConvertIndexAnd, ConvertIndexOr, ConvertIndexXor, ConvertIndexConstantBoolOpPattern, ConvertIndexConstantOpPattern, ConvertIndexCeilDivSPattern, ConvertIndexCeilDivUPattern, ConvertIndexFloorDivSPattern, ConvertIndexCastS, ConvertIndexCastU, ConvertIndexCmpPattern, ConvertIndexSizeOf >(typeConverter, patterns.getContext()); } //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// namespace mlir { #define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ConvertIndexToSPIRVPass : public impl::ConvertIndexToSPIRVPassBase { using Base::Base; void runOnOperation() override { Operation *op = getOperation(); spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); SPIRVConversionOptions options; options.use64bitIndex = this->use64bitIndex; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. target->addLegalOp(); // Allow the spirv operations we are converting to target->addLegalDialect(); // Fail hard when there are any remaining 'index' ops. target->addIllegalDialect(); RewritePatternSet patterns(&getContext()); index::populateIndexToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, *target, std::move(patterns)))) signalPassFailure(); } }; } // namespace