152 lines
5.4 KiB
C++
152 lines
5.4 KiB
C++
//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Printing and parsing for match ops.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Keyword syntax for positional specification inversion.
|
|
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
|
|
|
|
/// Keyword syntax for full inclusion in positional specification.
|
|
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
|
|
|
|
ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
|
|
DenseI64ArrayAttr &rawDimList,
|
|
UnitAttr &isInverted,
|
|
UnitAttr &isAll) {
|
|
Builder &builder = parser.getBuilder();
|
|
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
|
|
rawDimList = builder.getDenseI64ArrayAttr({});
|
|
isInverted = nullptr;
|
|
isAll = builder.getUnitAttr();
|
|
return success();
|
|
}
|
|
|
|
isAll = nullptr;
|
|
isInverted = nullptr;
|
|
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
|
|
isInverted = builder.getUnitAttr();
|
|
}
|
|
|
|
if (isInverted) {
|
|
if (parser.parseLParen().failed())
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<int64_t> values;
|
|
ParseResult listResult = parser.parseCommaSeparatedList(
|
|
[&]() { return parser.parseInteger(values.emplace_back()); });
|
|
if (listResult.failed())
|
|
return failure();
|
|
|
|
rawDimList = builder.getDenseI64ArrayAttr(values);
|
|
|
|
if (isInverted) {
|
|
if (parser.parseRParen().failed())
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
|
|
DenseI64ArrayAttr rawDimList,
|
|
UnitAttr isInverted, UnitAttr isAll) {
|
|
if (isAll) {
|
|
printer << kDimAllKeyword;
|
|
return;
|
|
}
|
|
if (isInverted) {
|
|
printer << kDimExceptKeyword << "(";
|
|
}
|
|
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
|
|
[&](int64_t value) { printer << value; });
|
|
if (isInverted) {
|
|
printer << ")";
|
|
}
|
|
}
|
|
|
|
LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
|
|
ArrayRef<int64_t> raw,
|
|
bool inverted, bool all) {
|
|
if (all) {
|
|
if (inverted) {
|
|
return op->emitOpError()
|
|
<< "cannot request both 'all' and 'inverted' values in the list";
|
|
}
|
|
if (!raw.empty()) {
|
|
return op->emitOpError()
|
|
<< "cannot both request 'all' and specific values in the list";
|
|
}
|
|
}
|
|
if (!all && raw.empty()) {
|
|
return op->emitOpError() << "must request specific values in the list if "
|
|
"'all' is not specified";
|
|
}
|
|
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
|
|
auto *it = std::unique(rawVector.begin(), rawVector.end());
|
|
if (it != rawVector.end())
|
|
return op->emitOpError() << "expected the listed values to be unique";
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure transform::expandTargetSpecification(
|
|
Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
|
|
int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
|
|
assert(maxNumber > 0 && "expected size to be positive");
|
|
assert(!(isAll && isInverted) && "cannot invert all");
|
|
if (isAll) {
|
|
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
SmallVector<int64_t> expanded;
|
|
llvm::SmallDenseSet<int64_t> visited;
|
|
expanded.reserve(rawList.size());
|
|
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
|
|
for (int64_t raw : rawList) {
|
|
int64_t updated = raw < 0 ? maxNumber + raw : raw;
|
|
if (updated >= maxNumber) {
|
|
return emitSilenceableFailure(loc)
|
|
<< "position overflow " << updated << " (updated from " << raw
|
|
<< ") for maximum " << maxNumber;
|
|
}
|
|
if (updated < 0) {
|
|
return emitSilenceableFailure(loc) << "position underflow " << updated
|
|
<< " (updated from " << raw << ")";
|
|
}
|
|
if (!visited.insert(updated).second) {
|
|
return emitSilenceableFailure(loc) << "repeated position " << updated
|
|
<< " (updated from " << raw << ")";
|
|
}
|
|
target.push_back(updated);
|
|
}
|
|
|
|
if (!isInverted)
|
|
return DiagnosedSilenceableFailure::success();
|
|
|
|
result.reserve(result.size() + (maxNumber - expanded.size()));
|
|
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
|
|
if (llvm::is_contained(expanded, candidate))
|
|
continue;
|
|
result.push_back(candidate);
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Generated interface implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc"
|