713 lines
28 KiB
C++
713 lines
28 KiB
C++
|
//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
|
||
|
#include "mlir/Analysis/SliceAnalysis.h"
|
||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
||
|
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
|
||
|
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
|
||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||
|
#include "mlir/Interfaces/FunctionImplementation.h"
|
||
|
#include "llvm/Support/Debug.h"
|
||
|
#include "llvm/Support/FormatVariadic.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
#define DEBUG_TYPE "linalg-transforms"
|
||
|
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// StructuredMatchOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
// First, check if the payload operation is a structured Linalg operation.
|
||
|
if (!isa<linalg::LinalgOp>(current)) {
|
||
|
if (getFailurePropagationMode().value_or(
|
||
|
FailurePropagationMode::Propagate) ==
|
||
|
FailurePropagationMode::Propagate) {
|
||
|
return emitSilenceableError() << "expected a Linalg op";
|
||
|
}
|
||
|
// If errors are suppressed, succeed and set all results to empty lists.
|
||
|
LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
|
||
|
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
// Bind `current` to the block argument.
|
||
|
auto scope = state.make_region_scope(getBodyRegion());
|
||
|
if (failed(state.mapBlockArgument(getBody()->getArgument(0),
|
||
|
MappedValue(current)))) {
|
||
|
return DiagnosedSilenceableFailure::definiteFailure();
|
||
|
}
|
||
|
|
||
|
for (Operation &nested : getBody()->without_terminator()) {
|
||
|
DiagnosedSilenceableFailure diag =
|
||
|
state.applyTransform(cast<TransformOpInterface>(nested));
|
||
|
if (diag.isDefiniteFailure())
|
||
|
return diag;
|
||
|
if (diag.succeeded())
|
||
|
continue;
|
||
|
|
||
|
// If propagating errors, do this immediately.
|
||
|
assert(diag.isSilenceableFailure());
|
||
|
if (getFailurePropagationMode().value_or(
|
||
|
FailurePropagationMode::Propagate) ==
|
||
|
FailurePropagationMode::Propagate) {
|
||
|
return diag;
|
||
|
}
|
||
|
|
||
|
// If suppressing errors, print the message into the debug stream before
|
||
|
// silencing it. Then set all results value that are already known.
|
||
|
// Results come from the terminator operands, which may be defined in the
|
||
|
// (single) block of this operation or above it. When they are defined
|
||
|
// above, they are known to be mapped at this point per SSA dominance.
|
||
|
// When they are defined in this block, we additionally check if we have
|
||
|
// already applied the operation that defines them. If not, the
|
||
|
// corresponding results will be set to empty lists.
|
||
|
LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
|
||
|
<< "\n");
|
||
|
(void)diag.silence();
|
||
|
SmallVector<OpOperand *> undefinedOperands;
|
||
|
for (OpOperand &terminatorOperand :
|
||
|
getBody()->getTerminator()->getOpOperands()) {
|
||
|
Operation *definingOp = terminatorOperand.get().getDefiningOp();
|
||
|
if (!definingOp)
|
||
|
continue;
|
||
|
if (definingOp->getBlock() != getBody())
|
||
|
continue;
|
||
|
if (definingOp->isBeforeInBlock(&nested))
|
||
|
continue;
|
||
|
|
||
|
undefinedOperands.push_back(&terminatorOperand);
|
||
|
}
|
||
|
|
||
|
SmallVector<SmallVector<transform::MappedValue>> mappings;
|
||
|
auto filtered = llvm::make_filter_range(
|
||
|
getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
|
||
|
return !llvm::is_contained(undefinedOperands, &opOperand);
|
||
|
});
|
||
|
SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
|
||
|
filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
|
||
|
detail::prepareValueMappings(mappings, definedOperands, state);
|
||
|
for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
|
||
|
results.setMappedValues(getResults()[operand.getOperandNumber()],
|
||
|
mapping);
|
||
|
}
|
||
|
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
// Set the results.
|
||
|
detail::forwardTerminatorOperands(getBody(), state, results);
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
void transform::MatchStructuredOp::getEffects(
|
||
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||
|
onlyReadsHandle(getCurrent(), effects);
|
||
|
onlyReadsPayload(effects);
|
||
|
producesHandle(getOutputs(), effects);
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredOp::verify() {
|
||
|
if (getBody()->getNumArguments() != 1)
|
||
|
return emitOpError() << "expected one body argument";
|
||
|
if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
|
||
|
return emitOpError() << "expected body argument to implement "
|
||
|
"TransformHandleTypeInterface";
|
||
|
}
|
||
|
for (Operation &nested : getBody()->without_terminator()) {
|
||
|
if (isa<MatchOpInterface>(nested))
|
||
|
continue;
|
||
|
InFlightDiagnostic diag =
|
||
|
emitOpError()
|
||
|
<< "expects nested operations to implement MatchOpInterface";
|
||
|
diag.attachNote(nested.getLoc()) << "offending operation";
|
||
|
return diag;
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// StructuredOpPredicateOpTrait
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
|
||
|
Operation *op, Value structuredOpHandle) {
|
||
|
if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
|
||
|
return op->emitOpError() << "expects parent op to be '"
|
||
|
<< MatchStructuredOp::getOperationName() << "'";
|
||
|
}
|
||
|
|
||
|
// Bail out here, let the verifier of the parent complain.
|
||
|
Operation *parent = op->getParentOp();
|
||
|
if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
|
||
|
parent->getRegion(0).front().getNumArguments() < 1)
|
||
|
return success();
|
||
|
|
||
|
if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
|
||
|
return op->emitOpError()
|
||
|
<< "expected predicate to apply to the surrounding structured op";
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredBodyOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
if (std::optional<uint64_t> position = getReductionPosition()) {
|
||
|
SmallVector<Operation *> combinerOps;
|
||
|
if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
|
||
|
combinerOps)) {
|
||
|
return emitSilenceableError() << "could not match reduction";
|
||
|
}
|
||
|
if (combinerOps.size() != 1) {
|
||
|
return emitSilenceableError() << "reduction combiner is not a single op";
|
||
|
}
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
if (getPassthrough()) {
|
||
|
Block &body = linalgOp->getRegion(0).front();
|
||
|
if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
|
||
|
return emitSilenceableError() << "not a passthrough";
|
||
|
}
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
|
||
|
Block &body = linalgOp->getRegion(0).front();
|
||
|
std::string message;
|
||
|
llvm::raw_string_ostream os(message);
|
||
|
bool result = linalg::detail::isContractionBody(
|
||
|
body,
|
||
|
[&](Operation *elem, Operation *red) {
|
||
|
return elem->getName().getStringRef() ==
|
||
|
(*contractionOps)[0].cast<StringAttr>().getValue() &&
|
||
|
red->getName().getStringRef() ==
|
||
|
(*contractionOps)[1].cast<StringAttr>().getValue();
|
||
|
},
|
||
|
os);
|
||
|
if (result)
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
return emitSilenceableError() << "contraction: " << os.str();
|
||
|
}
|
||
|
return emitDefiniteFailure() << "unknown body condition";
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredBodyOp::verify() {
|
||
|
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
|
||
|
getContraction().has_value();
|
||
|
|
||
|
if (numOptions > 1) {
|
||
|
std::string attributeNames;
|
||
|
llvm::raw_string_ostream os(attributeNames);
|
||
|
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
|
||
|
getPassthroughAttrName(),
|
||
|
getContractionAttrName()},
|
||
|
os);
|
||
|
return emitOpError() << "only one of {" << os.str() << "} is allowed";
|
||
|
}
|
||
|
|
||
|
if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
|
||
|
if (contractionAttr->size() != 2) {
|
||
|
return emitOpError() << "expects " << getContractionAttrName()
|
||
|
<< " to contain two elements";
|
||
|
}
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredClassifyContractionDimsOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
FailureOr<linalg::ContractionDimensions> contractionDims =
|
||
|
linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
|
||
|
if (failed(contractionDims))
|
||
|
return emitSilenceableError() << "could not infer contraction dimensions";
|
||
|
|
||
|
MLIRContext *context = current->getContext();
|
||
|
Builder builder(context);
|
||
|
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
|
||
|
return llvm::to_vector(
|
||
|
llvm::map_range(values, [&](unsigned value) -> Attribute {
|
||
|
return builder.getI64IntegerAttr(value);
|
||
|
}));
|
||
|
};
|
||
|
results.setParams(getBatch().cast<OpResult>(),
|
||
|
makeI64Attrs(contractionDims->batch));
|
||
|
results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
|
||
|
results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
|
||
|
results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredClassifyConvolutionDimsOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
FailureOr<linalg::ConvolutionDimensions> convolutionDims =
|
||
|
linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
|
||
|
if (failed(convolutionDims))
|
||
|
return emitSilenceableError() << "could not infer convolution dimensions";
|
||
|
|
||
|
MLIRContext *context = current->getContext();
|
||
|
Builder builder(context);
|
||
|
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
|
||
|
return llvm::to_vector(
|
||
|
llvm::map_range(values, [&](unsigned value) -> Attribute {
|
||
|
return builder.getI64IntegerAttr(value);
|
||
|
}));
|
||
|
};
|
||
|
results.setParams(getBatch().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->batch));
|
||
|
results.setParams(getOutputImage().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->outputImage));
|
||
|
results.setParams(getOutputChannel().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->outputChannel));
|
||
|
results.setParams(getFilterLoop().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->filterLoop));
|
||
|
results.setParams(getInputChannel().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->inputChannel));
|
||
|
results.setParams(getDepth().cast<OpResult>(),
|
||
|
makeI64Attrs(convolutionDims->depth));
|
||
|
|
||
|
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
|
||
|
return llvm::to_vector(
|
||
|
llvm::map_range(values, [&](int64_t value) -> Attribute {
|
||
|
return builder.getI64IntegerAttr(value);
|
||
|
}));
|
||
|
};
|
||
|
results.setParams(getStrides().cast<OpResult>(),
|
||
|
makeI64AttrsFromI64(convolutionDims->strides));
|
||
|
results.setParams(getDilations().cast<OpResult>(),
|
||
|
makeI64AttrsFromI64(convolutionDims->dilations));
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Utilities for structured match predicates.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Checks if all values from `list` are also contained in `reference`. Returns
|
||
|
/// a silenceable error with the given message at the given location when it is
|
||
|
/// not the case. The error message must contain the "{0}" placeholder that
|
||
|
/// will be substituted with the value from `list` that is not contained in
|
||
|
/// `reference`.
|
||
|
static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
|
||
|
ArrayRef<int64_t> list,
|
||
|
Location loc,
|
||
|
const char *message) {
|
||
|
for (int64_t value : list) {
|
||
|
if (llvm::any_of(reference, [&](unsigned ref) {
|
||
|
return static_cast<int64_t>(ref) == value;
|
||
|
})) {
|
||
|
continue;
|
||
|
}
|
||
|
return emitSilenceableFailure(loc) << llvm::formatv(message, value);
|
||
|
}
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredDimOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
SmallVector<int64_t> dimensions;
|
||
|
DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
|
||
|
if (!diag.succeeded())
|
||
|
return diag;
|
||
|
|
||
|
// If asked to check for the kind of dimension, perform the check.
|
||
|
if (getParallel() || getReduction()) {
|
||
|
SmallVector<unsigned> reference;
|
||
|
if (getParallel())
|
||
|
linalgOp.getParallelDims(reference);
|
||
|
else if (getReduction())
|
||
|
linalgOp.getReductionDims(reference);
|
||
|
|
||
|
DiagnosedSilenceableFailure diag =
|
||
|
containsAll(reference, dimensions, getLoc(),
|
||
|
getParallel() ? "expects dimension #{0} to be parallel"
|
||
|
: "expects dimension #{0} to be reduction");
|
||
|
if (!diag.succeeded())
|
||
|
return diag;
|
||
|
}
|
||
|
|
||
|
// If not capturing, we are done here.
|
||
|
if (!getResult())
|
||
|
return diag;
|
||
|
|
||
|
SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
|
||
|
Builder builder(current);
|
||
|
SmallVector<Attribute> captured = llvm::to_vector(
|
||
|
llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
|
||
|
return builder.getI64IntegerAttr(ranges[dim]);
|
||
|
}));
|
||
|
results.setParams(cast<OpResult>(getResult()), captured);
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
|
||
|
linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
|
||
|
DiagnosedSilenceableFailure diag =
|
||
|
expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
|
||
|
getRawDimList(), op.getNumLoops(), dims);
|
||
|
if (diag.isSilenceableFailure()) {
|
||
|
diag.attachNote(op->getLoc())
|
||
|
<< "while considering dimensions of this payload operation";
|
||
|
}
|
||
|
return diag;
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredDimOp::verify() {
|
||
|
if (getParallel() && getReduction()) {
|
||
|
return emitOpError() << "cannot request the same dimension to be both "
|
||
|
"parallel and reduction";
|
||
|
}
|
||
|
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
|
||
|
getIsInverted(), getIsAll());
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredElementalBitwidthOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredElementalBitwidthOp::matchValue(
|
||
|
Value current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto setupResult = [&](int64_t bitwidth) {
|
||
|
Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
|
||
|
results.setParams(cast<OpResult>(getResult()), {attr});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
};
|
||
|
|
||
|
Type type = current.getType();
|
||
|
if (type.isIntOrFloat())
|
||
|
return setupResult(type.getIntOrFloatBitWidth());
|
||
|
|
||
|
if (auto shapedType = dyn_cast<ShapedType>(type)) {
|
||
|
if (shapedType.getElementType().isIntOrFloat())
|
||
|
return setupResult(shapedType.getElementTypeBitWidth());
|
||
|
}
|
||
|
return emitSilenceableError()
|
||
|
<< "unsupported type for bitwidth extraction: " << type;
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredInputOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
SmallVector<int64_t> positions;
|
||
|
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
|
||
|
if (!diag.succeeded())
|
||
|
return diag;
|
||
|
|
||
|
SmallVector<MappedValue> operandMapping;
|
||
|
operandMapping.reserve(positions.size());
|
||
|
for (int64_t position : positions) {
|
||
|
AffineMap indexingMap =
|
||
|
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
|
||
|
if (getPermutation() && !indexingMap.isPermutation()) {
|
||
|
return emitSilenceableError() << "the indexing map for input #"
|
||
|
<< position << " is not a permutation";
|
||
|
}
|
||
|
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
|
||
|
return emitSilenceableError()
|
||
|
<< "the indexing map for input #" << position
|
||
|
<< " is not a projected permutation";
|
||
|
}
|
||
|
|
||
|
// If capture not requested, skip it.
|
||
|
if (!getResult())
|
||
|
continue;
|
||
|
|
||
|
if (isa<AffineMapParamType>(getResult().getType())) {
|
||
|
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
Value operand = linalgOp.getDpsInputOperand(position)->get();
|
||
|
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
|
||
|
operandMapping.emplace_back(operand);
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
Operation *operandProducer = operand.getDefiningOp();
|
||
|
if (!operandProducer) {
|
||
|
return emitSilenceableError()
|
||
|
<< "input #" << position << " is not produced by an operation";
|
||
|
}
|
||
|
operandMapping.emplace_back(operandProducer);
|
||
|
}
|
||
|
if (getResult())
|
||
|
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
|
||
|
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
|
||
|
DiagnosedSilenceableFailure diag = expandTargetSpecification(
|
||
|
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
|
||
|
op.getNumDpsInputs(), positions);
|
||
|
if (diag.isSilenceableFailure()) {
|
||
|
diag.attachNote(op->getLoc())
|
||
|
<< "while considering DPS inputs of this payload operation";
|
||
|
}
|
||
|
return diag;
|
||
|
}
|
||
|
|
||
|
/// Verifies a matcher op for structured input or output, specifically the
|
||
|
/// attributes specifying the operand positions.
|
||
|
template <typename OpTy>
|
||
|
LogicalResult verifyStructuredOperandOp(OpTy op) {
|
||
|
if (op.getPermutation() && op.getProjectedPermutation()) {
|
||
|
return op.emitOpError()
|
||
|
<< op.getPermutationAttrName() << " and "
|
||
|
<< op.getProjectedPermutationAttrName() << " are mutually exclusive";
|
||
|
}
|
||
|
if (op.getRawPositionList().size() > 1 && op.getResult()) {
|
||
|
return op.emitOpError()
|
||
|
<< "cannot bind multiple inputs/inits to the same value";
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredInputOp::verify() {
|
||
|
if (failed(verifyStructuredOperandOp(*this)))
|
||
|
return failure();
|
||
|
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||
|
getIsInverted(), getIsAll());
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredInitOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
SmallVector<int64_t> positions;
|
||
|
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
|
||
|
if (!diag.succeeded())
|
||
|
return diag;
|
||
|
|
||
|
SmallVector<MappedValue> operandMapping;
|
||
|
operandMapping.reserve(positions.size());
|
||
|
for (int64_t position : positions) {
|
||
|
AffineMap indexingMap =
|
||
|
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
|
||
|
if (getPermutation() && !indexingMap.isPermutation()) {
|
||
|
return emitSilenceableError() << "the indexing map for output(init) #"
|
||
|
<< position << " is not a permutation";
|
||
|
}
|
||
|
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
|
||
|
return emitSilenceableError() << "the indexing map for output(init) #"
|
||
|
<< position << " is not a permutation";
|
||
|
}
|
||
|
|
||
|
// If capture not requested, skip it.
|
||
|
if (!getResult())
|
||
|
continue;
|
||
|
|
||
|
if (isa<AffineMapParamType>(getResult().getType())) {
|
||
|
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
Value operand = linalgOp.getDpsInitOperand(position)->get();
|
||
|
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
|
||
|
operandMapping.emplace_back(operand);
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
Operation *operandProducer = operand.getDefiningOp();
|
||
|
if (!operandProducer) {
|
||
|
return emitSilenceableError() << "output(init) #" << position
|
||
|
<< " is not produced by an operation";
|
||
|
}
|
||
|
operandMapping.emplace_back(operandProducer);
|
||
|
}
|
||
|
if (getResult())
|
||
|
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
|
||
|
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
|
||
|
DiagnosedSilenceableFailure diag = expandTargetSpecification(
|
||
|
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
|
||
|
op.getNumDpsInits(), positions);
|
||
|
if (diag.isSilenceableFailure()) {
|
||
|
diag.attachNote(op->getLoc())
|
||
|
<< "while considering DPS inits (outputs) of this payload operation";
|
||
|
}
|
||
|
return diag;
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredInitOp::verify() {
|
||
|
if (failed(verifyStructuredOperandOp(*this)))
|
||
|
return failure();
|
||
|
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||
|
getIsInverted(), getIsAll());
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredNumInputsOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredNumInputsOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
Attribute attr =
|
||
|
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
|
||
|
results.setParams(cast<OpResult>(getResult()), {attr});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredNumInitsOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredNumInitsOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
Attribute attr =
|
||
|
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
|
||
|
results.setParams(cast<OpResult>(getResult()), {attr});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredRankOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
|
||
|
Operation *current, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(current);
|
||
|
int64_t numLoops = linalgOp.getNumLoops();
|
||
|
Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
|
||
|
results.setParams(cast<OpResult>(getRank()), {attr});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredResultOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
|
||
|
Operation *op, transform::TransformResults &results,
|
||
|
transform::TransformState &state) {
|
||
|
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||
|
int64_t position;
|
||
|
DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
|
||
|
if (!diag.succeeded())
|
||
|
return diag;
|
||
|
|
||
|
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
|
||
|
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
|
||
|
results.setValues(cast<OpResult>(getResult()), {result});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
if (result.getUsers().empty()) {
|
||
|
return emitSilenceableError()
|
||
|
<< "no users of the result #" << getPosition();
|
||
|
}
|
||
|
Operation *firstUser = *result.getUsers().begin();
|
||
|
if (getAny()) {
|
||
|
results.set(cast<OpResult>(getResult()), {firstUser});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
if (getSingle()) {
|
||
|
if (!llvm::hasSingleElement(result.getUsers())) {
|
||
|
return emitSilenceableError()
|
||
|
<< "more than one result user with single user requested";
|
||
|
}
|
||
|
results.set(cast<OpResult>(getResult()), {firstUser});
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
return emitDefiniteFailure() << "unknown sub-predicate";
|
||
|
}
|
||
|
|
||
|
DiagnosedSilenceableFailure
|
||
|
transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
|
||
|
int64_t &position) {
|
||
|
auto rawPosition = static_cast<int64_t>(getPosition());
|
||
|
position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
|
||
|
if (position >= op.getNumDpsInits() || position < 0) {
|
||
|
return emitSilenceableError()
|
||
|
<< "position " << rawPosition
|
||
|
<< " overflows the number of results(ints) of the payload operation";
|
||
|
}
|
||
|
return DiagnosedSilenceableFailure::success();
|
||
|
}
|
||
|
|
||
|
LogicalResult transform::MatchStructuredResultOp::verify() {
|
||
|
if ((getAny() || getSingle()) ^
|
||
|
isa<TransformHandleTypeInterface>(getResult().getType())) {
|
||
|
return emitOpError() << "expects either the any/single keyword or the type "
|
||
|
"value handle result type";
|
||
|
}
|
||
|
if (getAny() && getSingle()) {
|
||
|
return emitOpError() << "'any' and 'single' are mutually exclusive";
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MatchStructuredYieldOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
void transform::MatchStructuredYieldOp::getEffects(
|
||
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||
|
onlyReadsHandle(getHandles(), effects);
|
||
|
onlyReadsPayload(effects);
|
||
|
}
|
||
|
|
||
|
void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
|
||
|
OperationState &state) {
|
||
|
build(builder, state, ValueRange());
|
||
|
}
|
||
|
|
||
|
#define GET_OP_CLASSES
|
||
|
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
|