98 lines
3.7 KiB
C++
98 lines
3.7 KiB
C++
|
//===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// This file implements a pass for integration testing of wide integer
|
||
|
// emulation patterns. Applies conversion patterns only to functions whose
|
||
|
// names start with a specified prefix.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||
|
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
|
||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
#include "mlir/Transforms/DialectConversion.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
namespace {
|
||
|
struct TestEmulateWideIntPass
|
||
|
: public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
|
||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
|
||
|
|
||
|
TestEmulateWideIntPass() = default;
|
||
|
TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
|
||
|
: PassWrapper(pass) {}
|
||
|
|
||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||
|
registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
|
||
|
vector::VectorDialect>();
|
||
|
}
|
||
|
StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
|
||
|
StringRef getDescription() const final {
|
||
|
return "Function pass to test Wide Integer Emulation";
|
||
|
}
|
||
|
|
||
|
void runOnOperation() override {
|
||
|
if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
|
||
|
signalPassFailure();
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
func::FuncOp op = getOperation();
|
||
|
if (!op.getSymName().starts_with(testFunctionPrefix))
|
||
|
return;
|
||
|
|
||
|
MLIRContext *ctx = op.getContext();
|
||
|
arith::WideIntEmulationConverter typeConverter(widestIntSupported);
|
||
|
|
||
|
// Use `llvm.bitcast` as the bridge so that we can use preserve the
|
||
|
// function argument and return types of the processed function.
|
||
|
// TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
|
||
|
// casts (and vice versa) and using it insted of `llvm.bitcast`.
|
||
|
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
|
||
|
Location loc) -> std::optional<Value> {
|
||
|
auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
|
||
|
return cast->getResult(0);
|
||
|
};
|
||
|
typeConverter.addSourceMaterialization(addBitcast);
|
||
|
typeConverter.addTargetMaterialization(addBitcast);
|
||
|
|
||
|
ConversionTarget target(*ctx);
|
||
|
target
|
||
|
.addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
|
||
|
[&typeConverter](Operation *op) {
|
||
|
return typeConverter.isLegal(op);
|
||
|
});
|
||
|
|
||
|
RewritePatternSet patterns(ctx);
|
||
|
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
|
||
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
||
|
signalPassFailure();
|
||
|
}
|
||
|
|
||
|
Option<std::string> testFunctionPrefix{
|
||
|
*this, "function-prefix",
|
||
|
llvm::cl::desc("Prefix of functions to run the emulation pass on"),
|
||
|
llvm::cl::init("emulate_")};
|
||
|
Option<unsigned> widestIntSupported{
|
||
|
*this, "widest-int-supported",
|
||
|
llvm::cl::desc("Maximum integer bit width supported by the target"),
|
||
|
llvm::cl::init(32)};
|
||
|
};
|
||
|
} // namespace
|
||
|
|
||
|
namespace mlir::test {
|
||
|
void registerTestArithEmulateWideIntPass() {
|
||
|
PassRegistration<TestEmulateWideIntPass>();
|
||
|
}
|
||
|
} // namespace mlir::test
|