//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains various utilities that number IR structures in preparation // for bytecode emission. // //===----------------------------------------------------------------------===// #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" #include namespace mlir { class BytecodeDialectInterface; class BytecodeWriterConfig; namespace bytecode { namespace detail { struct DialectNumbering; //===----------------------------------------------------------------------===// // Attribute and Type Numbering //===----------------------------------------------------------------------===// /// This class represents a numbering entry for an Attribute or Type. struct AttrTypeNumbering { AttrTypeNumbering(PointerUnion value) : value(value) {} /// The concrete value. PointerUnion value; /// The number assigned to this value. unsigned number = 0; /// The number of references to this value. unsigned refCount = 1; /// The dialect of this value. DialectNumbering *dialect = nullptr; }; struct AttributeNumbering : public AttrTypeNumbering { AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} Attribute getValue() const { return value.get(); } }; struct TypeNumbering : public AttrTypeNumbering { TypeNumbering(Type value) : AttrTypeNumbering(value) {} Type getValue() const { return value.get(); } }; //===----------------------------------------------------------------------===// // OpName Numbering //===----------------------------------------------------------------------===// /// This class represents the numbering entry of an operation name. struct OpNameNumbering { OpNameNumbering(DialectNumbering *dialect, OperationName name) : dialect(dialect), name(name) {} /// The dialect of this value. DialectNumbering *dialect; /// The concrete name. OperationName name; /// The number assigned to this name. unsigned number = 0; /// The number of references to this name. unsigned refCount = 1; }; //===----------------------------------------------------------------------===// // Dialect Resource Numbering //===----------------------------------------------------------------------===// /// This class represents a numbering entry for a dialect resource. struct DialectResourceNumbering { DialectResourceNumbering(std::string key) : key(std::move(key)) {} /// The key used to reference this resource. std::string key; /// The number assigned to this resource. unsigned number = 0; /// A flag indicating if this resource is only a declaration, not a full /// definition. bool isDeclaration = true; }; //===----------------------------------------------------------------------===// // Dialect Numbering //===----------------------------------------------------------------------===// /// This class represents a numbering entry for an Dialect. struct DialectNumbering { DialectNumbering(StringRef name, unsigned number) : name(name), number(number) {} /// The namespace of the dialect. StringRef name; /// The number assigned to the dialect. unsigned number; /// The bytecode dialect interface of the dialect if defined. const BytecodeDialectInterface *interface = nullptr; /// The asm dialect interface of the dialect if defined. const OpAsmDialectInterface *asmInterface = nullptr; /// The referenced resources of this dialect. SetVector resources; /// A mapping from resource key to the corresponding resource numbering entry. llvm::MapVector resourceMap; }; //===----------------------------------------------------------------------===// // Operation Numbering //===----------------------------------------------------------------------===// /// This class represents the numbering entry of an operation. struct OperationNumbering { OperationNumbering(unsigned number) : number(number) {} /// The number assigned to this operation. unsigned number; /// A flag indicating if this operation's regions are isolated. If unset, the /// operation isn't yet known to be isolated. std::optional isIsolatedFromAbove; }; //===----------------------------------------------------------------------===// // IRNumberingState //===----------------------------------------------------------------------===// /// This class manages numbering IR entities in preparation of bytecode /// emission. class IRNumberingState { public: IRNumberingState(Operation *op, const BytecodeWriterConfig &config); /// Return the numbered dialects. auto getDialects() { return llvm::make_pointee_range(llvm::make_second_range(dialects)); } auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); } auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); } auto getTypes() { return llvm::make_pointee_range(orderedTypes); } /// Return the number for the given IR unit. unsigned getNumber(Attribute attr) { assert(attrs.count(attr) && "attribute not numbered"); return attrs[attr]->number; } unsigned getNumber(Block *block) { assert(blockIDs.count(block) && "block not numbered"); return blockIDs[block]; } unsigned getNumber(Operation *op) { assert(operations.count(op) && "operation not numbered"); return operations[op]->number; } unsigned getNumber(OperationName opName) { assert(opNames.count(opName) && "opName not numbered"); return opNames[opName]->number; } unsigned getNumber(Type type) { assert(types.count(type) && "type not numbered"); return types[type]->number; } unsigned getNumber(Value value) { assert(valueIDs.count(value) && "value not numbered"); return valueIDs[value]; } unsigned getNumber(const AsmDialectResourceHandle &resource) { assert(dialectResources.count(resource) && "resource not numbered"); return dialectResources[resource]->number; } /// Return the block and value counts of the given region. std::pair getBlockValueCount(Region *region) { assert(regionBlockValueCounts.count(region) && "value not numbered"); return regionBlockValueCounts[region]; } /// Return the number of operations in the given block. unsigned getOperationCount(Block *block) { assert(blockOperationCounts.count(block) && "block not numbered"); return blockOperationCounts[block]; } /// Return if the given operation is isolated from above. bool isIsolatedFromAbove(Operation *op) { assert(operations.count(op) && "operation not numbered"); return operations[op]->isIsolatedFromAbove.value_or(false); } /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; private: /// This class is used to provide a fake dialect writer for numbering nested /// attributes and types. struct NumberingDialectWriter; /// Compute the global numbering state for the given root operation. void computeGlobalNumberingState(Operation *rootOp); /// Number the given IR unit for bytecode emission. void number(Attribute attr); void number(Block &block); DialectNumbering &numberDialect(Dialect *dialect); DialectNumbering &numberDialect(StringRef dialect); void number(Operation &op); void number(OperationName opName); void number(Region ®ion); void number(Type type); /// Number the given dialect resources. void number(Dialect *dialect, ArrayRef resources); /// Finalize the numberings of any dialect resources. void finalizeDialectResourceNumberings(Operation *rootOp); /// Mapping from IR to the respective numbering entries. DenseMap attrs; DenseMap operations; DenseMap opNames; DenseMap types; DenseMap registeredDialects; llvm::MapVector dialects; std::vector orderedAttrs; std::vector orderedOpNames; std::vector orderedTypes; /// A mapping from dialect resource handle to the numbering for the referenced /// resource. llvm::DenseMap dialectResources; /// Allocators used for the various numbering entries. llvm::SpecificBumpPtrAllocator attrAllocator; llvm::SpecificBumpPtrAllocator dialectAllocator; llvm::SpecificBumpPtrAllocator opAllocator; llvm::SpecificBumpPtrAllocator opNameAllocator; llvm::SpecificBumpPtrAllocator resourceAllocator; llvm::SpecificBumpPtrAllocator typeAllocator; /// The value ID for each Block and Value. DenseMap blockIDs; DenseMap valueIDs; /// The number of operations in each block. DenseMap blockOperationCounts; /// A map from region to the number of blocks and values within that region. DenseMap> regionBlockValueCounts; /// The next value ID to assign when numbering. unsigned nextValueID = 0; // Configuration: useful to query the required version to emit. const BytecodeWriterConfig &config; }; } // namespace detail } // namespace bytecode } // namespace mlir #endif