//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Math dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "math-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the /// given type is not a 32-bit scalar/vector type. static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc) { if (auto vectorType = dyn_cast(type)) { if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector values(vectorType.getNumElements(), value); return builder.create(loc, type, builder.getI32VectorAttr(values)); } if (type.isInteger(32)) return builder.create(loc, type, builder.getI32IntegerAttr(value)); return nullptr; } /// Check if the type is supported by math-to-spirv conversion. We expect to /// only see scalars and vectors at this point, with higher-level types already /// lowered. static bool isSupportedSourceType(Type originalType) { if (originalType.isIntOrIndexOrFloat()) return true; if (auto vecTy = dyn_cast(originalType)) { if (!vecTy.getElementType().isIntOrIndexOrFloat()) return false; if (vecTy.isScalable()) return false; if (vecTy.getRank() > 1) return false; return true; } return false; } /// Check if all `sourceOp` types are supported by math-to-spirv conversion. /// Notify of a match failure othwerise and return a `failure` result. /// This is intended to simplify type checks in `OpConversionPattern`s. static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp) { auto allTypes = llvm::to_vector(sourceOp->getOperandTypes()); llvm::append_range(allTypes, sourceOp->getResultTypes()); for (Type ty : allTypes) { if (!isSupportedSourceType(ty)) { return rewriter.notifyMatchFailure( sourceOp, llvm::formatv( "unsupported source type for Math to SPIR-V conversion: {0}", ty)); } } return success(); } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts elementwise unary, binary, and ternary standard operations to /// SPIR-V operations. Checks that source `Op` types are supported. template struct CheckedElementwiseOpPattern final : public spirv::ElementwiseOpPattern { using BasePattern = typename spirv::ElementwiseOpPattern; using BasePattern::BasePattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) return res; return BasePattern::matchAndRewrite(op, adaptor, rewriter); } }; /// Converts math.copysign to SPIR-V ops. struct CopySignPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp); failed(res)) return res; Type type = getTypeConverter()->convertType(copySignOp.getType()); if (!type) return failure(); FloatType floatType; if (auto scalarType = dyn_cast(copySignOp.getType())) { floatType = scalarType; } else if (auto vectorType = dyn_cast(copySignOp.getType())) { floatType = cast(vectorType.getElementType()); } else { return failure(); } Location loc = copySignOp.getLoc(); int bitwidth = floatType.getWidth(); Type intType = rewriter.getIntegerType(bitwidth); uint64_t intValue = uint64_t(1) << (bitwidth - 1); Value signMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, intValue)); Value valueMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); if (auto vectorType = dyn_cast(type)) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); intType = VectorType::get(count, intType); SmallVector signSplat(count, signMask); signMask = rewriter.create(loc, intType, signSplat); SmallVector valueSplat(count, valueMask); valueMask = rewriter.create(loc, intType, valueSplat); } Value lhsCast = rewriter.create(loc, intType, adaptor.getLhs()); Value rhsCast = rewriter.create(loc, intType, adaptor.getRhs()); Value value = rewriter.create( loc, intType, ValueRange{lhsCast, valueMask}); Value sign = rewriter.create( loc, intType, ValueRange{rhsCast, signMask}); Value result = rewriter.create(loc, intType, ValueRange{value, sign}); rewriter.replaceOpWithNewOp(copySignOp, type, result); return success(); } }; /// Converts math.ctlz to SPIR-V ops. /// /// SPIR-V does not have a direct operations for counting leading zeros. If /// Shader capability is supported, we can leverage GL FindUMsb to calculate /// it. struct CountLeadingZerosPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res)) return res; Type type = getTypeConverter()->convertType(countOp.getType()); if (!type) return failure(); // We can only support 32-bit integer types for now. unsigned bitwidth = 0; if (isa(type)) bitwidth = type.getIntOrFloatBitWidth(); if (auto vectorType = dyn_cast(type)) bitwidth = vectorType.getElementTypeBitWidth(); if (bitwidth != 32) return failure(); Location loc = countOp.getLoc(); Value input = adaptor.getOperand(); Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); Value msb = rewriter.create(loc, input); // We need to subtract from 31 given that the index returned by GLSL // FindUMsb is counted from the least significant bit. Theoretically this // also gives the correct result even if the integer has all zero bits, in // which case GL FindUMsb would return -1. Value subMsb = rewriter.create(loc, val31, msb); // However, certain Vulkan implementations have driver bugs for the corner // case where the input is zero. And.. it can be smart to optimize a select // only involving the corner case. So separately compute the result when the // input is either zero or one. Value subInput = rewriter.create(loc, val32, input); Value cmp = rewriter.create(loc, input, val1); rewriter.replaceOpWithNewOp(countOp, cmp, subInput, subMsb); return success(); } }; /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to /// these operations. template struct ExpM1OpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); if (LogicalResult res = checkSourceOpTypes(rewriter, operation); failed(res)) return res; Location loc = operation.getLoc(); Type type = this->getTypeConverter()->convertType(operation.getType()); if (!type) return failure(); Value exp = rewriter.create(loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); } }; /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. template struct Log1pOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); if (LogicalResult res = checkSourceOpTypes(rewriter, operation); failed(res)) return res; Location loc = operation.getLoc(); Type type = this->getTypeConverter()->convertType(operation.getType()); if (!type) return failure(); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); Value onePlus = rewriter.create(loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; /// Converts math.powf to SPIRV-Ops. struct PowFOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res)) return res; Type dstType = getTypeConverter()->convertType(powfOp.getType()); if (!dstType) return failure(); // Get the scalar float type. FloatType scalarFloatType; if (auto scalarType = dyn_cast(powfOp.getType())) { scalarFloatType = scalarType; } else if (auto vectorType = dyn_cast(powfOp.getType())) { scalarFloatType = cast(vectorType.getElementType()); } else { return failure(); } // Get int type of the same shape as the float type. Type scalarIntType = rewriter.getIntegerType(32); Type intType = scalarIntType; if (auto vectorType = dyn_cast(adaptor.getRhs().getType())) { auto shape = vectorType.getShape(); intType = VectorType::get(shape, scalarIntType); } // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); Value zero = spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter); Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); // TODO: The following just forcefully casts y into an integer value in // order to properly propagate the sign, assuming integer y cases. It // doesn't cover other cases and should be fixed. // Cast exponent to integer and calculate exponent % 2 != 0. Value intRhs = rewriter.create(loc, intType, adaptor.getRhs()); Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); Value bitwiseAndOne = rewriter.create(loc, intRhs, intOne); Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); // calculate pow based on abs(lhs)^rhs. Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); // if the exponent is odd and lhs < 0, negate the result. Value shouldNegate = rewriter.create(loc, lessThan, isOdd); rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, pow); return success(); } }; /// Converts math.round to GLSL SPIRV extended ops. struct RoundOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res)) return res; Location loc = roundOp.getLoc(); Value operand = roundOp.getOperand(); Type ty = operand.getType(); Type ety = getElementTypeOrSelf(ty); auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; if (VectorType vty = dyn_cast(ty)) { half = rewriter.create( loc, vty, DenseElementsAttr::get(vty, rewriter.getFloatAttr(ety, 0.5).getValue())); } else { half = rewriter.create( loc, ty, rewriter.getFloatAttr(ety, 0.5)); } auto abs = rewriter.create(loc, operand); auto floor = rewriter.create(loc, abs); auto sub = rewriter.create(loc, abs, floor); auto greater = rewriter.create(loc, sub, half); auto select = rewriter.create(loc, greater, one, zero); auto add = rewriter.create(loc, floor, select); rewriter.replaceOpWithNewOp(roundOp, add, operand); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns patterns.add(typeConverter, patterns.getContext()); // GLSL patterns patterns .add, ExpM1OpPattern, PowFOpPattern, RoundOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, ExpM1OpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern, CheckedElementwiseOpPattern>( typeConverter, patterns.getContext()); } } // namespace mlir