//===- IRDLOps.cpp - IRDL dialect -------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/IR/ValueRange.h" #include using namespace mlir; using namespace mlir::irdl; /// Maps given `args` to the index in the `valueToConstr` static SmallVector getConstraintIndicesForArgs(mlir::OperandRange args, ArrayRef valueToConstr) { SmallVector constraints; for (Value arg : args) { for (auto [i, value] : enumerate(valueToConstr)) { if (value == arg) { constraints.push_back(i); break; } } } return constraints; } std::unique_ptr IsOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { return std::make_unique(getExpectedAttr()); } std::unique_ptr BaseOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { MLIRContext *ctx = getContext(); // Case where the input is a symbol reference. // This corresponds to the case where the base is an IRDL type or attribute. if (auto baseRef = getBaseRef()) { Operation *defOp = SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value()); // Type case. if (auto typeOp = dyn_cast(defOp)) { DynamicTypeDefinition *typeDef = types.at(typeOp).get(); auto name = StringAttr::get(ctx, typeDef->getDialect()->getNamespace() + "." + typeDef->getName().str()); return std::make_unique(typeDef->getTypeID(), name); } // Attribute case. auto attrOp = cast(defOp); DynamicAttrDefinition *attrDef = attrs.at(attrOp).get(); auto name = StringAttr::get(ctx, attrDef->getDialect()->getNamespace() + "." + attrDef->getName().str()); return std::make_unique(attrDef->getTypeID(), name); } // Case where the input is string literal. // This corresponds to the case where the base is a registered type or // attribute. StringRef baseName = getBaseName().value(); // Type case. if (baseName[0] == '!') { auto abstractType = AbstractType::lookup(baseName.drop_front(1), ctx); if (!abstractType) { emitError() << "no registered type with name " << baseName; return nullptr; } return std::make_unique(abstractType->get().getTypeID(), abstractType->get().getName()); } auto abstractAttr = AbstractAttribute::lookup(baseName.drop_front(1), ctx); if (!abstractAttr) { emitError() << "no registered attribute with name " << baseName; return nullptr; } return std::make_unique(abstractAttr->get().getTypeID(), abstractAttr->get().getName()); } std::unique_ptr ParametricOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { SmallVector constraints = getConstraintIndicesForArgs(getArgs(), valueToConstr); // Symbol reference case for the base SymbolRefAttr symRef = getBaseType(); Operation *defOp = SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef); if (!defOp) { emitError() << symRef << " does not refer to any existing symbol"; return nullptr; } if (auto typeOp = dyn_cast(defOp)) return std::make_unique(types.at(typeOp).get(), constraints); if (auto attrOp = dyn_cast(defOp)) return std::make_unique(attrs.at(attrOp).get(), constraints); llvm_unreachable("verifier should ensure that the referenced operation is " "either a type or an attribute definition"); } std::unique_ptr AnyOfOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { return std::make_unique( getConstraintIndicesForArgs(getArgs(), valueToConstr)); } std::unique_ptr AllOfOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { return std::make_unique( getConstraintIndicesForArgs(getArgs(), valueToConstr)); } std::unique_ptr AnyOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { return std::make_unique(); } std::unique_ptr RegionOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, DenseMap> const &attrs) { return std::make_unique( getConstrainedArguments() ? std::optional{getConstraintIndicesForArgs( getEntryBlockArgs(), valueToConstr)} : std::nullopt, getNumberOfBlocks()); }