//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; namespace mlir { namespace memref { namespace { template struct AllocOpInterface : public ValueBoundsOpInterface::ExternalModel, OpTy> { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto allocOp = cast(op); assert(value == allocOp.getResult() && "invalid value"); cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; } }; struct CastOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto castOp = cast(op); assert(value == castOp.getResult() && "invalid value"); if (llvm::isa(castOp.getResult().getType()) && llvm::isa(castOp.getSource().getType())) { cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); } } }; struct DimOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto dimOp = cast(op); assert(value == dimOp.getResult() && "invalid value"); auto constIndex = dimOp.getConstantIndex(); if (!constIndex.has_value()) return; cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); } }; struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto getGlobalOp = cast(op); assert(value == getGlobalOp.getResult() && "invalid value"); auto type = getGlobalOp.getType(); assert(!type.isDynamicDim(dim) && "expected static dim"); cstr.bound(value)[dim] == type.getDimSize(dim); } }; struct RankOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto rankOp = cast(op); assert(value == rankOp.getResult() && "invalid value"); auto memrefType = llvm::dyn_cast(rankOp.getMemref().getType()); if (!memrefType) return; cstr.bound(value) == memrefType.getRank(); } }; struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel { void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { auto subViewOp = cast(op); assert(value == subViewOp.getResult() && "invalid value"); llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); int64_t ctr = -1; for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { // Skip over rank-reduced dimensions. if (!dropped.test(i)) ++ctr; if (ctr == dim) { cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; return; } } llvm_unreachable("could not find non-rank-reduced dim"); } }; } // namespace } // namespace memref } // namespace mlir void mlir::memref::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { memref::AllocOp::attachInterface>( *ctx); memref::AllocaOp::attachInterface< memref::AllocOpInterface>(*ctx); memref::CastOp::attachInterface(*ctx); memref::DimOp::attachInterface(*ctx); memref::GetGlobalOp::attachInterface(*ctx); memref::RankOp::attachInterface(*ctx); memref::SubViewOp::attachInterface(*ctx); }); }