//====----- OutlineShapeComputation.cpp -----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" #include #include #include namespace mlir { #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "outline-shape-computation" using namespace mlir; namespace { // A Value is an input of the cluster if it is an operand of an operation in the // cluster and its defining operation is not in the cluster. SmallVector getInputsOfCluster(const llvm::SmallVector &cluster) { SmallVector inputs; llvm::SmallDenseSet inputSet; llvm::SmallDenseSet opSet; for (Operation *op : cluster) { bool inserted = opSet.insert(op).second; (void)inserted; assert(inserted && "cluster contains duplicate operations"); } for (Operation *op : cluster) { for (Value operand : op->getOperands()) { Operation *operandOp = operand.getDefiningOp(); if (opSet.contains(operandOp)) { // Skip if defining op is in the cluster. continue; } if (inputSet.insert(operand).second) inputs.push_back(operand); } } return inputs; } // Create a shape.func representing the shape computation for `shape`. std::pair> createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, Value shape, StringRef fnName, Location loc) { SmallVector inputs = getInputsOfCluster(cluster); auto fnType = cluster.empty() ? b.getFunctionType(shape.getType(), shape.getType()) : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType()); shape::FuncOp fnOp = b.create(loc, fnName, fnType); Block *block = fnOp.addEntryBlock(); b.setInsertionPoint(block, block->end()); IRMapping bvm; if (cluster.empty()) { bvm.map(shape, fnOp.getArgument(0)); } else { for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); } for (Operation *op : cluster) b.clone(*op, bvm); llvm::SmallVector fnReturns; fnReturns.push_back(bvm.lookupOrDefault(shape)); b.create(loc, fnReturns); fnOp.setPrivate(); return std::make_pair(fnOp, inputs); } // The operations in the cluster might be unsorted, which could be inconvenient // when creating shape.func op. DenseMap> getOrderedClusters(const DenseMap> &clusters, func::FuncOp funcOp) { // Compute all clusters that each operation is in DenseMap> op2Shapes; for (const auto &it : clusters) { Value shape = it.first; const DenseSet &cluster = it.second; for (Operation *cOp : cluster) op2Shapes[cOp].push_back(shape); } // Iterate through all operations in order. Get all the clusters `cOp` belongs // to and construct the new ordered cluster as it traverses. DenseMap> orderedClusters; funcOp.walk([&](Operation *op) { auto it = op2Shapes.find(op); if (it != op2Shapes.end()) { Operation *cOp = it->first; for (Value shape : it->second) orderedClusters[shape].push_back(cOp); } }); return orderedClusters; } void constructShapeFunc( const std::vector &allWithOps, MLIRContext *context, DenseMap> &clusters, SymbolTable &symbolTable, DenseMap &dynShape2ShapeFunc, func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) { std::string shapeCalculationNamePrefix = "shape_cal_"; int shapeCalculationNameIdx = 0; OpBuilder builder(context); // Construct a shape function for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); Value shape = withOp.getShape(); RankedTensorType rankedType = dyn_cast(value.getType()); if (rankedType == nullptr) continue; const SmallVector &cluster = clusters[shape]; shape::ShapeMappingValue shapeMappingValue; auto it = dynShape2ShapeFunc.find(shape); if (it == dynShape2ShapeFunc.end()) { std::string name = shapeCalculationNamePrefix + std::to_string(shapeCalculationNameIdx++); Location loc = value.getLoc(); builder.setInsertionPointAfter(funcOp); auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); const SmallVector &inputs = pair.second; shape::FuncOp shapeFuncOp = pair.first; StringAttr insertedName = symbolTable.insert(shapeFuncOp); auto symbol = FlatSymbolRefAttr::get(context, insertedName); shapeMappingValue.funcSymbol = symbol; shapeMappingValue.inputs = inputs; } else { shapeMappingValue = it->second; } dynShape2ShapeFunc[shape] = shapeMappingValue; shapeMappingAnalysis.shapeMapping.insert( std::make_pair(value, shapeMappingValue)); } } struct OutlineShapeComputationPass : public impl::OutlineShapeComputationBase { void runOnOperation() override; private: bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput); void getClusterFromValue(Value shape, DenseMap> &clusters); DenseMap> constructClustersForEachShape(const std::vector &allWithOps, func::FuncOp funcOp); DenseSet onlyUsedByWithShapes; }; class TensorDimOpRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::DimOp op, PatternRewriter &rewriter) const override { auto shapeOf = rewriter.create(op.getLoc(), op.getSource()); rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, op.getIndex()); return success(); } }; void OutlineShapeComputationPass::runOnOperation() { ModuleOp moduleOp = getOperation(); SymbolTable symbolTable(moduleOp); DenseMap dynShape2ShapeFunc; auto &shapeMappingAnalysis = getAnalysis(); // TODO: This is as we populate this analysis during a pass that mutates. This // pass currently requires 1 single module being compiled. shapeMappingAnalysis.shapeMapping.clear(); markAnalysesPreserved(); moduleOp.walk([&](func::FuncOp funcOp) { MLIRContext *context = funcOp.getContext(); RewritePatternSet prevPatterns(context); prevPatterns.insert(context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) return signalPassFailure(); // initialize class member `onlyUsedByWithShapes` onlyUsedByWithShapes.clear(); funcOp.walk([&](Operation *op) { calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr); }); LLVM_DEBUG({ llvm::dbgs() << "onlyUsedByWithShapes table: \n"; for (auto it : onlyUsedByWithShapes) llvm::dbgs() << *it << "\n"; }); // collect all the shape.with_shape ops. std::vector allWithOps; funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); }); DenseMap> clusters = constructClustersForEachShape(allWithOps, funcOp); constructShapeFunc(allWithOps, context, clusters, symbolTable, dynShape2ShapeFunc, funcOp, shapeMappingAnalysis); for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); for (Operation *user : llvm::make_early_inc_range(withOp.getResult().getUsers())) { if (auto valueOf = llvm::dyn_cast(user)) { // For pattern like // %1 = shape.with_shape %arg1, %0 // %2 = shape.value_of %1 // because shape.value doesn't care the shape, the shape.with_shape is // redundant. // If type of %arg1 and %2 has same type, just // replaced %2 with %arg1. // If type of %arg1 has different type like !shape.value_shape, // transform into // %2 = shape.value_of %arg1 if (valueOf.getType() == value.getType()) valueOf.replaceAllUsesWith(value); else valueOf.setOperand(value); } } } // Apply patterns, note this also performs DCE. if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) return signalPassFailure(); }); } DenseMap> OutlineShapeComputationPass::constructClustersForEachShape( const std::vector &allWithOps, func::FuncOp funcOp) { DenseMap> clusters; for (shape::WithOp withOp : allWithOps) { Value shape = withOp.getShape(); if (clusters.count(shape) == 0) getClusterFromValue(shape, clusters); } return getOrderedClusters(clusters, funcOp); } // The output of a cluster is the `shape`, and the inputs are the outputs of // operations who are not in `onlyUsedByWithShapes` void OutlineShapeComputationPass::getClusterFromValue( Value shape, DenseMap> &clusters) { DenseSet cluster; DenseSet visited; std::queue queue; // defOp == nullptr means shape is the argument of the func op if (Operation *defOp = shape.getDefiningOp()) { visited.insert(defOp); queue.push(defOp); } while (!queue.empty()) { Operation *op = queue.front(); queue.pop(); if (onlyUsedByWithShapes.contains(op)) { cluster.insert(op); for (Value inp : op->getOperands()) { Operation *inpDefOp = inp.getDefiningOp(); if (nullptr != inpDefOp && !visited.contains(inpDefOp)) { visited.insert(inpDefOp); queue.push(inpDefOp); } } } } clusters[shape] = std::move(cluster); } // Returns whether `op` is a shape.with_shape, or all the users' of `op` // eventually point to the shape operand of shape.with_shape ops bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively( Operation *op, Value prevOutput) { if (onlyUsedByWithShapes.contains(op)) return true; if (auto withOp = llvm::dyn_cast(op)) return withOp.getShape() == prevOutput; if (op->use_empty()) return false; for (Value oup : op->getResults()) for (Operation *user : oup.getUsers()) if (!calOnlyUsedByWithShapesRecursively(user, oup)) return false; onlyUsedByWithShapes.insert(op); return true; } } // namespace std::unique_ptr> mlir::createOutlineShapeComputationPass() { return std::make_unique(); }