bolt/deps/llvm-18.1.8/llvm/unittests/Target/X86/TernlogTest.cpp
2025-02-14 19:21:04 +01:00

201 lines
6.2 KiB
C++

//===- LICMTest.cpp - LICM unit tests -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "gtest/gtest.h"
#include <random>
namespace llvm {
static std::unique_ptr<LLVMTargetMachine> initTM() {
LLVMInitializeX86TargetInfo();
LLVMInitializeX86Target();
LLVMInitializeX86TargetMC();
auto TT(Triple::normalize("x86_64--"));
std::string Error;
const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error);
return std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt,
std::nullopt, CodeGenOptLevel::Default)));
}
struct TernTester {
unsigned NElem;
unsigned ElemWidth;
std::mt19937_64 Rng;
unsigned ImmVal;
SmallVector<uint64_t, 16> VecElems[3];
void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; }
void updateNElem(unsigned NewNElem) {
NElem = NewNElem;
for (unsigned I = 0; I < 3; ++I) {
VecElems[I].resize(NElem);
}
}
void updateElemWidth(unsigned NewElemWidth) {
ElemWidth = NewElemWidth;
assert(ElemWidth == 32 || ElemWidth == 64);
}
uint64_t getElemMask() const {
return (~uint64_t(0)) >> ((ElemWidth - 0) % 64);
}
void RandomizeVecArgs() {
uint64_t ElemMask = getElemMask();
for (unsigned I = 0; I < 3; ++I) {
for (unsigned J = 0; J < NElem; ++J) {
VecElems[I][J] = Rng() & ElemMask;
}
}
}
std::pair<std::string, std::string> getScalarInfo() const {
switch (ElemWidth) {
case 32:
return {"i32", "d"};
case 64:
return {"i64", "q"};
default:
llvm_unreachable("Invalid ElemWidth");
}
}
std::string getScalarType() const { return getScalarInfo().first; }
std::string getScalarExt() const { return getScalarInfo().second; }
std::string getVecType() const {
return "<" + Twine(NElem).str() + " x " + getScalarType() + ">";
};
std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); }
std::string getFunctionName() const {
return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth();
}
std::string getFunctionDecl() const {
return "declare " + getVecType() + getFunctionName() + "(" + getVecType() +
", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)";
}
std::string getVecN(unsigned N) const {
assert(N < 3);
std::string VecStr = getVecType() + " <";
for (unsigned I = 0; I < VecElems[N].size(); ++I) {
if (I != 0)
VecStr += ", ";
VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str();
}
return VecStr + ">";
}
std::string getFunctionCall() const {
return "tail call " + getVecType() + " " + getFunctionName() + "(" +
getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " +
Twine(ImmVal).str() + ")";
}
std::string getTestText() const {
return getFunctionDecl() + "\ndefine " + getVecType() +
"@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() +
" %r\n}\n";
}
void checkResult(const Value *V) {
auto GetValElem = [&](unsigned Idx) -> uint64_t {
if (auto *CV = dyn_cast<ConstantDataVector>(V))
return CV->getElementAsInteger(Idx);
auto *C = dyn_cast<Constant>(V);
assert(C);
if (C->isNullValue())
return 0;
if (C->isAllOnesValue())
return ((~uint64_t(0)) >> (ElemWidth % 64));
if (C->isOneValue())
return 1;
llvm_unreachable("Unknown constant type");
};
auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t {
unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1);
return (ImmVal >> BitIdx) & 1;
};
for (unsigned I = 0; I < NElem; ++I) {
uint64_t Expec = 0;
uint64_t AEle = VecElems[0][I];
uint64_t BEle = VecElems[1][I];
uint64_t CEle = VecElems[2][I];
for (unsigned J = 0; J < ElemWidth; ++J) {
Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J;
}
ASSERT_EQ(Expec, GetValElem(I));
}
}
void check(LLVMContext &Ctx, FunctionPassManager &FPM,
FunctionAnalysisManager &FAM) {
SMDiagnostic Error;
std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx);
ASSERT_TRUE(M);
Function *F = M->getFunction("foo");
ASSERT_TRUE(F);
ASSERT_EQ(F->getInstructionCount(), 2u);
FPM.run(*F, FAM);
ASSERT_EQ(F->getInstructionCount(), 1u);
ASSERT_EQ(F->size(), 1u);
const Instruction *I = F->begin()->getTerminator();
ASSERT_TRUE(I);
ASSERT_EQ(I->getNumOperands(), 1u);
checkResult(I->getOperand(0));
}
};
TEST(TernlogTest, TestConstantFolding) {
LLVMContext Ctx;
FunctionAnalysisManager FAM;
FunctionPassManager FPM;
PassBuilder PB;
LoopAnalysisManager LAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
TargetIRAnalysis TIRA = TargetIRAnalysis(
[&](const Function &F) { return initTM()->getTargetTransformInfo(F); });
FAM.registerPass([&] { return TIRA; });
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
FPM.addPass(InstCombinePass());
TernTester TT;
for (unsigned NElem = 2; NElem < 16; NElem += NElem) {
TT.updateNElem(NElem);
for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) {
if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128)
continue;
TT.updateElemWidth(ElemWidth);
TT.RandomizeVecArgs();
for (unsigned Imm = 0; Imm < 256; ++Imm) {
TT.updateImm(Imm);
TT.check(Ctx, FPM, FAM);
}
}
}
}
} // namespace llvm