//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" using namespace mlir; //===----------------------------------------------------------------------===// // Printing and parsing for match ops. //===----------------------------------------------------------------------===// /// Keyword syntax for positional specification inversion. constexpr const static llvm::StringLiteral kDimExceptKeyword = "except"; /// Keyword syntax for full inclusion in positional specification. constexpr const static llvm::StringLiteral kDimAllKeyword = "all"; ParseResult transform::parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll) { Builder &builder = parser.getBuilder(); if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) { rawDimList = builder.getDenseI64ArrayAttr({}); isInverted = nullptr; isAll = builder.getUnitAttr(); return success(); } isAll = nullptr; isInverted = nullptr; if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) { isInverted = builder.getUnitAttr(); } if (isInverted) { if (parser.parseLParen().failed()) return failure(); } SmallVector values; ParseResult listResult = parser.parseCommaSeparatedList( [&]() { return parser.parseInteger(values.emplace_back()); }); if (listResult.failed()) return failure(); rawDimList = builder.getDenseI64ArrayAttr(values); if (isInverted) { if (parser.parseRParen().failed()) return failure(); } return success(); } void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll) { if (isAll) { printer << kDimAllKeyword; return; } if (isInverted) { printer << kDimExceptKeyword << "("; } llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(), [&](int64_t value) { printer << value; }); if (isInverted) { printer << ")"; } } LogicalResult transform::verifyTransformMatchDimsOp(Operation *op, ArrayRef raw, bool inverted, bool all) { if (all) { if (inverted) { return op->emitOpError() << "cannot request both 'all' and 'inverted' values in the list"; } if (!raw.empty()) { return op->emitOpError() << "cannot both request 'all' and specific values in the list"; } } if (!all && raw.empty()) { return op->emitOpError() << "must request specific values in the list if " "'all' is not specified"; } SmallVector rawVector = llvm::to_vector(raw); auto *it = std::unique(rawVector.begin(), rawVector.end()); if (it != rawVector.end()) return op->emitOpError() << "expected the listed values to be unique"; return success(); } DiagnosedSilenceableFailure transform::expandTargetSpecification( Location loc, bool isAll, bool isInverted, ArrayRef rawList, int64_t maxNumber, SmallVectorImpl &result) { assert(maxNumber > 0 && "expected size to be positive"); assert(!(isAll && isInverted) && "cannot invert all"); if (isAll) { result = llvm::to_vector(llvm::seq(0, maxNumber)); return DiagnosedSilenceableFailure::success(); } SmallVector expanded; llvm::SmallDenseSet visited; expanded.reserve(rawList.size()); SmallVectorImpl &target = isInverted ? expanded : result; for (int64_t raw : rawList) { int64_t updated = raw < 0 ? maxNumber + raw : raw; if (updated >= maxNumber) { return emitSilenceableFailure(loc) << "position overflow " << updated << " (updated from " << raw << ") for maximum " << maxNumber; } if (updated < 0) { return emitSilenceableFailure(loc) << "position underflow " << updated << " (updated from " << raw << ")"; } if (!visited.insert(updated).second) { return emitSilenceableFailure(loc) << "repeated position " << updated << " (updated from " << raw << ")"; } target.push_back(updated); } if (!isInverted) return DiagnosedSilenceableFailure::success(); result.reserve(result.size() + (maxNumber - expanded.size())); for (int64_t candidate : llvm::seq(0, maxNumber)) { if (llvm::is_contained(expanded, candidate)) continue; result.push_back(candidate); } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc"