//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; namespace mlir { namespace tensor { namespace { struct CastOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); if (llvm::isa(castOp.getResult().getType()) && llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } }; struct DimOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto dimOp = cast(op); assert(value == dimOp.getResult() && "invalid value"); auto constIndex = dimOp.getConstantIndex(); if (!constIndex.has_value()) return; cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); } }; struct EmptyOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto emptyOp = cast(op); assert(value == emptyOp.getResult() && "invalid value"); cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim]; } }; struct ExtractSliceOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto extractSliceOp = cast(op); assert(value == extractSliceOp.getResult() && "invalid value"); llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims(); int64_t ctr = -1; for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) { // Skip over rank-reduced dimensions. if (!dropped.test(i)) ++ctr; if (ctr == dim) { cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i]; return; } } llvm_unreachable("could not find non-rank-reduced dim"); } }; struct PadOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto padOp = cast(op); assert(value == padOp.getResult() && "invalid value"); AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim); AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]); AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]); cstr.bound(value)[dim] == srcSize + lowPad + highPad; } }; struct RankOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); auto tensorType = llvm::dyn_cast(rankOp.getTensor().getType()); if (!tensorType) return; cstr.bound(value) == tensorType.getRank(); } }; } // namespace } // namespace tensor } // namespace mlir void mlir::tensor::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { tensor::CastOp::attachInterface(*ctx); tensor::DimOp::attachInterface(*ctx); tensor::EmptyOp::attachInterface(*ctx); tensor::ExtractSliceOp::attachInterface( *ctx); tensor::PadOp::attachInterface(*ctx); tensor::RankOp::attachInterface(*ctx); // Note: ValueBoundsOpInterface implementation is not required for ops that // implement `DestinationStyleOpInterface` (for querying shaped OpResults). }); }