//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "ByteCode.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include using namespace mlir; // Include the PDL rewrite support. #if MLIR_ENABLE_PDL_IN_PATTERNMATCH #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule, DenseMap &configMap) { // Skip the conversion if the module doesn't contain pdl. if (pdlModule.getOps().empty()) return success(); // Simplify the provided PDL module. Note that we can't use the canonicalizer // here because it would create a cyclic dependency. auto simplifyFn = [](Operation *op) { // TODO: Add folding here if ever necessary. if (isOpTriviallyDead(op)) op->erase(); }; pdlModule.getBody()->walk(simplifyFn); /// Lower the PDL pattern module to the interpreter dialect. PassManager pdlPipeline(pdlModule->getName()); #ifdef NDEBUG // We don't want to incur the hit of running the verifier when in release // mode. pdlPipeline.enableVerifier(false); #endif pdlPipeline.addPass(createPDLToPDLInterpPass(configMap)); if (failed(pdlPipeline.run(pdlModule))) return failure(); // Simplify again after running the lowering pipeline. pdlModule.getBody()->walk(simplifyFn); return success(); } #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH //===----------------------------------------------------------------------===// // FrozenRewritePatternSet //===----------------------------------------------------------------------===// FrozenRewritePatternSet::FrozenRewritePatternSet() : impl(std::make_shared()) {} FrozenRewritePatternSet::FrozenRewritePatternSet( RewritePatternSet &&patterns, ArrayRef disabledPatternLabels, ArrayRef enabledPatternLabels) : impl(std::make_shared()) { DenseSet disabledPatterns, enabledPatterns; disabledPatterns.insert(disabledPatternLabels.begin(), disabledPatternLabels.end()); enabledPatterns.insert(enabledPatternLabels.begin(), enabledPatternLabels.end()); // Functor used to walk all of the operations registered in the context. This // is useful for patterns that get applied to multiple operations, such as // interface and trait based patterns. std::vector opInfos; auto addToOpsWhen = [&](std::unique_ptr &pattern, function_ref callbackFn) { if (opInfos.empty()) opInfos = pattern->getContext()->getRegisteredOperations(); for (RegisteredOperationName info : opInfos) if (callbackFn(info)) impl->nativeOpSpecificPatternMap[info].push_back(pattern.get()); impl->nativeOpSpecificPatternList.push_back(std::move(pattern)); }; for (std::unique_ptr &pat : patterns.getNativePatterns()) { // Don't add patterns that haven't been enabled by the user. if (!enabledPatterns.empty()) { auto isEnabledFn = [&](StringRef label) { return enabledPatterns.count(label); }; if (!isEnabledFn(pat->getDebugName()) && llvm::none_of(pat->getDebugLabels(), isEnabledFn)) continue; } // Don't add patterns that have been disabled by the user. if (!disabledPatterns.empty()) { auto isDisabledFn = [&](StringRef label) { return disabledPatterns.count(label); }; if (isDisabledFn(pat->getDebugName()) || llvm::any_of(pat->getDebugLabels(), isDisabledFn)) continue; } if (std::optional rootName = pat->getRootKind()) { impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get()); impl->nativeOpSpecificPatternList.push_back(std::move(pat)); continue; } if (std::optional interfaceID = pat->getRootInterfaceID()) { addToOpsWhen(pat, [&](RegisteredOperationName info) { return info.hasInterface(*interfaceID); }); continue; } if (std::optional traitID = pat->getRootTraitID()) { addToOpsWhen(pat, [&](RegisteredOperationName info) { return info.hasTrait(*traitID); }); continue; } impl->nativeAnyOpPatterns.push_back(std::move(pat)); } #if MLIR_ENABLE_PDL_IN_PATTERNMATCH // Generate the bytecode for the PDL patterns if any were provided. PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); ModuleOp pdlModule = pdlPatterns.getModule(); if (!pdlModule) return; DenseMap configMap = pdlPatterns.takeConfigMap(); if (failed(convertPDLToPDLInterp(pdlModule, configMap))) llvm::report_fatal_error( "failed to lower PDL pattern module to the PDL Interpreter"); // Generate the pdl bytecode. impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConfigs(), configMap, pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeRewriteFunctions()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH } FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;