//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the folders and canonicalization patterns for SPIR-V ops. // //===----------------------------------------------------------------------===// #include #include #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" using namespace mlir; //===----------------------------------------------------------------------===// // Common utility functions //===----------------------------------------------------------------------===// /// Returns the boolean value under the hood if the given `boolAttr` is a scalar /// or splat vector bool constant. static std::optional getScalarOrSplatBoolAttr(Attribute attr) { if (!attr) return std::nullopt; if (auto boolAttr = llvm::dyn_cast(attr)) return boolAttr.getValue(); if (auto splatAttr = llvm::dyn_cast(attr)) if (splatAttr.getElementType().isInteger(1)) return splatAttr.getSplatValue(); return std::nullopt; } // Extracts an element from the given `composite` by following the given // `indices`. Returns a null Attribute if error happens. static Attribute extractCompositeElement(Attribute composite, ArrayRef indices) { // Check that given composite is a constant. if (!composite) return {}; // Return composite itself if we reach the end of the index chain. if (indices.empty()) return composite; if (auto vector = llvm::dyn_cast(composite)) { assert(indices.size() == 1 && "must have exactly one index for a vector"); return vector.getValues()[indices[0]]; } if (auto array = llvm::dyn_cast(composite)) { assert(!indices.empty() && "must have at least one index for an array"); return extractCompositeElement(array.getValue()[indices[0]], indices.drop_front()); } return {}; } static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) { bool div0 = b.isZero(); bool overflow = a.isMinSignedValue() && b.isAllOnes(); return div0 || overflow; } //===----------------------------------------------------------------------===// // TableGen'erated canonicalizers //===----------------------------------------------------------------------===// namespace { #include "SPIRVCanonicalization.inc" } // namespace //===----------------------------------------------------------------------===// // spirv.AccessChainOp //===----------------------------------------------------------------------===// namespace { /// Combines chained `spirv::AccessChainOp` operations into one /// `spirv::AccessChainOp` operation. struct CombineChainedAccessChain final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp, PatternRewriter &rewriter) const override { auto parentAccessChainOp = accessChainOp.getBasePtr().getDefiningOp(); if (!parentAccessChainOp) { return failure(); } // Combine indices. SmallVector indices(parentAccessChainOp.getIndices()); llvm::append_range(indices, accessChainOp.getIndices()); rewriter.replaceOpWithNewOp( accessChainOp, parentAccessChainOp.getBasePtr(), indices); return success(); } }; } // namespace void spirv::AccessChainOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // spirv.IAddCarry //===----------------------------------------------------------------------===// // We are required to use CompositeConstructOp to create a constant struct as // they are not yet implemented as constant, hence we can not do so in a fold. struct IAddCarryFold final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); Type constituentType = lhs.getType(); // iaddcarry (x, 0) = <0, x> if (matchPattern(rhs, m_Zero())) { Value constituents[2] = {rhs, lhs}; rewriter.replaceOpWithNewOp(op, op.getType(), constituents); return success(); } // According to the SPIR-V spec: // // Result Type must be from OpTypeStruct. The struct must have two // members... // // Member 0 of the result gets the low-order bits (full component width) of // the addition. // // Member 1 of the result gets the high-order (carry) bit of the result of // the addition. That is, it gets the value 1 if the addition overflowed // the component width, and 0 otherwise. Attribute lhsAttr; Attribute rhsAttr; if (!matchPattern(lhs, m_Constant(&lhsAttr)) || !matchPattern(rhs, m_Constant(&rhsAttr))) return failure(); auto adds = constFoldBinaryOp( {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { return a + b; }); if (!adds) return failure(); auto carrys = constFoldBinaryOp( ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) { APInt zero = APInt::getZero(a.getBitWidth()); return a.ult(b) ? (zero + 1) : zero; }); if (!carrys) return failure(); Value addsVal = rewriter.create(loc, constituentType, adds); Value carrysVal = rewriter.create(loc, constituentType, carrys); // Create empty struct Value undef = rewriter.create(loc, op.getType()); // Fill in adds at id 0 Value intermediate = rewriter.create(loc, addsVal, undef, 0); // Fill in carrys at id 1 rewriter.replaceOpWithNewOp(op, carrysVal, intermediate, 1); return success(); } }; void spirv::IAddCarryOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // spirv.[S|U]MulExtended //===----------------------------------------------------------------------===// // We are required to use CompositeConstructOp to create a constant struct as // they are not yet implemented as constant, hence we can not do so in a fold. template struct MulExtendedFold final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); Type constituentType = lhs.getType(); // [su]mulextended (x, 0) = <0, 0> if (matchPattern(rhs, m_Zero())) { Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter); Value constituents[2] = {zero, zero}; rewriter.replaceOpWithNewOp(op, op.getType(), constituents); return success(); } // According to the SPIR-V spec: // // Result Type must be from OpTypeStruct. The struct must have two // members... // // Member 0 of the result gets the low-order bits of the multiplication. // // Member 1 of the result gets the high-order bits of the multiplication. Attribute lhsAttr; Attribute rhsAttr; if (!matchPattern(lhs, m_Constant(&lhsAttr)) || !matchPattern(rhs, m_Constant(&rhsAttr))) return failure(); auto lowBits = constFoldBinaryOp( {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { return a * b; }); if (!lowBits) return failure(); auto highBits = constFoldBinaryOp( {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt c; if (IsSigned) { c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); } else { c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); } return c.extractBits(bitWidth, bitWidth); // Extract high result }); if (!highBits) return failure(); Value lowBitsVal = rewriter.create(loc, constituentType, lowBits); Value highBitsVal = rewriter.create(loc, constituentType, highBits); // Create empty struct Value undef = rewriter.create(loc, op.getType()); // Fill in lowBits at id 0 Value intermediate = rewriter.create(loc, lowBitsVal, undef, 0); // Fill in highBits at id 1 rewriter.replaceOpWithNewOp(op, highBitsVal, intermediate, 1); return success(); } }; using SMulExtendedOpFold = MulExtendedFold; void spirv::SMulExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } struct UMulExtendedOpXOne final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lhs = op.getOperand1(); Value rhs = op.getOperand2(); Type constituentType = lhs.getType(); // umulextended (x, 1) = if (matchPattern(rhs, m_One())) { Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter); Value constituents[2] = {lhs, zero}; rewriter.replaceOpWithNewOp(op, op.getType(), constituents); return success(); } return failure(); } }; using UMulExtendedOpFold = MulExtendedFold; void spirv::UMulExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// // Input: // %0 = spirv.UMod %arg0, %const32 : i32 // %1 = spirv.UMod %0, %const4 : i32 // Output: // %0 = spirv.UMod %arg0, %const32 : i32 // %1 = spirv.UMod %arg0, %const4 : i32 // The transformation is only applied if one divisor is a multiple of the other. // TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants struct UModSimplification final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override { auto prevUMod = umodOp.getOperand(0).getDefiningOp(); if (!prevUMod) return failure(); IntegerAttr prevValue; IntegerAttr currValue; if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) || !matchPattern(umodOp.getOperand(1), m_Constant(&currValue))) return failure(); APInt prevConstValue = prevValue.getValue(); APInt currConstValue = currValue.getValue(); // Ensure that one divisor is a multiple of the other. If not, fail the // transformation. if (prevConstValue.urem(currConstValue) != 0 && currConstValue.urem(prevConstValue) != 0) return failure(); // The transformation is safe. Replace the existing UMod operation with a // new UMod operation, using the original dividend and the current divisor. rewriter.replaceOpWithNewOp( umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1)); return success(); } }; void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // spirv.BitcastOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) { Value curInput = getOperand(); if (getType() == curInput.getType()) return curInput; // Look through nested bitcasts. if (auto prevCast = curInput.getDefiningOp()) { Value prevInput = prevCast.getOperand(); if (prevInput.getType() == getType()) return prevInput; getOperandMutable().assign(prevInput); return getResult(); } // TODO(kuhar): Consider constant-folding the operand attribute. return {}; } //===----------------------------------------------------------------------===// // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { Value compositeOp = getComposite(); while (auto insertOp = compositeOp.getDefiningOp()) { if (getIndices() == insertOp.getIndices()) return insertOp.getObject(); compositeOp = insertOp.getComposite(); } if (auto constructOp = compositeOp.getDefiningOp()) { auto type = llvm::cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { auto i = llvm::cast(*getIndices().begin()); if (i.getValue().getSExtValue() < static_cast(constructOp.getConstituents().size())) return constructOp.getConstituents()[i.getValue().getSExtValue()]; } } auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) { return static_cast(llvm::cast(attr).getInt()); }); return extractCompositeElement(adaptor.getComposite(), indexVector); } //===----------------------------------------------------------------------===// // spirv.Constant //===----------------------------------------------------------------------===// OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); } //===----------------------------------------------------------------------===// // spirv.IAdd //===----------------------------------------------------------------------===// OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) { // x + 0 = x if (matchPattern(getOperand2(), m_Zero())) return getOperand1(); // According to the SPIR-V spec: // // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; }); } //===----------------------------------------------------------------------===// // spirv.IMul //===----------------------------------------------------------------------===// OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) { // x * 0 == 0 if (matchPattern(getOperand2(), m_Zero())) return getOperand2(); // x * 1 = x if (matchPattern(getOperand2(), m_One())) return getOperand1(); // According to the SPIR-V spec: // // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; }); } //===----------------------------------------------------------------------===// // spirv.ISub //===----------------------------------------------------------------------===// OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) { // x - x = 0 if (getOperand1() == getOperand2()) return Builder(getContext()).getIntegerAttr(getType(), 0); // According to the SPIR-V spec: // // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) - b; }); } //===----------------------------------------------------------------------===// // spirv.SDiv //===----------------------------------------------------------------------===// OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) { // sdiv (x, 1) = x if (matchPattern(getOperand2(), m_One())) return getOperand1(); // According to the SPIR-V spec: // // Signed-integer division of Operand 1 divided by Operand 2. // Results are computed per component. Behavior is undefined if Operand 2 is // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum // representable value for the operands' type, causing signed overflow. // // So don't fold during undefined behavior. bool div0OrOverflow = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { div0OrOverflow = true; return a; } return a.sdiv(b); }); return div0OrOverflow ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.SMod //===----------------------------------------------------------------------===// OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) { // smod (x, 1) = 0 if (matchPattern(getOperand2(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // According to SPIR-V spec: // // Signed remainder operation for the remainder whose sign matches the sign // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is // undefined if Operand 2 is -1 and Operand 1 is the minimum representable // value for the operands' type, causing signed overflow. Otherwise, the // result is the remainder r of Operand 1 divided by Operand 2 where if // r ≠ 0, the sign of r is the same as the sign of Operand 2. // // So don't fold during undefined behavior bool div0OrOverflow = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { div0OrOverflow = true; return a; } APInt c = a.abs().urem(b.abs()); if (c.isZero()) return c; if (b.isNegative()) { APInt zero = APInt::getZero(c.getBitWidth()); return a.isNegative() ? (zero - c) : (b + c); } return a.isNegative() ? (b - c) : c; }); return div0OrOverflow ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.SRem //===----------------------------------------------------------------------===// OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) { // x % 1 = 0 if (matchPattern(getOperand2(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // According to SPIR-V spec: // // Signed remainder operation for the remainder whose sign matches the sign // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is // undefined if Operand 2 is -1 and Operand 1 is the minimum representable // value for the operands' type, causing signed overflow. Otherwise, the // result is the remainder r of Operand 1 divided by Operand 2 where if // r ≠ 0, the sign of r is the same as the sign of Operand 1. // Don't fold if it would do undefined behavior. bool div0OrOverflow = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (div0OrOverflow || isDivZeroOrOverflow(a, b)) { div0OrOverflow = true; return a; } return a.srem(b); }); return div0OrOverflow ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.UDiv //===----------------------------------------------------------------------===// OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) { // udiv (x, 1) = x if (matchPattern(getOperand2(), m_One())) return getOperand1(); // According to the SPIR-V spec: // // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is // undefined if Operand 2 is 0. // // So don't fold during undefined behavior. bool div0 = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (div0 || b.isZero()) { div0 = true; return a; } return a.udiv(b); }); return div0 ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.UMod //===----------------------------------------------------------------------===// OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) { // umod (x, 1) = 0 if (matchPattern(getOperand2(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // According to the SPIR-V spec: // // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is // undefined if Operand 2 is 0. // // So don't fold during undefined behavior. bool div0 = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (div0 || b.isZero()) { div0 = true; return a; } return a.urem(b); }); return div0 ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.SNegate //===----------------------------------------------------------------------===// OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) { // -(-x) = 0 - (0 - x) = x auto op = getOperand(); if (auto negateOp = op.getDefiningOp()) return negateOp->getOperand(0); // According to the SPIR-V spec: // // Signed-integer subtract of Operand from zero. return constFoldUnaryOp( adaptor.getOperands(), [](const APInt &a) { APInt zero = APInt::getZero(a.getBitWidth()); return zero - a; }); } //===----------------------------------------------------------------------===// // spirv.NotOp //===----------------------------------------------------------------------===// OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) { // !(!x) = x auto op = getOperand(); if (auto notOp = op.getDefiningOp()) return notOp->getOperand(0); // According to the SPIR-V spec: // // Complement the bits of Operand. return constFoldUnaryOp(adaptor.getOperands(), [&](APInt a) { a.flipAllBits(); return a; }); } //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) { if (std::optional rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { // x && true = x if (*rhs) return getOperand1(); // x && false = false if (!*rhs) return adaptor.getOperand2(); } return Attribute(); } //===----------------------------------------------------------------------===// // spirv.LogicalEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) { // x == x -> true if (getOperand1() == getOperand2()) { auto trueAttr = BoolAttr::get(getContext(), true); if (isa(getType())) return trueAttr; if (auto vecTy = dyn_cast(getType())) return SplatElementsAttr::get(vecTy, trueAttr); } return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a == b ? APInt::getAllOnes(1) : APInt::getZero(1); }); } //===----------------------------------------------------------------------===// // spirv.LogicalNotEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) { if (std::optional rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { // x != false -> x if (!rhs.value()) return getOperand1(); } // x == x -> false if (getOperand1() == getOperand2()) { auto falseAttr = BoolAttr::get(getContext(), false); if (isa(getType())) return falseAttr; if (auto vecTy = dyn_cast(getType())) return SplatElementsAttr::get(vecTy, falseAttr); } return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a == b ? APInt::getZero(1) : APInt::getAllOnes(1); }); } //===----------------------------------------------------------------------===// // spirv.LogicalNot //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) { // !(!x) = x auto op = getOperand(); if (auto notOp = op.getDefiningOp()) return notOp->getOperand(0); // According to the SPIR-V spec: // // Complement the bits of Operand. return constFoldUnaryOp(adaptor.getOperands(), [](const APInt &a) { APInt zero = APInt::getZero(1); return a == 1 ? zero : (zero + 1); }); } void spirv::LogicalNotOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results .add( context); } //===----------------------------------------------------------------------===// // spirv.LogicalOr //===----------------------------------------------------------------------===// OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { if (*rhs) { // x || true = true return adaptor.getOperand2(); } if (!*rhs) { // x || false = x return getOperand1(); } } return Attribute(); } //===----------------------------------------------------------------------===// // spirv.IEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) { // x == x -> true if (getOperand1() == getOperand2()) { auto trueAttr = BoolAttr::get(getContext(), true); if (isa(getType())) return trueAttr; if (auto vecTy = dyn_cast(getType())) return SplatElementsAttr::get(vecTy, trueAttr); } return constFoldBinaryOp( adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) { return a == b ? APInt::getAllOnes(1) : APInt::getZero(1); }); } //===----------------------------------------------------------------------===// // spirv.INotEqualOp //===----------------------------------------------------------------------===// OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) { // x == x -> false if (getOperand1() == getOperand2()) { auto falseAttr = BoolAttr::get(getContext(), false); if (isa(getType())) return falseAttr; if (auto vecTy = dyn_cast(getType())) return SplatElementsAttr::get(vecTy, falseAttr); } return constFoldBinaryOp( adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) { return a == b ? APInt::getZero(1) : APInt::getAllOnes(1); }); } //===----------------------------------------------------------------------===// // spirv.ShiftLeftLogical //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftLeftLogicalOp::fold( spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) { // x << 0 -> x if (matchPattern(adaptor.getOperand2(), m_Zero())) { return getOperand1(); } // Unfortunately due to below undefined behaviour can't fold 0 for Base. // Results are computed per component, and within each component, per bit... // // The result is undefined if Shift is greater than or equal to the bit width // of the components of Base. // // So we can use the APInt << method, but don't fold if undefined behaviour. bool shiftToLarge = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (shiftToLarge || b.uge(a.getBitWidth())) { shiftToLarge = true; return a; } return a << b; }); return shiftToLarge ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.ShiftRightArithmetic //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftRightArithmeticOp::fold( spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) { // x >> 0 -> x if (matchPattern(adaptor.getOperand2(), m_Zero())) { return getOperand1(); } // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base. // Results are computed per component, and within each component, per bit... // // The result is undefined if Shift is greater than or equal to the bit width // of the components of Base. // // So we can use the APInt ashr method, but don't fold if undefined behaviour. bool shiftToLarge = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (shiftToLarge || b.uge(a.getBitWidth())) { shiftToLarge = true; return a; } return a.ashr(b); }); return shiftToLarge ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.ShiftRightLogical //===----------------------------------------------------------------------===// OpFoldResult spirv::ShiftRightLogicalOp::fold( spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) { // x >> 0 -> x if (matchPattern(adaptor.getOperand2(), m_Zero())) { return getOperand1(); } // Unfortunately due to below undefined behaviour can't fold 0 for Base. // Results are computed per component, and within each component, per bit... // // The result is undefined if Shift is greater than or equal to the bit width // of the components of Base. // // So we can use the APInt lshr method, but don't fold if undefined behaviour. bool shiftToLarge = false; auto res = constFoldBinaryOp( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { if (shiftToLarge || b.uge(a.getBitWidth())) { shiftToLarge = true; return a; } return a.lshr(b); }); return shiftToLarge ? Attribute() : res; } //===----------------------------------------------------------------------===// // spirv.BitwiseAndOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) { // x & x -> x if (getOperand1() == getOperand2()) { return getOperand1(); } APInt rhsMask; if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) { // x & 0 -> 0 if (rhsMask.isZero()) return getOperand2(); // x & -> x if (rhsMask.isAllOnes()) return getOperand1(); // (UConvert x : iN to iK) & -> UConvert x if (auto zext = getOperand1().getDefiningOp()) { int valueBits = getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth(); if (rhsMask.zextOrTrunc(valueBits).isAllOnes()) return getOperand1(); } } // According to the SPIR-V spec: // // Type is a scalar or vector of integer type. // Results are computed per component, and within each component, per bit. // So we can use the APInt & method. return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a & b; }); } //===----------------------------------------------------------------------===// // spirv.BitwiseOrOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) { // x | x -> x if (getOperand1() == getOperand2()) { return getOperand1(); } APInt rhsMask; if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) { // x | 0 -> x if (rhsMask.isZero()) return getOperand1(); // x | -> if (rhsMask.isAllOnes()) return getOperand2(); } // According to the SPIR-V spec: // // Type is a scalar or vector of integer type. // Results are computed per component, and within each component, per bit. // So we can use the APInt | method. return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a | b; }); } //===----------------------------------------------------------------------===// // spirv.BitwiseXorOp //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) { // x ^ 0 -> x if (matchPattern(adaptor.getOperand2(), m_Zero())) { return getOperand1(); } // x ^ x -> 0 if (getOperand1() == getOperand2()) return Builder(getContext()).getZeroAttr(getType()); // According to the SPIR-V spec: // // Type is a scalar or vector of integer type. // Results are computed per component, and within each component, per bit. // So we can use the APInt ^ method. return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a ^ b; }); } //===----------------------------------------------------------------------===// // spirv.mlir.selection //===----------------------------------------------------------------------===// namespace { // Blocks from the given `spirv.mlir.selection` operation must satisfy the // following layout: // // +-----------------------------------------------+ // | header block | // | spirv.BranchConditionalOp %cond, ^case0, ^case1 | // +-----------------------------------------------+ // / \ // ... // // // +------------------------+ +------------------------+ // | case #0 | | case #1 | // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 | // | spirv.Branch ^merge | | spirv.Branch ^merge | // +------------------------+ +------------------------+ // // // ... // \ / // v // +-------------+ // | merge block | // +-------------+ // struct ConvertSelectionOpToSelect final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp, PatternRewriter &rewriter) const override { Operation *op = selectionOp.getOperation(); Region &body = op->getRegion(0); // Verifier allows an empty region for `spirv.mlir.selection`. if (body.empty()) { return failure(); } // Check that region consists of 4 blocks: // header block, `true` block, `false` block and merge block. if (llvm::range_size(body) != 4) { return failure(); } Block *headerBlock = selectionOp.getHeaderBlock(); if (!onlyContainsBranchConditionalOp(headerBlock)) { return failure(); } auto brConditionalOp = cast(headerBlock->front()); Block *trueBlock = brConditionalOp.getSuccessor(0); Block *falseBlock = brConditionalOp.getSuccessor(1); Block *mergeBlock = selectionOp.getMergeBlock(); if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock))) return failure(); Value trueValue = getSrcValue(trueBlock); Value falseValue = getSrcValue(falseBlock); Value ptrValue = getDstPtr(trueBlock); auto storeOpAttributes = cast(trueBlock->front())->getAttrs(); auto selectOp = rewriter.create( selectionOp.getLoc(), trueValue.getType(), brConditionalOp.getCondition(), trueValue, falseValue); rewriter.create(selectOp.getLoc(), ptrValue, selectOp.getResult(), storeOpAttributes); // `spirv.mlir.selection` is not needed anymore. rewriter.eraseOp(op); return success(); } private: // Checks that given blocks follow the following rules: // 1. Each conditional block consists of two operations, the first operation // is a `spirv.Store` and the last operation is a `spirv.Branch`. // 2. Each `spirv.Store` uses the same pointer and the same memory attributes. // 3. A control flow goes into the given merge block from the given // conditional blocks. LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock, Block *mergeBlock) const; bool onlyContainsBranchConditionalOp(Block *block) const { return llvm::hasSingleElement(*block) && isa(block->front()); } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { return lhs->getDiscardableAttrDictionary() == rhs->getDiscardableAttrDictionary() && lhs.getProperties() == rhs.getProperties(); } // Returns a source value for the given block. Value getSrcValue(Block *block) const { auto storeOp = cast(block->front()); return storeOp.getValue(); } // Returns a destination value for the given block. Value getDstPtr(Block *block) const { auto storeOp = cast(block->front()); return storeOp.getPtr(); } }; LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { // Each block must consists of 2 operations. if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) { return failure(); } auto trueBrStoreOp = dyn_cast(trueBlock->front()); auto trueBrBranchOp = dyn_cast(*std::next(trueBlock->begin())); auto falseBrStoreOp = dyn_cast(falseBlock->front()); auto falseBrBranchOp = dyn_cast(*std::next(falseBlock->begin())); if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp || !falseBrBranchOp) { return failure(); } // Checks that given type is valid for `spirv.SelectOp`. // According to SPIR-V spec: // "Before version 1.4, Result Type must be a pointer, scalar, or vector. // Starting with version 1.4, Result Type can additionally be a composite type // other than a vector." bool isScalarOrVector = llvm::cast(trueBrStoreOp.getValue().getType()) .isScalarOrVector(); // Check that each `spirv.Store` uses the same pointer, memory access // attributes and a valid type of the value. if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) || !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) { return failure(); } if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) || (falseBrBranchOp->getSuccessor(0) != mergeBlock)) { return failure(); } return success(); } } // namespace void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); }