//===- PatternBenefit.cpp - RewritePattern benefit unit tests -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/PatternApplicator.h" #include "gtest/gtest.h" using namespace mlir; namespace { TEST(PatternBenefitTest, BenefitOrder) { // There was a bug which caused low-benefit op-specific patterns to never be // called in presence of high-benefit op-agnostic pattern MLIRContext context; OpBuilder builder(&context); OwningOpRef module = ModuleOp::create(builder.getUnknownLoc()); struct Pattern1 : public OpRewritePattern { Pattern1(mlir::MLIRContext *context, bool *called) : OpRewritePattern(context, /*benefit*/ 1), called(called) {} mlir::LogicalResult matchAndRewrite(ModuleOp /*op*/, mlir::PatternRewriter & /*rewriter*/) const override { *called = true; return failure(); } private: bool *called; }; struct Pattern2 : public RewritePattern { Pattern2(MLIRContext *context, bool *called) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/2, context), called(called) {} mlir::LogicalResult matchAndRewrite(Operation * /*op*/, mlir::PatternRewriter & /*rewriter*/) const override { *called = true; return failure(); } private: bool *called; }; RewritePatternSet patterns(&context); bool called1 = false; bool called2 = false; patterns.add(&context, &called1); patterns.add(&context, &called2); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); PatternApplicator pa(frozenPatterns); pa.applyDefaultCostModel(); class MyPatternRewriter : public PatternRewriter { public: MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} }; MyPatternRewriter rewriter(&context); (void)pa.matchAndRewrite(*module, rewriter); EXPECT_TRUE(called1); EXPECT_TRUE(called2); } } // namespace