//===- ShardingInterfaceImpl.cpp ------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/DialectRegistry.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "tosa-sharding-impl" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::tosa; using namespace mlir::mesh; namespace { template struct ElemwiseSharding : public ShardingInterface::ExternalModel, ElemwiseOp> { SmallVector getLoopIteratorTypes(Operation *op) const { Value val = op->getOperand(0); auto type = val.getType().dyn_cast(); if (!type) return {}; SmallVector types(type.getRank(), IteratorType::Parallel); return types; } SmallVector getIndexingMaps(Operation *op) const { MLIRContext *ctx = op->getContext(); Value val = op->getOperand(0); auto type = val.getType().dyn_cast(); if (!type) return {}; int64_t rank = type.getRank(); int64_t num = op->getNumOperands() + op->getNumResults(); SmallVector maps(num, AffineMap::getMultiDimIdentityMap(rank, ctx)); return maps; } }; // loop types: [parallel, parallel, parallel, reduction_sum] // indexing maps: // (d0, d1, d2, d3) -> (d0, d1, d3) // (d0, d1, d2, d3) -> (d0, d3, d2) // (d0, d1, d2, d3) -> (d0, d1, d2) struct MatMulOpSharding : public ShardingInterface::ExternalModel { SmallVector getLoopIteratorTypes(Operation *op) const { auto tensorType = op->getResult(0).getType().dyn_cast(); if (!tensorType) return {}; SmallVector types(tensorType.getRank() + 1, IteratorType::Parallel); types[tensorType.getRank()] = IteratorType::ReductionSum; return types; } SmallVector getIndexingMaps(Operation *op) const { auto tensorType = op->getResult(0).getType().dyn_cast(); if (!tensorType) return {}; MLIRContext *ctx = op->getContext(); SmallVector maps; maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx)); maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx)); return maps; } }; template static void registerElemwiseOne(MLIRContext *ctx) { OpType::template attachInterface>(*ctx); } /// Variadic helper function. template static void registerElemwiseAll(MLIRContext *ctx) { (registerElemwiseOne(ctx), ...); } } // namespace void mlir::tosa::registerShardingInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) { registerElemwiseAll< ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, GreaterEqualOp>(ctx); MatMulOp::attachInterface(*ctx); }); }