//===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; using namespace mlir::arith; /// Build Arith IR for the given affine map and its operands. static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, ValueRange operands) { assert(map.getNumResults() == 1 && "multiple results not supported yet"); std::function buildExpr = [&](AffineExpr e) -> Value { switch (e.getKind()) { case AffineExprKind::Constant: return b.create(loc, cast(e).getValue()); case AffineExprKind::DimId: return operands[cast(e).getPosition()]; case AffineExprKind::SymbolId: return operands[cast(e).getPosition() + map.getNumDims()]; case AffineExprKind::Add: { auto binaryExpr = cast(e); return b.create(loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mul: { auto binaryExpr = cast(e); return b.create(loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::FloorDiv: { auto binaryExpr = cast(e); return b.create(loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::CeilDiv: { auto binaryExpr = cast(e); return b.create(loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mod: { auto binaryExpr = cast(e); return b.create(loc, buildExpr(binaryExpr.getLHS()), buildExpr(binaryExpr.getRHS())); } } llvm_unreachable("unsupported AffineExpr kind"); }; return buildExpr(map.getResult(0)); } static FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) return failure(); // Materialize tensor.dim/memref.dim ops. SmallVector operands; for (auto valueDim : mapOperands) { Value value = valueDim.first; std::optional dim = valueDim.second; if (!dim.has_value()) { // This is an index-typed value. assert(value.getType().isIndex() && "expected index type"); operands.push_back(value); continue; } assert(cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. operands.push_back(b.create(loc, value, *dim)); } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. operands.push_back(b.create(loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } } // Check for special cases where no arith ops are needed. if (boundMap.isSingleConstant()) { // Bound is a constant: return an IntegerAttr. return static_cast( b.getIndexAttr(boundMap.getSingleConstantResult())); } // No arith ops are needed if the bound is a single SSA value. if (auto expr = dyn_cast(boundMap.getResult(0))) return static_cast(operands[expr.getPosition()]); if (auto expr = dyn_cast(boundMap.getResult(0))) return static_cast( operands[expr.getPosition() + boundMap.getNumDims()]); // General case: build Arith ops. return static_cast(buildArithValue(b, loc, boundMap, operands)); } FailureOr mlir::arith::reifyShapedValueDimBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { auto reifyToOperands = [&](Value v, std::optional d) { // We are trying to reify a bound for `value` in terms of the owning op's // operands. Construct a stop condition that evaluates to "true" for any SSA // value expect for `value`. I.e., the bound will be computed in terms of // any SSA values expect for `value`. The first such values are operands of // the owner of `value`. return v != value; }; return reifyValueBound(b, loc, type, value, dim, stopCondition ? stopCondition : reifyToOperands, closedUB); } FailureOr mlir::arith::reifyIndexValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { auto reifyToOperands = [&](Value v, std::optional d) { return v != value; }; return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, stopCondition ? stopCondition : reifyToOperands, closedUB); }