281 lines
9.3 KiB
C++
281 lines
9.3 KiB
C++
|
//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
|
||
|
//
|
||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDL.h"
|
||
|
#include "mlir/IR/Builders.h"
|
||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||
|
#include "mlir/IR/Diagnostics.h"
|
||
|
#include "mlir/IR/DialectImplementation.h"
|
||
|
#include "mlir/IR/ExtensibleDialect.h"
|
||
|
#include "mlir/IR/OpDefinition.h"
|
||
|
#include "mlir/IR/OpImplementation.h"
|
||
|
#include "mlir/IR/Operation.h"
|
||
|
#include "mlir/Support/LLVM.h"
|
||
|
#include "mlir/Support/LogicalResult.h"
|
||
|
#include "llvm/ADT/STLExtras.h"
|
||
|
#include "llvm/ADT/TypeSwitch.h"
|
||
|
#include "llvm/IR/Metadata.h"
|
||
|
#include "llvm/Support/Casting.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::irdl;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// IRDL dialect.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
|
||
|
|
||
|
void IRDLDialect::initialize() {
|
||
|
addOperations<
|
||
|
#define GET_OP_LIST
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
|
||
|
>();
|
||
|
addTypes<
|
||
|
#define GET_TYPEDEF_LIST
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
|
||
|
>();
|
||
|
addAttributes<
|
||
|
#define GET_ATTRDEF_LIST
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
|
||
|
>();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Parsing/Printing
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Parse a region, and add a single block if the region is empty.
|
||
|
/// If no region is parsed, create a new region with a single empty block.
|
||
|
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion) {
|
||
|
auto regionParseRes = p.parseOptionalRegion(region);
|
||
|
if (regionParseRes.has_value() && failed(regionParseRes.value()))
|
||
|
return failure();
|
||
|
|
||
|
// If the region is empty, add a single empty block.
|
||
|
if (region.empty())
|
||
|
region.push_back(new Block());
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op,
|
||
|
Region ®ion) {
|
||
|
if (!region.getBlocks().front().empty())
|
||
|
p.printRegion(region);
|
||
|
}
|
||
|
|
||
|
LogicalResult DialectOp::verify() {
|
||
|
if (!Dialect::isValidNamespace(getName()))
|
||
|
return emitOpError("invalid dialect name");
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult OperandsOp::verify() {
|
||
|
size_t numVariadicities = getVariadicity().size();
|
||
|
size_t numOperands = getNumOperands();
|
||
|
|
||
|
if (numOperands != numVariadicities)
|
||
|
return emitOpError()
|
||
|
<< "the number of operands and their variadicities must be "
|
||
|
"the same, but got "
|
||
|
<< numOperands << " and " << numVariadicities << " respectively";
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult ResultsOp::verify() {
|
||
|
size_t numVariadicities = getVariadicity().size();
|
||
|
size_t numOperands = this->getNumOperands();
|
||
|
|
||
|
if (numOperands != numVariadicities)
|
||
|
return emitOpError()
|
||
|
<< "the number of operands and their variadicities must be "
|
||
|
"the same, but got "
|
||
|
<< numOperands << " and " << numVariadicities << " respectively";
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult AttributesOp::verify() {
|
||
|
size_t namesSize = getAttributeValueNames().size();
|
||
|
size_t valuesSize = getAttributeValues().size();
|
||
|
|
||
|
if (namesSize != valuesSize)
|
||
|
return emitOpError()
|
||
|
<< "the number of attribute names and their constraints must be "
|
||
|
"the same but got "
|
||
|
<< namesSize << " and " << valuesSize << " respectively";
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult BaseOp::verify() {
|
||
|
std::optional<StringRef> baseName = getBaseName();
|
||
|
std::optional<SymbolRefAttr> baseRef = getBaseRef();
|
||
|
if (baseName.has_value() == baseRef.has_value())
|
||
|
return emitOpError() << "the base type or attribute should be specified by "
|
||
|
"either a name or a reference";
|
||
|
|
||
|
if (baseName &&
|
||
|
(baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
|
||
|
return emitOpError() << "the base type or attribute name should start with "
|
||
|
"'!' or '#'";
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||
|
std::optional<SymbolRefAttr> baseRef = getBaseRef();
|
||
|
if (!baseRef)
|
||
|
return success();
|
||
|
|
||
|
TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
|
||
|
if (typeOp)
|
||
|
return success();
|
||
|
|
||
|
AttributeOp attrOp =
|
||
|
symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
|
||
|
if (attrOp)
|
||
|
return success();
|
||
|
|
||
|
return emitOpError() << "'" << *baseRef
|
||
|
<< "' does not refer to a type or attribute definition";
|
||
|
}
|
||
|
|
||
|
/// Parse a value with its variadicity first. By default, the variadicity is
|
||
|
/// single.
|
||
|
///
|
||
|
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
|
||
|
static ParseResult
|
||
|
parseValueWithVariadicity(OpAsmParser &p,
|
||
|
OpAsmParser::UnresolvedOperand &operand,
|
||
|
VariadicityAttr &variadicityAttr) {
|
||
|
MLIRContext *ctx = p.getBuilder().getContext();
|
||
|
|
||
|
// Parse the variadicity, if present
|
||
|
if (p.parseOptionalKeyword("single").succeeded()) {
|
||
|
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
|
||
|
} else if (p.parseOptionalKeyword("optional").succeeded()) {
|
||
|
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
|
||
|
} else if (p.parseOptionalKeyword("variadic").succeeded()) {
|
||
|
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
|
||
|
} else {
|
||
|
variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
|
||
|
}
|
||
|
|
||
|
// Parse the value
|
||
|
if (p.parseOperand(operand))
|
||
|
return failure();
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
/// Parse a list of values with their variadicities first. By default, the
|
||
|
/// variadicity is single.
|
||
|
///
|
||
|
/// values-with-variadicity ::=
|
||
|
/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
|
||
|
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
|
||
|
static ParseResult parseValuesWithVariadicity(
|
||
|
OpAsmParser &p, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
|
||
|
VariadicityArrayAttr &variadicityAttr) {
|
||
|
Builder &builder = p.getBuilder();
|
||
|
MLIRContext *ctx = builder.getContext();
|
||
|
SmallVector<VariadicityAttr> variadicities;
|
||
|
|
||
|
// Parse a single value with its variadicity
|
||
|
auto parseOne = [&] {
|
||
|
OpAsmParser::UnresolvedOperand operand;
|
||
|
VariadicityAttr variadicity;
|
||
|
if (parseValueWithVariadicity(p, operand, variadicity))
|
||
|
return failure();
|
||
|
operands.push_back(operand);
|
||
|
variadicities.push_back(variadicity);
|
||
|
return success();
|
||
|
};
|
||
|
|
||
|
if (p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOne))
|
||
|
return failure();
|
||
|
variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
/// Print a list of values with their variadicities first. By default, the
|
||
|
/// variadicity is single.
|
||
|
///
|
||
|
/// values-with-variadicity ::=
|
||
|
/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
|
||
|
/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
|
||
|
static void printValuesWithVariadicity(OpAsmPrinter &p, Operation *op,
|
||
|
OperandRange operands,
|
||
|
VariadicityArrayAttr variadicityAttr) {
|
||
|
p << "(";
|
||
|
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
|
||
|
Variadicity variadicity = variadicityAttr[i].getValue();
|
||
|
if (variadicity != Variadicity::single) {
|
||
|
p << stringifyVariadicity(variadicity) << " ";
|
||
|
}
|
||
|
p << operands[i];
|
||
|
});
|
||
|
p << ")";
|
||
|
}
|
||
|
|
||
|
static ParseResult
|
||
|
parseAttributesOp(OpAsmParser &p,
|
||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
|
||
|
ArrayAttr &attrNamesAttr) {
|
||
|
Builder &builder = p.getBuilder();
|
||
|
SmallVector<Attribute> attrNames;
|
||
|
if (succeeded(p.parseOptionalLBrace())) {
|
||
|
auto parseOperands = [&]() {
|
||
|
if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
|
||
|
p.parseOperand(attrOperands.emplace_back()))
|
||
|
return failure();
|
||
|
return success();
|
||
|
};
|
||
|
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
|
||
|
return failure();
|
||
|
}
|
||
|
attrNamesAttr = builder.getArrayAttr(attrNames);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
|
||
|
OperandRange attrArgs, ArrayAttr attrNames) {
|
||
|
if (attrNames.empty())
|
||
|
return;
|
||
|
p << "{";
|
||
|
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
|
||
|
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
|
||
|
p << '}';
|
||
|
}
|
||
|
|
||
|
LogicalResult RegionOp::verify() {
|
||
|
if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
|
||
|
if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
|
||
|
return emitOpError("the number of blocks is expected to be >= 1 but got ")
|
||
|
<< number;
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
|
||
|
|
||
|
#define GET_TYPEDEF_CLASSES
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
|
||
|
|
||
|
#define GET_ATTRDEF_CLASSES
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
|
||
|
|
||
|
#define GET_OP_CLASSES
|
||
|
#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
|