bolt/deps/llvm-18.1.8/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp
2025-02-14 19:21:04 +01:00

529 lines
20 KiB
C++

//===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "LoopAnnotationImporter.h"
#include "llvm/IR/Constants.h"
using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
namespace {
/// Helper class that keeps the state of one metadata to attribute conversion.
struct LoopMetadataConversion {
LoopMetadataConversion(const llvm::MDNode *node, Location loc,
LoopAnnotationImporter &loopAnnotationImporter)
: node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
ctx(loc->getContext()){};
/// Converts this structs loop metadata node into a LoopAnnotationAttr.
LoopAnnotationAttr convert();
/// Initializes the shared state for the conversion member functions.
LogicalResult initConversionState();
/// Helper function to get and erase a property.
const llvm::MDNode *lookupAndEraseProperty(StringRef name);
/// Helper functions to lookup and convert MDNodes into a specifc attribute
/// kind. These functions return null-attributes if there is no node with the
/// specified name, or failure, if the node is ill-formatted.
FailureOr<BoolAttr> lookupUnitNode(StringRef name);
FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false);
FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
FailureOr<IntegerAttr> lookupIntNode(StringRef name);
FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
StringRef disableName,
bool negated = false);
/// Conversion functions for sub-attributes.
FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
FailureOr<LoopUnrollAttr> convertUnrollAttr();
FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
FailureOr<LoopLICMAttr> convertLICMAttr();
FailureOr<LoopDistributeAttr> convertDistributeAttr();
FailureOr<LoopPipelineAttr> convertPipelineAttr();
FailureOr<LoopPeeledAttr> convertPeeledAttr();
FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
FusedLoc convertStartLoc();
FailureOr<FusedLoc> convertEndLoc();
llvm::SmallVector<llvm::DILocation *, 2> locations;
llvm::StringMap<const llvm::MDNode *> propertyMap;
const llvm::MDNode *node;
Location loc;
LoopAnnotationImporter &loopAnnotationImporter;
MLIRContext *ctx;
};
} // namespace
LogicalResult LoopMetadataConversion::initConversionState() {
// Check if it's a valid node.
if (node->getNumOperands() == 0 ||
dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
return emitWarning(loc) << "invalid loop node";
for (const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
if (auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
locations.push_back(diLoc);
continue;
}
auto *property = dyn_cast<llvm::MDNode>(operand);
if (!property)
return emitWarning(loc) << "expected all loop properties to be either "
"debug locations or metadata nodes";
if (property->getNumOperands() == 0)
return emitWarning(loc) << "cannot import empty loop property";
auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
if (!nameNode)
return emitWarning(loc) << "cannot import loop property without a name";
StringRef name = nameNode->getString();
bool succ = propertyMap.try_emplace(name, property).second;
if (!succ)
return emitWarning(loc)
<< "cannot import loop properties with duplicated names " << name;
}
return success();
}
const llvm::MDNode *
LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
auto it = propertyMap.find(name);
if (it == propertyMap.end())
return nullptr;
const llvm::MDNode *property = it->getValue();
propertyMap.erase(it);
return property;
}
FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
if (!property)
return BoolAttr(nullptr);
if (property->getNumOperands() != 1)
return emitWarning(loc)
<< "expected metadata node " << name << " to hold no value";
return BoolAttr::get(ctx, true);
}
FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
StringRef enableName, StringRef disableName, bool negated) {
auto enable = lookupUnitNode(enableName);
auto disable = lookupUnitNode(disableName);
if (failed(enable) || failed(disable))
return failure();
if (*enable && *disable)
return emitWarning(loc)
<< "expected metadata nodes " << enableName << " and " << disableName
<< " to be mutually exclusive.";
if (*enable)
return BoolAttr::get(ctx, !negated);
if (*disable)
return BoolAttr::get(ctx, negated);
return BoolAttr(nullptr);
}
FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
bool negated) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
if (!property)
return BoolAttr(nullptr);
auto emitNodeWarning = [&]() {
return emitWarning(loc)
<< "expected metadata node " << name << " to hold a boolean value";
};
if (property->getNumOperands() != 2)
return emitNodeWarning();
llvm::ConstantInt *val =
llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
if (!val || val->getBitWidth() != 1)
return emitNodeWarning();
return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
}
FailureOr<BoolAttr>
LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
if (!property)
return BoolAttr(nullptr);
auto emitNodeWarning = [&]() {
return emitWarning(loc)
<< "expected metadata node " << name << " to hold an integer value";
};
if (property->getNumOperands() != 2)
return emitNodeWarning();
llvm::ConstantInt *val =
llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
if (!val || val->getBitWidth() != 32)
return emitNodeWarning();
return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
}
FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
if (!property)
return IntegerAttr(nullptr);
auto emitNodeWarning = [&]() {
return emitWarning(loc)
<< "expected metadata node " << name << " to hold an i32 value";
};
if (property->getNumOperands() != 2)
return emitNodeWarning();
llvm::ConstantInt *val =
llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
if (!val || val->getBitWidth() != 32)
return emitNodeWarning();
return IntegerAttr::get(IntegerType::get(ctx, 32),
val->getValue().getLimitedValue());
}
FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
if (!property)
return nullptr;
auto emitNodeWarning = [&]() {
return emitWarning(loc)
<< "expected metadata node " << name << " to hold an MDNode";
};
if (property->getNumOperands() != 2)
return emitNodeWarning();
auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
if (!node)
return emitNodeWarning();
return node;
}
FailureOr<SmallVector<llvm::MDNode *>>
LoopMetadataConversion::lookupMDNodes(StringRef name) {
const llvm::MDNode *property = lookupAndEraseProperty(name);
SmallVector<llvm::MDNode *> res;
if (!property)
return res;
auto emitNodeWarning = [&]() {
return emitWarning(loc) << "expected metadata node " << name
<< " to hold one or multiple MDNodes";
};
if (property->getNumOperands() < 2)
return emitNodeWarning();
for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
if (!node)
return emitNodeWarning();
res.push_back(node);
}
return res;
}
FailureOr<LoopAnnotationAttr>
LoopMetadataConversion::lookupFollowupNode(StringRef name) {
auto node = lookupMDNode(name);
if (failed(node))
return failure();
if (*node == nullptr)
return LoopAnnotationAttr(nullptr);
return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
}
static bool isEmptyOrNull(const Attribute attr) { return !attr; }
template <typename T>
static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) {
return vec.empty();
}
/// Helper function that only creates and attribute of type T if all argument
/// conversion were successfull and at least one of them holds a non-null value.
template <typename T, typename... P>
static T createIfNonNull(MLIRContext *ctx, const P &...args) {
bool anyFailed = (failed(args) || ...);
if (anyFailed)
return {};
bool allEmpty = (isEmptyOrNull(*args) && ...);
if (allEmpty)
return {};
return T::get(ctx, *args...);
}
FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
FailureOr<BoolAttr> enable =
lookupBoolNode("llvm.loop.vectorize.enable", true);
FailureOr<BoolAttr> predicateEnable =
lookupBoolNode("llvm.loop.vectorize.predicate.enable");
FailureOr<BoolAttr> scalableEnable =
lookupBoolNode("llvm.loop.vectorize.scalable.enable");
FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width");
FailureOr<LoopAnnotationAttr> followupVec =
lookupFollowupNode("llvm.loop.vectorize.followup_vectorized");
FailureOr<LoopAnnotationAttr> followupEpi =
lookupFollowupNode("llvm.loop.vectorize.followup_epilogue");
FailureOr<LoopAnnotationAttr> followupAll =
lookupFollowupNode("llvm.loop.vectorize.followup_all");
return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
scalableEnable, width, followupVec,
followupEpi, followupAll);
}
FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count");
return createIfNonNull<LoopInterleaveAttr>(ctx, count);
}
FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
"llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true);
FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count");
FailureOr<BoolAttr> runtimeDisable =
lookupUnitNode("llvm.loop.unroll.runtime.disable");
FailureOr<BoolAttr> full = lookupUnitNode("llvm.loop.unroll.full");
FailureOr<LoopAnnotationAttr> followupUnrolled =
lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
FailureOr<LoopAnnotationAttr> followupRemainder =
lookupFollowupNode("llvm.loop.unroll.followup_remainder");
FailureOr<LoopAnnotationAttr> followupAll =
lookupFollowupNode("llvm.loop.unroll.followup_all");
return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
full, followupUnrolled,
followupRemainder, followupAll);
}
FailureOr<LoopUnrollAndJamAttr>
LoopMetadataConversion::convertUnrollAndJamAttr() {
FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
"llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable",
/*negated=*/true);
FailureOr<IntegerAttr> count =
lookupIntNode("llvm.loop.unroll_and_jam.count");
FailureOr<LoopAnnotationAttr> followupOuter =
lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer");
FailureOr<LoopAnnotationAttr> followupInner =
lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner");
FailureOr<LoopAnnotationAttr> followupRemainderOuter =
lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer");
FailureOr<LoopAnnotationAttr> followupRemainderInner =
lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner");
FailureOr<LoopAnnotationAttr> followupAll =
lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all");
return createIfNonNull<LoopUnrollAndJamAttr>(
ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
followupRemainderInner, followupAll);
}
FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
FailureOr<BoolAttr> disable = lookupUnitNode("llvm.licm.disable");
FailureOr<BoolAttr> versioningDisable =
lookupUnitNode("llvm.loop.licm_versioning.disable");
return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
}
FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
FailureOr<BoolAttr> disable =
lookupBoolNode("llvm.loop.distribute.enable", true);
FailureOr<LoopAnnotationAttr> followupCoincident =
lookupFollowupNode("llvm.loop.distribute.followup_coincident");
FailureOr<LoopAnnotationAttr> followupSequential =
lookupFollowupNode("llvm.loop.distribute.followup_sequential");
FailureOr<LoopAnnotationAttr> followupFallback =
lookupFollowupNode("llvm.loop.distribute.followup_fallback");
FailureOr<LoopAnnotationAttr> followupAll =
lookupFollowupNode("llvm.loop.distribute.followup_all");
return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
followupSequential,
followupFallback, followupAll);
}
FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
FailureOr<BoolAttr> disable = lookupBoolNode("llvm.loop.pipeline.disable");
FailureOr<IntegerAttr> initiationinterval =
lookupIntNode("llvm.loop.pipeline.initiationinterval");
return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
}
FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count");
return createIfNonNull<LoopPeeledAttr>(ctx, count);
}
FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
FailureOr<BoolAttr> partialDisable =
lookupUnitNode("llvm.loop.unswitch.partial.disable");
return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
}
FailureOr<SmallVector<AccessGroupAttr>>
LoopMetadataConversion::convertParallelAccesses() {
FailureOr<SmallVector<llvm::MDNode *>> nodes =
lookupMDNodes("llvm.loop.parallel_accesses");
if (failed(nodes))
return failure();
SmallVector<AccessGroupAttr> refs;
for (llvm::MDNode *node : *nodes) {
FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
loopAnnotationImporter.lookupAccessGroupAttrs(node);
if (failed(accessGroups)) {
emitWarning(loc) << "could not lookup access group";
continue;
}
llvm::append_range(refs, *accessGroups);
}
return refs;
}
FusedLoc LoopMetadataConversion::convertStartLoc() {
if (locations.empty())
return {};
return dyn_cast<FusedLoc>(
loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
}
FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
if (locations.size() < 2)
return FusedLoc();
if (locations.size() > 2)
return emitError(loc)
<< "expected loop metadata to have at most two DILocations";
return dyn_cast<FusedLoc>(
loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
}
LoopAnnotationAttr LoopMetadataConversion::convert() {
if (failed(initConversionState()))
return {};
FailureOr<BoolAttr> disableNonForced =
lookupUnitNode("llvm.loop.disable_nonforced");
FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
FailureOr<BoolAttr> mustProgress = lookupUnitNode("llvm.loop.mustprogress");
FailureOr<BoolAttr> isVectorized =
lookupIntNodeAsBoolAttr("llvm.loop.isvectorized");
FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
convertParallelAccesses();
// Drop the metadata if there are parts that cannot be imported.
if (!propertyMap.empty()) {
for (auto name : propertyMap.keys())
emitWarning(loc) << "unknown loop annotation " << name;
return {};
}
FailureOr<FusedLoc> startLoc = convertStartLoc();
FailureOr<FusedLoc> endLoc = convertEndLoc();
return createIfNonNull<LoopAnnotationAttr>(
ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
parallelAccesses);
}
LoopAnnotationAttr
LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
Location loc) {
if (!node)
return {};
// Note: This check is necessary to distinguish between failed translations
// and not yet attempted translations.
auto it = loopMetadataMapping.find(node);
if (it != loopMetadataMapping.end())
return it->getSecond();
LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
mapLoopMetadata(node, attr);
return attr;
}
LogicalResult
LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node,
Location loc) {
SmallVector<const llvm::MDNode *> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(node);
for (const llvm::MDOperand &operand : node->operands()) {
auto *childNode = dyn_cast<llvm::MDNode>(operand);
if (!childNode)
return failure();
accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
}
// Convert all entries of the access group list to access group operations.
for (const llvm::MDNode *accessGroup : accessGroups) {
if (accessGroupMapping.count(accessGroup))
continue;
// Verify the access group node is distinct and empty.
if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
return emitWarning(loc)
<< "expected an access group node to be empty and distinct";
// Add a mapping from the access group node to the newly created attribute.
accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>();
}
return success();
}
FailureOr<SmallVector<AccessGroupAttr>>
LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
// An access group node is either a single access group or an access group
// list.
SmallVector<AccessGroupAttr> accessGroups;
if (!node->getNumOperands())
accessGroups.push_back(accessGroupMapping.lookup(node));
for (const llvm::MDOperand &operand : node->operands()) {
auto *node = cast<llvm::MDNode>(operand.get());
accessGroups.push_back(accessGroupMapping.lookup(node));
}
// Exit if one of the access group node lookups failed.
if (llvm::is_contained(accessGroups, nullptr))
return failure();
return accessGroups;
}