419 lines
17 KiB
C++
419 lines
17 KiB
C++
|
//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
|
||
|
#include "../SPIRVCommon/Pattern.h"
|
||
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace index;
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Trivial Conversions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
|
||
|
using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
|
||
|
using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
|
||
|
using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
|
||
|
using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
|
||
|
using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
|
||
|
using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
|
||
|
using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
|
||
|
using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
|
||
|
using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
|
||
|
using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
|
||
|
|
||
|
using ConvertIndexShl =
|
||
|
spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
|
||
|
using ConvertIndexShrS =
|
||
|
spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
|
||
|
using ConvertIndexShrU =
|
||
|
spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
|
||
|
|
||
|
/// It is the case that when we convert bitwise operations to SPIR-V operations
|
||
|
/// we must take into account the special pattern in SPIR-V that if the
|
||
|
/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
|
||
|
/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
|
||
|
/// index.add is never a boolean operation so we can directly convert it to the
|
||
|
/// Bitwise[And|Or]Op.
|
||
|
using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
|
||
|
using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
|
||
|
using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertConstantBool
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Converts index.bool.constant operation to spirv.Constant.
|
||
|
struct ConvertIndexConstantBoolOpPattern final
|
||
|
: OpConversionPattern<BoolConstantOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
|
||
|
op.getValueAttr());
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertConstant
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
|
||
|
// when required.
|
||
|
struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
|
||
|
Type indexType = typeConverter->getIndexType();
|
||
|
|
||
|
APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
|
||
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||
|
op, indexType, IntegerAttr::get(indexType, value));
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexCeilDivS
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
|
||
|
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
|
||
|
/// conversion in IndexToLLVM.
|
||
|
struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
Location loc = op.getLoc();
|
||
|
Value n = adaptor.getLhs();
|
||
|
Type n_type = n.getType();
|
||
|
Value m = adaptor.getRhs();
|
||
|
|
||
|
// Define the constants
|
||
|
Value zero = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, 0));
|
||
|
Value posOne = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, 1));
|
||
|
Value negOne = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, -1));
|
||
|
|
||
|
// Compute `x`.
|
||
|
Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
|
||
|
Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
|
||
|
|
||
|
// Compute the positive result.
|
||
|
Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
|
||
|
Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
|
||
|
Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
|
||
|
|
||
|
// Compute the negative result.
|
||
|
Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
|
||
|
Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
|
||
|
Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
|
||
|
|
||
|
// Pick the positive result if `n` and `m` have the same sign and `n` is
|
||
|
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
|
||
|
Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
|
||
|
Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
|
||
|
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
|
||
|
Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
|
||
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexCeilDivU
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
|
||
|
/// from the equivalent conversion in IndexToLLVM.
|
||
|
struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
Location loc = op.getLoc();
|
||
|
Value n = adaptor.getLhs();
|
||
|
Type n_type = n.getType();
|
||
|
Value m = adaptor.getRhs();
|
||
|
|
||
|
// Define the constants
|
||
|
Value zero = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, 0));
|
||
|
Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
|
||
|
IntegerAttr::get(n_type, 1));
|
||
|
|
||
|
// Compute the non-zero result.
|
||
|
Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
|
||
|
Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
|
||
|
Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
|
||
|
|
||
|
// Pick the result
|
||
|
Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
|
||
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexFloorDivS
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
|
||
|
/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
|
||
|
/// in IndexToLLVM.
|
||
|
struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
Location loc = op.getLoc();
|
||
|
Value n = adaptor.getLhs();
|
||
|
Type n_type = n.getType();
|
||
|
Value m = adaptor.getRhs();
|
||
|
|
||
|
// Define the constants
|
||
|
Value zero = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, 0));
|
||
|
Value posOne = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, 1));
|
||
|
Value negOne = rewriter.create<spirv::ConstantOp>(
|
||
|
loc, n_type, IntegerAttr::get(n_type, -1));
|
||
|
|
||
|
// Compute `x`.
|
||
|
Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
|
||
|
Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
|
||
|
|
||
|
// Compute the negative result
|
||
|
Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
|
||
|
Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
|
||
|
Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
|
||
|
|
||
|
// Compute the positive result.
|
||
|
Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
|
||
|
|
||
|
// Pick the negative result if `n` and `m` have different signs and `n` is
|
||
|
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
|
||
|
Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
|
||
|
Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
|
||
|
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
|
||
|
|
||
|
Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
|
||
|
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexCast
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Convert a cast op. If the materialized index type is the same as the other
|
||
|
/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
|
||
|
/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
|
||
|
/// zero extend when the result bitwidth is larger.
|
||
|
template <typename CastOp, typename ConvertOp>
|
||
|
struct ConvertIndexCast final : OpConversionPattern<CastOp> {
|
||
|
using OpConversionPattern<CastOp>::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
|
||
|
Type indexType = typeConverter->getIndexType();
|
||
|
|
||
|
Type srcType = adaptor.getInput().getType();
|
||
|
Type dstType = op.getType();
|
||
|
if (isa<IndexType>(srcType)) {
|
||
|
srcType = indexType;
|
||
|
}
|
||
|
if (isa<IndexType>(dstType)) {
|
||
|
dstType = indexType;
|
||
|
}
|
||
|
|
||
|
if (srcType == dstType) {
|
||
|
rewriter.replaceOp(op, adaptor.getInput());
|
||
|
} else {
|
||
|
rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
|
||
|
adaptor.getOperands());
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
|
||
|
using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexCmp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
// Helper template to replace the operation
|
||
|
template <typename ICmpOp>
|
||
|
static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) {
|
||
|
rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
// We must convert the predicates to the corresponding int comparions.
|
||
|
switch (op.getPred()) {
|
||
|
case IndexCmpPredicate::EQ:
|
||
|
return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::NE:
|
||
|
return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::SGE:
|
||
|
return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::SGT:
|
||
|
return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::SLE:
|
||
|
return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::SLT:
|
||
|
return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::UGE:
|
||
|
return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::UGT:
|
||
|
return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::ULE:
|
||
|
return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
|
||
|
case IndexCmpPredicate::ULT:
|
||
|
return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConvertIndexSizeOf
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
|
||
|
struct ConvertIndexSizeOf final : OpConversionPattern<SizeOfOp> {
|
||
|
using OpConversionPattern::OpConversionPattern;
|
||
|
|
||
|
LogicalResult
|
||
|
matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
|
||
|
ConversionPatternRewriter &rewriter) const override {
|
||
|
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
|
||
|
Type indexType = typeConverter->getIndexType();
|
||
|
unsigned bitwidth = typeConverter->getIndexTypeBitwidth();
|
||
|
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||
|
op, indexType, IntegerAttr::get(indexType, bitwidth));
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
} // namespace
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Pattern Population
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||
|
RewritePatternSet &patterns) {
|
||
|
patterns.add<
|
||
|
// clang-format off
|
||
|
ConvertIndexAdd,
|
||
|
ConvertIndexSub,
|
||
|
ConvertIndexMul,
|
||
|
ConvertIndexDivS,
|
||
|
ConvertIndexDivU,
|
||
|
ConvertIndexRemS,
|
||
|
ConvertIndexRemU,
|
||
|
ConvertIndexMaxS,
|
||
|
ConvertIndexMaxU,
|
||
|
ConvertIndexMinS,
|
||
|
ConvertIndexMinU,
|
||
|
ConvertIndexShl,
|
||
|
ConvertIndexShrS,
|
||
|
ConvertIndexShrU,
|
||
|
ConvertIndexAnd,
|
||
|
ConvertIndexOr,
|
||
|
ConvertIndexXor,
|
||
|
ConvertIndexConstantBoolOpPattern,
|
||
|
ConvertIndexConstantOpPattern,
|
||
|
ConvertIndexCeilDivSPattern,
|
||
|
ConvertIndexCeilDivUPattern,
|
||
|
ConvertIndexFloorDivSPattern,
|
||
|
ConvertIndexCastS,
|
||
|
ConvertIndexCastU,
|
||
|
ConvertIndexCmpPattern,
|
||
|
ConvertIndexSizeOf
|
||
|
>(typeConverter, patterns.getContext());
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ODS-Generated Definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
namespace mlir {
|
||
|
#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS
|
||
|
#include "mlir/Conversion/Passes.h.inc"
|
||
|
} // namespace mlir
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Pass Definition
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
namespace {
|
||
|
struct ConvertIndexToSPIRVPass
|
||
|
: public impl::ConvertIndexToSPIRVPassBase<ConvertIndexToSPIRVPass> {
|
||
|
using Base::Base;
|
||
|
|
||
|
void runOnOperation() override {
|
||
|
Operation *op = getOperation();
|
||
|
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
|
||
|
std::unique_ptr<SPIRVConversionTarget> target =
|
||
|
SPIRVConversionTarget::get(targetAttr);
|
||
|
|
||
|
SPIRVConversionOptions options;
|
||
|
options.use64bitIndex = this->use64bitIndex;
|
||
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||
|
|
||
|
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
|
||
|
// in patterns for other dialects.
|
||
|
target->addLegalOp<UnrealizedConversionCastOp>();
|
||
|
|
||
|
// Allow the spirv operations we are converting to
|
||
|
target->addLegalDialect<spirv::SPIRVDialect>();
|
||
|
// Fail hard when there are any remaining 'index' ops.
|
||
|
target->addIllegalDialect<index::IndexDialect>();
|
||
|
|
||
|
RewritePatternSet patterns(&getContext());
|
||
|
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
|
||
|
|
||
|
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
|
||
|
signalPassFailure();
|
||
|
}
|
||
|
};
|
||
|
} // namespace
|