//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include #include using namespace mlir; namespace { /// This pass tests that: /// 1) we can shuffle use-lists correctly; /// 2) use-list orders are preserved after a roundtrip to bytecode. class TestPreserveUseListOrders : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders) TestPreserveUseListOrders() = default; TestPreserveUseListOrders(const TestPreserveUseListOrders &pass) : PassWrapper(pass) {} StringRef getArgument() const final { return "test-verify-uselistorder"; } StringRef getDescription() const final { return "Verify that roundtripping the IR to bytecode preserves the order " "of the uselists"; } Option rngSeed{*this, "rng-seed", llvm::cl::desc("Specify an input random seed"), llvm::cl::init(1)}; LogicalResult initialize(MLIRContext *context) override { rng.seed(static_cast(rngSeed)); return success(); } void runOnOperation() override { // Clone the module so that we can plug in this pass to any other // independently. OwningOpRef cloneModule = getOperation().clone(); // 1. Compute the op numbering of the module. computeOpNumbering(*cloneModule); // 2. Loop over all the values and shuffle the uses. While doing so, check // that each shuffle is correct. if (failed(shuffleUses(*cloneModule))) return signalPassFailure(); // 3. Do a bytecode roundtrip to version 3, which supports use-list order // preservation. auto roundtripModuleOr = doRoundtripToBytecode(*cloneModule, 3); // If the bytecode roundtrip failed, try to roundtrip the original module // to version 2, which does not support use-list. If this also fails, the // original module had an issue unrelated to uselists. if (failed(roundtripModuleOr)) { auto testModuleOr = doRoundtripToBytecode(getOperation(), 2); if (failed(testModuleOr)) return; return signalPassFailure(); } // 4. Recompute the op numbering on the new module. The numbering should be // the same as (1), but on the new operation pointers. computeOpNumbering(roundtripModuleOr->get()); // 5. Loop over all the values and verify that the use-list is consistent // with the post-shuffle order of step (2). if (failed(verifyUseListOrders(roundtripModuleOr->get()))) return signalPassFailure(); } private: FailureOr> doRoundtripToBytecode(Operation *module, uint32_t version) { std::string str; llvm::raw_string_ostream m(str); BytecodeWriterConfig config; config.setDesiredBytecodeVersion(version); if (failed(writeBytecodeToFile(module, m, config))) return failure(); ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true); auto newModuleOp = parseSourceString(StringRef(str), parseConfig); if (!newModuleOp.get()) return failure(); return newModuleOp; } /// Compute an ordered numbering for all the operations in the IR. void computeOpNumbering(Operation *topLevelOp) { uint32_t operationID = 0; opNumbering.clear(); topLevelOp->walk( [&](Operation *op) { opNumbering.try_emplace(op, operationID++); }); } template SmallVector getUseIDs(ValueT val) { return SmallVector(llvm::map_range(val.getUses(), [&](auto &use) { return bytecode::getUseID(use, opNumbering.at(use.getOwner())); })); } LogicalResult shuffleUses(Operation *topLevelOp) { uint32_t valueID = 0; /// Permute randomly the use-list of each value. It is guaranteed that at /// least one pair of the use list is permuted. auto doShuffleForRange = [&](ValueRange range) -> LogicalResult { for (auto val : range) { if (val.use_empty() || val.hasOneUse()) continue; /// Get a valid index permutation for the uses of value. SmallVector permutation = getRandomPermutation(val); /// Store original order and verify that the shuffle was applied /// correctly. auto useIDs = getUseIDs(val); /// Apply shuffle to the uselist. val.shuffleUseList(permutation); /// Get the new order and verify the shuffle happened correctly. auto permutedIDs = getUseIDs(val); if (permutedIDs.size() != useIDs.size()) return failure(); for (size_t idx = 0; idx < permutation.size(); idx++) if (useIDs[idx] != permutedIDs[permutation[idx]]) return failure(); referenceUseListOrder.try_emplace( valueID++, llvm::map_range(val.getUses(), [&](auto &use) { return bytecode::getUseID(use, opNumbering.at(use.getOwner())); })); } return success(); }; return walkOverValues(topLevelOp, doShuffleForRange); } LogicalResult verifyUseListOrders(Operation *topLevelOp) { uint32_t valueID = 0; /// Check that the use-list for the value range matches the one stored in /// the reference. auto doValidationForRange = [&](ValueRange range) -> LogicalResult { for (auto val : range) { if (val.use_empty() || val.hasOneUse()) continue; auto referenceOrder = referenceUseListOrder.at(valueID++); for (auto [use, referenceID] : llvm::zip(val.getUses(), referenceOrder)) { uint64_t uniqueID = bytecode::getUseID(use, opNumbering.at(use.getOwner())); if (uniqueID != referenceID) { use.getOwner()->emitError() << "found use-list order mismatch for value: " << val; return failure(); } } } return success(); }; return walkOverValues(topLevelOp, doValidationForRange); } /// Walk over blocks and operations and execute a callable over the ranges of /// operands/results respectively. template LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) { auto blockWalk = topLevelOp->walk([&](Block *block) { if (failed(callable(block->getArguments()))) return WalkResult::interrupt(); return WalkResult::advance(); }); if (blockWalk.wasInterrupted()) return failure(); auto resultsWalk = topLevelOp->walk([&](Operation *op) { if (failed(callable(op->getResults()))) return WalkResult::interrupt(); return WalkResult::advance(); }); return failure(resultsWalk.wasInterrupted()); } /// Creates a random permutation of the uselist order chain of the provided /// value. SmallVector getRandomPermutation(Value value) { size_t numUses = std::distance(value.use_begin(), value.use_end()); SmallVector permutation(numUses); unsigned zero = 0; std::iota(permutation.begin(), permutation.end(), zero); std::shuffle(permutation.begin(), permutation.end(), rng); return permutation; } /// Map each value to its use-list order encoded with unique use IDs. DenseMap> referenceUseListOrder; /// Map each operation to its global ID. DenseMap opNumbering; std::default_random_engine rng; }; } // namespace namespace mlir { void registerTestPreserveUseListOrders() { PassRegistration(); } } // namespace mlir