//===- ShardingPropagation.cpp ------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" #include namespace mlir { namespace mesh { #define GEN_PASS_DEF_SHARDINGPROPAGATION #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" } // namespace mesh } // namespace mlir #define DEBUG_TYPE "sharding-propagation" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::mesh; namespace { //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// // This method retrieves all potential sharding attributes, prioritizing // specific shardings. For example, mustShardings = [shard0, None] and // optionalShardings = [None, shard1], the result will be [[shard0, shard1], // [shard0, None]] static SmallVector> getOrderedPossibleShardingAttrs(ArrayRef mustShardings, ArrayRef optionalShardings) { SmallVector> allShardingAttrs; SmallVector curShardingAttrs; std::function dfsCreateShardingAttrs = [&](size_t i) { if (i == mustShardings.size()) { allShardingAttrs.push_back( SmallVector(curShardingAttrs)); return; } if (mustShardings[i]) { curShardingAttrs.push_back(mustShardings[i]); dfsCreateShardingAttrs(i + 1); curShardingAttrs.pop_back(); return; } if (optionalShardings[i]) { curShardingAttrs.push_back(optionalShardings[i]); dfsCreateShardingAttrs(i + 1); curShardingAttrs.pop_back(); curShardingAttrs.push_back(nullptr); dfsCreateShardingAttrs(i + 1); curShardingAttrs.pop_back(); return; } curShardingAttrs.push_back(nullptr); dfsCreateShardingAttrs(i + 1); curShardingAttrs.pop_back(); }; dfsCreateShardingAttrs(0); return allShardingAttrs; } // For each operation that implements the ShardingInterface, infer the sharding // option of the operation from its operands and/or results using the // `getShardingOption` method. If the inferred sharding option is not empty, add // a `mesh.shard` operation for all remaining operands and results that do not // have sharding annotations. LogicalResult visitOp(Operation *op, OpBuilder &builder) { if (op->hasTrait() || llvm::isa(op)) return success(); ShardingInterface shardingOp = llvm::dyn_cast(op); if (!shardingOp) { op->emitOpError() << "sharding interface is not implemented."; return failure(); } // collect MeshShardingAttr from results SmallVector allowConflictsResultShardings; allowConflictsResultShardings.resize(op->getNumResults()); SmallVector resultMustShardings; resultMustShardings.resize(op->getNumResults()); for (OpResult result : op->getResults()) { FailureOr> maybeShardAttr = getMeshShardingAttr(result); if (failed(maybeShardAttr)) continue; if (!maybeShardAttr->first) resultMustShardings[result.getResultNumber()] = maybeShardAttr->second; else allowConflictsResultShardings[result.getResultNumber()] = maybeShardAttr->second; } // collect MeshShardingAttr from operands SmallVector allowConflictsOperandShardings; allowConflictsOperandShardings.resize(op->getNumOperands()); SmallVector operandMustShardings; operandMustShardings.resize(op->getNumOperands()); for (OpOperand &opOperand : op->getOpOperands()) { FailureOr> maybeShardAttr = getMeshShardingAttr(opOperand); if (failed(maybeShardAttr)) continue; if (maybeShardAttr->first) operandMustShardings[opOperand.getOperandNumber()] = maybeShardAttr->second; else allowConflictsOperandShardings[opOperand.getOperandNumber()] = maybeShardAttr->second; } // try to get the sharding option SmallVector> possibleOperandShardingAttrs = getOrderedPossibleShardingAttrs(operandMustShardings, allowConflictsOperandShardings); SmallVector> possibleResultShardingAttrs = getOrderedPossibleShardingAttrs(resultMustShardings, allowConflictsResultShardings); FailureOr finalShardingOption = failure(); for (ArrayRef resultShardings : possibleResultShardingAttrs) { if (succeeded(finalShardingOption)) break; for (ArrayRef operandShardings : possibleOperandShardingAttrs) { FailureOr shardingOption = shardingOp.getShardingOption(operandShardings, resultShardings); if (succeeded(shardingOption)) { finalShardingOption = shardingOption; break; } } } if (failed(finalShardingOption)) { op->emitOpError() << "fail to get sharding option."; return failure(); } // sharding info is empty, return immediately if (finalShardingOption->empty) return success(); if (failed( shardingOp.addShardingAnnotations(builder, *finalShardingOption))) { op->emitOpError() << "fail to set sharding annotations."; return failure(); } return success(); } //===----------------------------------------------------------------------===// // ShardingPropagation //===----------------------------------------------------------------------===// struct ShardingPropagation : public mesh::impl::ShardingPropagationBase { void runOnOperation() override { func::FuncOp funcOp = getOperation(); MLIRContext *ctx = funcOp.getContext(); Region ®ion = funcOp.getBody(); OpBuilder builder(ctx); if (!region.hasOneBlock()) { funcOp.emitOpError() << "only one block is supported!"; signalPassFailure(); } Block &block = region.front(); LLVM_DEBUG( DBGS() << "print all the ops' iterator types and indexing maps in the " "block.\n"; for (Operation &op : block.getOperations()) { if (auto shardingOp = llvm::dyn_cast(&op)) shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); }); // 1. propagate in reversed order for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) if (failed(visitOp(&op, builder))) return signalPassFailure(); LLVM_DEBUG(DBGS() << "After reversed order propagation:\n" << funcOp << "\n"); // 2. propagate in original order for (Operation &op : llvm::make_early_inc_range(block)) if (failed(visitOp(&op, builder))) return signalPassFailure(); } }; } // namespace