bolt/deps/llvm-18.1.8/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
2025-02-14 19:21:04 +01:00

129 lines
4.8 KiB
C++

//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace memref {
namespace {
template <typename OpTy>
struct AllocOpInterface
: public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
OpTy> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto allocOp = cast<OpTy>(op);
assert(value == allocOp.getResult() && "invalid value");
cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
}
};
struct CastOpInterface
: public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto castOp = cast<CastOp>(op);
assert(value == castOp.getResult() && "invalid value");
if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
llvm::isa<MemRefType>(castOp.getSource().getType())) {
cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
}
}
};
struct DimOpInterface
: public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto dimOp = cast<DimOp>(op);
assert(value == dimOp.getResult() && "invalid value");
auto constIndex = dimOp.getConstantIndex();
if (!constIndex.has_value())
return;
cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
}
};
struct GetGlobalOpInterface
: public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
GetGlobalOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto getGlobalOp = cast<GetGlobalOp>(op);
assert(value == getGlobalOp.getResult() && "invalid value");
auto type = getGlobalOp.getType();
assert(!type.isDynamicDim(dim) && "expected static dim");
cstr.bound(value)[dim] == type.getDimSize(dim);
}
};
struct RankOpInterface
: public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto rankOp = cast<RankOp>(op);
assert(value == rankOp.getResult() && "invalid value");
auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
if (!memrefType)
return;
cstr.bound(value) == memrefType.getRank();
}
};
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto subViewOp = cast<SubViewOp>(op);
assert(value == subViewOp.getResult() && "invalid value");
llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
int64_t ctr = -1;
for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
// Skip over rank-reduced dimensions.
if (!dropped.test(i))
++ctr;
if (ctr == dim) {
cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
return;
}
}
llvm_unreachable("could not find non-rank-reduced dim");
}
};
} // namespace
} // namespace memref
} // namespace mlir
void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
*ctx);
memref::AllocaOp::attachInterface<
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
});
}