//===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/FileUtilities.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/CrashRecoveryContext.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // RecoveryReproducerContext //===----------------------------------------------------------------------===// namespace mlir { namespace detail { /// This class contains all of the context for generating a recovery reproducer. /// Each recovery context is registered globally to allow for generating /// reproducers when a signal is raised, such as a segfault. struct RecoveryReproducerContext { RecoveryReproducerContext(std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses); ~RecoveryReproducerContext(); /// Generate a reproducer with the current context. void generate(std::string &description); /// Disable this reproducer context. This prevents the context from generating /// a reproducer in the result of a crash. void disable(); /// Enable a previously disabled reproducer context. void enable(); private: /// This function is invoked in the event of a crash. static void crashHandler(void *); /// Register a signal handler to run in the event of a crash. static void registerSignalHandler(); /// The textual description of the currently executing pipeline. std::string pipelineElements; /// The MLIR operation representing the IR before the crash. Operation *preCrashOperation; /// The factory for the reproducer output stream to use when generating the /// reproducer. ReproducerStreamFactory &streamFactory; /// Various pass manager and context flags. bool disableThreads; bool verifyPasses; /// The current set of active reproducer contexts. This is used in the event /// of a crash. This is not thread_local as the pass manager may produce any /// number of child threads. This uses a set to allow for multiple MLIR pass /// managers to be running at the same time. static llvm::ManagedStatic> reproducerMutex; static llvm::ManagedStatic< llvm::SmallSetVector> reproducerSet; }; } // namespace detail } // namespace mlir llvm::ManagedStatic> RecoveryReproducerContext::reproducerMutex; llvm::ManagedStatic> RecoveryReproducerContext::reproducerSet; RecoveryReproducerContext::RecoveryReproducerContext( std::string passPipelineStr, Operation *op, ReproducerStreamFactory &streamFactory, bool verifyPasses) : pipelineElements(std::move(passPipelineStr)), preCrashOperation(op->clone()), streamFactory(streamFactory), disableThreads(!op->getContext()->isMultithreadingEnabled()), verifyPasses(verifyPasses) { enable(); } RecoveryReproducerContext::~RecoveryReproducerContext() { // Erase the cloned preCrash IR that we cached. preCrashOperation->erase(); disable(); } static void appendReproducer(std::string &description, Operation *op, const ReproducerStreamFactory &factory, const std::string &pipelineElements, bool disableThreads, bool verifyPasses) { llvm::raw_string_ostream descOS(description); // Try to create a new output stream for this crash reproducer. std::string error; std::unique_ptr stream = factory(error); if (!stream) { descOS << "failed to create output stream: " << error; return; } descOS << "reproducer generated at `" << stream->description() << "`"; std::string pipeline = (op->getName().getStringRef() + "(" + pipelineElements + ")").str(); AsmState state(op); state.attachResourcePrinter( "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) { builder.buildString("pipeline", pipeline); builder.buildBool("disable_threading", disableThreads); builder.buildBool("verify_each", verifyPasses); }); // Output the .mlir module. op->print(stream->os(), state); } void RecoveryReproducerContext::generate(std::string &description) { appendReproducer(description, preCrashOperation, streamFactory, pipelineElements, disableThreads, verifyPasses); } void RecoveryReproducerContext::disable() { llvm::sys::SmartScopedLock lock(*reproducerMutex); reproducerSet->remove(this); if (reproducerSet->empty()) llvm::CrashRecoveryContext::Disable(); } void RecoveryReproducerContext::enable() { llvm::sys::SmartScopedLock lock(*reproducerMutex); if (reproducerSet->empty()) llvm::CrashRecoveryContext::Enable(); registerSignalHandler(); reproducerSet->insert(this); } void RecoveryReproducerContext::crashHandler(void *) { // Walk the current stack of contexts and generate a reproducer for each one. // We can't know for certain which one was the cause, so we need to generate // a reproducer for all of them. for (RecoveryReproducerContext *context : *reproducerSet) { std::string description; context->generate(description); // Emit an error using information only available within the context. emitError(context->preCrashOperation->getLoc()) << "A signal was caught while processing the MLIR module:" << description << "; marking pass as failed"; } } void RecoveryReproducerContext::registerSignalHandler() { // Ensure that the handler is only registered once. static bool registered = (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); (void)registered; } //===----------------------------------------------------------------------===// // PassCrashReproducerGenerator //===----------------------------------------------------------------------===// struct PassCrashReproducerGenerator::Impl { Impl(ReproducerStreamFactory &streamFactory, bool localReproducer) : streamFactory(streamFactory), localReproducer(localReproducer) {} /// The factory to use when generating a crash reproducer. ReproducerStreamFactory streamFactory; /// Flag indicating if reproducer generation should be localized to the /// failing pass. bool localReproducer = false; /// A record of all of the currently active reproducer contexts. SmallVector> activeContexts; /// The set of all currently running passes. Note: This is not populated when /// `localReproducer` is true, as each pass will get its own recovery context. SetVector> runningPasses; /// Various pass manager flags that get emitted when generating a reproducer. bool pmFlagVerifyPasses = false; }; PassCrashReproducerGenerator::PassCrashReproducerGenerator( ReproducerStreamFactory &streamFactory, bool localReproducer) : impl(std::make_unique(streamFactory, localReproducer)) {} PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default; void PassCrashReproducerGenerator::initialize( iterator_range passes, Operation *op, bool pmFlagVerifyPasses) { assert((!impl->localReproducer || !op->getContext()->isMultithreadingEnabled()) && "expected multi-threading to be disabled when generating a local " "reproducer"); llvm::CrashRecoveryContext::Enable(); impl->pmFlagVerifyPasses = pmFlagVerifyPasses; // If we aren't generating a local reproducer, prepare a reproducer for the // given top-level operation. if (!impl->localReproducer) prepareReproducerFor(passes, op); } static void formatPassOpReproducerMessage(Diagnostic &os, std::pair passOpPair) { os << "`" << passOpPair.first->getName() << "` on " << "'" << passOpPair.second->getName() << "' operation"; if (SymbolOpInterface symbol = dyn_cast(passOpPair.second)) os << ": @" << symbol.getName(); } void PassCrashReproducerGenerator::finalize(Operation *rootOp, LogicalResult executionResult) { // Don't generate a reproducer if we have no active contexts. if (impl->activeContexts.empty()) return; // If the pass manager execution succeeded, we don't generate any reproducers. if (succeeded(executionResult)) return impl->activeContexts.clear(); InFlightDiagnostic diag = emitError(rootOp->getLoc()) << "Failures have been detected while " "processing an MLIR pass pipeline"; // If we are generating a global reproducer, we include all of the running // passes in the error message for the only active context. if (!impl->localReproducer) { assert(impl->activeContexts.size() == 1 && "expected one active context"); // Generate the reproducer. std::string description; impl->activeContexts.front()->generate(description); // Emit an error to the user. Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ["; llvm::interleaveComma(impl->runningPasses, note, [&](const std::pair &value) { formatPassOpReproducerMessage(note, value); }); note << "]: " << description; impl->runningPasses.clear(); impl->activeContexts.clear(); return; } // If we were generating a local reproducer, we generate a reproducer for the // most recently executing pass using the matching entry from `runningPasses` // to generate a localized diagnostic message. assert(impl->activeContexts.size() == impl->runningPasses.size() && "expected running passes to match active contexts"); // Generate the reproducer. RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); std::string description; reproducerContext.generate(description); // Emit an error to the user. Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing "; formatPassOpReproducerMessage(note, impl->runningPasses.back()); note << ": " << description; impl->activeContexts.clear(); impl->runningPasses.clear(); } void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, Operation *op) { // If not tracking local reproducers, we simply remember that this pass is // running. impl->runningPasses.insert(std::make_pair(pass, op)); if (!impl->localReproducer) return; // Disable the current pass recovery context, if there is one. This may happen // in the case of dynamic pass pipelines. if (!impl->activeContexts.empty()) impl->activeContexts.back()->disable(); // Collect all of the parent scopes of this operation. SmallVector scopes; while (Operation *parentOp = op->getParentOp()) { scopes.push_back(op->getName()); op = parentOp; } // Emit a pass pipeline string for the current pass running on the current // operation type. std::string passStr; llvm::raw_string_ostream passOS(passStr); for (OperationName scope : llvm::reverse(scopes)) passOS << scope << "("; pass->printAsTextualPipeline(passOS); for (unsigned i = 0, e = scopes.size(); i < e; ++i) passOS << ")"; impl->activeContexts.push_back(std::make_unique( passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); } void PassCrashReproducerGenerator::prepareReproducerFor( iterator_range passes, Operation *op) { std::string passStr; llvm::raw_string_ostream passOS(passStr); llvm::interleaveComma( passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); }); impl->activeContexts.push_back(std::make_unique( passOS.str(), op, impl->streamFactory, impl->pmFlagVerifyPasses)); } void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, Operation *op) { // We only pop the active context if we are tracking local reproducers. impl->runningPasses.remove(std::make_pair(pass, op)); if (impl->localReproducer) { impl->activeContexts.pop_back(); // Re-enable the previous pass recovery context, if there was one. This may // happen in the case of dynamic pass pipelines. if (!impl->activeContexts.empty()) impl->activeContexts.back()->enable(); } } //===----------------------------------------------------------------------===// // CrashReproducerInstrumentation //===----------------------------------------------------------------------===// namespace { struct CrashReproducerInstrumentation : public PassInstrumentation { CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) : generator(generator) {} ~CrashReproducerInstrumentation() override = default; void runBeforePass(Pass *pass, Operation *op) override { if (!isa(pass)) generator.prepareReproducerFor(pass, op); } void runAfterPass(Pass *pass, Operation *op) override { if (!isa(pass)) generator.removeLastReproducerFor(pass, op); } void runAfterPassFailed(Pass *pass, Operation *op) override { // Only generate one reproducer per crash reproducer instrumentation. if (alreadyFailed) return; alreadyFailed = true; generator.finalize(op, /*executionResult=*/failure()); } private: /// The generator used to create crash reproducers. PassCrashReproducerGenerator &generator; bool alreadyFailed = false; }; } // namespace //===----------------------------------------------------------------------===// // FileReproducerStream //===----------------------------------------------------------------------===// namespace { /// This class represents a default instance of mlir::ReproducerStream /// that is backed by a file. struct FileReproducerStream : public mlir::ReproducerStream { FileReproducerStream(std::unique_ptr outputFile) : outputFile(std::move(outputFile)) {} ~FileReproducerStream() override { outputFile->keep(); } /// Returns a description of the reproducer stream. StringRef description() override { return outputFile->getFilename(); } /// Returns the stream on which to output the reproducer. raw_ostream &os() override { return outputFile->os(); } private: /// ToolOutputFile corresponding to opened `filename`. std::unique_ptr outputFile = nullptr; }; } // namespace //===----------------------------------------------------------------------===// // PassManager //===----------------------------------------------------------------------===// LogicalResult PassManager::runWithCrashRecovery(Operation *op, AnalysisManager am) { crashReproGenerator->initialize(getPasses(), op, verifyPasses); // Safely invoke the passes within a recovery context. LogicalResult passManagerResult = failure(); llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread( [&] { passManagerResult = runPasses(op, am); }); crashReproGenerator->finalize(op, passManagerResult); return passManagerResult; } static ReproducerStreamFactory makeReproducerStreamFactory(StringRef outputFile) { // Capture the filename by value in case outputFile is out of scope when // invoked. std::string filename = outputFile.str(); return [filename](std::string &error) -> std::unique_ptr { std::unique_ptr outputFile = mlir::openOutputFile(filename, &error); if (!outputFile) { error = "Failed to create reproducer stream: " + error; return nullptr; } return std::make_unique(std::move(outputFile)); }; } void printAsTextualPipeline( raw_ostream &os, StringRef anchorName, const llvm::iterator_range &passes); std::string mlir::makeReproducer( StringRef anchorName, const llvm::iterator_range &passes, Operation *op, StringRef outputFile, bool disableThreads, bool verifyPasses) { std::string description; std::string pipelineStr; llvm::raw_string_ostream passOS(pipelineStr); ::printAsTextualPipeline(passOS, anchorName, passes); appendReproducer(description, op, makeReproducerStreamFactory(outputFile), pipelineStr, disableThreads, verifyPasses); return description; } void PassManager::enableCrashReproducerGeneration(StringRef outputFile, bool genLocalReproducer) { enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile), genLocalReproducer); } void PassManager::enableCrashReproducerGeneration( ReproducerStreamFactory factory, bool genLocalReproducer) { assert(!crashReproGenerator && "crash reproducer has already been initialized"); if (genLocalReproducer && getContext()->isMultithreadingEnabled()) llvm::report_fatal_error( "Local crash reproduction can't be setup on a " "pass-manager without disabling multi-threading first."); crashReproGenerator = std::make_unique( factory, genLocalReproducer); addInstrumentation( std::make_unique(*crashReproGenerator)); } //===----------------------------------------------------------------------===// // Asm Resource //===----------------------------------------------------------------------===// void PassReproducerOptions::attachResourceParser(ParserConfig &config) { auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult { if (entry.getKey() == "pipeline") { FailureOr value = entry.parseAsString(); if (succeeded(value)) this->pipeline = std::move(*value); return value; } if (entry.getKey() == "disable_threading") { FailureOr value = entry.parseAsBool(); if (succeeded(value)) this->disableThreading = *value; return value; } if (entry.getKey() == "verify_each") { FailureOr value = entry.parseAsBool(); if (succeeded(value)) this->verifyEach = *value; return value; } return entry.emitError() << "unknown 'mlir_reproducer' resource key '" << entry.getKey() << "'"; }; config.attachResourceParser("mlir_reproducer", parseFn); } LogicalResult PassReproducerOptions::apply(PassManager &pm) const { if (pipeline.has_value()) { FailureOr reproPm = parsePassPipeline(*pipeline); if (failed(reproPm)) return failure(); static_cast(pm) = std::move(*reproPm); } if (disableThreading.has_value()) pm.getContext()->disableMultithreading(*disableThreading); if (verifyEach.has_value()) pm.enableVerifier(*verifyEach); return success(); }