//===- TestSimplification.cpp - Test simplification -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::mesh; namespace { struct TestMeshReshardingRewritePattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShardOp op, PatternRewriter &rewriter) const override { if (op.getAnnotateForUsers()) { return failure(); } SymbolTableCollection symbolTable; mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom( op, op.getShard().getCluster()); bool foundUser = false; for (auto user : op->getUsers()) { if (auto targetShardOp = llvm::dyn_cast(user)) { if (targetShardOp.getAnnotateForUsers() && mesh == symbolTable.lookupNearestSymbolFrom( targetShardOp, targetShardOp.getShard().getCluster())) { foundUser = true; break; } } } if (!foundUser) { return failure(); } for (auto user : op->getUsers()) { auto targetShardOp = llvm::dyn_cast(user); if (!targetShardOp || !targetShardOp.getAnnotateForUsers() || symbolTable.lookupNearestSymbolFrom( targetShardOp, targetShardOp.getShard().getCluster()) != mesh) { continue; } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); ShapedType sourceShardShape = shardShapedType(op.getResult().getType(), mesh, op.getShard()); TypedValue sourceShard = builder .create(sourceShardShape, op.getOperand()) ->getResult(0) .cast>(); TypedValue targetShard = reshard(builder, mesh, op, targetShardOp, sourceShard); Value newTargetUnsharded = builder .create( targetShardOp.getResult().getType(), targetShard) ->getResult(0); rewriter.replaceAllUsesWith(targetShardOp.getResult(), newTargetUnsharded); } return success(); } }; struct TestMeshReshardingPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass) void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(), std::move(patterns)))) { return signalPassFailure(); } } void getDependentDialects(DialectRegistry ®istry) const override { reshardingRegisterDependentDialects(registry); registry.insert(); } StringRef getArgument() const final { return "test-mesh-resharding-spmdization"; } StringRef getDescription() const final { return "Test Mesh dialect resharding spmdization."; } }; } // namespace namespace mlir { namespace test { void registerTestMeshReshardingSpmdizationPass() { PassRegistration(); } } // namespace test } // namespace mlir