//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; #define DEBUG_TYPE "linalg-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") //===----------------------------------------------------------------------===// // StructuredMatchOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { // First, check if the payload operation is a structured Linalg operation. if (!isa(current)) { if (getFailurePropagationMode().value_or( FailurePropagationMode::Propagate) == FailurePropagationMode::Propagate) { return emitSilenceableError() << "expected a Linalg op"; } // If errors are suppressed, succeed and set all results to empty lists. LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); results.setRemainingToEmpty(cast(getOperation())); return DiagnosedSilenceableFailure::success(); } // Bind `current` to the block argument. auto scope = state.make_region_scope(getBodyRegion()); if (failed(state.mapBlockArgument(getBody()->getArgument(0), MappedValue(current)))) { return DiagnosedSilenceableFailure::definiteFailure(); } for (Operation &nested : getBody()->without_terminator()) { DiagnosedSilenceableFailure diag = state.applyTransform(cast(nested)); if (diag.isDefiniteFailure()) return diag; if (diag.succeeded()) continue; // If propagating errors, do this immediately. assert(diag.isSilenceableFailure()); if (getFailurePropagationMode().value_or( FailurePropagationMode::Propagate) == FailurePropagationMode::Propagate) { return diag; } // If suppressing errors, print the message into the debug stream before // silencing it. Then set all results value that are already known. // Results come from the terminator operands, which may be defined in the // (single) block of this operation or above it. When they are defined // above, they are known to be mapped at this point per SSA dominance. // When they are defined in this block, we additionally check if we have // already applied the operation that defines them. If not, the // corresponding results will be set to empty lists. LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() << "\n"); (void)diag.silence(); SmallVector undefinedOperands; for (OpOperand &terminatorOperand : getBody()->getTerminator()->getOpOperands()) { Operation *definingOp = terminatorOperand.get().getDefiningOp(); if (!definingOp) continue; if (definingOp->getBlock() != getBody()) continue; if (definingOp->isBeforeInBlock(&nested)) continue; undefinedOperands.push_back(&terminatorOperand); } SmallVector> mappings; auto filtered = llvm::make_filter_range( getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) { return !llvm::is_contained(undefinedOperands, &opOperand); }); SmallVector definedOperands = llvm::to_vector(llvm::map_range( filtered, [](OpOperand &opOperand) { return opOperand.get(); })); detail::prepareValueMappings(mappings, definedOperands, state); for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) { results.setMappedValues(getResults()[operand.getOperandNumber()], mapping); } results.setRemainingToEmpty(cast(getOperation())); return DiagnosedSilenceableFailure::success(); } // Set the results. detail::forwardTerminatorOperands(getBody(), state, results); return DiagnosedSilenceableFailure::success(); } void transform::MatchStructuredOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getCurrent(), effects); onlyReadsPayload(effects); producesHandle(getOutputs(), effects); } LogicalResult transform::MatchStructuredOp::verify() { if (getBody()->getNumArguments() != 1) return emitOpError() << "expected one body argument"; if (!isa(getBody()->getArgument(0).getType())) { return emitOpError() << "expected body argument to implement " "TransformHandleTypeInterface"; } for (Operation &nested : getBody()->without_terminator()) { if (isa(nested)) continue; InFlightDiagnostic diag = emitOpError() << "expects nested operations to implement MatchOpInterface"; diag.attachNote(nested.getLoc()) << "offending operation"; return diag; } return success(); } //===----------------------------------------------------------------------===// // StructuredOpPredicateOpTrait //===----------------------------------------------------------------------===// LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait( Operation *op, Value structuredOpHandle) { if (!isa_and_nonnull(op->getParentOp())) { return op->emitOpError() << "expects parent op to be '" << MatchStructuredOp::getOperationName() << "'"; } // Bail out here, let the verifier of the parent complain. Operation *parent = op->getParentOp(); if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() || parent->getRegion(0).front().getNumArguments() < 1) return success(); if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) { return op->emitOpError() << "expected predicate to apply to the surrounding structured op"; } return success(); } //===----------------------------------------------------------------------===// // MatchStructuredBodyOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); if (std::optional position = getReductionPosition()) { SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), *position, combinerOps)) { return emitSilenceableError() << "could not match reduction"; } if (combinerOps.size() != 1) { return emitSilenceableError() << "reduction combiner is not a single op"; } return DiagnosedSilenceableFailure::success(); } if (getPassthrough()) { Block &body = linalgOp->getRegion(0).front(); if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) { return emitSilenceableError() << "not a passthrough"; } return DiagnosedSilenceableFailure::success(); } if (std::optional contractionOps = getContraction()) { Block &body = linalgOp->getRegion(0).front(); std::string message; llvm::raw_string_ostream os(message); bool result = linalg::detail::isContractionBody( body, [&](Operation *elem, Operation *red) { return elem->getName().getStringRef() == (*contractionOps)[0].cast().getValue() && red->getName().getStringRef() == (*contractionOps)[1].cast().getValue(); }, os); if (result) return DiagnosedSilenceableFailure::success(); return emitSilenceableError() << "contraction: " << os.str(); } return emitDefiniteFailure() << "unknown body condition"; } LogicalResult transform::MatchStructuredBodyOp::verify() { int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + getContraction().has_value(); if (numOptions > 1) { std::string attributeNames; llvm::raw_string_ostream os(attributeNames); llvm::interleaveComma(ArrayRef{getReductionPositionAttrName(), getPassthroughAttrName(), getContractionAttrName()}, os); return emitOpError() << "only one of {" << os.str() << "} is allowed"; } if (std::optional contractionAttr = getContraction()) { if (contractionAttr->size() != 2) { return emitOpError() << "expects " << getContractionAttrName() << " to contain two elements"; } } return success(); } //===----------------------------------------------------------------------===// // MatchStructuredClassifyContractionDimsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredClassifyContractionDimsOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { FailureOr contractionDims = linalg::inferContractionDims(cast(current)); if (failed(contractionDims)) return emitSilenceableError() << "could not infer contraction dimensions"; MLIRContext *context = current->getContext(); Builder builder(context); auto makeI64Attrs = [&](ArrayRef values) { return llvm::to_vector( llvm::map_range(values, [&](unsigned value) -> Attribute { return builder.getI64IntegerAttr(value); })); }; results.setParams(getBatch().cast(), makeI64Attrs(contractionDims->batch)); results.setParams(getM().cast(), makeI64Attrs(contractionDims->m)); results.setParams(getN().cast(), makeI64Attrs(contractionDims->n)); results.setParams(getK().cast(), makeI64Attrs(contractionDims->k)); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MatchStructuredClassifyConvolutionDimsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { FailureOr convolutionDims = linalg::inferConvolutionDims(cast(current)); if (failed(convolutionDims)) return emitSilenceableError() << "could not infer convolution dimensions"; MLIRContext *context = current->getContext(); Builder builder(context); auto makeI64Attrs = [&](ArrayRef values) { return llvm::to_vector( llvm::map_range(values, [&](unsigned value) -> Attribute { return builder.getI64IntegerAttr(value); })); }; results.setParams(getBatch().cast(), makeI64Attrs(convolutionDims->batch)); results.setParams(getOutputImage().cast(), makeI64Attrs(convolutionDims->outputImage)); results.setParams(getOutputChannel().cast(), makeI64Attrs(convolutionDims->outputChannel)); results.setParams(getFilterLoop().cast(), makeI64Attrs(convolutionDims->filterLoop)); results.setParams(getInputChannel().cast(), makeI64Attrs(convolutionDims->inputChannel)); results.setParams(getDepth().cast(), makeI64Attrs(convolutionDims->depth)); auto makeI64AttrsFromI64 = [&](ArrayRef values) { return llvm::to_vector( llvm::map_range(values, [&](int64_t value) -> Attribute { return builder.getI64IntegerAttr(value); })); }; results.setParams(getStrides().cast(), makeI64AttrsFromI64(convolutionDims->strides)); results.setParams(getDilations().cast(), makeI64AttrsFromI64(convolutionDims->dilations)); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Utilities for structured match predicates. //===----------------------------------------------------------------------===// /// Checks if all values from `list` are also contained in `reference`. Returns /// a silenceable error with the given message at the given location when it is /// not the case. The error message must contain the "{0}" placeholder that /// will be substituted with the value from `list` that is not contained in /// `reference`. static DiagnosedSilenceableFailure containsAll(ArrayRef reference, ArrayRef list, Location loc, const char *message) { for (int64_t value : list) { if (llvm::any_of(reference, [&](unsigned ref) { return static_cast(ref) == value; })) { continue; } return emitSilenceableFailure(loc) << llvm::formatv(message, value); } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MatchStructuredDimOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); SmallVector dimensions; DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions); if (!diag.succeeded()) return diag; // If asked to check for the kind of dimension, perform the check. if (getParallel() || getReduction()) { SmallVector reference; if (getParallel()) linalgOp.getParallelDims(reference); else if (getReduction()) linalgOp.getReductionDims(reference); DiagnosedSilenceableFailure diag = containsAll(reference, dimensions, getLoc(), getParallel() ? "expects dimension #{0} to be parallel" : "expects dimension #{0} to be reduction"); if (!diag.succeeded()) return diag; } // If not capturing, we are done here. if (!getResult()) return diag; SmallVector ranges = linalgOp.getStaticLoopRanges(); Builder builder(current); SmallVector captured = llvm::to_vector( llvm::map_range(dimensions, [&](int64_t dim) -> Attribute { return builder.getI64IntegerAttr(ranges[dim]); })); results.setParams(cast(getResult()), captured); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor( linalg::LinalgOp op, SmallVectorImpl &dims) { DiagnosedSilenceableFailure diag = expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(), getRawDimList(), op.getNumLoops(), dims); if (diag.isSilenceableFailure()) { diag.attachNote(op->getLoc()) << "while considering dimensions of this payload operation"; } return diag; } LogicalResult transform::MatchStructuredDimOp::verify() { if (getParallel() && getReduction()) { return emitOpError() << "cannot request the same dimension to be both " "parallel and reduction"; } return verifyTransformMatchDimsOp(getOperation(), getRawDimList(), getIsInverted(), getIsAll()); } //===----------------------------------------------------------------------===// // MatchStructuredElementalBitwidthOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredElementalBitwidthOp::matchValue( Value current, transform::TransformResults &results, transform::TransformState &state) { auto setupResult = [&](int64_t bitwidth) { Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth); results.setParams(cast(getResult()), {attr}); return DiagnosedSilenceableFailure::success(); }; Type type = current.getType(); if (type.isIntOrFloat()) return setupResult(type.getIntOrFloatBitWidth()); if (auto shapedType = dyn_cast(type)) { if (shapedType.getElementType().isIntOrFloat()) return setupResult(shapedType.getElementTypeBitWidth()); } return emitSilenceableError() << "unsupported type for bitwidth extraction: " << type; } //===----------------------------------------------------------------------===// // MatchStructuredInputOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); SmallVector positions; DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); if (!diag.succeeded()) return diag; SmallVector operandMapping; operandMapping.reserve(positions.size()); for (int64_t position : positions) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position)); if (getPermutation() && !indexingMap.isPermutation()) { return emitSilenceableError() << "the indexing map for input #" << position << " is not a permutation"; } if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { return emitSilenceableError() << "the indexing map for input #" << position << " is not a projected permutation"; } // If capture not requested, skip it. if (!getResult()) continue; if (isa(getResult().getType())) { operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); continue; } Value operand = linalgOp.getDpsInputOperand(position)->get(); if (isa(getResult().getType())) { operandMapping.emplace_back(operand); continue; } Operation *operandProducer = operand.getDefiningOp(); if (!operandProducer) { return emitSilenceableError() << "input #" << position << " is not produced by an operation"; } operandMapping.emplace_back(operandProducer); } if (getResult()) results.setMappedValues(cast(getResult()), operandMapping); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor( linalg::LinalgOp op, SmallVectorImpl &positions) { DiagnosedSilenceableFailure diag = expandTargetSpecification( getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), op.getNumDpsInputs(), positions); if (diag.isSilenceableFailure()) { diag.attachNote(op->getLoc()) << "while considering DPS inputs of this payload operation"; } return diag; } /// Verifies a matcher op for structured input or output, specifically the /// attributes specifying the operand positions. template LogicalResult verifyStructuredOperandOp(OpTy op) { if (op.getPermutation() && op.getProjectedPermutation()) { return op.emitOpError() << op.getPermutationAttrName() << " and " << op.getProjectedPermutationAttrName() << " are mutually exclusive"; } if (op.getRawPositionList().size() > 1 && op.getResult()) { return op.emitOpError() << "cannot bind multiple inputs/inits to the same value"; } return success(); } LogicalResult transform::MatchStructuredInputOp::verify() { if (failed(verifyStructuredOperandOp(*this))) return failure(); return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), getIsInverted(), getIsAll()); } //===----------------------------------------------------------------------===// // MatchStructuredInitOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); SmallVector positions; DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions); if (!diag.succeeded()) return diag; SmallVector operandMapping; operandMapping.reserve(positions.size()); for (int64_t position : positions) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position)); if (getPermutation() && !indexingMap.isPermutation()) { return emitSilenceableError() << "the indexing map for output(init) #" << position << " is not a permutation"; } if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) { return emitSilenceableError() << "the indexing map for output(init) #" << position << " is not a permutation"; } // If capture not requested, skip it. if (!getResult()) continue; if (isa(getResult().getType())) { operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); continue; } Value operand = linalgOp.getDpsInitOperand(position)->get(); if (isa(getResult().getType())) { operandMapping.emplace_back(operand); continue; } Operation *operandProducer = operand.getDefiningOp(); if (!operandProducer) { return emitSilenceableError() << "output(init) #" << position << " is not produced by an operation"; } operandMapping.emplace_back(operandProducer); } if (getResult()) results.setMappedValues(cast(getResult()), operandMapping); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor( linalg::LinalgOp op, SmallVectorImpl &positions) { DiagnosedSilenceableFailure diag = expandTargetSpecification( getLoc(), getIsAll(), getIsInverted(), getRawPositionList(), op.getNumDpsInits(), positions); if (diag.isSilenceableFailure()) { diag.attachNote(op->getLoc()) << "while considering DPS inits (outputs) of this payload operation"; } return diag; } LogicalResult transform::MatchStructuredInitOp::verify() { if (failed(verifyStructuredOperandOp(*this))) return failure(); return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(), getIsInverted(), getIsAll()); } //===----------------------------------------------------------------------===// // MatchStructuredNumInputsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredNumInputsOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); Attribute attr = Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs()); results.setParams(cast(getResult()), {attr}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MatchStructuredNumInitsOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredNumInitsOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); Attribute attr = Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits()); results.setParams(cast(getResult()), {attr}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MatchStructuredRankOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation( Operation *current, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(current); int64_t numLoops = linalgOp.getNumLoops(); Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops); results.setParams(cast(getRank()), {attr}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MatchStructuredResultOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( Operation *op, transform::TransformResults &results, transform::TransformState &state) { auto linalgOp = cast(op); int64_t position; DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position); if (!diag.succeeded()) return diag; Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); if (isa(getResult().getType())) { results.setValues(cast(getResult()), {result}); return DiagnosedSilenceableFailure::success(); } if (result.getUsers().empty()) { return emitSilenceableError() << "no users of the result #" << getPosition(); } Operation *firstUser = *result.getUsers().begin(); if (getAny()) { results.set(cast(getResult()), {firstUser}); return DiagnosedSilenceableFailure::success(); } if (getSingle()) { if (!llvm::hasSingleElement(result.getUsers())) { return emitSilenceableError() << "more than one result user with single user requested"; } results.set(cast(getResult()), {firstUser}); return DiagnosedSilenceableFailure::success(); } return emitDefiniteFailure() << "unknown sub-predicate"; } DiagnosedSilenceableFailure transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, int64_t &position) { auto rawPosition = static_cast(getPosition()); position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition; if (position >= op.getNumDpsInits() || position < 0) { return emitSilenceableError() << "position " << rawPosition << " overflows the number of results(ints) of the payload operation"; } return DiagnosedSilenceableFailure::success(); } LogicalResult transform::MatchStructuredResultOp::verify() { if ((getAny() || getSingle()) ^ isa(getResult().getType())) { return emitOpError() << "expects either the any/single keyword or the type " "value handle result type"; } if (getAny() && getSingle()) { return emitOpError() << "'any' and 'single' are mutually exclusive"; } return success(); } //===----------------------------------------------------------------------===// // MatchStructuredYieldOp //===----------------------------------------------------------------------===// void transform::MatchStructuredYieldOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getHandles(), effects); onlyReadsPayload(effects); } void transform::MatchStructuredYieldOp::build(OpBuilder &builder, OperationState &state) { build(builder, state, ValueRange()); } #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"