259 lines
8.9 KiB
C++
259 lines
8.9 KiB
C++
|
//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
|
||
|
//
|
||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Verifiers for objects declared by IRDL.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
|
||
|
#include "mlir/IR/Attributes.h"
|
||
|
#include "mlir/IR/Block.h"
|
||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||
|
#include "mlir/IR/Diagnostics.h"
|
||
|
#include "mlir/IR/ExtensibleDialect.h"
|
||
|
#include "mlir/IR/Location.h"
|
||
|
#include "mlir/IR/Region.h"
|
||
|
#include "mlir/IR/Value.h"
|
||
|
#include "mlir/Support/LogicalResult.h"
|
||
|
#include "llvm/Support/FormatVariadic.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::irdl;
|
||
|
|
||
|
ConstraintVerifier::ConstraintVerifier(
|
||
|
ArrayRef<std::unique_ptr<Constraint>> constraints)
|
||
|
: constraints(constraints), assigned() {
|
||
|
assigned.resize(this->constraints.size());
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr, unsigned variable) {
|
||
|
|
||
|
assert(variable < constraints.size() && "invalid constraint variable");
|
||
|
|
||
|
// If the variable is already assigned, check that the attribute is the same.
|
||
|
if (assigned[variable].has_value()) {
|
||
|
if (attr == assigned[variable].value()) {
|
||
|
return success();
|
||
|
}
|
||
|
if (emitError)
|
||
|
return emitError() << "expected '" << assigned[variable].value()
|
||
|
<< "' but got '" << attr << "'";
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Otherwise, check the constraint and assign the attribute to the variable.
|
||
|
LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
|
||
|
if (succeeded(result))
|
||
|
assigned[variable] = attr;
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr,
|
||
|
ConstraintVerifier &context) const {
|
||
|
if (attr == expectedAttribute)
|
||
|
return success();
|
||
|
|
||
|
if (emitError)
|
||
|
return emitError() << "expected '" << expectedAttribute << "' but got '"
|
||
|
<< attr << "'";
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr, ConstraintVerifier &context) const {
|
||
|
if (attr.getTypeID() == baseTypeID)
|
||
|
return success();
|
||
|
|
||
|
if (emitError)
|
||
|
return emitError() << "expected base attribute '" << baseName
|
||
|
<< "' but got '" << attr.getAbstractAttribute().getName()
|
||
|
<< "'";
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr, ConstraintVerifier &context) const {
|
||
|
auto typeAttr = dyn_cast<TypeAttr>(attr);
|
||
|
if (!typeAttr) {
|
||
|
if (emitError)
|
||
|
return emitError() << "expected type, got attribute '" << attr;
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
Type type = typeAttr.getValue();
|
||
|
if (type.getTypeID() == baseTypeID)
|
||
|
return success();
|
||
|
|
||
|
if (emitError)
|
||
|
return emitError() << "expected base type '" << baseName << "' but got '"
|
||
|
<< type.getAbstractType().getName() << "'";
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
LogicalResult DynParametricAttrConstraint::verify(
|
||
|
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
|
||
|
ConstraintVerifier &context) const {
|
||
|
|
||
|
// Check that the base is the expected one.
|
||
|
auto dynAttr = dyn_cast<DynamicAttr>(attr);
|
||
|
if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
|
||
|
if (emitError) {
|
||
|
StringRef dialectName = attrDef->getDialect()->getNamespace();
|
||
|
StringRef attrName = attrDef->getName();
|
||
|
return emitError() << "expected base attribute '" << attrName << '.'
|
||
|
<< dialectName << "' but got '" << attr << "'";
|
||
|
}
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Check that the parameters satisfy the constraints.
|
||
|
ArrayRef<Attribute> params = dynAttr.getParams();
|
||
|
if (params.size() != constraints.size()) {
|
||
|
if (emitError) {
|
||
|
StringRef dialectName = attrDef->getDialect()->getNamespace();
|
||
|
StringRef attrName = attrDef->getName();
|
||
|
emitError() << "attribute '" << dialectName << "." << attrName
|
||
|
<< "' expects " << params.size() << " parameters but got "
|
||
|
<< constraints.size();
|
||
|
}
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
for (size_t i = 0, s = params.size(); i < s; i++)
|
||
|
if (failed(context.verify(emitError, params[i], constraints[i])))
|
||
|
return failure();
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult DynParametricTypeConstraint::verify(
|
||
|
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
|
||
|
ConstraintVerifier &context) const {
|
||
|
// Check that the base is a TypeAttr.
|
||
|
auto typeAttr = dyn_cast<TypeAttr>(attr);
|
||
|
if (!typeAttr) {
|
||
|
if (emitError)
|
||
|
return emitError() << "expected type, got attribute '" << attr;
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Check that the type base is the expected one.
|
||
|
auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
|
||
|
if (!dynType || dynType.getTypeDef() != typeDef) {
|
||
|
if (emitError) {
|
||
|
StringRef dialectName = typeDef->getDialect()->getNamespace();
|
||
|
StringRef attrName = typeDef->getName();
|
||
|
return emitError() << "expected base type '" << dialectName << '.'
|
||
|
<< attrName << "' but got '" << attr << "'";
|
||
|
}
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Check that the parameters satisfy the constraints.
|
||
|
ArrayRef<Attribute> params = dynType.getParams();
|
||
|
if (params.size() != constraints.size()) {
|
||
|
if (emitError) {
|
||
|
StringRef dialectName = typeDef->getDialect()->getNamespace();
|
||
|
StringRef attrName = typeDef->getName();
|
||
|
emitError() << "attribute '" << dialectName << "." << attrName
|
||
|
<< "' expects " << params.size() << " parameters but got "
|
||
|
<< constraints.size();
|
||
|
}
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
for (size_t i = 0, s = params.size(); i < s; i++)
|
||
|
if (failed(context.verify(emitError, params[i], constraints[i])))
|
||
|
return failure();
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr, ConstraintVerifier &context) const {
|
||
|
for (unsigned constr : constraints) {
|
||
|
// We do not pass the `emitError` here, since we want to emit an error
|
||
|
// only if none of the constraints are satisfied.
|
||
|
if (succeeded(context.verify({}, attr, constr))) {
|
||
|
return success();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (emitError)
|
||
|
return emitError() << "'" << attr << "' does not satisfy the constraint";
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr, ConstraintVerifier &context) const {
|
||
|
for (unsigned constr : constraints) {
|
||
|
if (failed(context.verify(emitError, attr, constr))) {
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
|
||
|
Attribute attr,
|
||
|
ConstraintVerifier &context) const {
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult RegionConstraint::verify(mlir::Region ®ion,
|
||
|
ConstraintVerifier &constraintContext) {
|
||
|
const auto emitError = [parentOp = region.getParentOp()](mlir::Location loc) {
|
||
|
return [loc, parentOp] {
|
||
|
InFlightDiagnostic diag = mlir::emitError(loc);
|
||
|
// If we already have been given location of the parent operation, which
|
||
|
// might happen when the region location is passed, we do not want to
|
||
|
// produce the note on the same location
|
||
|
if (loc != parentOp->getLoc())
|
||
|
diag.attachNote(parentOp->getLoc()).append("see the operation");
|
||
|
return diag;
|
||
|
};
|
||
|
};
|
||
|
|
||
|
if (blockCount.has_value() && *blockCount != region.getBlocks().size()) {
|
||
|
return emitError(region.getLoc())()
|
||
|
<< "expected region " << region.getRegionNumber() << " to have "
|
||
|
<< *blockCount << " block(s) but got " << region.getBlocks().size();
|
||
|
}
|
||
|
|
||
|
if (argumentConstraints.has_value()) {
|
||
|
auto actualArgs = region.getArguments();
|
||
|
if (actualArgs.size() != argumentConstraints->size()) {
|
||
|
const mlir::Location firstArgLoc =
|
||
|
actualArgs.empty() ? region.getLoc() : actualArgs.front().getLoc();
|
||
|
return emitError(firstArgLoc)()
|
||
|
<< "expected region " << region.getRegionNumber() << " to have "
|
||
|
<< argumentConstraints->size() << " arguments but got "
|
||
|
<< actualArgs.size();
|
||
|
}
|
||
|
|
||
|
for (auto [arg, constraint] : llvm::zip(actualArgs, *argumentConstraints)) {
|
||
|
mlir::Attribute type = TypeAttr::get(arg.getType());
|
||
|
if (failed(constraintContext.verify(emitError(arg.getLoc()), type,
|
||
|
constraint))) {
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return success();
|
||
|
}
|