bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

713 lines
28 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
// First, check if the payload operation is a structured Linalg operation.
if (!isa<linalg::LinalgOp>(current)) {
if (getFailurePropagationMode().value_or(
FailurePropagationMode::Propagate) ==
FailurePropagationMode::Propagate) {
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
// Bind `current` to the block argument.
auto scope = state.make_region_scope(getBodyRegion());
if (failed(state.mapBlockArgument(getBody()->getArgument(0),
MappedValue(current)))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
for (Operation &nested : getBody()->without_terminator()) {
DiagnosedSilenceableFailure diag =
state.applyTransform(cast<TransformOpInterface>(nested));
if (diag.isDefiniteFailure())
return diag;
if (diag.succeeded())
continue;
// If propagating errors, do this immediately.
assert(diag.isSilenceableFailure());
if (getFailurePropagationMode().value_or(
FailurePropagationMode::Propagate) ==
FailurePropagationMode::Propagate) {
return diag;
}
// If suppressing errors, print the message into the debug stream before
// silencing it. Then set all results value that are already known.
// Results come from the terminator operands, which may be defined in the
// (single) block of this operation or above it. When they are defined
// above, they are known to be mapped at this point per SSA dominance.
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
<< "\n");
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
getBody()->getTerminator()->getOpOperands()) {
Operation *definingOp = terminatorOperand.get().getDefiningOp();
if (!definingOp)
continue;
if (definingOp->getBlock() != getBody())
continue;
if (definingOp->isBeforeInBlock(&nested))
continue;
undefinedOperands.push_back(&terminatorOperand);
}
SmallVector<SmallVector<transform::MappedValue>> mappings;
auto filtered = llvm::make_filter_range(
getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
return !llvm::is_contained(undefinedOperands, &opOperand);
});
SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
detail::prepareValueMappings(mappings, definedOperands, state);
for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
results.setMappedValues(getResults()[operand.getOperandNumber()],
mapping);
}
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
// Set the results.
detail::forwardTerminatorOperands(getBody(), state, results);
return DiagnosedSilenceableFailure::success();
}
void transform::MatchStructuredOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getCurrent(), effects);
onlyReadsPayload(effects);
producesHandle(getOutputs(), effects);
}
LogicalResult transform::MatchStructuredOp::verify() {
if (getBody()->getNumArguments() != 1)
return emitOpError() << "expected one body argument";
if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
return emitOpError() << "expected body argument to implement "
"TransformHandleTypeInterface";
}
for (Operation &nested : getBody()->without_terminator()) {
if (isa<MatchOpInterface>(nested))
continue;
InFlightDiagnostic diag =
emitOpError()
<< "expects nested operations to implement MatchOpInterface";
diag.attachNote(nested.getLoc()) << "offending operation";
return diag;
}
return success();
}
//===----------------------------------------------------------------------===//
// StructuredOpPredicateOpTrait
//===----------------------------------------------------------------------===//
LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
Operation *op, Value structuredOpHandle) {
if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
return op->emitOpError() << "expects parent op to be '"
<< MatchStructuredOp::getOperationName() << "'";
}
// Bail out here, let the verifier of the parent complain.
Operation *parent = op->getParentOp();
if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
parent->getRegion(0).front().getNumArguments() < 1)
return success();
if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
return op->emitOpError()
<< "expected predicate to apply to the surrounding structured op";
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredBodyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
if (std::optional<uint64_t> position = getReductionPosition()) {
SmallVector<Operation *> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
combinerOps)) {
return emitSilenceableError() << "could not match reduction";
}
if (combinerOps.size() != 1) {
return emitSilenceableError() << "reduction combiner is not a single op";
}
return DiagnosedSilenceableFailure::success();
}
if (getPassthrough()) {
Block &body = linalgOp->getRegion(0).front();
if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
return emitSilenceableError() << "not a passthrough";
}
return DiagnosedSilenceableFailure::success();
}
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
Block &body = linalgOp->getRegion(0).front();
std::string message;
llvm::raw_string_ostream os(message);
bool result = linalg::detail::isContractionBody(
body,
[&](Operation *elem, Operation *red) {
return elem->getName().getStringRef() ==
(*contractionOps)[0].cast<StringAttr>().getValue() &&
red->getName().getStringRef() ==
(*contractionOps)[1].cast<StringAttr>().getValue();
},
os);
if (result)
return DiagnosedSilenceableFailure::success();
return emitSilenceableError() << "contraction: " << os.str();
}
return emitDefiniteFailure() << "unknown body condition";
}
LogicalResult transform::MatchStructuredBodyOp::verify() {
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
getContraction().has_value();
if (numOptions > 1) {
std::string attributeNames;
llvm::raw_string_ostream os(attributeNames);
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
getPassthroughAttrName(),
getContractionAttrName()},
os);
return emitOpError() << "only one of {" << os.str() << "} is allowed";
}
if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
if (contractionAttr->size() != 2) {
return emitOpError() << "expects " << getContractionAttrName()
<< " to contain two elements";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredClassifyContractionDimsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
FailureOr<linalg::ContractionDimensions> contractionDims =
linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
if (failed(contractionDims))
return emitSilenceableError() << "could not infer contraction dimensions";
MLIRContext *context = current->getContext();
Builder builder(context);
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
return llvm::to_vector(
llvm::map_range(values, [&](unsigned value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getBatch().cast<OpResult>(),
makeI64Attrs(contractionDims->batch));
results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredClassifyConvolutionDimsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
FailureOr<linalg::ConvolutionDimensions> convolutionDims =
linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
if (failed(convolutionDims))
return emitSilenceableError() << "could not infer convolution dimensions";
MLIRContext *context = current->getContext();
Builder builder(context);
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
return llvm::to_vector(
llvm::map_range(values, [&](unsigned value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getBatch().cast<OpResult>(),
makeI64Attrs(convolutionDims->batch));
results.setParams(getOutputImage().cast<OpResult>(),
makeI64Attrs(convolutionDims->outputImage));
results.setParams(getOutputChannel().cast<OpResult>(),
makeI64Attrs(convolutionDims->outputChannel));
results.setParams(getFilterLoop().cast<OpResult>(),
makeI64Attrs(convolutionDims->filterLoop));
results.setParams(getInputChannel().cast<OpResult>(),
makeI64Attrs(convolutionDims->inputChannel));
results.setParams(getDepth().cast<OpResult>(),
makeI64Attrs(convolutionDims->depth));
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::to_vector(
llvm::map_range(values, [&](int64_t value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getStrides().cast<OpResult>(),
makeI64AttrsFromI64(convolutionDims->strides));
results.setParams(getDilations().cast<OpResult>(),
makeI64AttrsFromI64(convolutionDims->dilations));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Utilities for structured match predicates.
//===----------------------------------------------------------------------===//
/// Checks if all values from `list` are also contained in `reference`. Returns
/// a silenceable error with the given message at the given location when it is
/// not the case. The error message must contain the "{0}" placeholder that
/// will be substituted with the value from `list` that is not contained in
/// `reference`.
static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
ArrayRef<int64_t> list,
Location loc,
const char *message) {
for (int64_t value : list) {
if (llvm::any_of(reference, [&](unsigned ref) {
return static_cast<int64_t>(ref) == value;
})) {
continue;
}
return emitSilenceableFailure(loc) << llvm::formatv(message, value);
}
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredDimOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> dimensions;
DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
if (!diag.succeeded())
return diag;
// If asked to check for the kind of dimension, perform the check.
if (getParallel() || getReduction()) {
SmallVector<unsigned> reference;
if (getParallel())
linalgOp.getParallelDims(reference);
else if (getReduction())
linalgOp.getReductionDims(reference);
DiagnosedSilenceableFailure diag =
containsAll(reference, dimensions, getLoc(),
getParallel() ? "expects dimension #{0} to be parallel"
: "expects dimension #{0} to be reduction");
if (!diag.succeeded())
return diag;
}
// If not capturing, we are done here.
if (!getResult())
return diag;
SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
Builder builder(current);
SmallVector<Attribute> captured = llvm::to_vector(
llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
return builder.getI64IntegerAttr(ranges[dim]);
}));
results.setParams(cast<OpResult>(getResult()), captured);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
DiagnosedSilenceableFailure diag =
expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
getRawDimList(), op.getNumLoops(), dims);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering dimensions of this payload operation";
}
return diag;
}
LogicalResult transform::MatchStructuredDimOp::verify() {
if (getParallel() && getReduction()) {
return emitOpError() << "cannot request the same dimension to be both "
"parallel and reduction";
}
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredElementalBitwidthOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredElementalBitwidthOp::matchValue(
Value current, transform::TransformResults &results,
transform::TransformState &state) {
auto setupResult = [&](int64_t bitwidth) {
Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
};
Type type = current.getType();
if (type.isIntOrFloat())
return setupResult(type.getIntOrFloatBitWidth());
if (auto shapedType = dyn_cast<ShapedType>(type)) {
if (shapedType.getElementType().isIntOrFloat())
return setupResult(shapedType.getElementTypeBitWidth());
}
return emitSilenceableError()
<< "unsupported type for bitwidth extraction: " << type;
}
//===----------------------------------------------------------------------===//
// MatchStructuredInputOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> positions;
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
if (!diag.succeeded())
return diag;
SmallVector<MappedValue> operandMapping;
operandMapping.reserve(positions.size());
for (int64_t position : positions) {
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
if (getPermutation() && !indexingMap.isPermutation()) {
return emitSilenceableError() << "the indexing map for input #"
<< position << " is not a permutation";
}
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
return emitSilenceableError()
<< "the indexing map for input #" << position
<< " is not a projected permutation";
}
// If capture not requested, skip it.
if (!getResult())
continue;
if (isa<AffineMapParamType>(getResult().getType())) {
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
continue;
}
Value operand = linalgOp.getDpsInputOperand(position)->get();
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
operandMapping.emplace_back(operand);
continue;
}
Operation *operandProducer = operand.getDefiningOp();
if (!operandProducer) {
return emitSilenceableError()
<< "input #" << position << " is not produced by an operation";
}
operandMapping.emplace_back(operandProducer);
}
if (getResult())
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
op.getNumDpsInputs(), positions);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering DPS inputs of this payload operation";
}
return diag;
}
/// Verifies a matcher op for structured input or output, specifically the
/// attributes specifying the operand positions.
template <typename OpTy>
LogicalResult verifyStructuredOperandOp(OpTy op) {
if (op.getPermutation() && op.getProjectedPermutation()) {
return op.emitOpError()
<< op.getPermutationAttrName() << " and "
<< op.getProjectedPermutationAttrName() << " are mutually exclusive";
}
if (op.getRawPositionList().size() > 1 && op.getResult()) {
return op.emitOpError()
<< "cannot bind multiple inputs/inits to the same value";
}
return success();
}
LogicalResult transform::MatchStructuredInputOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredInitOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> positions;
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
if (!diag.succeeded())
return diag;
SmallVector<MappedValue> operandMapping;
operandMapping.reserve(positions.size());
for (int64_t position : positions) {
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
if (getPermutation() && !indexingMap.isPermutation()) {
return emitSilenceableError() << "the indexing map for output(init) #"
<< position << " is not a permutation";
}
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
return emitSilenceableError() << "the indexing map for output(init) #"
<< position << " is not a permutation";
}
// If capture not requested, skip it.
if (!getResult())
continue;
if (isa<AffineMapParamType>(getResult().getType())) {
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
continue;
}
Value operand = linalgOp.getDpsInitOperand(position)->get();
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
operandMapping.emplace_back(operand);
continue;
}
Operation *operandProducer = operand.getDefiningOp();
if (!operandProducer) {
return emitSilenceableError() << "output(init) #" << position
<< " is not produced by an operation";
}
operandMapping.emplace_back(operandProducer);
}
if (getResult())
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
op.getNumDpsInits(), positions);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering DPS inits (outputs) of this payload operation";
}
return diag;
}
LogicalResult transform::MatchStructuredInitOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredNumInputsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredNumInputsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
Attribute attr =
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredNumInitsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredNumInitsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
Attribute attr =
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredRankOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
int64_t numLoops = linalgOp.getNumLoops();
Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
results.setParams(cast<OpResult>(getRank()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredResultOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
Operation *op, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(op);
int64_t position;
DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
if (!diag.succeeded())
return diag;
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), {result});
return DiagnosedSilenceableFailure::success();
}
if (result.getUsers().empty()) {
return emitSilenceableError()
<< "no users of the result #" << getPosition();
}
Operation *firstUser = *result.getUsers().begin();
if (getAny()) {
results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
if (getSingle()) {
if (!llvm::hasSingleElement(result.getUsers())) {
return emitSilenceableError()
<< "more than one result user with single user requested";
}
results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
return emitDefiniteFailure() << "unknown sub-predicate";
}
DiagnosedSilenceableFailure
transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
int64_t &position) {
auto rawPosition = static_cast<int64_t>(getPosition());
position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
if (position >= op.getNumDpsInits() || position < 0) {
return emitSilenceableError()
<< "position " << rawPosition
<< " overflows the number of results(ints) of the payload operation";
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::MatchStructuredResultOp::verify() {
if ((getAny() || getSingle()) ^
isa<TransformHandleTypeInterface>(getResult().getType())) {
return emitOpError() << "expects either the any/single keyword or the type "
"value handle result type";
}
if (getAny() && getSingle()) {
return emitOpError() << "'any' and 'single' are mutually exclusive";
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredYieldOp
//===----------------------------------------------------------------------===//
void transform::MatchStructuredYieldOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandles(), effects);
onlyReadsPayload(effects);
}
void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
OperationState &state) {
build(builder, state, ValueRange());
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"