//===- IndexOps.cpp - Index operation definitions --------------------------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::index; //===----------------------------------------------------------------------===// // IndexDialect //===----------------------------------------------------------------------===// void IndexDialect::registerOperations() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" >(); } Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value, Type type, Location loc) { // Materialize bool constants as `i1`. if (auto boolValue = dyn_cast(value)) { if (!type.isSignlessInteger(1)) return nullptr; return b.create(loc, type, boolValue); } // Materialize integer attributes as `index`. if (auto indexValue = dyn_cast(value)) { if (!llvm::isa(indexValue.getType()) || !llvm::isa(type)) return nullptr; assert(indexValue.getValue().getBitWidth() == IndexType::kInternalStorageBitWidth); return b.create(loc, indexValue); } return nullptr; } //===----------------------------------------------------------------------===// // Fold Utilities //===----------------------------------------------------------------------===// /// Fold an index operation irrespective of the target bitwidth. The /// operation must satisfy the property: /// /// ``` /// trunc(f(a, b)) = f(trunc(a), trunc(b)) /// ``` /// /// For all values of `a` and `b`. The function accepts a lambda that computes /// the integer result, which in turn must satisfy the above property. static OpFoldResult foldBinaryOpUnchecked( ArrayRef operands, function_ref(const APInt &, const APInt &)> calculate) { assert(operands.size() == 2 && "binary operation expected 2 operands"); auto lhs = dyn_cast_if_present(operands[0]); auto rhs = dyn_cast_if_present(operands[1]); if (!lhs || !rhs) return {}; std::optional result = calculate(lhs.getValue(), rhs.getValue()); if (!result) return {}; assert(result->trunc(32) == calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32))); return IntegerAttr::get(IndexType::get(lhs.getContext()), *result); } /// Fold an index operation only if the truncated 64-bit result matches the /// 32-bit result for operations that don't satisfy the above property. These /// are operations where the upper bits of the operands can affect the lower /// bits of the results. /// /// The function accepts a lambda that computes the integer result in both /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is /// not folded. static OpFoldResult foldBinaryOpChecked( ArrayRef operands, function_ref(const APInt &, const APInt &lhs)> calculate) { assert(operands.size() == 2 && "binary operation expected 2 operands"); auto lhs = dyn_cast_if_present(operands[0]); auto rhs = dyn_cast_if_present(operands[1]); // Only fold index operands. if (!lhs || !rhs) return {}; // Compute the 64-bit result and the 32-bit result. std::optional result64 = calculate(lhs.getValue(), rhs.getValue()); if (!result64) return {}; std::optional result32 = calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)); if (!result32) return {}; // Compare the truncated 64-bit result to the 32-bit result. if (result64->trunc(32) != *result32) return {}; // The operation can be folded for these particular operands. return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64); } //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// OpFoldResult AddOp::fold(FoldAdaptor adaptor) { if (OpFoldResult result = foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) return result; if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { // Fold `add(x, 0) -> x`. if (rhs.getValue().isZero()) return getLhs(); } return {}; } //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// OpFoldResult SubOp::fold(FoldAdaptor adaptor) { if (OpFoldResult result = foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) return result; if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { // Fold `sub(x, 0) -> x`. if (rhs.getValue().isZero()) return getLhs(); } return {}; } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(FoldAdaptor adaptor) { if (OpFoldResult result = foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) return result; if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { // Fold `mul(x, 1) -> x`. if (rhs.getValue().isOne()) return getLhs(); // Fold `mul(x, 0) -> 0`. if (rhs.getValue().isZero()) return rhs; } return {}; } //===----------------------------------------------------------------------===// // DivSOp //===----------------------------------------------------------------------===// OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; return lhs.sdiv(rhs); }); } //===----------------------------------------------------------------------===// // DivUOp //===----------------------------------------------------------------------===// OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; return lhs.udiv(rhs); }); } //===----------------------------------------------------------------------===// // CeilDivSOp //===----------------------------------------------------------------------===// /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. static std::optional calculateCeilDivS(const APInt &n, const APInt &m) { // Don't fold division by zero. if (m.isZero()) return std::nullopt; // Short-circuit the zero case. if (n.isZero()) return n; bool mGtZ = m.sgt(0); if (n.sgt(0) != mGtZ) { // If the operands have different signs, compute the negative result. Signed // division overflow is not possible, since if `m == -1`, `n` can be at most // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement. return -(-n).sdiv(m); } // Otherwise, compute the positive result. Signed division overflow is not // possible since if `m == -1`, `x` will be `1`. int64_t x = mGtZ ? -1 : 1; return (n + x).sdiv(m) + 1; } OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); } //===----------------------------------------------------------------------===// // CeilDivUOp //===----------------------------------------------------------------------===// OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &n, const APInt &m) -> std::optional { // Don't fold division by zero. if (m.isZero()) return std::nullopt; // Short-circuit the zero case. if (n.isZero()) return n; return (n - 1).udiv(m) + 1; }); } //===----------------------------------------------------------------------===// // FloorDivSOp //===----------------------------------------------------------------------===// /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then /// `n*m < 0 ? -1 - (x-n)/m : n/m`. static std::optional calculateFloorDivS(const APInt &n, const APInt &m) { // Don't fold division by zero. if (m.isZero()) return std::nullopt; // Short-circuit the zero case. if (n.isZero()) return n; bool mLtZ = m.slt(0); if (n.slt(0) == mLtZ) { // If the operands have the same sign, compute the positive result. return n.sdiv(m); } // If the operands have different signs, compute the negative result. Signed // division overflow is not possible since if `m == -1`, `x` will be 1 and // `n` can be at most `INT_MAX`. int64_t x = mLtZ ? 1 : -1; return -1 - (x - n).sdiv(m); } OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); } //===----------------------------------------------------------------------===// // RemSOp //===----------------------------------------------------------------------===// OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; return lhs.srem(rhs); }); } //===----------------------------------------------------------------------===// // RemUOp //===----------------------------------------------------------------------===// OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; return lhs.urem(rhs); }); } //===----------------------------------------------------------------------===// // MaxSOp //===----------------------------------------------------------------------===// OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.sgt(rhs) ? lhs : rhs; }); } //===----------------------------------------------------------------------===// // MaxUOp //===----------------------------------------------------------------------===// OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.ugt(rhs) ? lhs : rhs; }); } //===----------------------------------------------------------------------===// // MinSOp //===----------------------------------------------------------------------===// OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.slt(rhs) ? lhs : rhs; }); } //===----------------------------------------------------------------------===// // MinUOp //===----------------------------------------------------------------------===// OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.ult(rhs) ? lhs : rhs; }); } //===----------------------------------------------------------------------===// // ShlOp //===----------------------------------------------------------------------===// OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // We cannot fold if the RHS is greater than or equal to 32 because // this would be UB in 32-bit systems but not on 64-bit systems. RHS is // already treated as unsigned. if (rhs.uge(32)) return {}; return lhs << rhs; }); } //===----------------------------------------------------------------------===// // ShrSOp //===----------------------------------------------------------------------===// OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold if RHS is greater than or equal to 32. if (rhs.uge(32)) return {}; return lhs.ashr(rhs); }); } //===----------------------------------------------------------------------===// // ShrUOp //===----------------------------------------------------------------------===// OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) -> std::optional { // Don't fold if RHS is greater than or equal to 32. if (rhs.uge(32)) return {}; return lhs.lshr(rhs); }); } //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// OpFoldResult AndOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// OpFoldResult OrOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); } //===----------------------------------------------------------------------===// // CastSOp //===----------------------------------------------------------------------===// static OpFoldResult foldCastOp(Attribute input, Type type, function_ref extFn, function_ref extOrTruncFn) { auto attr = dyn_cast_if_present(input); if (!attr) return {}; const APInt &value = attr.getValue(); if (isa(type)) { // When casting to an index type, perform the cast assuming a 64-bit target. // The result can be truncated to 32 bits as needed and always be correct. // This is because `cast32(cast64(value)) == cast32(value)`. APInt result = extOrTruncFn(value, 64); return IntegerAttr::get(type, result); } // When casting from an index type, we must ensure the results respect // `cast_t(value) == cast_t(trunc32(value))`. auto intType = cast(type); unsigned width = intType.getWidth(); // If the result type is at most 32 bits, then the cast can always be folded // because it is always a truncation. if (width <= 32) { APInt result = value.trunc(width); return IntegerAttr::get(type, result); } // If the result type is at least 64 bits, then the cast is always a // extension. The results will differ if `trunc32(value) != value)`. if (width >= 64) { if (extFn(value.trunc(32), 64) != value) return {}; APInt result = extFn(value, width); return IntegerAttr::get(type, result); } // Otherwise, we just have to check the property directly. APInt result = value.trunc(width); if (result != extFn(value.trunc(32), width)) return {}; return IntegerAttr::get(type, result); } bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { return llvm::isa(lhsTypes.front()) != llvm::isa(rhsTypes.front()); } OpFoldResult CastSOp::fold(FoldAdaptor adaptor) { return foldCastOp( adaptor.getInput(), getType(), [](const APInt &x, unsigned width) { return x.sext(width); }, [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); }); } //===----------------------------------------------------------------------===// // CastUOp //===----------------------------------------------------------------------===// bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) { return llvm::isa(lhsTypes.front()) != llvm::isa(rhsTypes.front()); } OpFoldResult CastUOp::fold(FoldAdaptor adaptor) { return foldCastOp( adaptor.getInput(), getType(), [](const APInt &x, unsigned width) { return x.zext(width); }, [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); }); } //===----------------------------------------------------------------------===// // CmpOp //===----------------------------------------------------------------------===// /// Compare two integers according to the comparison predicate. bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred) { switch (pred) { case IndexCmpPredicate::EQ: return lhs.eq(rhs); case IndexCmpPredicate::NE: return lhs.ne(rhs); case IndexCmpPredicate::SGE: return lhs.sge(rhs); case IndexCmpPredicate::SGT: return lhs.sgt(rhs); case IndexCmpPredicate::SLE: return lhs.sle(rhs); case IndexCmpPredicate::SLT: return lhs.slt(rhs); case IndexCmpPredicate::UGE: return lhs.uge(rhs); case IndexCmpPredicate::UGT: return lhs.ugt(rhs); case IndexCmpPredicate::ULE: return lhs.ule(rhs); case IndexCmpPredicate::ULT: return lhs.ult(rhs); } llvm_unreachable("unhandled IndexCmpPredicate predicate"); } /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the /// values of `cstA` and `cstB`, the max or min operation, and the comparison /// predicate. Check whether the value folds in both 32-bit and 64-bit /// arithmetic and to the same value. static std::optional foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred) { ConstantIntRanges lhsRange = TypeSwitch(lhsOp) .Case([&](MinSOp op) { return ConstantIntRanges::fromSigned( APInt::getSignedMinValue(width), cstA); }) .Case([&](MinUOp op) { return ConstantIntRanges::fromUnsigned( APInt::getMinValue(width), cstA); }) .Case([&](MaxSOp op) { return ConstantIntRanges::fromSigned( cstA, APInt::getSignedMaxValue(width)); }) .Case([&](MaxUOp op) { return ConstantIntRanges::fromUnsigned( cstA, APInt::getMaxValue(width)); }); return intrange::evaluatePred(static_cast(pred), lhsRange, ConstantIntRanges::constant(cstB)); } /// Return the result of `cmp(pred, x, x)` static bool compareSameArgs(IndexCmpPredicate pred) { switch (pred) { case IndexCmpPredicate::EQ: case IndexCmpPredicate::SGE: case IndexCmpPredicate::SLE: case IndexCmpPredicate::UGE: case IndexCmpPredicate::ULE: return true; case IndexCmpPredicate::NE: case IndexCmpPredicate::SGT: case IndexCmpPredicate::SLT: case IndexCmpPredicate::UGT: case IndexCmpPredicate::ULT: return false; } } OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { // Attempt to fold if both inputs are constant. auto lhs = dyn_cast_if_present(adaptor.getLhs()); auto rhs = dyn_cast_if_present(adaptor.getRhs()); if (lhs && rhs) { // Perform the comparison in 64-bit and 32-bit. bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); bool result32 = compareIndices(lhs.getValue().trunc(32), rhs.getValue().trunc(32), getPred()); if (result64 == result32) return BoolAttr::get(getContext(), result64); } // Fold `cmp(max/min(x, cstA), cstB)`. Operation *lhsOp = getLhs().getDefiningOp(); IntegerAttr cstA; if (isa_and_nonnull(lhsOp) && matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) { std::optional result64 = foldCmpOfMaxOrMin( lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred()); std::optional result32 = foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), rhs.getValue().trunc(32), 32, getPred()); // Fold if the 32-bit and 64-bit results are the same. if (result64 && result32 && *result64 == *result32) return BoolAttr::get(getContext(), *result64); } // Fold `cmp(x, x)` if (getLhs() == getRhs()) return BoolAttr::get(getContext(), compareSameArgs(getPred())); return {}; } /// Canonicalize /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { IntegerAttr cmpRhs; IntegerAttr cmpLhs; bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && cmpRhs.getValue().isZero(); bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && cmpLhs.getValue().isZero(); if (!rhsIsZero && !lhsIsZero) return rewriter.notifyMatchFailure(op.getLoc(), "cmp is not comparing something with 0"); SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp() : op.getRhs().getDefiningOp(); if (!subOp) return rewriter.notifyMatchFailure( op.getLoc(), "non-zero operand is not a result of subtraction"); index::CmpOp newCmp; if (rhsIsZero) newCmp = rewriter.create(op.getLoc(), op.getPred(), subOp.getLhs(), subOp.getRhs()); else newCmp = rewriter.create(op.getLoc(), op.getPred(), subOp.getRhs(), subOp.getLhs()); rewriter.replaceOp(op, newCmp); return success(); } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void ConstantOp::getAsmResultNames( function_ref setNameFn) { SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "idx" << getValueAttr().getValue(); setNameFn(getResult(), specialName.str()); } OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { build(b, state, b.getIndexType(), b.getIndexAttr(value)); } //===----------------------------------------------------------------------===// // BoolConstantOp //===----------------------------------------------------------------------===// OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void BoolConstantOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getValue() ? "true" : "false"); } //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"