//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; using namespace mlir::nvgpu; /// There are always 4 threads per [128|256|512] bit row. static constexpr int64_t kThreadsPerRow = 4; static constexpr int64_t kNumRowsPerTile = 8; static bool isAccumulatorOrResult(MatMulOperandRole operandType) { return operandType == MatMulOperandRole::C; } /// Returns the number of registers which compose a matrix fragment held by a /// single thread. static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { int64_t lineSize = inferTileWidthInBits(type); auto shape = type.vectorType.getShape(); return (shape[0] / kNumRowsPerTile) * (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / lineSize; } /// Returns the number of 8 x [128|256|512] bit tiles that compose the given /// operand shape. static std::array getTileShape(ArrayRef operandShape, Type elementType, int64_t lineSizeBits) { // For each 8x128bit square, a thread is responsible for one 32bit register. return {operandShape[0] / kNumRowsPerTile, (operandShape[1] * elementType.getIntOrFloatBitWidth()) / lineSizeBits}; } /// Returns the first user of the `op` that is vector.contract. If no /// vector.contract user exists, return failure. FailureOr nvgpu::getUserContract(Operation *op) { for (Operation *user : op->getUsers()) { if (auto contractOp = dyn_cast(user)) return contractOp; } return failure(); } FailureOr nvgpu::getWarpMatrixInfo(Operation *op) { WarpMatrixInfo info; // Determine the vector type at warp-level. if (vector::TransferWriteOp writeOp = dyn_cast(op)) { info.vectorType = writeOp.getVectorType(); } else if (isa(op)) { info.vectorType = cast(op->getResult(0).getType()); } else { return op->emitError() << "unhandled operation type in nvgpu.mma.sync conversion path"; } // Determine the operand role. We assume it is an accumulator/result unless it // is directly consumed by a `vector.contract` op. info.operandRole = MatMulOperandRole::C; FailureOr contractOp = getUserContract(op); if (failed(contractOp)) return info; if ((*contractOp).getLhs() == op->getResult(0)) info.operandRole = MatMulOperandRole::A; else if ((*contractOp).getRhs() == op->getResult(0)) info.operandRole = MatMulOperandRole::B; return info; } int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) { bool isAcc = isAccumulatorOrResult(type.operandRole); Type elType = type.vectorType.getElementType(); if (isAcc && elType.getIntOrFloatBitWidth() == 32) { return 256; } if (elType.getIntOrFloatBitWidth() == 64) { return isAcc ? 512 : 256; } return 128; } FailureOr nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { MLIRContext *ctx = type.vectorType.getContext(); const bool isAccum = isAccumulatorOrResult(type.operandRole); Type elType = type.vectorType.getElementType(); if (elType.isF16()) { return FragmentElementInfo{ LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, inferNumRegistersPerMatrixFragment(type)}; } // f64 operand Type f64Ty = Float64Type::get(ctx); if (elType.isF64()) { return isAccum ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f64Ty, 1, 64, inferNumRegistersPerMatrixFragment(type)}; } // int8 operand if (elType.isInteger(8)) { return FragmentElementInfo{ LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, inferNumRegistersPerMatrixFragment(type)}; } // int4 operand if (elType.isInteger(4)) { return FragmentElementInfo{ LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32, inferNumRegistersPerMatrixFragment(type)}; } // Integer 32bit acc operands if (elType.isInteger(32)) { return FragmentElementInfo{ LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, inferNumRegistersPerMatrixFragment(type)}; } // Floating point 32bit operands if (elType.isF32()) { Type f32Ty = Float32Type::get(ctx); return isAccum ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, inferNumRegistersPerMatrixFragment(type)} : FragmentElementInfo{f32Ty, 1, 32, inferNumRegistersPerMatrixFragment(type)}; } return failure(); } static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId) { const int64_t elementsPerLine = lineSize / elementType.getIntOrFloatBitWidth(); const std::array num8x128bTiles = getTileShape(operandShape, elementType, lineSize); AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); return AffineMap::get( 2, 0, {(registerIdx % num8x128bTiles[0]) * 8, (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, elementType.getContext()); } FailureOr nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef operandShape = fragmentType.vectorType.getShape(); FailureOr regInfo = getMmaSyncRegisterType(fragmentType); if (failed(regInfo)) return failure(); const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); const int64_t elementsPerRegister = regInfo->registerWidthBits / elementBitWidth; const int64_t lineSize = inferTileWidthInBits(fragmentType); AffineExpr laneId, logicalValueIdDim; bindDims(builder.getContext(), laneId, logicalValueIdDim); // Determine what register logicalValueId corresponds to. Use that as a // linear index into the coordinate mapping `index -> (tile row, tile col)`. AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( lineSize, elementType, operandShape, isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, logicalValueIdDim); auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { return AffineMap::get(2, 0, dimExprs, builder.getContext()); }; auto tileRow = registerIndexToTileCoord.getResult(0); auto tileCol = registerIndexToTileCoord.getResult(1); return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + (logicalValueIdDim % elementsPerRegister)}); } FailureOr nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { LdMatrixParams params; Type elType = type.vectorType.getElementType(); params.fragmentType = type.vectorType; if (type.operandRole == MatMulOperandRole::A || type.operandRole == MatMulOperandRole::C) { params.targetLayout = NVVM::MMALayout::row; } else { params.targetLayout = NVVM::MMALayout::col; } ArrayRef shape = type.vectorType.getShape(); params.contiguousDimType = transpose ? vector::IteratorType::parallel : vector::IteratorType::reduction; if (params.contiguousDimType == vector::IteratorType::reduction) { params.numTiles = (shape[0] / kNumRowsPerTile) * ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); } else { params.numTiles = (shape[1] / kNumRowsPerTile) * ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); } if (params.numTiles == 0) return failure(); return params; } FailureOr nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms) { // One thread per 128b row. const int bitsPerElement = static_cast( params.fragmentType.getElementType().getIntOrFloatBitWidth()); const int kElementsPer128b = (128 / bitsPerElement); ArrayRef operandShape = params.fragmentType.getShape(); AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { return AffineMap::get(1, 0, dimExprs, builder.getContext()); }; // Index `idx` in vectorType `operandShape` maps to the strided dimension of // the `srcMemref` memory of the LdMatrixOp. int idx = (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1; // Affine expr in strided and contiguous dimension encodes the coordinate // mapping for the element a thread points to for warp-wide LdMatrixOp. AffineExpr strided = d0 % (operandShape[idx]); AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b); // This case corresponds to row-major matrixA or col-major matrixB or // row-major matrixC. This is when the memory layout in `srcMemref` // match mma.sync hardware vector register operand layout. if (params.contiguousDimType == vector::IteratorType::reduction) return makeMap({strided, contiguous}); // This case corresponds to col-major matrixA or row-major matrixB or // col-major matrixC. This is when the memory layout in `srcMemref` does not // match mma.sync hardware vector register operand layout. if (params.contiguousDimType == vector::IteratorType::parallel) return makeMap({contiguous, strided}); return failure(); } bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) { if (op.getMask() || op.hasOutOfBoundsDim()) return false; VectorType type = op.getType(); // The result type should be 2D. Note that it is possible to expand support so // that we are robust to extra unit dimensions that failed to fold, but that // would significantly increase downstream code complexity in the conversion // step. For now, we rely on other patterns to ensure canonical 2D form is // used when targeting the `nvgpu.mma.sync` lowering path. if (!type.hasStaticShape() || type.getRank() != 2) return false; // Currently we can't support reads on tensor types because we need stride // information to ensure correctness of downstream assumptions. It is possible // to enable this if caller can assert that tensor will be lowered in a // particular manner. auto sourceType = dyn_cast(op.getSource().getType()); if (!sourceType) return false; // Check that the last dimension of the read is contiguous. Note that it is // possible to expand support for this by scalarizing all the loads during // conversion. auto [strides, offset] = mlir::getStridesAndOffset(sourceType); return strides.back() == 1; } bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) { if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0) return false; VectorType type = op.getVectorType(); if (!type.hasStaticShape() || type.getRank() != 2) return false; // TODO: Currently we rely on lowering to a `vector.store` operation. We could // support the transposed write case by lowering to scalarized `memref.store` // operations. if (!op.getPermutationMap().isMinorIdentity()) return false; // Currently we can't support reads on tensor types because we need stride // information to ensure correctness of downstream assumptions. auto sourceType = dyn_cast(op.getSource().getType()); if (!sourceType) return false; // Check that the last dimension of the target memref is contiguous. Note that // it is possible to expand support for this by scalarizing all the stores // during conversion. auto [strides, offset] = mlir::getStridesAndOffset(sourceType); return strides.back() == 1; }