608 lines
23 KiB
C++
608 lines
23 KiB
C++
|
//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "IRNumbering.h"
|
||
|
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||
|
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||
|
#include "mlir/Bytecode/Encoding.h"
|
||
|
#include "mlir/IR/AsmState.h"
|
||
|
#include "mlir/IR/BuiltinTypes.h"
|
||
|
#include "mlir/IR/OpDefinition.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::bytecode::detail;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// NumberingDialectWriter
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
|
||
|
NumberingDialectWriter(
|
||
|
IRNumberingState &state,
|
||
|
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
|
||
|
: state(state), dialectVersionMap(dialectVersionMap) {}
|
||
|
|
||
|
void writeAttribute(Attribute attr) override { state.number(attr); }
|
||
|
void writeOptionalAttribute(Attribute attr) override {
|
||
|
if (attr)
|
||
|
state.number(attr);
|
||
|
}
|
||
|
void writeType(Type type) override { state.number(type); }
|
||
|
void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
|
||
|
state.number(resource.getDialect(), resource);
|
||
|
}
|
||
|
|
||
|
/// Stubbed out methods that are not used for numbering.
|
||
|
void writeVarInt(uint64_t) override {}
|
||
|
void writeSignedVarInt(int64_t value) override {}
|
||
|
void writeAPIntWithKnownWidth(const APInt &value) override {}
|
||
|
void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
|
||
|
void writeOwnedString(StringRef) override {
|
||
|
// TODO: It might be nice to prenumber strings and sort by the number of
|
||
|
// references. This could potentially be useful for optimizing things like
|
||
|
// file locations.
|
||
|
}
|
||
|
void writeOwnedBlob(ArrayRef<char> blob) override {}
|
||
|
void writeOwnedBool(bool value) override {}
|
||
|
|
||
|
int64_t getBytecodeVersion() const override {
|
||
|
return state.getDesiredBytecodeVersion();
|
||
|
}
|
||
|
|
||
|
FailureOr<const DialectVersion *>
|
||
|
getDialectVersion(StringRef dialectName) const override {
|
||
|
auto dialectEntry = dialectVersionMap.find(dialectName);
|
||
|
if (dialectEntry == dialectVersionMap.end())
|
||
|
return failure();
|
||
|
return dialectEntry->getValue().get();
|
||
|
}
|
||
|
|
||
|
/// The parent numbering state that is populated by this writer.
|
||
|
IRNumberingState &state;
|
||
|
|
||
|
/// A map containing dialect version information for each dialect to emit.
|
||
|
llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
|
||
|
};
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// IR Numbering
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Group and sort the elements of the given range by their parent dialect. This
|
||
|
/// grouping is applied to sub-sections of the ranged defined by how many bytes
|
||
|
/// it takes to encode a varint index to that sub-section.
|
||
|
template <typename T>
|
||
|
static void groupByDialectPerByte(T range) {
|
||
|
if (range.empty())
|
||
|
return;
|
||
|
|
||
|
// A functor used to sort by a given dialect, with a desired dialect to be
|
||
|
// ordered first (to better enable sharing of dialects across byte groups).
|
||
|
auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
|
||
|
const auto &rhs) {
|
||
|
if (lhs->dialect->number == dialectToOrderFirst)
|
||
|
return rhs->dialect->number != dialectToOrderFirst;
|
||
|
if (rhs->dialect->number == dialectToOrderFirst)
|
||
|
return false;
|
||
|
return lhs->dialect->number < rhs->dialect->number;
|
||
|
};
|
||
|
|
||
|
unsigned dialectToOrderFirst = 0;
|
||
|
size_t elementsInByteGroup = 0;
|
||
|
auto iterRange = range;
|
||
|
for (unsigned i = 1; i < 9; ++i) {
|
||
|
// Update the number of elements in the current byte grouping. Reminder
|
||
|
// that varint encodes 7-bits per byte, so that's how we compute the
|
||
|
// number of elements in each byte grouping.
|
||
|
elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
|
||
|
|
||
|
// Slice out the sub-set of elements that are in the current byte grouping
|
||
|
// to be sorted.
|
||
|
auto byteSubRange = iterRange.take_front(elementsInByteGroup);
|
||
|
iterRange = iterRange.drop_front(byteSubRange.size());
|
||
|
|
||
|
// Sort the sub range for this byte.
|
||
|
llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
|
||
|
return sortByDialect(dialectToOrderFirst, lhs, rhs);
|
||
|
});
|
||
|
|
||
|
// Update the dialect to order first to be the dialect at the end of the
|
||
|
// current grouping. This seeks to allow larger dialect groupings across
|
||
|
// byte boundaries.
|
||
|
dialectToOrderFirst = byteSubRange.back()->dialect->number;
|
||
|
|
||
|
// If the data range is now empty, we are done.
|
||
|
if (iterRange.empty())
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
// Assign the entry numbers based on the sort order.
|
||
|
for (auto [idx, value] : llvm::enumerate(range))
|
||
|
value->number = idx;
|
||
|
}
|
||
|
|
||
|
IRNumberingState::IRNumberingState(Operation *op,
|
||
|
const BytecodeWriterConfig &config)
|
||
|
: config(config) {
|
||
|
computeGlobalNumberingState(op);
|
||
|
|
||
|
// Number the root operation.
|
||
|
number(*op);
|
||
|
|
||
|
// A worklist of region contexts to number and the next value id before that
|
||
|
// region.
|
||
|
SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
|
||
|
|
||
|
// Functor to push the regions of the given operation onto the numbering
|
||
|
// context.
|
||
|
auto addOpRegionsToNumber = [&](Operation *op) {
|
||
|
MutableArrayRef<Region> regions = op->getRegions();
|
||
|
if (regions.empty())
|
||
|
return;
|
||
|
|
||
|
// Isolated regions don't share value numbers with their parent, so we can
|
||
|
// start numbering these regions at zero.
|
||
|
unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
|
||
|
for (Region ®ion : regions)
|
||
|
numberContext.emplace_back(®ion, opFirstValueID);
|
||
|
};
|
||
|
addOpRegionsToNumber(op);
|
||
|
|
||
|
// Iteratively process each of the nested regions.
|
||
|
while (!numberContext.empty()) {
|
||
|
Region *region;
|
||
|
std::tie(region, nextValueID) = numberContext.pop_back_val();
|
||
|
number(*region);
|
||
|
|
||
|
// Traverse into nested regions.
|
||
|
for (Operation &op : region->getOps())
|
||
|
addOpRegionsToNumber(&op);
|
||
|
}
|
||
|
|
||
|
// Number each of the dialects. For now this is just in the order they were
|
||
|
// found, given that the number of dialects on average is small enough to fit
|
||
|
// within a singly byte (128). If we ever have real world use cases that have
|
||
|
// a huge number of dialects, this could be made more intelligent.
|
||
|
for (auto [idx, dialect] : llvm::enumerate(dialects))
|
||
|
dialect.second->number = idx;
|
||
|
|
||
|
// Number each of the recorded components within each dialect.
|
||
|
|
||
|
// First sort by ref count so that the most referenced elements are first. We
|
||
|
// try to bias more heavily used elements to the front. This allows for more
|
||
|
// frequently referenced things to be encoded using smaller varints.
|
||
|
auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
|
||
|
return lhs->refCount > rhs->refCount;
|
||
|
};
|
||
|
llvm::stable_sort(orderedAttrs, sortByRefCountFn);
|
||
|
llvm::stable_sort(orderedOpNames, sortByRefCountFn);
|
||
|
llvm::stable_sort(orderedTypes, sortByRefCountFn);
|
||
|
|
||
|
// After that, we apply a secondary ordering based on the parent dialect. This
|
||
|
// ordering is applied to sub-sections of the element list defined by how many
|
||
|
// bytes it takes to encode a varint index to that sub-section. This allows
|
||
|
// for more efficiently encoding components of the same dialect (e.g. we only
|
||
|
// have to encode the dialect reference once).
|
||
|
groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs));
|
||
|
groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames));
|
||
|
groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes));
|
||
|
|
||
|
// Finalize the numbering of the dialect resources.
|
||
|
finalizeDialectResourceNumberings(op);
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
|
||
|
// A simple state struct tracking data used when walking operations.
|
||
|
struct StackState {
|
||
|
/// The operation currently being walked.
|
||
|
Operation *op;
|
||
|
|
||
|
/// The numbering of the operation.
|
||
|
OperationNumbering *numbering;
|
||
|
|
||
|
/// A flag indicating if the current state or one of its parents has
|
||
|
/// unresolved isolation status. This is tracked separately from the
|
||
|
/// isIsolatedFromAbove bit on `numbering` because we need to be able to
|
||
|
/// handle the given case:
|
||
|
/// top.op {
|
||
|
/// %value = ...
|
||
|
/// middle.op {
|
||
|
/// %value2 = ...
|
||
|
/// inner.op {
|
||
|
/// // Here we mark `inner.op` as not isolated. Note `middle.op`
|
||
|
/// // isn't known not isolated yet.
|
||
|
/// use.op %value2
|
||
|
///
|
||
|
/// // Here inner.op is already known to be non-isolated, but
|
||
|
/// // `middle.op` is now also discovered to be non-isolated.
|
||
|
/// use.op %value
|
||
|
/// }
|
||
|
/// }
|
||
|
/// }
|
||
|
bool hasUnresolvedIsolation;
|
||
|
};
|
||
|
|
||
|
// Compute a global operation ID numbering according to the pre-order walk of
|
||
|
// the IR. This is used as reference to construct use-list orders.
|
||
|
unsigned operationID = 0;
|
||
|
|
||
|
// Walk each of the operations within the IR, tracking a stack of operations
|
||
|
// as we recurse into nested regions. This walk method hooks in at two stages
|
||
|
// during the walk:
|
||
|
//
|
||
|
// BeforeAllRegions:
|
||
|
// Here we generate a numbering for the operation and push it onto the
|
||
|
// stack if it has regions. We also compute the isolation status of parent
|
||
|
// regions at this stage. This is done by checking the parent regions of
|
||
|
// operands used by the operation, and marking each region between the
|
||
|
// the operand region and the current as not isolated. See
|
||
|
// StackState::hasUnresolvedIsolation above for an example.
|
||
|
//
|
||
|
// AfterAllRegions:
|
||
|
// Here we pop the operation from the stack, and if it hasn't been marked
|
||
|
// as non-isolated, we mark it as so. A non-isolated use would have been
|
||
|
// found while walking the regions, so it is safe to mark the operation at
|
||
|
// this point.
|
||
|
//
|
||
|
SmallVector<StackState> opStack;
|
||
|
rootOp->walk([&](Operation *op, const WalkStage &stage) {
|
||
|
// After visiting all nested regions, we pop the operation from the stack.
|
||
|
if (op->getNumRegions() && stage.isAfterAllRegions()) {
|
||
|
// If no non-isolated uses were found, we can safely mark this operation
|
||
|
// as isolated from above.
|
||
|
OperationNumbering *numbering = opStack.pop_back_val().numbering;
|
||
|
if (!numbering->isIsolatedFromAbove.has_value())
|
||
|
numbering->isIsolatedFromAbove = true;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
// When visiting before nested regions, we process "IsolatedFromAbove"
|
||
|
// checks and compute the number for this operation.
|
||
|
if (!stage.isBeforeAllRegions())
|
||
|
return;
|
||
|
// Update the isolation status of parent regions if any have yet to be
|
||
|
// resolved.
|
||
|
if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
|
||
|
Region *parentRegion = op->getParentRegion();
|
||
|
for (Value operand : op->getOperands()) {
|
||
|
Region *operandRegion = operand.getParentRegion();
|
||
|
if (operandRegion == parentRegion)
|
||
|
continue;
|
||
|
// We've found a use of an operand outside of the current region,
|
||
|
// walk the operation stack searching for the parent operation,
|
||
|
// marking every region on the way as not isolated.
|
||
|
Operation *operandContainerOp = operandRegion->getParentOp();
|
||
|
auto it = std::find_if(
|
||
|
opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
|
||
|
// We only need to mark up to the container region, or the first
|
||
|
// that has an unresolved status.
|
||
|
return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
|
||
|
});
|
||
|
assert(it != opStack.rend() && "expected to find the container");
|
||
|
for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
|
||
|
// If we stopped at a region that knows its isolation status, we can
|
||
|
// stop updating the isolation status for the parent regions.
|
||
|
state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
|
||
|
state.numbering->isIsolatedFromAbove = false;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Compute the number for this op and push it onto the stack.
|
||
|
auto *numbering =
|
||
|
new (opAllocator.Allocate()) OperationNumbering(operationID++);
|
||
|
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
||
|
numbering->isIsolatedFromAbove = true;
|
||
|
operations.try_emplace(op, numbering);
|
||
|
if (op->getNumRegions()) {
|
||
|
opStack.emplace_back(StackState{
|
||
|
op, numbering, !numbering->isIsolatedFromAbove.has_value()});
|
||
|
}
|
||
|
});
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Attribute attr) {
|
||
|
auto it = attrs.insert({attr, nullptr});
|
||
|
if (!it.second) {
|
||
|
++it.first->second->refCount;
|
||
|
return;
|
||
|
}
|
||
|
auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
|
||
|
it.first->second = numbering;
|
||
|
orderedAttrs.push_back(numbering);
|
||
|
|
||
|
// Check for OpaqueAttr, which is a dialect-specific attribute that didn't
|
||
|
// have a registered dialect when it got created. We don't want to encode this
|
||
|
// as the builtin OpaqueAttr, we want to encode it as if the dialect was
|
||
|
// actually loaded.
|
||
|
if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
|
||
|
numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
|
||
|
return;
|
||
|
}
|
||
|
numbering->dialect = &numberDialect(&attr.getDialect());
|
||
|
|
||
|
// If this attribute will be emitted using the bytecode format, perform a
|
||
|
// dummy writing to number any nested components.
|
||
|
// TODO: We don't allow custom encodings for mutable attributes right now.
|
||
|
if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
|
||
|
// Try overriding emission with callbacks.
|
||
|
for (const auto &callback : config.getAttributeWriterCallbacks()) {
|
||
|
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
|
||
|
// The client has the ability to override the group name through the
|
||
|
// callback.
|
||
|
std::optional<StringRef> groupNameOverride;
|
||
|
if (succeeded(callback->write(attr, groupNameOverride, writer))) {
|
||
|
if (groupNameOverride.has_value())
|
||
|
numbering->dialect = &numberDialect(*groupNameOverride);
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (const auto *interface = numbering->dialect->interface) {
|
||
|
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
|
||
|
if (succeeded(interface->writeAttribute(attr, writer)))
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
// If this attribute will be emitted using the fallback, number the nested
|
||
|
// dialect resources. We don't number everything (e.g. no nested
|
||
|
// attributes/types), because we don't want to encode things we won't decode
|
||
|
// (the textual format can't really share much).
|
||
|
AsmState tempState(attr.getContext());
|
||
|
llvm::raw_null_ostream dummyOS;
|
||
|
attr.print(dummyOS, tempState);
|
||
|
|
||
|
// Number the used dialect resources.
|
||
|
for (const auto &it : tempState.getDialectResources())
|
||
|
number(it.getFirst(), it.getSecond().getArrayRef());
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Block &block) {
|
||
|
// Number the arguments of the block.
|
||
|
for (BlockArgument arg : block.getArguments()) {
|
||
|
valueIDs.try_emplace(arg, nextValueID++);
|
||
|
number(arg.getLoc());
|
||
|
number(arg.getType());
|
||
|
}
|
||
|
|
||
|
// Number the operations in this block.
|
||
|
unsigned &numOps = blockOperationCounts[&block];
|
||
|
for (Operation &op : block) {
|
||
|
number(op);
|
||
|
++numOps;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
|
||
|
DialectNumbering *&numbering = registeredDialects[dialect];
|
||
|
if (!numbering) {
|
||
|
numbering = &numberDialect(dialect->getNamespace());
|
||
|
numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
|
||
|
numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
|
||
|
}
|
||
|
return *numbering;
|
||
|
}
|
||
|
|
||
|
auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
|
||
|
DialectNumbering *&numbering = dialects[dialect];
|
||
|
if (!numbering) {
|
||
|
numbering = new (dialectAllocator.Allocate())
|
||
|
DialectNumbering(dialect, dialects.size() - 1);
|
||
|
}
|
||
|
return *numbering;
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Region ®ion) {
|
||
|
if (region.empty())
|
||
|
return;
|
||
|
size_t firstValueID = nextValueID;
|
||
|
|
||
|
// Number the blocks within this region.
|
||
|
size_t blockCount = 0;
|
||
|
for (auto it : llvm::enumerate(region)) {
|
||
|
blockIDs.try_emplace(&it.value(), it.index());
|
||
|
number(it.value());
|
||
|
++blockCount;
|
||
|
}
|
||
|
|
||
|
// Remember the number of blocks and values in this region.
|
||
|
regionBlockValueCounts.try_emplace(®ion, blockCount,
|
||
|
nextValueID - firstValueID);
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Operation &op) {
|
||
|
// Number the components of an operation that won't be numbered elsewhere
|
||
|
// (e.g. we don't number operands, regions, or successors here).
|
||
|
number(op.getName());
|
||
|
for (OpResult result : op.getResults()) {
|
||
|
valueIDs.try_emplace(result, nextValueID++);
|
||
|
number(result.getType());
|
||
|
}
|
||
|
|
||
|
// Only number the operation's dictionary if it isn't empty.
|
||
|
DictionaryAttr dictAttr = op.getDiscardableAttrDictionary();
|
||
|
// Prior to a version with native property encoding, or when properties are
|
||
|
// not used, we need to number also the merged dictionary containing both the
|
||
|
// inherent and discardable attribute.
|
||
|
if (config.getDesiredBytecodeVersion() <
|
||
|
bytecode::kNativePropertiesEncoding ||
|
||
|
!op.getPropertiesStorage()) {
|
||
|
dictAttr = op.getAttrDictionary();
|
||
|
}
|
||
|
if (!dictAttr.empty())
|
||
|
number(dictAttr);
|
||
|
|
||
|
// Visit the operation properties (if any) to make sure referenced attributes
|
||
|
// are numbered.
|
||
|
if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding &&
|
||
|
op.getPropertiesStorageSize()) {
|
||
|
if (op.isRegistered()) {
|
||
|
// Operation that have properties *must* implement this interface.
|
||
|
auto iface = cast<BytecodeOpInterface>(op);
|
||
|
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
|
||
|
iface.writeProperties(writer);
|
||
|
} else {
|
||
|
// Unregistered op are storing properties as an optional attribute.
|
||
|
if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
|
||
|
number(prop);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
number(op.getLoc());
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(OperationName opName) {
|
||
|
OpNameNumbering *&numbering = opNames[opName];
|
||
|
if (numbering) {
|
||
|
++numbering->refCount;
|
||
|
return;
|
||
|
}
|
||
|
DialectNumbering *dialectNumber = nullptr;
|
||
|
if (Dialect *dialect = opName.getDialect())
|
||
|
dialectNumber = &numberDialect(dialect);
|
||
|
else
|
||
|
dialectNumber = &numberDialect(opName.getDialectNamespace());
|
||
|
|
||
|
numbering =
|
||
|
new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
|
||
|
orderedOpNames.push_back(numbering);
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Type type) {
|
||
|
auto it = types.insert({type, nullptr});
|
||
|
if (!it.second) {
|
||
|
++it.first->second->refCount;
|
||
|
return;
|
||
|
}
|
||
|
auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
|
||
|
it.first->second = numbering;
|
||
|
orderedTypes.push_back(numbering);
|
||
|
|
||
|
// Check for OpaqueType, which is a dialect-specific type that didn't have a
|
||
|
// registered dialect when it got created. We don't want to encode this as the
|
||
|
// builtin OpaqueType, we want to encode it as if the dialect was actually
|
||
|
// loaded.
|
||
|
if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
|
||
|
numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
|
||
|
return;
|
||
|
}
|
||
|
numbering->dialect = &numberDialect(&type.getDialect());
|
||
|
|
||
|
// If this type will be emitted using the bytecode format, perform a dummy
|
||
|
// writing to number any nested components.
|
||
|
// TODO: We don't allow custom encodings for mutable types right now.
|
||
|
if (!type.hasTrait<TypeTrait::IsMutable>()) {
|
||
|
// Try overriding emission with callbacks.
|
||
|
for (const auto &callback : config.getTypeWriterCallbacks()) {
|
||
|
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
|
||
|
// The client has the ability to override the group name through the
|
||
|
// callback.
|
||
|
std::optional<StringRef> groupNameOverride;
|
||
|
if (succeeded(callback->write(type, groupNameOverride, writer))) {
|
||
|
if (groupNameOverride.has_value())
|
||
|
numbering->dialect = &numberDialect(*groupNameOverride);
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If this attribute will be emitted using the bytecode format, perform a
|
||
|
// dummy writing to number any nested components.
|
||
|
if (const auto *interface = numbering->dialect->interface) {
|
||
|
NumberingDialectWriter writer(*this, config.getDialectVersionMap());
|
||
|
if (succeeded(interface->writeType(type, writer)))
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
// If this type will be emitted using the fallback, number the nested dialect
|
||
|
// resources. We don't number everything (e.g. no nested attributes/types),
|
||
|
// because we don't want to encode things we won't decode (the textual format
|
||
|
// can't really share much).
|
||
|
AsmState tempState(type.getContext());
|
||
|
llvm::raw_null_ostream dummyOS;
|
||
|
type.print(dummyOS, tempState);
|
||
|
|
||
|
// Number the used dialect resources.
|
||
|
for (const auto &it : tempState.getDialectResources())
|
||
|
number(it.getFirst(), it.getSecond().getArrayRef());
|
||
|
}
|
||
|
|
||
|
void IRNumberingState::number(Dialect *dialect,
|
||
|
ArrayRef<AsmDialectResourceHandle> resources) {
|
||
|
DialectNumbering &dialectNumber = numberDialect(dialect);
|
||
|
assert(
|
||
|
dialectNumber.asmInterface &&
|
||
|
"expected dialect owning a resource to implement OpAsmDialectInterface");
|
||
|
|
||
|
for (const auto &resource : resources) {
|
||
|
// Check if this is a newly seen resource.
|
||
|
if (!dialectNumber.resources.insert(resource))
|
||
|
return;
|
||
|
|
||
|
auto *numbering =
|
||
|
new (resourceAllocator.Allocate()) DialectResourceNumbering(
|
||
|
dialectNumber.asmInterface->getResourceKey(resource));
|
||
|
dialectNumber.resourceMap.insert({numbering->key, numbering});
|
||
|
dialectResources.try_emplace(resource, numbering);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int64_t IRNumberingState::getDesiredBytecodeVersion() const {
|
||
|
return config.getDesiredBytecodeVersion();
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
/// A dummy resource builder used to number dialect resources.
|
||
|
struct NumberingResourceBuilder : public AsmResourceBuilder {
|
||
|
NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
|
||
|
: dialect(dialect), nextResourceID(nextResourceID) {}
|
||
|
~NumberingResourceBuilder() override = default;
|
||
|
|
||
|
void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
|
||
|
numberEntry(key);
|
||
|
}
|
||
|
void buildBool(StringRef key, bool) final { numberEntry(key); }
|
||
|
void buildString(StringRef key, StringRef) final {
|
||
|
// TODO: We could pre-number the value string here as well.
|
||
|
numberEntry(key);
|
||
|
}
|
||
|
|
||
|
/// Number the dialect entry for the given key.
|
||
|
void numberEntry(StringRef key) {
|
||
|
// TODO: We could pre-number resource key strings here as well.
|
||
|
|
||
|
auto *it = dialect->resourceMap.find(key);
|
||
|
if (it != dialect->resourceMap.end()) {
|
||
|
it->second->number = nextResourceID++;
|
||
|
it->second->isDeclaration = false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
DialectNumbering *dialect;
|
||
|
unsigned &nextResourceID;
|
||
|
};
|
||
|
} // namespace
|
||
|
|
||
|
void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
|
||
|
unsigned nextResourceID = 0;
|
||
|
for (DialectNumbering &dialect : getDialects()) {
|
||
|
if (!dialect.asmInterface)
|
||
|
continue;
|
||
|
NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
|
||
|
dialect.asmInterface->buildResources(rootOp, dialect.resources,
|
||
|
entryBuilder);
|
||
|
|
||
|
// Number any resources that weren't added by the dialect. This can happen
|
||
|
// if there was no backing data to the resource, but we still want these
|
||
|
// resource references to roundtrip, so we number them and indicate that the
|
||
|
// data is missing.
|
||
|
for (const auto &it : dialect.resourceMap)
|
||
|
if (it.second->isDeclaration)
|
||
|
it.second->number = nextResourceID++;
|
||
|
}
|
||
|
}
|