//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; /// Custom constraint invoked from PDL. static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, Operation *rootOp) { return success(rootOp->getName().getStringRef() == "test.op"); } static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, Operation *root, Operation *rootCopy) { return customSingleEntityConstraint(rewriter, rootCopy); } static LogicalResult customMultiEntityVariadicConstraint( PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { if (operandValues.size() != 2 || typeValues.size() != 2) return failure(); return success(); } // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); } static auto customVariadicResultCreate(PatternRewriter &rewriter, Operation *root) { return std::make_pair(root->getOperands(), root->getOperands().getTypes()); } static Type customCreateType(PatternRewriter &rewriter) { return rewriter.getF32Type(); } static std::string customCreateStrAttr(PatternRewriter &rewriter) { return "test.str"; } /// Custom rewriter invoked from PDL. static void customRewriter(PatternRewriter &rewriter, Operation *root, Value input) { rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), input); rewriter.eraseOp(root); } namespace { struct TestPDLByteCodePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } StringRef getDescription() const final { return "Test PDL ByteCode functionality"; } void getDependentDialects(DialectRegistry ®istry) const override { // Mark the pdl_interp dialect as a dependent. This is needed, because we // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. registry.insert(); } void runOnOperation() final { ModuleOp module = getOperation(); // The test cases are encompassed via two modules, one containing the // patterns and one containing the operations to rewrite. ModuleOp patternModule = module.lookupSymbol( StringAttr::get(module->getContext(), "patterns")); ModuleOp irModule = module.lookupSymbol( StringAttr::get(module->getContext(), "ir")); if (!patternModule || !irModule) return; RewritePatternSet patternList(module->getContext()); // Register ahead of time to test when functions are registered without a // pattern. patternList.getPDLPatterns().registerConstraintFunction( "multi_entity_constraint", customMultiEntityConstraint); patternList.getPDLPatterns().registerConstraintFunction( "single_entity_constraint", customSingleEntityConstraint); // Process the pattern module. patternModule.getOperation()->remove(); PDLPatternModule pdlPattern(patternModule); // Note: This constraint was already registered, but we re-register here to // ensure that duplication registration is allowed (the duplicate mapping // will be ignored). This tests that we support separating the registration // of library functions from the construction of patterns, and also that we // allow multiple patterns to depend on the same library functions (without // asserting/crashing). pdlPattern.registerConstraintFunction("multi_entity_constraint", customMultiEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); pdlPattern.registerRewriteFunction("type_creator", customCreateType); pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr); pdlPattern.registerRewriteFunction("rewriter", customRewriter); patternList.add(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), std::move(patternList)); } }; } // namespace namespace mlir { namespace test { void registerTestPDLByteCodePass() { PassRegistration(); } } // namespace test } // namespace mlir