//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/APSInt.h" namespace mlir { bool isZeroIndex(OpFoldResult v) { if (!v) return false; if (auto attr = llvm::dyn_cast_if_present(v)) { IntegerAttr intAttr = dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) return cst.value() == 0; return false; } std::tuple, SmallVector, SmallVector> getOffsetsSizesAndStrides(ArrayRef ranges) { SmallVector offsets, sizes, strides; offsets.reserve(ranges.size()); sizes.reserve(ranges.size()); strides.reserve(ranges.size()); for (const auto &[offset, size, stride] : ranges) { offsets.push_back(offset); sizes.push_back(size); strides.push_back(stride); } return std::make_tuple(offsets, sizes, strides); } /// Helper function to dispatch an OpFoldResult into `staticVec` if: /// a) it is an IntegerAttr /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. /// In such dynamic cases, a copy of the `sentinel` value is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries that /// come from an AttrSizedOperandSegments trait. void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec) { auto v = llvm::dyn_cast_if_present(ofr); if (!v) { APInt apInt = cast(ofr.get()).getValue(); staticVec.push_back(apInt.getSExtValue()); return; } dynamicVec.push_back(v); staticVec.push_back(ShapedType::kDynamic); } void dispatchIndexOpFoldResults(ArrayRef ofrs, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec) { for (OpFoldResult ofr : ofrs) dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec); } /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. OpFoldResult getAsOpFoldResult(Value val) { if (!val) return OpFoldResult(); Attribute attr; if (matchPattern(val, m_Constant(&attr))) return attr; return val; } /// Given an array of values, try to extract a constant Attribute from each /// value. If this fails, return the original value. SmallVector getAsOpFoldResult(ValueRange values) { return llvm::to_vector( llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); })); } /// Convert `arrayAttr` to a vector of OpFoldResult. SmallVector getAsOpFoldResult(ArrayAttr arrayAttr) { SmallVector res; res.reserve(arrayAttr.size()); for (Attribute a : arrayAttr) res.push_back(a); return res; } OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) { return IntegerAttr::get(IndexType::get(ctx), val); } SmallVector getAsIndexOpFoldResult(MLIRContext *ctx, ArrayRef values) { return llvm::to_vector(llvm::map_range( values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); })); } /// If ofr is a constant integer or an IntegerAttr, return the integer. std::optional getConstantIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer. if (auto val = llvm::dyn_cast_if_present(ofr)) { APSInt intVal; if (matchPattern(val, m_ConstantInt(&intVal))) return intVal.getSExtValue(); return std::nullopt; } // Case 2: Check for IntegerAttr. Attribute attr = llvm::dyn_cast_if_present(ofr); if (auto intAttr = dyn_cast_or_null(attr)) return intAttr.getValue().getSExtValue(); return std::nullopt; } std::optional> getConstantIntValues(ArrayRef ofrs) { bool failed = false; SmallVector res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) { auto cv = getConstantIntValue(ofr); if (!cv.has_value()) failed = true; return cv.has_value() ? cv.value() : 0; }); if (failed) return std::nullopt; return res; } /// Return true if `ofr` is constant integer equal to `value`. bool isConstantIntValue(OpFoldResult ofr, int64_t value) { auto val = getConstantIntValue(ofr); return val && *val == value; } /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwidth and type mismatch that come from the fact there is /// no IndexAttr and that IndexType has no bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); if (cst1 && cst2 && *cst1 == *cst2) return true; auto v1 = llvm::dyn_cast_if_present(ofr1), v2 = llvm::dyn_cast_if_present(ofr2); return v1 && v1 == v2; } bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, ArrayRef ofrs2) { if (ofrs1.size() != ofrs2.size()) return false; for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2)) if (!isEqualConstantIntOrValue(ofr1, ofr2)) return false; return true; } /// Return a vector of OpFoldResults with the same size a staticValues, but all /// elements for which ShapedType::isDynamic is true, will be replaced by /// dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b) { SmallVector res; res.reserve(staticValues.size()); unsigned numDynamic = 0; unsigned count = static_cast(staticValues.size()); for (unsigned idx = 0; idx < count; ++idx) { int64_t value = staticValues[idx]; res.push_back(ShapedType::isDynamic(value) ? OpFoldResult{dynamicValues[numDynamic++]} : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); } return res; } /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. std::pair> decomposeMixedValues(Builder &b, const SmallVectorImpl &mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { if (it.is()) { staticValues.push_back(cast(it.get()).getInt()); } else { staticValues.push_back(ShapedType::kDynamic); dynamicValues.push_back(it.get()); } } return {b.getI64ArrayAttr(staticValues), dynamicValues}; } /// Helper to sort `values` according to matching `keys`. template static SmallVector getValuesSortedByKeyImpl(ArrayRef keys, ArrayRef values, llvm::function_ref compare) { if (keys.empty()) return SmallVector{values}; assert(keys.size() == values.size() && "unexpected mismatching sizes"); auto indices = llvm::to_vector(llvm::seq(0, values.size())); std::sort(indices.begin(), indices.end(), [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); SmallVector res; res.reserve(values.size()); for (int64_t i = 0, e = indices.size(); i < e; ++i) res.push_back(values[indices[i]]); return res; } SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare) { return getValuesSortedByKeyImpl(keys, values, compare); } SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare) { return getValuesSortedByKeyImpl(keys, values, compare); } SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare) { return getValuesSortedByKeyImpl(keys, values, compare); } /// Return the number of iterations for a loop with a lower bound `lb`, upper /// bound `ub` and step `step`. std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, OpFoldResult step) { if (lb == ub) return 0; std::optional lbConstant = getConstantIntValue(lb); if (!lbConstant) return std::nullopt; std::optional ubConstant = getConstantIntValue(ub); if (!ubConstant) return std::nullopt; std::optional stepConstant = getConstantIntValue(step); if (!stepConstant) return std::nullopt; return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); } bool hasValidSizesOffsets(SmallVector sizesOrOffsets) { return llvm::none_of(sizesOrOffsets, [](int64_t value) { return !ShapedType::isDynamic(value) && value < 0; }); } bool hasValidStrides(SmallVector strides) { return llvm::none_of(strides, [](int64_t value) { return !ShapedType::isDynamic(value) && value == 0; }); } LogicalResult foldDynamicIndexList(SmallVectorImpl &ofrs, bool onlyNonNegative, bool onlyNonZero) { bool valuesChanged = false; for (OpFoldResult &ofr : ofrs) { if (ofr.is()) continue; Attribute attr; if (matchPattern(ofr.get(), m_Constant(&attr))) { // Note: All ofrs have index type. if (onlyNonNegative && *getConstantIntValue(attr) < 0) continue; if (onlyNonZero && *getConstantIntValue(attr) == 0) continue; ofr = attr; valuesChanged = true; } } return success(valuesChanged); } LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl &offsetsOrSizes) { return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true, /*onlyNonZero=*/false); } LogicalResult foldDynamicStrideList(SmallVectorImpl &strides) { return foldDynamicIndexList(strides, /*onlyNonNegative=*/false, /*onlyNonZero=*/true); } } // namespace mlir