//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include #include namespace mlir { #define GEN_PASS_DEF_CONVERTARITHTOSPIRV #include "mlir/Conversion/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "arith-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Conversion Helpers //===----------------------------------------------------------------------===// /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { if (auto boolAttr = dyn_cast(srcAttr)) return boolAttr; if (auto intAttr = dyn_cast(srcAttr)) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return {}; } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if conversion fails. static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder) { // If the source number uses less active bits than the target bitwidth, then // it should be safe to convert. if (srcAttr.getValue().isIntN(dstType.getWidth())) return builder.getIntegerAttr(dstType, srcAttr.getInt()); // XXX: Try again by interpreting the source number as a signed value. // Although integers in the standard dialect are signless, they can represent // a signed number. It's the operation decides how to interpret. This is // dangerous, but it seems there is no good way of handling this if we still // want to change the bitwidth. Emit a message at least. if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" << dstAttr << "' for type '" << dstType << "'\n"); return dstAttr; } LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' illegal: cannot fit into target type '" << dstType << "'\n"); return {}; } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if `dstType` is not 32-bit or conversion fails. static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder) { // Only support converting to float for now. if (!dstType.isF32()) return FloatAttr(); // Try to convert the source floating-point number to single precision. APFloat dstVal = srcAttr.getValue(); bool losesInfo = false; APFloat::opStatus status = dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); if (status != APFloat::opOK || losesInfo) { LLVM_DEBUG(llvm::dbgs() << srcAttr << " illegal: cannot fit into converted type '" << dstType << "'\n"); return FloatAttr(); } return builder.getF32FloatAttr(dstVal.convertToFloat()); } /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); if (type.isInteger(1)) return true; if (auto vecType = dyn_cast(type)) return vecType.getElementType().isInteger(1); return false; } /// Creates a scalar/vector integer constant. static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc) { if (auto vectorType = dyn_cast(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); return builder.create(loc, vectorType, attr); } if (auto intType = dyn_cast(type)) return builder.create( loc, type, builder.getIntegerAttr(type, value)); return nullptr; } /// Returns true if scalar/vector type `a` and `b` have the same number of /// bitwidth. static bool hasSameBitwidth(Type a, Type b) { auto getNumBitwidth = [](Type type) { unsigned bw = 0; if (type.isIntOrFloat()) bw = type.getIntOrFloatBitWidth(); else if (auto vecType = dyn_cast(type)) bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); return bw; }; unsigned aBW = getNumBitwidth(a); unsigned bBW = getNumBitwidth(b); return aBW != 0 && bBW != 0 && aBW == bBW; } /// Returns a source type conversion failure for `srcType` and operation `op`. static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert source type '{0}'", srcType)); } /// Returns a source type conversion failure for the result type of `op`. static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { assert(op->getNumResults() == 1); return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); } // TODO: Move to some common place? static std::string getDecorationString(spirv::Decoration decor) { return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor)); } namespace { /// Converts elementwise unary, binary and ternary arith operations to SPIR-V /// operations. Op can potentially support overflow flags. template struct ElementwiseArithOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() <= 3); auto converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert type {0} for SPIR-V", op.getType())); } if (SPIRVOp::template hasTrait() && !getElementTypeOrSelf(op.getType()).isIndex() && dstType != op.getType()) { return op.emitError("bitwidth emulation is not implemented yet on " "unsigned op pattern version"); } auto overflowFlags = arith::IntegerOverflowFlags::none; if (auto overflowIface = dyn_cast(*op)) { if (converter->getTargetEnv().allows( spirv::Extension::SPV_KHR_no_integer_wrap_decoration)) overflowFlags = overflowIface.getOverflowAttr().getValue(); } auto newOp = rewriter.template replaceOpWithNewOp( op, dstType, adaptor.getOperands()); if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw)) newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap), rewriter.getUnitAttr()); if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw)) newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap), rewriter.getUnitAttr()); return success(); } }; //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// /// Converts composite arith.constant operation to spirv.Constant. struct ConstantCompositeOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = dyn_cast(constOp.getType()); if (!srcType || srcType.getNumElements() == 1) return failure(); // arith.constant should only have vector or tenor types. assert((isa(srcType))); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); auto dstElementsAttr = dyn_cast(constOp.getValue()); if (!dstElementsAttr) return failure(); ShapedType dstAttrType = dstElementsAttr.getType(); // If the composite type has more than one dimensions, perform // linearization. if (srcType.getRank() > 1) { if (isa(srcType)) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } else { // TODO: add support for large vectors. return failure(); } } Type srcElemType = srcType.getElementType(); Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. if (auto arrayType = dyn_cast(dstType)) dstElemType = arrayType.getElementType(); else dstElemType = cast(dstType).getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; if (isa(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { FloatAttr dstAttr = convertFloatAttr(srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } else if (srcElemType.isInteger(1)) { return failure(); } else { for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { IntegerAttr dstAttr = convertIntegerAttr( srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } // Unfortunately, we cannot use dialect-specific types for element // attributes; element attributes only works with builtin types. So we // need to prepare another converted builtin types for the destination // elements attribute. if (isa(dstAttrType)) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } rewriter.replaceOpWithNewOp(constOp, dstType, dstElementsAttr); return success(); } }; /// Converts scalar arith.constant operation to spirv.Constant. struct ConstantScalarOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = constOp.getType(); if (auto shapedType = dyn_cast(srcType)) { if (shapedType.getNumElements() != 1) return failure(); srcType = shapedType.getElementType(); } if (!srcType.isIntOrIndexOrFloat()) return failure(); Attribute cstAttr = constOp.getValue(); if (auto elementsAttr = dyn_cast(cstAttr)) cstAttr = elementsAttr.getSplatValue(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); // Floating-point types. if (isa(srcType)) { auto srcAttr = cast(cstAttr); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); } rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // Bool type. if (srcType.isInteger(1)) { // arith.constant can use 0/1 instead of true/false for i1 values. We need // to handle that here. auto dstAttr = convertBoolAttr(cstAttr, rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. auto srcAttr = cast(cstAttr); IntegerAttr dstAttr = convertIntegerAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } }; //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// /// Returns signed remainder for `lhs` and `rhs` and lets the result follow /// the sign of `signOperand`. /// /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod /// if either operand can be negative. Emulate it via spirv.UMod. template static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { assert(lhs.getType() == rhs.getType()); assert(lhs == signOperand || rhs == signOperand); Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. Value lhsAbs = builder.create(loc, type, lhs); Value rhsAbs = builder.create(loc, type, rhs); Value abs = builder.create(loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) isPositive = builder.create(loc, lhs, lhsAbs); else isPositive = builder.create(loc, rhs, rhsAbs); Value absNegate = builder.create(loc, type, abs); return builder.create(loc, type, isPositive, abs, absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to Vulkan /// restrictions over spirv.SRem and spirv.SMod. struct RemSIOpGLPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value result = emulateSignedRemainder( op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], adaptor.getOperands()[0], rewriter); rewriter.replaceOp(op, result); return success(); } }; /// Converts arith.remsi to OpenCL SPIR-V ops. struct RemSIOpCLPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value result = emulateSignedRemainder( op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], adaptor.getOperands()[0], rewriter); rewriter.replaceOp(op, result); return success(); } }; //===----------------------------------------------------------------------===// // BitwiseOp //===----------------------------------------------------------------------===// /// Converts bitwise operations to SPIR-V operations. This is a special pattern /// other than the BinaryOpPatternPattern because if the operands are boolean /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template struct BitwiseOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 2); Type dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { rewriter.template replaceOpWithNewOp( op, dstType, adaptor.getOperands()); } else { rewriter.template replaceOpWithNewOp( op, dstType, adaptor.getOperands()); } return success(); } }; //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// /// Converts arith.xori to SPIR-V operations. struct XOrIOpLogicalPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 2); if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); } }; /// Converts arith.xori to SPIR-V operations if the type of source is i1 or /// vector of i1. struct XOrIOpBooleanPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 2); if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands()); return success(); } }; //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector /// of i1. struct UIToFPI1Pattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtSII1Pattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value operand = adaptor.getIn(); if (!isBoolScalarOrVector(operand.getType())) return failure(); Location loc = op.getLoc(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); Value allOnes; if (auto intTy = dyn_cast(dstType)) { unsigned componentBitwidth = intTy.getWidth(); allOnes = rewriter.create( loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); allOnes = rewriter.create( loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { return rewriter.notifyMatchFailure( loc, llvm::formatv("unhandled type: {0}", dstType)); } Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); rewriter.replaceOpWithNewOp(op, dstType, operand, allOnes, zero); return success(); } }; /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor /// vector of i1. struct ExtSIPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getIn().getType(); if (isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (dstType == srcType) { // We can have the same source and destination type due to type emulation. // Perform bit shifting to make sure we have the proper leading set bits. unsigned srcBW = getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); unsigned dstBW = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); assert(srcBW < dstBW); Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, rewriter, op.getLoc()); // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. auto shiftLOp = rewriter.create( op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. rewriter.replaceOpWithNewOp( op, dstType, shiftLOp, shiftSize); } else { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); } return success(); } }; //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// /// Converts arith.extui to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtUII1Pattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; /// Converts arith.extui for cases where the type of source is neither i1 nor /// vector of i1. struct ExtUIPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getIn().getType(); if (isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (dstType == srcType) { // We can have the same source and destination type due to type emulation. // Perform bit masking to make sure we don't pollute downstream consumers // with unwanted bits. Here we need to use the original source type's // bitwidth. unsigned bitwidth = getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); Value mask = getScalarOrVectorConstInt( dstType, llvm::maskTrailingOnes(bitwidth), rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), mask); } else { rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); } return success(); } }; //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector /// of i1. struct TruncII1Pattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (!isBoolScalarOrVector(dstType)) return failure(); Location loc = op.getLoc(); auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); Value maskedSrc = rewriter.create( loc, srcType, adaptor.getOperands()[0], mask); Value isOne = rewriter.create(loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); return success(); } }; /// Converts arith.trunci for cases where the type of result is neither i1 /// nor vector of i1. struct TruncIPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = adaptor.getIn().getType(); Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (isBoolScalarOrVector(dstType)) return failure(); if (dstType == srcType) { // We can have the same source and destination type due to type emulation. // Perform bit masking to make sure we don't pollute downstream consumers // with unwanted bits. Here we need to use the original result type's // bitwidth. unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); Value mask = getScalarOrVectorConstInt( dstType, llvm::maskTrailingOnes(bw), rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, dstType, adaptor.getIn(), mask); } else { // Given this is truncation, either SConvertOp or UConvertOp works. rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); } return success(); } }; //===----------------------------------------------------------------------===// // TypeCastingOp //===----------------------------------------------------------------------===// /// Converts type-casting standard operations to SPIR-V operations. template struct TypeCastingOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Type srcType = adaptor.getOperands().front().getType(); Type dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(op, adaptor.getOperands().front()); } else { rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); } return success(); } }; //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// /// Converts integer compare operation on i1 type operands to SPIR-V ops. class CmpIOpBooleanPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = op.getLhs().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { case arith::CmpIPredicate::eq: { rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); return success(); } case arith::CmpIPredicate::ne: { rewriter.replaceOpWithNewOp( op, adaptor.getLhs(), adaptor.getRhs()); return success(); } case arith::CmpIPredicate::uge: case arith::CmpIPredicate::ugt: case arith::CmpIPredicate::ule: case arith::CmpIPredicate::ult: { // There are no direct corresponding instructions in SPIR-V for such // cases. Extend them to 32-bit and do comparision then. Type type = rewriter.getI32Type(); if (auto vectorType = dyn_cast(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = rewriter.create(op.getLoc(), type, adaptor.getLhs()); Value extRhs = rewriter.create(op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, extRhs); return success(); } default: break; } return failure(); } }; /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = op.getLhs().getType(); if (isBoolScalarOrVector(srcType)) return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \ !hasSameBitwidth(srcType, dstType)) { \ return op.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ adaptor.getRhs()); \ return success(); DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH } return failure(); } }; //===----------------------------------------------------------------------===// // CmpFOpPattern //===----------------------------------------------------------------------===// /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ adaptor.getRhs()); \ return success(); // Ordered. DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); // Unordered. DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH default: break; } return failure(); } }; /// Converts floating point NaN check to SPIR-V ops. This pattern requires /// Kernel capability. class CmpFOpNanKernelPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (op.getPredicate() == arith::CmpFPredicate::ORD) { rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); return success(); } if (op.getPredicate() == arith::CmpFPredicate::UNO) { rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); return success(); } return failure(); } }; /// Converts floating point NaN check to SPIR-V ops. This pattern does not /// require additional capability. class CmpFOpNanNonePattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (op.getPredicate() != arith::CmpFPredicate::ORD && op.getPredicate() != arith::CmpFPredicate::UNO) return failure(); Location loc = op.getLoc(); auto *converter = getTypeConverter(); Value replace; if (converter->getOptions().enableFastMathMode) { if (op.getPredicate() == arith::CmpFPredicate::ORD) { // Ordered comparsion checks if neither operand is NaN. replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); } else { // Unordered comparsion checks if either operand is NaN. replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) replace = rewriter.create(loc, replace); } rewriter.replaceOp(op, replace); return success(); } }; //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// /// Converts arith.addui_extended to spirv.IAddCarry. class AddUIExtendedOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); Value result = rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); Value sumResult = rewriter.create( loc, result, llvm::ArrayRef(0)); Value carryValue = rewriter.create( loc, result, llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); Value carryResult = rewriter.create(loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); } }; //===----------------------------------------------------------------------===// // MulIExtendedOp //===----------------------------------------------------------------------===// /// Converts arith.mul*i_extended to spirv.*MulExtended. template class MulIExtendedOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); Value low = rewriter.create(loc, result, llvm::ArrayRef(0)); Value high = rewriter.create(loc, result, llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); } }; //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// /// Converts arith.select to spirv.Select. class SelectOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), adaptor.getTrueValue(), adaptor.getFalseValue()); return success(); } }; //===----------------------------------------------------------------------===// // MinimumFOp, MaximumFOp //===----------------------------------------------------------------------===// /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or /// spirv.CL.fmax/fmin. template class MinimumMaximumFOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); // arith.maximumf/minimumf: // "if one of the arguments is NaN, then the result is also NaN." // spirv.GL.FMax/FMin // "which operand is the result is undefined if one of the operands // is a NaN." // spirv.CL.fmax/fmin: // "If one argument is a NaN, Fmin returns the other argument." Location loc = op.getLoc(); Value spirvOp = rewriter.create(loc, dstType, adaptor.getOperands()); if (converter->getOptions().enableFastMathMode) { rewriter.replaceOp(op, spirvOp); return success(); } Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); Value select1 = rewriter.create(loc, dstType, lhsIsNan, adaptor.getLhs(), spirvOp); Value select2 = rewriter.create(loc, dstType, rhsIsNan, adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); } }; //===----------------------------------------------------------------------===// // MinNumFOp, MaxNumFOp //===----------------------------------------------------------------------===// /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or /// spirv.CL.fmax/fmin. template class MinNumMaxNumFOpPattern final : public OpConversionPattern { template constexpr bool shouldInsertNanGuards() const { return llvm::is_one_of::value; } public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) return getTypeConversionFailure(rewriter, op); // arith.maxnumf/minnumf: // "If one of the arguments is NaN, then the result is the other // argument." // spirv.GL.FMax/FMin // "which operand is the result is undefined if one of the operands // is a NaN." // spirv.CL.fmax/fmin: // "If one argument is a NaN, Fmin returns the other argument." Location loc = op.getLoc(); Value spirvOp = rewriter.create(loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards() || converter->getOptions().enableFastMathMode) { rewriter.replaceOp(op, spirvOp); return success(); } Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); Value select1 = rewriter.create(loc, dstType, lhsIsNan, adaptor.getRhs(), spirvOp); Value select2 = rewriter.create(loc, dstType, rhsIsNan, adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::arith::populateArithToSPIRVPatterns( SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // clang-format off patterns.add< ConstantCompositeOpPattern, ConstantScalarOpPattern, ElementwiseArithOpPattern, ElementwiseArithOpPattern, ElementwiseArithOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, RemSIOpGLPattern, RemSIOpCLPattern, BitwiseOpPattern, BitwiseOpPattern, XOrIOpLogicalPattern, XOrIOpBooleanPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, ExtUIPattern, ExtUII1Pattern, ExtSIPattern, ExtSII1Pattern, TypeCastingOpPattern, TruncIPattern, TruncII1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, UIToFPI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, AddUIExtendedOpPattern, MulIExtendedOpPattern, MulIExtendedOpPattern, SelectOpPattern, MinimumMaximumFOpPattern, MinimumMaximumFOpPattern, MinNumMaxNumFOpPattern, MinNumMaxNumFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, MinimumMaximumFOpPattern, MinimumMaximumFOpPattern, MinNumMaxNumFOpPattern, MinNumMaxNumFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern >(typeConverter, patterns.getContext()); // clang-format on // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel // capability is available. patterns.add(typeConverter, patterns.getContext(), /*benefit=*/2); } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ConvertArithToSPIRVPass : public impl::ConvertArithToSPIRVBase { void runOnOperation() override { Operation *op = getOperation(); spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; options.enableFastMathMode = this->enableFastMath; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. target->addLegalOp(); // Fail hard when there are any remaining 'arith' ops. target->addIllegalDialect(); RewritePatternSet patterns(&getContext()); arith::populateArithToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, *target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr> mlir::arith::createConvertArithToSPIRVPass() { return std::make_unique(); }