//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" using namespace mlir; /// Helper function to call the `makeRegionIsolatedFromAbove` to convert /// `test.one_region_op` to `test.isolated_one_region_op`. static LogicalResult makeIsolatedFromAboveImpl(RewriterBase &rewriter, test::OneRegionWithOperandsOp regionOp, llvm::function_ref callBack) { Region ®ion = regionOp.getRegion(); SmallVector capturedValues = makeRegionIsolatedFromAbove(rewriter, region, callBack); SmallVector operands = regionOp.getOperands(); operands.append(capturedValues); auto isolatedRegionOp = rewriter.create(regionOp.getLoc(), operands); rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(), isolatedRegionOp.getRegion().begin()); rewriter.eraseOp(regionOp); return success(); } namespace { /// Simple test for making region isolated from above without cloning any /// operations. struct SimpleMakeIsolatedFromAbove : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, PatternRewriter &rewriter) const override { return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *) { return false; }); } }; /// Test for making region isolated from above while clong operations /// with no operands. struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, PatternRewriter &rewriter) const override { return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) { return op->getNumOperands() == 0; }); } }; /// Test for making region isolated from above while clong operations /// with no operands. struct MakeIsolatedFromAboveAndCloneOpsWithOperands : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp, PatternRewriter &rewriter) const override { return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) { return true; }); } }; /// Test pass for testing the `makeIsolatedFromAbove` function. struct TestMakeIsolatedFromAbovePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass) TestMakeIsolatedFromAbovePass() = default; TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass) : PassWrapper(pass) {} StringRef getArgument() const final { return "test-make-isolated-from-above"; } StringRef getDescription() const final { return "Test making a region isolated from above"; } Option simple{ *this, "simple", llvm::cl::desc("Test simple case with no cloning of operations"), llvm::cl::init(false)}; Option cloneOpsWithNoOperands{ *this, "clone-ops-with-no-operands", llvm::cl::desc("Test case with cloning of operations with no operands"), llvm::cl::init(false)}; Option cloneOpsWithOperands{ *this, "clone-ops-with-operands", llvm::cl::desc("Test case with cloning of operations with no operands"), llvm::cl::init(false)}; void runOnOperation() override; }; } // namespace void TestMakeIsolatedFromAbovePass::runOnOperation() { MLIRContext *context = &getContext(); func::FuncOp funcOp = getOperation(); if (simple) { RewritePatternSet patterns(context); patterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; } if (cloneOpsWithNoOperands) { RewritePatternSet patterns(context); patterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; } if (cloneOpsWithOperands) { RewritePatternSet patterns(context); patterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; } } namespace mlir { namespace test { void registerTestMakeIsolatedFromAbovePass() { PassRegistration(); } } // namespace test } // namespace mlir