bolt/deps/llvm-18.1.8/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
2025-02-14 19:21:04 +01:00

301 lines
12 KiB
C++

//===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "LoopAnnotationTranslation.h"
#include "llvm/IR/DebugInfoMetadata.h"
using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
namespace {
/// Helper class that keeps the state of one attribute to metadata conversion.
struct LoopAnnotationConversion {
LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
LoopAnnotationTranslation &loopAnnotationTranslation,
llvm::LLVMContext &ctx)
: attr(attr), op(op),
loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
/// Converts this struct's loop annotation into a corresponding LLVMIR
/// metadata representation.
llvm::MDNode *convert();
/// Conversion functions for different payload attribute kinds.
void addUnitNode(StringRef name);
void addUnitNode(StringRef name, BoolAttr attr);
void addI32NodeWithVal(StringRef name, uint32_t val);
void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false);
void convertI32Node(StringRef name, IntegerAttr attr);
void convertFollowupNode(StringRef name, LoopAnnotationAttr attr);
void convertLocation(FusedLoc attr);
/// Conversion functions for each for each loop annotation sub-attribute.
void convertLoopOptions(LoopVectorizeAttr options);
void convertLoopOptions(LoopInterleaveAttr options);
void convertLoopOptions(LoopUnrollAttr options);
void convertLoopOptions(LoopUnrollAndJamAttr options);
void convertLoopOptions(LoopLICMAttr options);
void convertLoopOptions(LoopDistributeAttr options);
void convertLoopOptions(LoopPipelineAttr options);
void convertLoopOptions(LoopPeeledAttr options);
void convertLoopOptions(LoopUnswitchAttr options);
LoopAnnotationAttr attr;
Operation *op;
LoopAnnotationTranslation &loopAnnotationTranslation;
llvm::LLVMContext &ctx;
llvm::SmallVector<llvm::Metadata *> metadataNodes;
};
} // namespace
void LoopAnnotationConversion::addUnitNode(StringRef name) {
metadataNodes.push_back(
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name)}));
}
void LoopAnnotationConversion::addUnitNode(StringRef name, BoolAttr attr) {
if (attr && attr.getValue())
addUnitNode(name);
}
void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) {
llvm::Constant *cstValue = llvm::ConstantInt::get(
llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false);
metadataNodes.push_back(
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
llvm::ConstantAsMetadata::get(cstValue)}));
}
void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr,
bool negated) {
if (!attr)
return;
bool val = negated ^ attr.getValue();
llvm::Constant *cstValue = llvm::ConstantInt::getBool(ctx, val);
metadataNodes.push_back(
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
llvm::ConstantAsMetadata::get(cstValue)}));
}
void LoopAnnotationConversion::convertI32Node(StringRef name,
IntegerAttr attr) {
if (!attr)
return;
addI32NodeWithVal(name, attr.getInt());
}
void LoopAnnotationConversion::convertFollowupNode(StringRef name,
LoopAnnotationAttr attr) {
if (!attr)
return;
llvm::MDNode *node =
loopAnnotationTranslation.translateLoopAnnotation(attr, op);
metadataNodes.push_back(
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
}
void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options) {
convertBoolNode("llvm.loop.vectorize.enable", options.getDisable(), true);
convertBoolNode("llvm.loop.vectorize.predicate.enable",
options.getPredicateEnable());
convertBoolNode("llvm.loop.vectorize.scalable.enable",
options.getScalableEnable());
convertI32Node("llvm.loop.vectorize.width", options.getWidth());
convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
options.getFollowupVectorized());
convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
options.getFollowupEpilogue());
convertFollowupNode("llvm.loop.vectorize.followup_all",
options.getFollowupAll());
}
void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options) {
convertI32Node("llvm.loop.interleave.count", options.getCount());
}
void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options) {
if (auto disable = options.getDisable())
addUnitNode(disable.getValue() ? "llvm.loop.unroll.disable"
: "llvm.loop.unroll.enable");
convertI32Node("llvm.loop.unroll.count", options.getCount());
convertBoolNode("llvm.loop.unroll.runtime.disable",
options.getRuntimeDisable());
addUnitNode("llvm.loop.unroll.full", options.getFull());
convertFollowupNode("llvm.loop.unroll.followup_unrolled",
options.getFollowupUnrolled());
convertFollowupNode("llvm.loop.unroll.followup_remainder",
options.getFollowupRemainder());
convertFollowupNode("llvm.loop.unroll.followup_all",
options.getFollowupAll());
}
void LoopAnnotationConversion::convertLoopOptions(
LoopUnrollAndJamAttr options) {
if (auto disable = options.getDisable())
addUnitNode(disable.getValue() ? "llvm.loop.unroll_and_jam.disable"
: "llvm.loop.unroll_and_jam.enable");
convertI32Node("llvm.loop.unroll_and_jam.count", options.getCount());
convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
options.getFollowupOuter());
convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
options.getFollowupInner());
convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
options.getFollowupRemainderOuter());
convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
options.getFollowupRemainderInner());
convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
options.getFollowupAll());
}
void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options) {
addUnitNode("llvm.licm.disable", options.getDisable());
addUnitNode("llvm.loop.licm_versioning.disable",
options.getVersioningDisable());
}
void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options) {
convertBoolNode("llvm.loop.distribute.enable", options.getDisable(), true);
convertFollowupNode("llvm.loop.distribute.followup_coincident",
options.getFollowupCoincident());
convertFollowupNode("llvm.loop.distribute.followup_sequential",
options.getFollowupSequential());
convertFollowupNode("llvm.loop.distribute.followup_fallback",
options.getFollowupFallback());
convertFollowupNode("llvm.loop.distribute.followup_all",
options.getFollowupAll());
}
void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options) {
convertBoolNode("llvm.loop.pipeline.disable", options.getDisable());
convertI32Node("llvm.loop.pipeline.initiationinterval",
options.getInitiationinterval());
}
void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options) {
convertI32Node("llvm.loop.peeled.count", options.getCount());
}
void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
addUnitNode("llvm.loop.unswitch.partial.disable",
options.getPartialDisable());
}
void LoopAnnotationConversion::convertLocation(FusedLoc location) {
auto localScopeAttr =
dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
if (!localScopeAttr)
return;
auto *localScope = dyn_cast<llvm::DILocalScope>(
loopAnnotationTranslation.moduleTranslation.translateDebugInfo(
localScopeAttr));
if (!localScope)
return;
llvm::Metadata *loc =
loopAnnotationTranslation.moduleTranslation.translateLoc(location,
localScope);
metadataNodes.push_back(loc);
}
llvm::MDNode *LoopAnnotationConversion::convert() {
// Reserve operand 0 for loop id self reference.
auto dummy = llvm::MDNode::getTemporary(ctx, std::nullopt);
metadataNodes.push_back(dummy.get());
if (FusedLoc startLoc = attr.getStartLoc())
convertLocation(startLoc);
if (FusedLoc endLoc = attr.getEndLoc())
convertLocation(endLoc);
addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced());
addUnitNode("llvm.loop.mustprogress", attr.getMustProgress());
// "isvectorized" is encoded as an i32 value.
if (BoolAttr isVectorized = attr.getIsVectorized())
addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue());
if (auto options = attr.getVectorize())
convertLoopOptions(options);
if (auto options = attr.getInterleave())
convertLoopOptions(options);
if (auto options = attr.getUnroll())
convertLoopOptions(options);
if (auto options = attr.getUnrollAndJam())
convertLoopOptions(options);
if (auto options = attr.getLicm())
convertLoopOptions(options);
if (auto options = attr.getDistribute())
convertLoopOptions(options);
if (auto options = attr.getPipeline())
convertLoopOptions(options);
if (auto options = attr.getPeeled())
convertLoopOptions(options);
if (auto options = attr.getUnswitch())
convertLoopOptions(options);
ArrayRef<AccessGroupAttr> parallelAccessGroups = attr.getParallelAccesses();
if (!parallelAccessGroups.empty()) {
SmallVector<llvm::Metadata *> parallelAccess;
parallelAccess.push_back(
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
for (AccessGroupAttr accessGroupAttr : parallelAccessGroups)
parallelAccess.push_back(
loopAnnotationTranslation.getAccessGroup(accessGroupAttr));
metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
}
// Create loop options and set the first operand to itself.
llvm::MDNode *loopMD = llvm::MDNode::get(ctx, metadataNodes);
loopMD->replaceOperandWith(0, loopMD);
return loopMD;
}
llvm::MDNode *
LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
Operation *op) {
if (!attr)
return nullptr;
llvm::MDNode *loopMD = lookupLoopMetadata(attr);
if (loopMD)
return loopMD;
loopMD =
LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
.convert();
// Store a map from this Attribute to the LLVM metadata in case we
// encounter it again.
mapLoopMetadata(attr, loopMD);
return loopMD;
}
llvm::MDNode *
LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr) {
auto [result, inserted] =
accessGroupMetadataMapping.insert({accessGroupAttr, nullptr});
if (inserted)
result->second = llvm::MDNode::getDistinct(llvmModule.getContext(), {});
return result->second;
}
llvm::MDNode *
LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op) {
ArrayAttr accessGroups = op.getAccessGroupsOrNull();
if (!accessGroups || accessGroups.empty())
return nullptr;
SmallVector<llvm::Metadata *> groupMDs;
for (AccessGroupAttr group : accessGroups.getAsRange<AccessGroupAttr>())
groupMDs.push_back(getAccessGroup(group));
if (groupMDs.size() == 1)
return llvm::cast<llvm::MDNode>(groupMDs.front());
return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
}