//===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "LoopAnnotationImporter.h" #include "llvm/IR/Constants.h" using namespace mlir; using namespace mlir::LLVM; using namespace mlir::LLVM::detail; namespace { /// Helper class that keeps the state of one metadata to attribute conversion. struct LoopMetadataConversion { LoopMetadataConversion(const llvm::MDNode *node, Location loc, LoopAnnotationImporter &loopAnnotationImporter) : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), ctx(loc->getContext()){}; /// Converts this structs loop metadata node into a LoopAnnotationAttr. LoopAnnotationAttr convert(); /// Initializes the shared state for the conversion member functions. LogicalResult initConversionState(); /// Helper function to get and erase a property. const llvm::MDNode *lookupAndEraseProperty(StringRef name); /// Helper functions to lookup and convert MDNodes into a specifc attribute /// kind. These functions return null-attributes if there is no node with the /// specified name, or failure, if the node is ill-formatted. FailureOr lookupUnitNode(StringRef name); FailureOr lookupBoolNode(StringRef name, bool negated = false); FailureOr lookupIntNodeAsBoolAttr(StringRef name); FailureOr lookupIntNode(StringRef name); FailureOr lookupMDNode(StringRef name); FailureOr> lookupMDNodes(StringRef name); FailureOr lookupFollowupNode(StringRef name); FailureOr lookupBooleanUnitNode(StringRef enableName, StringRef disableName, bool negated = false); /// Conversion functions for sub-attributes. FailureOr convertVectorizeAttr(); FailureOr convertInterleaveAttr(); FailureOr convertUnrollAttr(); FailureOr convertUnrollAndJamAttr(); FailureOr convertLICMAttr(); FailureOr convertDistributeAttr(); FailureOr convertPipelineAttr(); FailureOr convertPeeledAttr(); FailureOr convertUnswitchAttr(); FailureOr> convertParallelAccesses(); FusedLoc convertStartLoc(); FailureOr convertEndLoc(); llvm::SmallVector locations; llvm::StringMap propertyMap; const llvm::MDNode *node; Location loc; LoopAnnotationImporter &loopAnnotationImporter; MLIRContext *ctx; }; } // namespace LogicalResult LoopMetadataConversion::initConversionState() { // Check if it's a valid node. if (node->getNumOperands() == 0 || dyn_cast(node->getOperand(0)) != node) return emitWarning(loc) << "invalid loop node"; for (const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) { if (auto *diLoc = dyn_cast(operand)) { locations.push_back(diLoc); continue; } auto *property = dyn_cast(operand); if (!property) return emitWarning(loc) << "expected all loop properties to be either " "debug locations or metadata nodes"; if (property->getNumOperands() == 0) return emitWarning(loc) << "cannot import empty loop property"; auto *nameNode = dyn_cast(property->getOperand(0)); if (!nameNode) return emitWarning(loc) << "cannot import loop property without a name"; StringRef name = nameNode->getString(); bool succ = propertyMap.try_emplace(name, property).second; if (!succ) return emitWarning(loc) << "cannot import loop properties with duplicated names " << name; } return success(); } const llvm::MDNode * LoopMetadataConversion::lookupAndEraseProperty(StringRef name) { auto it = propertyMap.find(name); if (it == propertyMap.end()) return nullptr; const llvm::MDNode *property = it->getValue(); propertyMap.erase(it); return property; } FailureOr LoopMetadataConversion::lookupUnitNode(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) return BoolAttr(nullptr); if (property->getNumOperands() != 1) return emitWarning(loc) << "expected metadata node " << name << " to hold no value"; return BoolAttr::get(ctx, true); } FailureOr LoopMetadataConversion::lookupBooleanUnitNode( StringRef enableName, StringRef disableName, bool negated) { auto enable = lookupUnitNode(enableName); auto disable = lookupUnitNode(disableName); if (failed(enable) || failed(disable)) return failure(); if (*enable && *disable) return emitWarning(loc) << "expected metadata nodes " << enableName << " and " << disableName << " to be mutually exclusive."; if (*enable) return BoolAttr::get(ctx, !negated); if (*disable) return BoolAttr::get(ctx, negated); return BoolAttr(nullptr); } FailureOr LoopMetadataConversion::lookupBoolNode(StringRef name, bool negated) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) return BoolAttr(nullptr); auto emitNodeWarning = [&]() { return emitWarning(loc) << "expected metadata node " << name << " to hold a boolean value"; }; if (property->getNumOperands() != 2) return emitNodeWarning(); llvm::ConstantInt *val = llvm::mdconst::dyn_extract(property->getOperand(1)); if (!val || val->getBitWidth() != 1) return emitNodeWarning(); return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated); } FailureOr LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) return BoolAttr(nullptr); auto emitNodeWarning = [&]() { return emitWarning(loc) << "expected metadata node " << name << " to hold an integer value"; }; if (property->getNumOperands() != 2) return emitNodeWarning(); llvm::ConstantInt *val = llvm::mdconst::dyn_extract(property->getOperand(1)); if (!val || val->getBitWidth() != 32) return emitNodeWarning(); return BoolAttr::get(ctx, val->getValue().getLimitedValue(1)); } FailureOr LoopMetadataConversion::lookupIntNode(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) return IntegerAttr(nullptr); auto emitNodeWarning = [&]() { return emitWarning(loc) << "expected metadata node " << name << " to hold an i32 value"; }; if (property->getNumOperands() != 2) return emitNodeWarning(); llvm::ConstantInt *val = llvm::mdconst::dyn_extract(property->getOperand(1)); if (!val || val->getBitWidth() != 32) return emitNodeWarning(); return IntegerAttr::get(IntegerType::get(ctx, 32), val->getValue().getLimitedValue()); } FailureOr LoopMetadataConversion::lookupMDNode(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) return nullptr; auto emitNodeWarning = [&]() { return emitWarning(loc) << "expected metadata node " << name << " to hold an MDNode"; }; if (property->getNumOperands() != 2) return emitNodeWarning(); auto *node = dyn_cast(property->getOperand(1)); if (!node) return emitNodeWarning(); return node; } FailureOr> LoopMetadataConversion::lookupMDNodes(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); SmallVector res; if (!property) return res; auto emitNodeWarning = [&]() { return emitWarning(loc) << "expected metadata node " << name << " to hold one or multiple MDNodes"; }; if (property->getNumOperands() < 2) return emitNodeWarning(); for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) { auto *node = dyn_cast(property->getOperand(i)); if (!node) return emitNodeWarning(); res.push_back(node); } return res; } FailureOr LoopMetadataConversion::lookupFollowupNode(StringRef name) { auto node = lookupMDNode(name); if (failed(node)) return failure(); if (*node == nullptr) return LoopAnnotationAttr(nullptr); return loopAnnotationImporter.translateLoopAnnotation(*node, loc); } static bool isEmptyOrNull(const Attribute attr) { return !attr; } template static bool isEmptyOrNull(const SmallVectorImpl &vec) { return vec.empty(); } /// Helper function that only creates and attribute of type T if all argument /// conversion were successfull and at least one of them holds a non-null value. template static T createIfNonNull(MLIRContext *ctx, const P &...args) { bool anyFailed = (failed(args) || ...); if (anyFailed) return {}; bool allEmpty = (isEmptyOrNull(*args) && ...); if (allEmpty) return {}; return T::get(ctx, *args...); } FailureOr LoopMetadataConversion::convertVectorizeAttr() { FailureOr enable = lookupBoolNode("llvm.loop.vectorize.enable", true); FailureOr predicateEnable = lookupBoolNode("llvm.loop.vectorize.predicate.enable"); FailureOr scalableEnable = lookupBoolNode("llvm.loop.vectorize.scalable.enable"); FailureOr width = lookupIntNode("llvm.loop.vectorize.width"); FailureOr followupVec = lookupFollowupNode("llvm.loop.vectorize.followup_vectorized"); FailureOr followupEpi = lookupFollowupNode("llvm.loop.vectorize.followup_epilogue"); FailureOr followupAll = lookupFollowupNode("llvm.loop.vectorize.followup_all"); return createIfNonNull(ctx, enable, predicateEnable, scalableEnable, width, followupVec, followupEpi, followupAll); } FailureOr LoopMetadataConversion::convertInterleaveAttr() { FailureOr count = lookupIntNode("llvm.loop.interleave.count"); return createIfNonNull(ctx, count); } FailureOr LoopMetadataConversion::convertUnrollAttr() { FailureOr disable = lookupBooleanUnitNode( "llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true); FailureOr count = lookupIntNode("llvm.loop.unroll.count"); FailureOr runtimeDisable = lookupUnitNode("llvm.loop.unroll.runtime.disable"); FailureOr full = lookupUnitNode("llvm.loop.unroll.full"); FailureOr followupUnrolled = lookupFollowupNode("llvm.loop.unroll.followup_unrolled"); FailureOr followupRemainder = lookupFollowupNode("llvm.loop.unroll.followup_remainder"); FailureOr followupAll = lookupFollowupNode("llvm.loop.unroll.followup_all"); return createIfNonNull(ctx, disable, count, runtimeDisable, full, followupUnrolled, followupRemainder, followupAll); } FailureOr LoopMetadataConversion::convertUnrollAndJamAttr() { FailureOr disable = lookupBooleanUnitNode( "llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable", /*negated=*/true); FailureOr count = lookupIntNode("llvm.loop.unroll_and_jam.count"); FailureOr followupOuter = lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer"); FailureOr followupInner = lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner"); FailureOr followupRemainderOuter = lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer"); FailureOr followupRemainderInner = lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner"); FailureOr followupAll = lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all"); return createIfNonNull( ctx, disable, count, followupOuter, followupInner, followupRemainderOuter, followupRemainderInner, followupAll); } FailureOr LoopMetadataConversion::convertLICMAttr() { FailureOr disable = lookupUnitNode("llvm.licm.disable"); FailureOr versioningDisable = lookupUnitNode("llvm.loop.licm_versioning.disable"); return createIfNonNull(ctx, disable, versioningDisable); } FailureOr LoopMetadataConversion::convertDistributeAttr() { FailureOr disable = lookupBoolNode("llvm.loop.distribute.enable", true); FailureOr followupCoincident = lookupFollowupNode("llvm.loop.distribute.followup_coincident"); FailureOr followupSequential = lookupFollowupNode("llvm.loop.distribute.followup_sequential"); FailureOr followupFallback = lookupFollowupNode("llvm.loop.distribute.followup_fallback"); FailureOr followupAll = lookupFollowupNode("llvm.loop.distribute.followup_all"); return createIfNonNull(ctx, disable, followupCoincident, followupSequential, followupFallback, followupAll); } FailureOr LoopMetadataConversion::convertPipelineAttr() { FailureOr disable = lookupBoolNode("llvm.loop.pipeline.disable"); FailureOr initiationinterval = lookupIntNode("llvm.loop.pipeline.initiationinterval"); return createIfNonNull(ctx, disable, initiationinterval); } FailureOr LoopMetadataConversion::convertPeeledAttr() { FailureOr count = lookupIntNode("llvm.loop.peeled.count"); return createIfNonNull(ctx, count); } FailureOr LoopMetadataConversion::convertUnswitchAttr() { FailureOr partialDisable = lookupUnitNode("llvm.loop.unswitch.partial.disable"); return createIfNonNull(ctx, partialDisable); } FailureOr> LoopMetadataConversion::convertParallelAccesses() { FailureOr> nodes = lookupMDNodes("llvm.loop.parallel_accesses"); if (failed(nodes)) return failure(); SmallVector refs; for (llvm::MDNode *node : *nodes) { FailureOr> accessGroups = loopAnnotationImporter.lookupAccessGroupAttrs(node); if (failed(accessGroups)) { emitWarning(loc) << "could not lookup access group"; continue; } llvm::append_range(refs, *accessGroups); } return refs; } FusedLoc LoopMetadataConversion::convertStartLoc() { if (locations.empty()) return {}; return dyn_cast( loopAnnotationImporter.moduleImport.translateLoc(locations[0])); } FailureOr LoopMetadataConversion::convertEndLoc() { if (locations.size() < 2) return FusedLoc(); if (locations.size() > 2) return emitError(loc) << "expected loop metadata to have at most two DILocations"; return dyn_cast( loopAnnotationImporter.moduleImport.translateLoc(locations[1])); } LoopAnnotationAttr LoopMetadataConversion::convert() { if (failed(initConversionState())) return {}; FailureOr disableNonForced = lookupUnitNode("llvm.loop.disable_nonforced"); FailureOr vecAttr = convertVectorizeAttr(); FailureOr interleaveAttr = convertInterleaveAttr(); FailureOr unrollAttr = convertUnrollAttr(); FailureOr unrollAndJamAttr = convertUnrollAndJamAttr(); FailureOr licmAttr = convertLICMAttr(); FailureOr distributeAttr = convertDistributeAttr(); FailureOr pipelineAttr = convertPipelineAttr(); FailureOr peeledAttr = convertPeeledAttr(); FailureOr unswitchAttr = convertUnswitchAttr(); FailureOr mustProgress = lookupUnitNode("llvm.loop.mustprogress"); FailureOr isVectorized = lookupIntNodeAsBoolAttr("llvm.loop.isvectorized"); FailureOr> parallelAccesses = convertParallelAccesses(); // Drop the metadata if there are parts that cannot be imported. if (!propertyMap.empty()) { for (auto name : propertyMap.keys()) emitWarning(loc) << "unknown loop annotation " << name; return {}; } FailureOr startLoc = convertStartLoc(); FailureOr endLoc = convertEndLoc(); return createIfNonNull( ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr, unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr, unswitchAttr, mustProgress, isVectorized, startLoc, endLoc, parallelAccesses); } LoopAnnotationAttr LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, Location loc) { if (!node) return {}; // Note: This check is necessary to distinguish between failed translations // and not yet attempted translations. auto it = loopMetadataMapping.find(node); if (it != loopMetadataMapping.end()) return it->getSecond(); LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); mapLoopMetadata(node, attr); return attr; } LogicalResult LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node, Location loc) { SmallVector accessGroups; if (!node->getNumOperands()) accessGroups.push_back(node); for (const llvm::MDOperand &operand : node->operands()) { auto *childNode = dyn_cast(operand); if (!childNode) return failure(); accessGroups.push_back(cast(operand.get())); } // Convert all entries of the access group list to access group operations. for (const llvm::MDNode *accessGroup : accessGroups) { if (accessGroupMapping.count(accessGroup)) continue; // Verify the access group node is distinct and empty. if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) return emitWarning(loc) << "expected an access group node to be empty and distinct"; // Add a mapping from the access group node to the newly created attribute. accessGroupMapping[accessGroup] = builder.getAttr(); } return success(); } FailureOr> LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { // An access group node is either a single access group or an access group // list. SmallVector accessGroups; if (!node->getNumOperands()) accessGroups.push_back(accessGroupMapping.lookup(node)); for (const llvm::MDOperand &operand : node->operands()) { auto *node = cast(operand.get()); accessGroups.push_back(accessGroupMapping.lookup(node)); } // Exit if one of the access group node lookups failed. if (llvm::is_contained(accessGroups, nullptr)) return failure(); return accessGroups; }