637 lines
27 KiB
C++
637 lines
27 KiB
C++
//===- Spmdization.cpp --------------------------------------------- C++ --===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Support/MathExtras.h"
|
|
#include "llvm/ADT/ADL.h"
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/DenseSet.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include <algorithm>
|
|
#include <iterator>
|
|
#include <numeric>
|
|
#include <optional>
|
|
#include <tuple>
|
|
#include <type_traits>
|
|
|
|
namespace mlir {
|
|
namespace mesh {
|
|
|
|
int64_t shardDimension(int64_t dim, int64_t shardCount) {
|
|
if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
|
|
return ShapedType::kDynamic;
|
|
|
|
assert(dim % shardCount == 0);
|
|
return ceilDiv(dim, shardCount);
|
|
}
|
|
|
|
int64_t unshardDimension(int64_t dim, int64_t shardCount) {
|
|
if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
|
|
return ShapedType::kDynamic;
|
|
|
|
return dim * shardCount;
|
|
}
|
|
|
|
template <typename MeshShape, typename SplitAxes>
|
|
int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
|
|
int64_t res = 1;
|
|
for (auto splitAxis : splitAxes) {
|
|
int64_t meshDimSize = meshShape[splitAxis];
|
|
if (ShapedType::isDynamic(meshDimSize)) {
|
|
return ShapedType::kDynamic;
|
|
}
|
|
res *= meshDimSize;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
// Compute the shape for the tensor on each device in the mesh.
|
|
// Example:
|
|
// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
|
|
// would result in a shape for each shard of ?x2x?.
|
|
template <typename InShape, typename MeshShape, typename SplitAxes,
|
|
typename OutShape>
|
|
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
|
|
const SplitAxes &splitAxes, OutShape &outShape) {
|
|
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
|
|
llvm::adl_begin(outShape));
|
|
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
|
|
outShape[tensorAxis] =
|
|
shardDimension(inShape[tensorAxis],
|
|
shardCount(meshShape, innerSplitAxes.asArrayRef()));
|
|
}
|
|
}
|
|
|
|
ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
|
|
MeshShardingAttr sharding) {
|
|
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
|
|
SmallVector<Dim> resShapeArr(shape.getShape().size());
|
|
shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
|
|
resShapeArr);
|
|
return shape.clone(resShapeArr);
|
|
}
|
|
|
|
template <typename SourceAxes, typename TargetAxes>
|
|
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
|
|
const TargetAxes &targetAxes) {
|
|
return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
|
|
return sourceAxes.contains(targetAxis);
|
|
});
|
|
}
|
|
|
|
// Return the reduced value and its corresponding sharding.
|
|
// Example:
|
|
// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
|
|
// targetSharding = <@mesh_1d, [[]]>
|
|
// Then will apply all-reduce on the source value
|
|
// and return it with the sharding <@mesh_1d, [[0]]>.
|
|
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
|
|
handlePartialAxesDuringResharding(OpBuilder &builder,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
if (sourceSharding.getPartialAxes().empty() &&
|
|
targetSharding.getPartialAxes().empty()) {
|
|
return {sourceShard, sourceSharding};
|
|
}
|
|
assert(targetSharding.getPartialAxes().empty() ||
|
|
(!sourceSharding.getPartialAxes().empty() &&
|
|
sourceSharding.getPartialType() == targetSharding.getPartialType()));
|
|
using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
|
|
using AxisSet = llvm::SmallDenseSet<Axis>;
|
|
AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
|
|
sourceSharding.getPartialAxes().end());
|
|
AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
|
|
targetSharding.getPartialAxes().end());
|
|
assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
|
|
targetShardingPartialAxesSet));
|
|
llvm::SmallVector<MeshAxis> allReduceMeshAxes;
|
|
llvm::copy_if(sourceShardingPartialAxesSet,
|
|
std::back_inserter(allReduceMeshAxes),
|
|
[&targetShardingPartialAxesSet](Axis a) {
|
|
return !targetShardingPartialAxesSet.contains(a);
|
|
});
|
|
if (allReduceMeshAxes.empty()) {
|
|
return {sourceShard, sourceSharding};
|
|
}
|
|
|
|
builder.setInsertionPointAfterValue(sourceShard);
|
|
TypedValue<ShapedType> resultValue =
|
|
builder
|
|
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
|
|
sourceSharding.getCluster().getLeafReference(),
|
|
allReduceMeshAxes, sourceShard,
|
|
sourceSharding.getPartialType())
|
|
.getResult()
|
|
.cast<TypedValue<ShapedType>>();
|
|
|
|
llvm::SmallVector<MeshAxis> remainingPartialAxes;
|
|
llvm::copy_if(sourceShardingPartialAxesSet,
|
|
std::back_inserter(allReduceMeshAxes),
|
|
[&targetShardingPartialAxesSet](Axis a) {
|
|
return targetShardingPartialAxesSet.contains(a);
|
|
});
|
|
MeshShardingAttr resultSharding =
|
|
MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
|
|
sourceSharding.getSplitAxes(), remainingPartialAxes,
|
|
sourceSharding.getPartialType());
|
|
return {resultValue, resultSharding};
|
|
}
|
|
|
|
static MeshShardingAttr
|
|
targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
|
|
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
|
|
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
|
|
llvm::to_vector(sourceSharding.getSplitAxes());
|
|
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
|
|
splitTensorAxis) {
|
|
targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
|
|
}
|
|
auto targetSplitAxes =
|
|
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
|
|
targetSplitAxes.push_back(splitMeshAxis);
|
|
targetShardingSplitAxes[splitTensorAxis] =
|
|
MeshAxesAttr::get(ctx, targetSplitAxes);
|
|
return MeshShardingAttr::get(
|
|
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
|
|
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
|
|
}
|
|
|
|
static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
|
|
int64_t splitTensorAxis,
|
|
int64_t splitCount) {
|
|
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
|
|
targetShape[splitTensorAxis] =
|
|
shardDimension(targetShape[splitTensorAxis], splitCount);
|
|
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
|
|
}
|
|
|
|
// Split a replicated tensor along a mesh axis.
|
|
// e.g. [[0, 1]] -> [[0, 1, 2]].
|
|
// Returns the spmdized target value with its sharding.
|
|
//
|
|
// The implementation is the extract the tensor slice corresponding
|
|
// to the current device.
|
|
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
|
|
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
|
|
MeshShardingAttr sourceSharding,
|
|
TypedValue<ShapedType> sourceShard, ClusterOp mesh,
|
|
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
builder.setInsertionPointAfterValue(sourceShard);
|
|
|
|
Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
|
|
|
|
Value processIndexAlongAxis =
|
|
builder
|
|
.create<ProcessMultiIndexOp>(mesh.getSymName(),
|
|
SmallVector<MeshAxis>({splitMeshAxis}))
|
|
.getResult()[0];
|
|
|
|
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
|
|
ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
|
|
ShapedType targetShape = targetShapeInSplitLastAxis(
|
|
sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
|
|
|
|
Value meshAxisSize =
|
|
builder
|
|
.create<ClusterShapeOp>(mesh.getSymName(),
|
|
SmallVector<MeshAxis>({splitMeshAxis}))
|
|
.getResult()[0];
|
|
|
|
Value sourceAxisSize =
|
|
builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
|
|
Value sourceAxisSizeModMeshAxisSize =
|
|
builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
|
|
Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
|
|
arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
|
|
builder.create<cf::AssertOp>(
|
|
isTargetShapeExactlyDivisible,
|
|
"Sharding a tensor with axis size that is not exactly divisible by the "
|
|
"mesh axis size is not supported.");
|
|
Value targetAxisSize =
|
|
builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
|
|
Value axisOffset =
|
|
builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
|
|
SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
|
|
staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
|
|
DenseI64ArrayAttr staticOffsetsAttr =
|
|
DenseI64ArrayAttr::get(ctx, staticOffsets);
|
|
SmallVector<Value> dynamicOffsets(1, axisOffset);
|
|
|
|
DenseI64ArrayAttr staticSizesAttr =
|
|
DenseI64ArrayAttr::get(ctx, targetShape.getShape());
|
|
SmallVector<Value> dynamicSizes;
|
|
for (int64_t i = 0; i < targetShape.getRank(); ++i) {
|
|
if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
|
|
if (i == splitTensorAxis) {
|
|
dynamicSizes.push_back(targetAxisSize);
|
|
} else {
|
|
Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
|
|
dynamicSizes.push_back(dimSize);
|
|
}
|
|
}
|
|
}
|
|
|
|
DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
|
|
ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
|
|
TypedValue<RankedTensorType> targetShard =
|
|
builder
|
|
.create<tensor::ExtractSliceOp>(
|
|
targetShape, sourceShard, dynamicOffsets, dynamicSizes,
|
|
SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
|
|
staticStridesAttr)
|
|
.getResult();
|
|
return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
|
|
}
|
|
|
|
// Detect if the resharding is of type e.g.
|
|
// [[0, 1]] -> [[0, 1, 2]].
|
|
// If detected, returns the corresponding tensor axis mesh axis pair.
|
|
// Does not detect insertions like
|
|
// [[0, 1]] -> [[0, 2, 1]].
|
|
static std::optional<std::tuple<int64_t, MeshAxis>>
|
|
detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding) {
|
|
for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
|
|
++tensorAxis) {
|
|
if (sourceSharding.getSplitAxes().size() > tensorAxis) {
|
|
if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
|
|
targetSharding.getSplitAxes()[tensorAxis].size()) {
|
|
continue;
|
|
}
|
|
if (!llvm::equal(
|
|
sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
|
|
llvm::make_range(
|
|
targetSharding.getSplitAxes()[tensorAxis]
|
|
.asArrayRef()
|
|
.begin(),
|
|
targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
|
|
1))) {
|
|
continue;
|
|
}
|
|
} else {
|
|
if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
|
|
continue;
|
|
}
|
|
}
|
|
return std::make_tuple(
|
|
tensorAxis,
|
|
targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
|
|
trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
if (auto detectRes =
|
|
detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
|
|
auto [tensorAxis, meshAxis] = detectRes.value();
|
|
return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
|
|
tensorAxis, meshAxis);
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Detect if the resharding is of type e.g.
|
|
// [[0, 1, 2]] -> [[0, 1]].
|
|
// If detected, returns the corresponding tensor axis mesh axis pair.
|
|
static std::optional<std::tuple<int64_t, MeshAxis>>
|
|
detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding) {
|
|
for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
|
|
++tensorAxis) {
|
|
if (targetSharding.getSplitAxes().size() > tensorAxis) {
|
|
if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
|
|
targetSharding.getSplitAxes()[tensorAxis].size() + 1)
|
|
continue;
|
|
if (!llvm::equal(
|
|
llvm::make_range(
|
|
sourceSharding.getSplitAxes()[tensorAxis]
|
|
.asArrayRef()
|
|
.begin(),
|
|
sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
|
|
1),
|
|
targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
|
|
continue;
|
|
} else {
|
|
if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
|
|
continue;
|
|
}
|
|
return std::make_tuple(
|
|
tensorAxis,
|
|
sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static MeshShardingAttr
|
|
targetShardingInUnsplitLastAxis(MLIRContext *ctx,
|
|
MeshShardingAttr sourceSharding,
|
|
int64_t splitTensorAxis) {
|
|
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
|
|
llvm::to_vector(sourceSharding.getSplitAxes());
|
|
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
|
|
splitTensorAxis);
|
|
auto targetSplitAxes =
|
|
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
|
|
|
|
targetSplitAxes.pop_back();
|
|
targetShardingSplitAxes[splitTensorAxis] =
|
|
MeshAxesAttr::get(ctx, targetSplitAxes);
|
|
return MeshShardingAttr::get(
|
|
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
|
|
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
|
|
}
|
|
|
|
static ShapedType allGatherResultShapeInUnsplitLastAxis(
|
|
ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
|
|
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
|
|
targetShape[splitTensorAxis] =
|
|
unshardDimension(targetShape[splitTensorAxis], splitCount);
|
|
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
|
|
}
|
|
|
|
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
|
|
unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
|
|
MeshShardingAttr sourceSharding,
|
|
ShapedType sourceUnshardedShape,
|
|
TypedValue<ShapedType> sourceShard, ClusterOp mesh,
|
|
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
builder.setInsertionPointAfterValue(sourceShard);
|
|
|
|
MeshShardingAttr targetSharding =
|
|
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
|
|
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
|
|
sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
|
|
Value allGatherResult = builder.create<AllGatherOp>(
|
|
RankedTensorType::get(allGatherResultShape.getShape(),
|
|
allGatherResultShape.getElementType()),
|
|
mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
|
|
APInt(64, splitTensorAxis));
|
|
ShapedType targetShape =
|
|
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
|
|
TypedValue<ShapedType> targetShard =
|
|
builder.create<tensor::CastOp>(targetShape, allGatherResult)
|
|
.getResult()
|
|
.cast<TypedValue<ShapedType>>();
|
|
return {targetShard, targetSharding};
|
|
}
|
|
|
|
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
|
|
tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
ShapedType sourceUnshardedShape,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
if (auto detectRes =
|
|
detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
|
|
auto [tensorAxis, meshAxis] = detectRes.value();
|
|
return unsplitLastAxisInResharding(builder, sourceSharding,
|
|
sourceUnshardedShape, sourceShard, mesh,
|
|
tensorAxis, meshAxis);
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Detect if the resharding is of type e.g.
|
|
// [[0, 1], [2]] -> [[0], [1, 2]].
|
|
// Only moving the last axis counts.
|
|
// If detected, returns the corresponding (source_tensor_axis,
|
|
// target_tensor_axis, mesh_axis) tuple.
|
|
static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
|
|
detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding) {
|
|
for (size_t sourceTensorAxis = 0;
|
|
sourceTensorAxis < sourceSharding.getSplitAxes().size();
|
|
++sourceTensorAxis) {
|
|
for (size_t targetTensorAxis = 0;
|
|
targetTensorAxis < targetSharding.getSplitAxes().size();
|
|
++targetTensorAxis) {
|
|
if (sourceTensorAxis == targetTensorAxis)
|
|
continue;
|
|
if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
|
|
targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
|
|
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
|
|
targetSharding.getSplitAxes()[targetTensorAxis]
|
|
.asArrayRef()
|
|
.back())
|
|
continue;
|
|
if (!llvm::equal(
|
|
llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
|
|
.asArrayRef()
|
|
.begin(),
|
|
sourceSharding.getSplitAxes()[sourceTensorAxis]
|
|
.asArrayRef()
|
|
.end() -
|
|
1),
|
|
llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
|
|
.asArrayRef()
|
|
.begin(),
|
|
targetSharding.getSplitAxes()[targetTensorAxis]
|
|
.asArrayRef()
|
|
.end() -
|
|
1)))
|
|
continue;
|
|
return std::make_tuple(
|
|
sourceTensorAxis, targetTensorAxis,
|
|
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static MeshShardingAttr
|
|
targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
|
|
int64_t sourceTensorAxis,
|
|
int64_t targetTensorAxis) {
|
|
SmallVector<MeshAxesAttr> targetShardingSplitAxes =
|
|
llvm::to_vector(sourceSharding.getSplitAxes());
|
|
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
|
|
targetTensorAxis) {
|
|
targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
|
|
}
|
|
|
|
auto sourceSplitAxes =
|
|
llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
|
|
assert(!sourceSplitAxes.empty());
|
|
auto meshAxis = sourceSplitAxes.back();
|
|
sourceSplitAxes.pop_back();
|
|
targetShardingSplitAxes[sourceTensorAxis] =
|
|
MeshAxesAttr::get(ctx, sourceSplitAxes);
|
|
|
|
auto targetSplitAxes =
|
|
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
|
|
targetSplitAxes.push_back(meshAxis);
|
|
targetShardingSplitAxes[targetTensorAxis] =
|
|
MeshAxesAttr::get(ctx, targetSplitAxes);
|
|
|
|
return MeshShardingAttr::get(
|
|
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
|
|
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
|
|
}
|
|
|
|
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
|
|
int64_t splitCount,
|
|
int64_t sourceTensorAxis,
|
|
int64_t targetTensorAxis) {
|
|
SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
|
|
targetShape[sourceTensorAxis] =
|
|
unshardDimension(targetShape[sourceTensorAxis], splitCount);
|
|
targetShape[targetTensorAxis] =
|
|
shardDimension(targetShape[targetTensorAxis], splitCount);
|
|
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
|
|
}
|
|
|
|
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
|
|
moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
ShapedType sourceUnshardedShape,
|
|
TypedValue<ShapedType> sourceShard,
|
|
int64_t sourceTensorAxis,
|
|
int64_t targetTensorAxis, MeshAxis meshAxis) {
|
|
MLIRContext *ctx = builder.getContext();
|
|
builder.setInsertionPointAfterValue(sourceShard);
|
|
|
|
MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
|
|
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
|
|
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
|
|
sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
|
|
targetTensorAxis);
|
|
Value allToAllResult = builder.create<AllToAllOp>(
|
|
RankedTensorType::get(allToAllResultShape.getShape(),
|
|
allToAllResultShape.getElementType()),
|
|
mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
|
|
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
|
|
ShapedType targetShape =
|
|
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
|
|
TypedValue<ShapedType> targetShard =
|
|
builder.create<tensor::CastOp>(targetShape, allToAllResult)
|
|
.getResult()
|
|
.cast<TypedValue<ShapedType>>();
|
|
return {targetShard, targetSharding};
|
|
}
|
|
|
|
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
|
|
tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
ShapedType sourceUnshardedShape,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
if (auto detectRes =
|
|
detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
|
|
auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
|
|
return moveLastSplitAxisInResharding(
|
|
builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
|
|
sourceTensorAxis, targetTensorAxis, meshAxis);
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Handles only resharding on a 1D mesh.
|
|
// Currently the sharded tensor axes must be exactly divisible by the single
|
|
// mesh axis size.
|
|
static TypedValue<ShapedType>
|
|
reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
TypedValue<ShapedType> sourceUnshardedValue,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
assert(sourceShard.getType() ==
|
|
shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
|
|
[[maybe_unused]] ShapedType targetShardType =
|
|
shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
|
|
assert(sourceShard.getType().getRank() == targetShardType.getRank());
|
|
assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
|
|
|
|
auto [reducedSourceShard, reducedSourceSharding] =
|
|
handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
|
|
sourceShard);
|
|
|
|
if (reducedSourceSharding == targetSharding) {
|
|
return reducedSourceShard;
|
|
}
|
|
|
|
TypedValue<ShapedType> targetShard;
|
|
MeshShardingAttr actualTargetSharding;
|
|
if (auto tryRes = tryMoveLastSplitAxisInResharding(
|
|
builder, mesh, reducedSourceSharding, targetSharding,
|
|
sourceUnshardedValue.getType(), reducedSourceShard)) {
|
|
std::tie(targetShard, actualTargetSharding) = tryRes.value();
|
|
} else if (auto tryRes = trySplitLastAxisInResharding(
|
|
builder, mesh, reducedSourceSharding, targetSharding,
|
|
reducedSourceShard)) {
|
|
std::tie(targetShard, actualTargetSharding) = tryRes.value();
|
|
} else if (auto tryRes = tryUnsplitLastAxisInResharding(
|
|
builder, mesh, reducedSourceSharding, targetSharding,
|
|
sourceUnshardedValue.getType(), reducedSourceShard)) {
|
|
std::tie(targetShard, actualTargetSharding) = tryRes.value();
|
|
} else {
|
|
assert(false && "Did not find any pattern to apply.");
|
|
}
|
|
|
|
assert(actualTargetSharding == targetSharding);
|
|
assert(targetShard.getType() == targetShardType);
|
|
return targetShard;
|
|
}
|
|
|
|
TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
|
|
MeshShardingAttr sourceSharding,
|
|
MeshShardingAttr targetSharding,
|
|
TypedValue<ShapedType> sourceUnshardedValue,
|
|
TypedValue<ShapedType> sourceShard) {
|
|
// Resort to handling only 1D meshes since the general case is complicated if
|
|
// it needs to be communication efficient in terms of minimizing the data
|
|
// transfered between devices.
|
|
return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
|
|
sourceUnshardedValue, sourceShard);
|
|
}
|
|
|
|
TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
|
|
ShardOp source, ShardOp target,
|
|
TypedValue<ShapedType> sourceShardValue) {
|
|
assert(!source.getAnnotateForUsers());
|
|
assert(target.getAnnotateForUsers());
|
|
assert(source.getResult() == target.getOperand());
|
|
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
|
|
return reshard(
|
|
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
|
|
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
|
|
}
|
|
|
|
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
|
|
registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
|
|
cf::ControlFlowDialect>();
|
|
}
|
|
|
|
} // namespace mesh
|
|
} // namespace mlir
|