//===- LICMTest.cpp - LICM unit tests -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "gtest/gtest.h" #include namespace llvm { static std::unique_ptr initTM() { LLVMInitializeX86TargetInfo(); LLVMInitializeX86Target(); LLVMInitializeX86TargetMC(); auto TT(Triple::normalize("x86_64--")); std::string Error; const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error); return std::unique_ptr(static_cast( TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt, std::nullopt, CodeGenOptLevel::Default))); } struct TernTester { unsigned NElem; unsigned ElemWidth; std::mt19937_64 Rng; unsigned ImmVal; SmallVector VecElems[3]; void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; } void updateNElem(unsigned NewNElem) { NElem = NewNElem; for (unsigned I = 0; I < 3; ++I) { VecElems[I].resize(NElem); } } void updateElemWidth(unsigned NewElemWidth) { ElemWidth = NewElemWidth; assert(ElemWidth == 32 || ElemWidth == 64); } uint64_t getElemMask() const { return (~uint64_t(0)) >> ((ElemWidth - 0) % 64); } void RandomizeVecArgs() { uint64_t ElemMask = getElemMask(); for (unsigned I = 0; I < 3; ++I) { for (unsigned J = 0; J < NElem; ++J) { VecElems[I][J] = Rng() & ElemMask; } } } std::pair getScalarInfo() const { switch (ElemWidth) { case 32: return {"i32", "d"}; case 64: return {"i64", "q"}; default: llvm_unreachable("Invalid ElemWidth"); } } std::string getScalarType() const { return getScalarInfo().first; } std::string getScalarExt() const { return getScalarInfo().second; } std::string getVecType() const { return "<" + Twine(NElem).str() + " x " + getScalarType() + ">"; }; std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); } std::string getFunctionName() const { return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth(); } std::string getFunctionDecl() const { return "declare " + getVecType() + getFunctionName() + "(" + getVecType() + ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)"; } std::string getVecN(unsigned N) const { assert(N < 3); std::string VecStr = getVecType() + " <"; for (unsigned I = 0; I < VecElems[N].size(); ++I) { if (I != 0) VecStr += ", "; VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str(); } return VecStr + ">"; } std::string getFunctionCall() const { return "tail call " + getVecType() + " " + getFunctionName() + "(" + getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " + Twine(ImmVal).str() + ")"; } std::string getTestText() const { return getFunctionDecl() + "\ndefine " + getVecType() + "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() + " %r\n}\n"; } void checkResult(const Value *V) { auto GetValElem = [&](unsigned Idx) -> uint64_t { if (auto *CV = dyn_cast(V)) return CV->getElementAsInteger(Idx); auto *C = dyn_cast(V); assert(C); if (C->isNullValue()) return 0; if (C->isAllOnesValue()) return ((~uint64_t(0)) >> (ElemWidth % 64)); if (C->isOneValue()) return 1; llvm_unreachable("Unknown constant type"); }; auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t { unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1); return (ImmVal >> BitIdx) & 1; }; for (unsigned I = 0; I < NElem; ++I) { uint64_t Expec = 0; uint64_t AEle = VecElems[0][I]; uint64_t BEle = VecElems[1][I]; uint64_t CEle = VecElems[2][I]; for (unsigned J = 0; J < ElemWidth; ++J) { Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J; } ASSERT_EQ(Expec, GetValElem(I)); } } void check(LLVMContext &Ctx, FunctionPassManager &FPM, FunctionAnalysisManager &FAM) { SMDiagnostic Error; std::unique_ptr M = parseAssemblyString(getTestText(), Error, Ctx); ASSERT_TRUE(M); Function *F = M->getFunction("foo"); ASSERT_TRUE(F); ASSERT_EQ(F->getInstructionCount(), 2u); FPM.run(*F, FAM); ASSERT_EQ(F->getInstructionCount(), 1u); ASSERT_EQ(F->size(), 1u); const Instruction *I = F->begin()->getTerminator(); ASSERT_TRUE(I); ASSERT_EQ(I->getNumOperands(), 1u); checkResult(I->getOperand(0)); } }; TEST(TernlogTest, TestConstantFolding) { LLVMContext Ctx; FunctionAnalysisManager FAM; FunctionPassManager FPM; PassBuilder PB; LoopAnalysisManager LAM; CGSCCAnalysisManager CGAM; ModuleAnalysisManager MAM; TargetIRAnalysis TIRA = TargetIRAnalysis( [&](const Function &F) { return initTM()->getTargetTransformInfo(F); }); FAM.registerPass([&] { return TIRA; }); PB.registerModuleAnalyses(MAM); PB.registerCGSCCAnalyses(CGAM); PB.registerFunctionAnalyses(FAM); PB.registerLoopAnalyses(LAM); PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); FPM.addPass(InstCombinePass()); TernTester TT; for (unsigned NElem = 2; NElem < 16; NElem += NElem) { TT.updateNElem(NElem); for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) { if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128) continue; TT.updateElemWidth(ElemWidth); TT.RandomizeVecArgs(); for (unsigned Imm = 0; Imm < 256; ++Imm) { TT.updateImm(Imm); TT.check(Ctx, FPM, FAM); } } } } } // namespace llvm