bolt/deps/llvm-18.1.8/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp

98 lines
3.7 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for integration testing of wide integer
// emulation patterns. Applies conversion patterns only to functions whose
// names start with a specified prefix.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct TestEmulateWideIntPass
: public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
TestEmulateWideIntPass() = default;
TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
vector::VectorDialect>();
}
StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
StringRef getDescription() const final {
return "Function pass to test Wide Integer Emulation";
}
void runOnOperation() override {
if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
signalPassFailure();
return;
}
func::FuncOp op = getOperation();
if (!op.getSymName().starts_with(testFunctionPrefix))
return;
MLIRContext *ctx = op.getContext();
arith::WideIntEmulationConverter typeConverter(widestIntSupported);
// Use `llvm.bitcast` as the bridge so that we can use preserve the
// function argument and return types of the processed function.
// TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
// casts (and vice versa) and using it insted of `llvm.bitcast`.
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> std::optional<Value> {
auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
return cast->getResult(0);
};
typeConverter.addSourceMaterialization(addBitcast);
typeConverter.addTargetMaterialization(addBitcast);
ConversionTarget target(*ctx);
target
.addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
[&typeConverter](Operation *op) {
return typeConverter.isLegal(op);
});
RewritePatternSet patterns(ctx);
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
Option<std::string> testFunctionPrefix{
*this, "function-prefix",
llvm::cl::desc("Prefix of functions to run the emulation pass on"),
llvm::cl::init("emulate_")};
Option<unsigned> widestIntSupported{
*this, "widest-int-supported",
llvm::cl::desc("Maximum integer bit width supported by the target"),
llvm::cl::init(32)};
};
} // namespace
namespace mlir::test {
void registerTestArithEmulateWideIntPass() {
PassRegistration<TestEmulateWideIntPass>();
}
} // namespace mlir::test