//===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions for "predicates" used when converting PDL into // a matcher tree. Predicates are composed of three different parts: // // * Positions // - A position refers to a specific location on the input DAG, i.e. an // existing MLIR entity being matched. These can be attributes, operands, // operations, results, and types. Each position also defines a relation to // its parent. For example, the operand `[0] -> 1` has a parent operation // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position // without a parent is `[0]`, which refers to the root operation. // * Questions // - A question refers to a query on a specific positional value. For // example, an operation name question checks the name of an operation // position. // * Answers // - An answer is the expected result of a question. For example, when // matching an operation with the name "foo.op". The question would be an // operation name question, with an expected answer of "foo.op". // //===----------------------------------------------------------------------===// #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" namespace mlir { namespace pdl_to_pdl_interp { namespace Predicates { /// An enumeration of the kinds of predicates. enum Kind : unsigned { /// Positions, ordered by decreasing priority. OperationPos, OperandPos, OperandGroupPos, AttributePos, ResultPos, ResultGroupPos, TypePos, AttributeLiteralPos, TypeLiteralPos, UsersPos, ForEachPos, // Questions, ordered by dependency and decreasing priority. IsNotNullQuestion, OperationNameQuestion, TypeQuestion, AttributeQuestion, OperandCountAtLeastQuestion, OperandCountQuestion, ResultCountAtLeastQuestion, ResultCountQuestion, EqualToQuestion, ConstraintQuestion, // Answers. AttributeAnswer, FalseAnswer, OperationNameAnswer, TrueAnswer, TypeAnswer, UnsignedAnswer, }; } // namespace Predicates /// Base class for all predicates, used to allow efficient pointer comparison. template class PredicateBase : public BaseT { public: using KeyTy = Key; using Base = PredicateBase; template explicit PredicateBase(KeyT &&key) : BaseT(Kind), key(std::forward(key)) {} /// Get an instance of this position. template static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { return uniquer.get(/*initFn=*/{}, std::forward(args)...); } /// Construct an instance with the given storage allocator. template static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, KeyT &&key) { return new (alloc.allocate()) ConcreteT(std::forward(key)); } /// Utility methods required by the storage allocator. bool operator==(const KeyTy &key) const { return this->key == key; } static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } /// Return the key value of this predicate. const KeyTy &getValue() const { return key; } protected: KeyTy key; }; /// Base storage for simple predicates that only unique with the kind. template class PredicateBase : public BaseT { public: using Base = PredicateBase; explicit PredicateBase() : BaseT(Kind) {} static ConcreteT *get(StorageUniquer &uniquer) { return uniquer.get(); } static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } }; //===----------------------------------------------------------------------===// // Positions //===----------------------------------------------------------------------===// struct OperationPosition; /// A position describes a value on the input IR on which a predicate may be /// applied, such as an operation or attribute. This enables re-use between /// predicates, and assists generating bytecode and memory management. /// /// Operation positions form the base of other positions, which are formed /// relative to a parent operation. Operations are anchored at Operand nodes, /// except for the root operation which is parentless. class Position : public StorageUniquer::BaseStorage { public: explicit Position(Predicates::Kind kind) : kind(kind) {} virtual ~Position(); /// Returns the depth of the first ancestor operation position. unsigned getOperationDepth() const; /// Returns the parent position. The root operation position has no parent. Position *getParent() const { return parent; } /// Returns the kind of this position. Predicates::Kind getKind() const { return kind; } protected: /// Link to the parent position. Position *parent = nullptr; private: /// The kind of this position. Predicates::Kind kind; }; //===----------------------------------------------------------------------===// // AttributePosition /// A position describing an attribute of an operation. struct AttributePosition : public PredicateBase, Predicates::AttributePos> { explicit AttributePosition(const KeyTy &key); /// Returns the attribute name of this position. StringAttr getName() const { return key.second; } }; //===----------------------------------------------------------------------===// // AttributeLiteralPosition /// A position describing a literal attribute. struct AttributeLiteralPosition : public PredicateBase { using PredicateBase::PredicateBase; }; //===----------------------------------------------------------------------===// // ForEachPosition /// A position describing an iterative choice of an operation. struct ForEachPosition : public PredicateBase, Predicates::ForEachPos> { explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } /// Returns the ID, for differentiating various loops. /// For upward traversals, this is the index of the root. unsigned getID() const { return key.second; } }; //===----------------------------------------------------------------------===// // OperandPosition /// A position describing an operand of an operation. struct OperandPosition : public PredicateBase, Predicates::OperandPos> { explicit OperandPosition(const KeyTy &key); /// Returns the operand number of this position. unsigned getOperandNumber() const { return key.second; } }; //===----------------------------------------------------------------------===// // OperandGroupPosition /// A position describing an operand group of an operation. struct OperandGroupPosition : public PredicateBase< OperandGroupPosition, Position, std::tuple, bool>, Predicates::OperandGroupPos> { explicit OperandGroupPosition(const KeyTy &key); /// Returns a hash suitable for the given keytype. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } /// Returns the group number of this position. If std::nullopt, this group /// refers to all operands. std::optional getOperandGroupNumber() const { return std::get<1>(key); } /// Returns if the operand group has unknown size. If false, the operand group /// has at max one element. bool isVariadic() const { return std::get<2>(key); } }; //===----------------------------------------------------------------------===// // OperationPosition /// An operation position describes an operation node in the IR. Other position /// kinds are formed with respect to an operation position. struct OperationPosition : public PredicateBase, Predicates::OperationPos> { explicit OperationPosition(const KeyTy &key) : Base(key) { parent = key.first; } /// Returns a hash suitable for the given keytype. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } /// Gets the root position. static OperationPosition *getRoot(StorageUniquer &uniquer) { return Base::get(uniquer, nullptr, 0); } /// Gets an operation position with the given parent. static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { return Base::get(uniquer, parent, parent->getOperationDepth() + 1); } /// Returns the depth of this position. unsigned getDepth() const { return key.second; } /// Returns if this operation position corresponds to the root. bool isRoot() const { return getDepth() == 0; } /// Returns if this operation represents an operand defining op. bool isOperandDefiningOp() const; }; //===----------------------------------------------------------------------===// // ResultPosition /// A position describing a result of an operation. struct ResultPosition : public PredicateBase, Predicates::ResultPos> { explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } /// Returns the result number of this position. unsigned getResultNumber() const { return key.second; } }; //===----------------------------------------------------------------------===// // ResultGroupPosition /// A position describing a result group of an operation. struct ResultGroupPosition : public PredicateBase< ResultGroupPosition, Position, std::tuple, bool>, Predicates::ResultGroupPos> { explicit ResultGroupPosition(const KeyTy &key) : Base(key) { parent = std::get<0>(key); } /// Returns a hash suitable for the given keytype. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } /// Returns the group number of this position. If std::nullopt, this group /// refers to all results. std::optional getResultGroupNumber() const { return std::get<1>(key); } /// Returns if the result group has unknown size. If false, the result group /// has at max one element. bool isVariadic() const { return std::get<2>(key); } }; //===----------------------------------------------------------------------===// // TypePosition /// A position describing the result type of an entity, i.e. an Attribute, /// Operand, Result, etc. struct TypePosition : public PredicateBase { explicit TypePosition(const KeyTy &key) : Base(key) { assert((isa(key)) && "expected parent to be an attribute, operand, or result"); parent = key; } }; //===----------------------------------------------------------------------===// // TypeLiteralPosition /// A position describing a literal type or type range. The value is stored as /// either a TypeAttr, or an ArrayAttr of TypeAttr. struct TypeLiteralPosition : public PredicateBase { using PredicateBase::PredicateBase; }; //===----------------------------------------------------------------------===// // UsersPosition /// A position describing the users of a value or a range of values. The second /// value in the key indicates whether we choose users of a representative for /// a range (this is true, e.g., in the upward traversals). struct UsersPosition : public PredicateBase, Predicates::UsersPos> { explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } /// Returns a hash suitable for the given keytype. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } /// Indicates whether to compute a range of a representative. bool useRepresentative() const { return key.second; } }; //===----------------------------------------------------------------------===// // Qualifiers //===----------------------------------------------------------------------===// /// An ordinal predicate consists of a "Question" and a set of acceptable /// "Answers" (later converted to ordinal values). A predicate will query some /// property of a positional value and decide what to do based on the result. /// /// This makes top-level predicate representations ordinal (SwitchOp). Later, /// predicates that end up with only one acceptable answer (including all /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the /// matcher. /// /// For simplicity, both are represented as "qualifiers", with a base kind and /// perhaps additional properties. For example, all OperationName predicates ask /// the same question, but GenericConstraint predicates may ask different ones. class Qualifier : public StorageUniquer::BaseStorage { public: explicit Qualifier(Predicates::Kind kind) : kind(kind) {} /// Returns the kind of this qualifier. Predicates::Kind getKind() const { return kind; } private: /// The kind of this position. Predicates::Kind kind; }; //===----------------------------------------------------------------------===// // Answers /// An Answer representing an `Attribute` value. struct AttributeAnswer : public PredicateBase { using Base::Base; }; /// An Answer representing an `OperationName` value. struct OperationNameAnswer : public PredicateBase { using Base::Base; }; /// An Answer representing a boolean `true` value. struct TrueAnswer : PredicateBase { using Base::Base; }; /// An Answer representing a boolean 'false' value. struct FalseAnswer : PredicateBase { using Base::Base; }; /// An Answer representing a `Type` value. The value is stored as either a /// TypeAttr, or an ArrayAttr of TypeAttr. struct TypeAnswer : public PredicateBase { using Base::Base; }; /// An Answer representing an unsigned value. struct UnsignedAnswer : public PredicateBase { using Base::Base; }; //===----------------------------------------------------------------------===// // Questions /// Compare an `Attribute` to a constant value. struct AttributeQuestion : public PredicateBase {}; /// Apply a parameterized constraint to multiple position values. struct ConstraintQuestion : public PredicateBase, bool>, Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. StringRef getName() const { return std::get<0>(key); } /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } /// Return the negation status of the constraint. bool getIsNegated() const { return std::get<2>(key); } /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), alloc.copyInto(std::get<1>(key)), std::get<2>(key)}); } /// Returns a hash suitable for the given keytype. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); } }; /// Compare the equality of two values. struct EqualToQuestion : public PredicateBase { using Base::Base; }; /// Compare a positional value with null, i.e. check if it exists. struct IsNotNullQuestion : public PredicateBase {}; /// Compare the number of operands of an operation with a known value. struct OperandCountQuestion : public PredicateBase {}; struct OperandCountAtLeastQuestion : public PredicateBase {}; /// Compare the name of an operation with a known value. struct OperationNameQuestion : public PredicateBase {}; /// Compare the number of results of an operation with a known value. struct ResultCountQuestion : public PredicateBase {}; struct ResultCountAtLeastQuestion : public PredicateBase {}; /// Compare the type of an attribute or value with a known type. struct TypeQuestion : public PredicateBase {}; //===----------------------------------------------------------------------===// // PredicateUniquer //===----------------------------------------------------------------------===// /// This class provides a storage uniquer that is used to allocate predicate /// instances. class PredicateUniquer : public StorageUniquer { public: PredicateUniquer() { // Register the types of Positions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); // Register the types of Questions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); // Register the types of Answers with the uniquer. registerParametricStorageType(); registerParametricStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); registerSingletonStorageType(); } }; //===----------------------------------------------------------------------===// // PredicateBuilder //===----------------------------------------------------------------------===// /// This class provides utilities for constructing predicates. class PredicateBuilder { public: PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) : uniquer(uniquer), ctx(ctx) {} //===--------------------------------------------------------------------===// // Positions //===--------------------------------------------------------------------===// /// Returns the root operation position. Position *getRoot() { return OperationPosition::getRoot(uniquer); } /// Returns the parent position defining the value held by the given operand. OperationPosition *getOperandDefiningOp(Position *p) { assert((isa(p)) && "expected operand position"); return OperationPosition::get(uniquer, p); } /// Returns the operation position equivalent to the given position. OperationPosition *getPassthroughOp(Position *p) { assert((isa(p)) && "expected users position"); return OperationPosition::get(uniquer, p); } /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); } /// Returns an attribute position for the given attribute. Position *getAttributeLiteral(Attribute attr) { return AttributeLiteralPosition::get(uniquer, attr); } Position *getForEach(Position *p, unsigned id) { return ForEachPosition::get(uniquer, p, id); } /// Returns an operand position for an operand of the given operation. Position *getOperand(OperationPosition *p, unsigned operand) { return OperandPosition::get(uniquer, p, operand); } /// Returns a position for a group of operands of the given operation. Position *getOperandGroup(OperationPosition *p, std::optional group, bool isVariadic) { return OperandGroupPosition::get(uniquer, p, group, isVariadic); } Position *getAllOperands(OperationPosition *p) { return getOperandGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true); } /// Returns a result position for a result of the given operation. Position *getResult(OperationPosition *p, unsigned result) { return ResultPosition::get(uniquer, p, result); } /// Returns a position for a group of results of the given operation. Position *getResultGroup(OperationPosition *p, std::optional group, bool isVariadic) { return ResultGroupPosition::get(uniquer, p, group, isVariadic); } Position *getAllResults(OperationPosition *p) { return getResultGroup(p, /*group=*/std::nullopt, /*isVariadic=*/true); } /// Returns a type position for the given entity. Position *getType(Position *p) { return TypePosition::get(uniquer, p); } /// Returns a type position for the given type value. The value is stored /// as either a TypeAttr, or an ArrayAttr of TypeAttr. Position *getTypeLiteral(Attribute attr) { return TypeLiteralPosition::get(uniquer, attr); } /// Returns the users of a position using the value at the given operand. UsersPosition *getUsers(Position *p, bool useRepresentative) { assert((isa(p)) && "expected result position"); return UsersPosition::get(uniquer, p, useRepresentative); } //===--------------------------------------------------------------------===// // Qualifiers //===--------------------------------------------------------------------===// /// An ordinal predicate consists of a "Question" and a set of acceptable /// "Answers" (later converted to ordinal values). A predicate will query some /// property of a positional value and decide what to do based on the result. using Predicate = std::pair; /// Create a predicate comparing an attribute to a known value. Predicate getAttributeConstraint(Attribute attr) { return {AttributeQuestion::get(uniquer), AttributeAnswer::get(uniquer, attr)}; } /// Create a predicate checking if two values are equal. Predicate getEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; } /// Create a predicate checking if two values are not equal. Predicate getNotEqualTo(Position *pos) { return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; } /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef pos, bool isNegated) { return { ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)), TrueAnswer::get(uniquer)}; } /// Create a predicate comparing a value with null. Predicate getIsNotNull() { return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; } /// Create a predicate comparing the number of operands of an operation to a /// known value. Predicate getOperandCount(unsigned count) { return {OperandCountQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } Predicate getOperandCountAtLeast(unsigned count) { return {OperandCountAtLeastQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } /// Create a predicate comparing the name of an operation to a known value. Predicate getOperationName(StringRef name) { return {OperationNameQuestion::get(uniquer), OperationNameAnswer::get(uniquer, OperationName(name, ctx))}; } /// Create a predicate comparing the number of results of an operation to a /// known value. Predicate getResultCount(unsigned count) { return {ResultCountQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } Predicate getResultCountAtLeast(unsigned count) { return {ResultCountAtLeastQuestion::get(uniquer), UnsignedAnswer::get(uniquer, count)}; } /// Create a predicate comparing the type of an attribute or value to a known /// type. The value is stored as either a TypeAttr, or an ArrayAttr of /// TypeAttr. Predicate getTypeConstraint(Attribute type) { return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; } private: /// The uniquer used when allocating predicate nodes. PredicateUniquer &uniquer; /// The current MLIR context. MLIRContext *ctx; }; } // namespace pdl_to_pdl_interp } // namespace mlir #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_