//===- Spmdization.cpp --------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/ADL.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include #include #include #include #include #include namespace mlir { namespace mesh { int64_t shardDimension(int64_t dim, int64_t shardCount) { if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount)) return ShapedType::kDynamic; assert(dim % shardCount == 0); return ceilDiv(dim, shardCount); } int64_t unshardDimension(int64_t dim, int64_t shardCount) { if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount)) return ShapedType::kDynamic; return dim * shardCount; } template int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) { int64_t res = 1; for (auto splitAxis : splitAxes) { int64_t meshDimSize = meshShape[splitAxis]; if (ShapedType::isDynamic(meshDimSize)) { return ShapedType::kDynamic; } res *= meshDimSize; } return res; } // Compute the shape for the tensor on each device in the mesh. // Example: // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 // would result in a shape for each shard of ?x2x?. template static void shardShape(const InShape &inShape, const MeshShape &meshShape, const SplitAxes &splitAxes, OutShape &outShape) { std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape), llvm::adl_begin(outShape)); for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { outShape[tensorAxis] = shardDimension(inShape[tensorAxis], shardCount(meshShape, innerSplitAxes.asArrayRef())); } } ShapedType shardShapedType(ShapedType shape, ClusterOp mesh, MeshShardingAttr sharding) { using Dim = std::decay_t; SmallVector resShapeArr(shape.getShape().size()); shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), resShapeArr); return shape.clone(resShapeArr); } template static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes) { return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { return sourceAxes.contains(targetAxis); }); } // Return the reduced value and its corresponding sharding. // Example: // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> // targetSharding = <@mesh_1d, [[]]> // Then will apply all-reduce on the source value // and return it with the sharding <@mesh_1d, [[0]]>. static std::tuple, MeshShardingAttr> handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue sourceShard) { if (sourceSharding.getPartialAxes().empty() && targetSharding.getPartialAxes().empty()) { return {sourceShard, sourceSharding}; } assert(targetSharding.getPartialAxes().empty() || (!sourceSharding.getPartialAxes().empty() && sourceSharding.getPartialType() == targetSharding.getPartialType())); using Axis = std::decay_t; using AxisSet = llvm::SmallDenseSet; AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), sourceSharding.getPartialAxes().end()); AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), targetSharding.getPartialAxes().end()); assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, targetShardingPartialAxesSet)); llvm::SmallVector allReduceMeshAxes; llvm::copy_if(sourceShardingPartialAxesSet, std::back_inserter(allReduceMeshAxes), [&targetShardingPartialAxesSet](Axis a) { return !targetShardingPartialAxesSet.contains(a); }); if (allReduceMeshAxes.empty()) { return {sourceShard, sourceSharding}; } builder.setInsertionPointAfterValue(sourceShard); TypedValue resultValue = builder .create(sourceShard.getLoc(), sourceShard.getType(), sourceSharding.getCluster().getLeafReference(), allReduceMeshAxes, sourceShard, sourceSharding.getPartialType()) .getResult() .cast>(); llvm::SmallVector remainingPartialAxes; llvm::copy_if(sourceShardingPartialAxesSet, std::back_inserter(allReduceMeshAxes), [&targetShardingPartialAxesSet](Axis a) { return targetShardingPartialAxesSet.contains(a); }); MeshShardingAttr resultSharding = MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(), sourceSharding.getSplitAxes(), remainingPartialAxes, sourceSharding.getPartialType()); return {resultValue, resultSharding}; } static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= splitTensorAxis) { targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); } auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); targetSplitAxes.push_back(splitMeshAxis); targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( ctx, sourceSharding.getCluster(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape, int64_t splitTensorAxis, int64_t splitCount) { SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); targetShape[splitTensorAxis] = shardDimension(targetShape[splitTensorAxis], splitCount); return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } // Split a replicated tensor along a mesh axis. // e.g. [[0, 1]] -> [[0, 1, 2]]. // Returns the spmdized target value with its sharding. // // The implementation is the extract the tensor slice corresponding // to the current device. static std::tuple, MeshShardingAttr> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue sourceShard, ClusterOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); Value zero = builder.create(builder.getIndexAttr(0)); Value processIndexAlongAxis = builder .create(mesh.getSymName(), SmallVector({splitMeshAxis})) .getResult()[0]; MeshShardingAttr targetSharding = targetShardingInSplitLastAxis( ctx, sourceSharding, splitTensorAxis, splitMeshAxis); ShapedType targetShape = targetShapeInSplitLastAxis( sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]); Value meshAxisSize = builder .create(mesh.getSymName(), SmallVector({splitMeshAxis})) .getResult()[0]; Value sourceAxisSize = builder.create(sourceShard, splitTensorAxis); Value sourceAxisSizeModMeshAxisSize = builder.create(sourceAxisSize, meshAxisSize); Value isTargetShapeExactlyDivisible = builder.create( arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero); builder.create( isTargetShapeExactlyDivisible, "Sharding a tensor with axis size that is not exactly divisible by the " "mesh axis size is not supported."); Value targetAxisSize = builder.create(sourceAxisSize, meshAxisSize); Value axisOffset = builder.create(targetAxisSize, processIndexAlongAxis); SmallVector staticOffsets(targetShape.getRank(), 0); staticOffsets[splitTensorAxis] = ShapedType::kDynamic; DenseI64ArrayAttr staticOffsetsAttr = DenseI64ArrayAttr::get(ctx, staticOffsets); SmallVector dynamicOffsets(1, axisOffset); DenseI64ArrayAttr staticSizesAttr = DenseI64ArrayAttr::get(ctx, targetShape.getShape()); SmallVector dynamicSizes; for (int64_t i = 0; i < targetShape.getRank(); ++i) { if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) { if (i == splitTensorAxis) { dynamicSizes.push_back(targetAxisSize); } else { Value dimSize = builder.create(sourceShard, i); dynamicSizes.push_back(dimSize); } } } DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get( ctx, SmallVector(targetShape.getRank(), 1)); TypedValue targetShard = builder .create( targetShape, sourceShard, dynamicOffsets, dynamicSizes, SmallVector({}), staticOffsetsAttr, staticSizesAttr, staticStridesAttr) .getResult(); return {targetShard.cast>(), targetSharding}; } // Detect if the resharding is of type e.g. // [[0, 1]] -> [[0, 1, 2]]. // If detected, returns the corresponding tensor axis mesh axis pair. // Does not detect insertions like // [[0, 1]] -> [[0, 2, 1]]. static std::optional> detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) { for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); ++tensorAxis) { if (sourceSharding.getSplitAxes().size() > tensorAxis) { if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != targetSharding.getSplitAxes()[tensorAxis].size()) { continue; } if (!llvm::equal( sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), llvm::make_range( targetSharding.getSplitAxes()[tensorAxis] .asArrayRef() .begin(), targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 1))) { continue; } } else { if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { continue; } } return std::make_tuple( tensorAxis, targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); } return std::nullopt; } static std::optional, MeshShardingAttr>> trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue sourceShard) { if (auto detectRes = detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { auto [tensorAxis, meshAxis] = detectRes.value(); return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, tensorAxis, meshAxis); } return std::nullopt; } // Detect if the resharding is of type e.g. // [[0, 1, 2]] -> [[0, 1]]. // If detected, returns the corresponding tensor axis mesh axis pair. static std::optional> detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) { for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); ++tensorAxis) { if (targetSharding.getSplitAxes().size() > tensorAxis) { if (sourceSharding.getSplitAxes()[tensorAxis].size() != targetSharding.getSplitAxes()[tensorAxis].size() + 1) continue; if (!llvm::equal( llvm::make_range( sourceSharding.getSplitAxes()[tensorAxis] .asArrayRef() .begin(), sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - 1), targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) continue; } else { if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) continue; } return std::make_tuple( tensorAxis, sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); } return std::nullopt; } static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); assert(static_cast(targetShardingSplitAxes.size()) > splitTensorAxis); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); targetSplitAxes.pop_back(); targetShardingSplitAxes[splitTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( ctx, sourceSharding.getCluster(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } static ShapedType allGatherResultShapeInUnsplitLastAxis( ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); targetShape[splitTensorAxis] = unshardDimension(targetShape[splitTensorAxis], splitCount); return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } static std::tuple, MeshShardingAttr> unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, ClusterOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); MeshShardingAttr targetSharding = targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); Value allGatherResult = builder.create( RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), mesh.getSymName(), SmallVector({splitMeshAxis}), sourceShard, APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = builder.create(targetShape, allGatherResult) .getResult() .cast>(); return {targetShard, targetSharding}; } static std::optional, MeshShardingAttr>> tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { auto [tensorAxis, meshAxis] = detectRes.value(); return unsplitLastAxisInResharding(builder, sourceSharding, sourceUnshardedShape, sourceShard, mesh, tensorAxis, meshAxis); } return std::nullopt; } // Detect if the resharding is of type e.g. // [[0, 1], [2]] -> [[0], [1, 2]]. // Only moving the last axis counts. // If detected, returns the corresponding (source_tensor_axis, // target_tensor_axis, mesh_axis) tuple. static std::optional> detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) { for (size_t sourceTensorAxis = 0; sourceTensorAxis < sourceSharding.getSplitAxes().size(); ++sourceTensorAxis) { for (size_t targetTensorAxis = 0; targetTensorAxis < targetSharding.getSplitAxes().size(); ++targetTensorAxis) { if (sourceTensorAxis == targetTensorAxis) continue; if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || targetSharding.getSplitAxes()[targetTensorAxis].empty() || sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .back()) continue; if (!llvm::equal( llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] .asArrayRef() .begin(), sourceSharding.getSplitAxes()[sourceTensorAxis] .asArrayRef() .end() - 1), llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .begin(), targetSharding.getSplitAxes()[targetTensorAxis] .asArrayRef() .end() - 1))) continue; return std::make_tuple( sourceTensorAxis, targetTensorAxis, sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); } } return std::nullopt; } static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) { SmallVector targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast(targetShardingSplitAxes.size()) <= targetTensorAxis) { targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); } auto sourceSplitAxes = llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); assert(!sourceSplitAxes.empty()); auto meshAxis = sourceSplitAxes.back(); sourceSplitAxes.pop_back(); targetShardingSplitAxes[sourceTensorAxis] = MeshAxesAttr::get(ctx, sourceSplitAxes); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); targetSplitAxes.push_back(meshAxis); targetShardingSplitAxes[targetTensorAxis] = MeshAxesAttr::get(ctx, targetSplitAxes); return MeshShardingAttr::get( ctx, sourceSharding.getCluster(), targetShardingSplitAxes, sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); } static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) { SmallVector targetShape = llvm::to_vector(sourceShape.getShape()); targetShape[sourceTensorAxis] = unshardDimension(targetShape[sourceTensorAxis], splitCount); targetShape[targetTensorAxis] = shardDimension(targetShape[targetTensorAxis], splitCount); return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } static std::tuple, MeshShardingAttr> moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); MeshShardingAttr targetSharding = targetShardingInMoveLastAxis( ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, targetTensorAxis); Value allToAllResult = builder.create( RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), mesh.getSymName(), SmallVector({meshAxis}), sourceShard, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, mesh, targetSharding); TypedValue targetShard = builder.create(targetShape, allToAllResult) .getResult() .cast>(); return {targetShard, targetSharding}; } static std::optional, MeshShardingAttr>> tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue sourceShard) { if (auto detectRes = detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); return moveLastSplitAxisInResharding( builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, sourceTensorAxis, targetTensorAxis, meshAxis); } return std::nullopt; } // Handles only resharding on a 1D mesh. // Currently the sharded tensor axes must be exactly divisible by the single // mesh axis size. static TypedValue reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { assert(sourceShard.getType() == shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); [[maybe_unused]] ShapedType targetShardType = shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); assert(sourceShard.getType().getRank() == targetShardType.getRank()); assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); auto [reducedSourceShard, reducedSourceSharding] = handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, sourceShard); if (reducedSourceSharding == targetSharding) { return reducedSourceShard; } TypedValue targetShard; MeshShardingAttr actualTargetSharding; if (auto tryRes = tryMoveLastSplitAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, sourceUnshardedValue.getType(), reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = trySplitLastAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = tryUnsplitLastAxisInResharding( builder, mesh, reducedSourceSharding, targetSharding, sourceUnshardedValue.getType(), reducedSourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else { assert(false && "Did not find any pattern to apply."); } assert(actualTargetSharding == targetSharding); assert(targetShard.getType() == targetShardType); return targetShard; } TypedValue reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue sourceUnshardedValue, TypedValue sourceShard) { // Resort to handling only 1D meshes since the general case is complicated if // it needs to be communication efficient in terms of minimizing the data // transfered between devices. return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, sourceUnshardedValue, sourceShard); } TypedValue reshard(OpBuilder &builder, ClusterOp mesh, ShardOp source, ShardOp target, TypedValue sourceShardValue) { assert(!source.getAnnotateForUsers()); assert(target.getAnnotateForUsers()); assert(source.getResult() == target.getOperand()); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard( implicitLocOpBuilder, mesh, source.getShard(), target.getShard(), source.getSrc().cast>(), sourceShardValue); } void reshardingRegisterDependentDialects(DialectRegistry ®istry) { registry.insert(); } } // namespace mesh } // namespace mlir