//===- TestIntRangeInference.cpp - Create consts from range inference ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // TODO: This pass is needed to test integer range inference until that // functionality has been integrated into SCCP. //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/TypeID.h" #include "mlir/Transforms/FoldUtils.h" #include using namespace mlir; using namespace mlir::dataflow; /// Patterned after SCCP static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, OperationFolder &folder, Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); std::optional maybeConstValue = inferredRange.getConstantValue(); if (!maybeConstValue.has_value()) return failure(); Operation *maybeDefiningOp = value.getDefiningOp(); Dialect *valueDialect = maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); Value constant = folder.getOrCreateConstant( b.getInsertionBlock(), valueDialect, constAttr, value.getType()); if (!constant) return failure(); value.replaceAllUsesWith(constant); return success(); } static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef initialRegions) { SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { for (Region ®ion : regions) for (Block &block : llvm::reverse(region)) worklist.push_back(&block); }; OpBuilder builder(context); OperationFolder folder(context); addToWorklist(initialRegions); while (!worklist.empty()) { Block *block = worklist.pop_back_val(); for (Operation &op : llvm::make_early_inc_range(*block)) { builder.setInsertionPoint(&op); // Replace any result with constants. bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) replacedAll &= succeeded(replaceWithConstant(solver, builder, folder, res)); // If all of the results of the operation were replaced, try to erase // the operation completely. if (replacedAll && wouldOpBeTriviallyDead(&op)) { assert(op.use_empty() && "expected all uses to be replaced"); op.erase(); continue; } // Add any the regions of this operation to the worklist. addToWorklist(op.getRegions()); } // Replace any block arguments with constants. builder.setInsertionPointToStart(block); for (BlockArgument arg : block->getArguments()) (void)replaceWithConstant(solver, builder, folder, arg); } } namespace { struct TestIntRangeInference : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) StringRef getArgument() const final { return "test-int-range-inference"; } StringRef getDescription() const final { return "Test integer range inference analysis"; } void runOnOperation() override { Operation *op = getOperation(); DataFlowSolver solver; solver.load(); solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); rewrite(solver, op->getContext(), op->getRegions()); } }; } // end anonymous namespace namespace mlir { namespace test { void registerTestIntRangeInference() { PassRegistration(); } } // end namespace test } // end namespace mlir