1695 lines
72 KiB
C++
1695 lines
72 KiB
C++
//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
|
|
|
|
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "nvgpu-to-nvvm"
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
#define DBGSE() (llvm::dbgs())
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
/// Number of bits that needs to be excluded when building matrix descriptor for
|
|
/// wgmma operations.
|
|
constexpr int exclude4LSB = 4;
|
|
|
|
/// GPU has 32 bit registers, this function truncates values when larger width
|
|
/// is not needed.
|
|
static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
|
|
Type type = value.getType();
|
|
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
|
|
if (type.getIntOrFloatBitWidth() <= 32)
|
|
return value;
|
|
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
|
|
}
|
|
|
|
/// Returns the type for the intrinsic given the vectorResultType of the
|
|
/// `gpu.mma.sync` operation.
|
|
static Type inferIntrinsicResultType(Type vectorResultType) {
|
|
MLIRContext *ctx = vectorResultType.getContext();
|
|
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
|
|
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
|
|
auto i32Ty = IntegerType::get(ctx, 32);
|
|
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
|
Type f64Ty = Float64Type::get(ctx);
|
|
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
|
Type f32Ty = Float32Type::get(ctx);
|
|
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
|
if (a.getElementType() == f16x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
|
|
}
|
|
if (a.getElementType() == i32x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx,
|
|
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
|
|
}
|
|
if (a.getElementType() == f64x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
|
|
}
|
|
if (a.getElementType() == f32x2Ty) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx,
|
|
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
|
|
}
|
|
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
|
|
return LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
|
|
}
|
|
return vectorResultType;
|
|
}
|
|
|
|
/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
|
|
/// always an LLVM struct) into a fragment that is compatible with the vector
|
|
/// type of this operation. This involves extracting elements from the struct
|
|
/// and inserting them into an LLVM array. These extra data-movement
|
|
/// operations should be canonicalized away by the LLVM backend.
|
|
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
|
|
Type resultType, Value intrinsicResult,
|
|
RewriterBase &rewriter) {
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
|
|
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
|
|
Type i32Ty = rewriter.getI32Type();
|
|
Type f32Ty = rewriter.getF32Type();
|
|
Type f64Ty = rewriter.getF64Type();
|
|
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
|
|
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
|
|
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
|
|
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
|
|
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
|
|
|
auto makeConst = [&](int32_t index) -> Value {
|
|
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
|
|
rewriter.getI32IntegerAttr(index));
|
|
};
|
|
|
|
if (arrayType) {
|
|
SmallVector<Value, 4> elements;
|
|
|
|
// The intrinsic returns 32-bit wide elements in a form which can be
|
|
// directly bitcasted and inserted into the result vector.
|
|
if (arrayType.getElementType() == f16x2Ty ||
|
|
arrayType.getElementType() == f32x1Ty) {
|
|
for (unsigned i = 0; i < structType.getBody().size(); i++) {
|
|
Value el =
|
|
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
|
|
el = rewriter.createOrFold<LLVM::BitcastOp>(
|
|
loc, arrayType.getElementType(), el);
|
|
elements.push_back(el);
|
|
}
|
|
}
|
|
|
|
// The intrinsic returns i32, f64, and f32 values as individual scalars,
|
|
// even when the result is notionally a 64-bit wide element (e.g. f32x2). We
|
|
// need to extract them from the struct and pack them into the 64-bit wide
|
|
// rows of the vector result.
|
|
if (arrayType.getElementType() == i32x2Ty ||
|
|
arrayType.getElementType() == f64x2Ty ||
|
|
arrayType.getElementType() == f32x2Ty) {
|
|
|
|
for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
|
|
Value vec =
|
|
rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
|
|
Value x1 =
|
|
rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
|
|
Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
|
|
i * 2 + 1);
|
|
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
|
|
x1, makeConst(0));
|
|
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
|
|
x2, makeConst(1));
|
|
elements.push_back(vec);
|
|
}
|
|
}
|
|
|
|
// Create the final vectorized result.
|
|
Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
|
|
for (const auto &el : llvm::enumerate(elements)) {
|
|
result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
|
|
el.index());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
return intrinsicResult;
|
|
}
|
|
|
|
/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
|
|
/// given as 2D `vectors` where the rows are 32b or 64b wide. The
|
|
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
|
|
/// scalars of certain types. This function helps unpack the `vector` arguments
|
|
/// and cast them to the types expected by `nvvm.mma.sync`.
|
|
static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
|
|
Value operand,
|
|
NVVM::MMATypes operandPtxType) {
|
|
SmallVector<Value> result;
|
|
Type i32Ty = b.getI32Type();
|
|
Type f64Ty = b.getF64Type();
|
|
Type f32Ty = b.getF32Type();
|
|
Type i64Ty = b.getI64Type();
|
|
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
|
|
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
|
|
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
|
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
|
|
|
|
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
|
|
Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
|
|
|
|
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
|
|
// scalar types.
|
|
if (arrayTy.getElementType() == i8x4Ty ||
|
|
arrayTy.getElementType() == i4x8Ty ||
|
|
(arrayTy.getElementType() == f32x1Ty &&
|
|
operandPtxType == NVVM::MMATypes::tf32)) {
|
|
result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
|
|
continue;
|
|
}
|
|
|
|
// For some element types (i32, f32, f64), we need to unpack the inner
|
|
// vector/array type as well because the intrinsic expects individual
|
|
// scalars to be provided.
|
|
VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
|
|
if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
|
|
innerArrayTy.getElementType() == f64Ty ||
|
|
innerArrayTy.getElementType() == f32Ty)) {
|
|
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
|
|
idx < innerSize; idx++) {
|
|
result.push_back(b.create<LLVM::ExtractElementOp>(
|
|
toUse,
|
|
b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
|
|
}
|
|
continue;
|
|
}
|
|
result.push_back(toUse);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Returns whether mbarrier object has shared memory address space.
|
|
static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
|
|
return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
|
|
barrierType.getMemorySpace()));
|
|
}
|
|
|
|
/// Returns the memory space attribute of the mbarrier object.
|
|
Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
|
|
nvgpu::MBarrierGroupType barrierType) {
|
|
Attribute memorySpace = {};
|
|
if (isMbarrierShared(barrierType)) {
|
|
memorySpace =
|
|
IntegerAttr::get(IntegerType::get(context, 64),
|
|
nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
|
|
}
|
|
return memorySpace;
|
|
}
|
|
|
|
/// Returns memref type of the mbarrier object. The type is defined in the
|
|
/// MBarrierGroupType.
|
|
MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
|
|
nvgpu::MBarrierGroupType barrierType) {
|
|
Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType);
|
|
MemRefLayoutAttrInterface layout;
|
|
return MemRefType::get({barrierType.getNumBarriers()},
|
|
IntegerType::get(context, 64), layout, memorySpace);
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MLIRContext *ctx = getContext();
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
// The result type of ldmatrix will always be a struct of 32bit integer
|
|
// registers if more than one 32bit value is returned. Otherwise, the result
|
|
// is a single i32. The result type of the GPU operation is always a vector
|
|
// of shape (NumRegisters, VectorRegister) where VectorRegister is the
|
|
// vector type of the result and always 32 bits long. We bitcast the result
|
|
// of the NVVM::LdMatrix to this vector type.
|
|
auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
|
|
if (!vectorResultType) {
|
|
return failure();
|
|
}
|
|
Type innerVectorType = LLVM::getFixedVectorType(
|
|
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
|
|
|
|
int64_t num32BitRegs = vectorResultType.getDimSize(0);
|
|
|
|
Type ldMatrixResultType;
|
|
if (num32BitRegs > 1) {
|
|
ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
|
|
ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
|
|
} else {
|
|
ldMatrixResultType = rewriter.getI32Type();
|
|
}
|
|
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
|
|
Value srcPtr =
|
|
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
|
|
adaptor.getIndices(), rewriter);
|
|
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
|
|
ldMatrixResultType, srcPtr,
|
|
/*num=*/op.getNumTiles(),
|
|
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
|
|
: NVVM::MMALayout::row);
|
|
|
|
// The ldmatrix operation returns either a single i32 value or a struct of
|
|
// i32 values. Here we unpack those values and cast them back to their
|
|
// actual vector type (still of width 32b) and repack them into a result
|
|
// struct.
|
|
Type finalResultType = typeConverter->convertType(vectorResultType);
|
|
Value result = b.create<LLVM::UndefOp>(finalResultType);
|
|
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
|
|
Value i32Register =
|
|
num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
|
|
: ldMatrixResult;
|
|
Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
|
|
result = b.create<LLVM::InsertValueOp>(result, casted, i);
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
|
|
/// enum).
|
|
static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
|
|
Type elType = getElementTypeOrSelf(t);
|
|
if (elType.isInteger(8))
|
|
return NVVM::MMATypes::s8;
|
|
if (elType.isInteger(4))
|
|
return NVVM::MMATypes::s4;
|
|
if (elType.isF16())
|
|
return NVVM::MMATypes::f16;
|
|
if (elType.isF64())
|
|
return NVVM::MMATypes::f64;
|
|
if (elType.isF32())
|
|
return NVVM::MMATypes::tf32;
|
|
return failure();
|
|
}
|
|
|
|
struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
// Get the shapes of the MMAMatrix type being used. The shapes will
|
|
// choose which intrinsic this op will be lowered to.
|
|
VectorType aType = op.getMatrixA().getType();
|
|
VectorType bType = op.getMatrixA().getType();
|
|
VectorType cType = op.getMatrixC().getType();
|
|
|
|
std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
|
|
|
|
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
|
|
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
|
|
if (aType.getElementType().isF32() && !tf32Enabled)
|
|
return failure();
|
|
|
|
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
|
|
if (failed(ptxTypeA))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
|
|
if (failed(ptxTypeB))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
std::optional<NVVM::MMATypes> ptxTypeC =
|
|
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
|
|
/*isAccumulator=*/true);
|
|
if (!ptxTypeC)
|
|
return op->emitError(
|
|
"could not infer the PTX type for the accumulator/result");
|
|
|
|
// TODO: add an attribute to the op to customize this behavior.
|
|
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
|
|
if (isa<IntegerType>(aType.getElementType()))
|
|
overflow = NVVM::MMAIntOverflow::satfinite;
|
|
|
|
SmallVector<Value> matA =
|
|
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
|
SmallVector<Value> matB =
|
|
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
|
SmallVector<Value> matC =
|
|
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
|
|
|
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
|
Type intrinsicResTy = inferIntrinsicResultType(
|
|
typeConverter->convertType(op->getResultTypes()[0]));
|
|
Value intrinsicResult = b.create<NVVM::MmaOp>(
|
|
intrinsicResTy, matA, matB, matC,
|
|
/*shape=*/gemmShape,
|
|
/*b1Op=*/std::nullopt,
|
|
/*intOverflow=*/overflow,
|
|
/*multiplicandPtxTypes=*/
|
|
std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
|
|
/*multiplicandLayouts=*/
|
|
std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
|
|
NVVM::MMALayout::col});
|
|
rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
|
|
desiredRetTy, intrinsicResult,
|
|
rewriter));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConvertNVGPUToNVVMPass
|
|
: public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
|
|
using Base::Base;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
|
|
arith::ArithDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
LowerToLLVMOptions options(&getContext());
|
|
RewritePatternSet patterns(&getContext());
|
|
LLVMTypeConverter converter(&getContext(), options);
|
|
IRRewriter rewriter(&getContext());
|
|
populateGpuMemorySpaceAttributeConversions(
|
|
converter, [](gpu::AddressSpace space) -> unsigned {
|
|
switch (space) {
|
|
case gpu::AddressSpace::Global:
|
|
return static_cast<unsigned>(
|
|
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
|
case gpu::AddressSpace::Workgroup:
|
|
return static_cast<unsigned>(
|
|
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
|
case gpu::AddressSpace::Private:
|
|
return 0;
|
|
}
|
|
llvm_unreachable("unknown address space enum value");
|
|
return 0;
|
|
});
|
|
/// device-side async tokens cannot be materialized in nvvm. We just
|
|
/// convert them to a dummy i32 type in order to easily drop them during
|
|
/// conversion.
|
|
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 32));
|
|
});
|
|
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
|
|
Type elemType = type.getFragmented().getElementType();
|
|
int64_t sizeM = type.getFragmented().getDimSize(0);
|
|
int64_t sizeN = type.getFragmented().getDimSize(1);
|
|
|
|
unsigned numMembers;
|
|
if (elemType.isF32() || elemType.isInteger(32))
|
|
numMembers = sizeN / 2;
|
|
else if (elemType.isF16())
|
|
numMembers = sizeN / 4;
|
|
else
|
|
llvm_unreachable("unsupported type for warpgroup accumulator");
|
|
|
|
SmallVector<Type> innerStructBody;
|
|
for (unsigned i = 0; i < numMembers; i++)
|
|
innerStructBody.push_back(elemType);
|
|
auto innerStructType =
|
|
LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
|
|
|
|
SmallVector<Type> structBody;
|
|
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
|
|
structBody.push_back(innerStructType);
|
|
|
|
auto convertedType =
|
|
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
|
|
return converter.convertType(convertedType);
|
|
});
|
|
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 64));
|
|
});
|
|
converter.addConversion(
|
|
[&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
|
|
return converter.convertType(IntegerType::get(type.getContext(), 64));
|
|
});
|
|
converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
|
|
return converter.convertType(
|
|
nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
|
|
});
|
|
converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type {
|
|
return LLVM::LLVMPointerType::get(type.getContext());
|
|
});
|
|
populateNVGPUToNVVMConversionPatterns(converter, patterns);
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
|
target.addLegalDialect<::mlir::arith::ArithDialect>();
|
|
target.addLegalDialect<::mlir::memref::MemRefDialect>();
|
|
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
|
|
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
|
converter, patterns, target);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
/// Returns the constraints for the sparse MMA inline assembly instruction.
|
|
static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
|
|
unsigned matBSize,
|
|
unsigned matCSize) {
|
|
std::string str;
|
|
llvm::raw_string_ostream ss(str);
|
|
for (unsigned i = 0; i < matCSize; i++)
|
|
ss << "=r,";
|
|
for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
|
|
ss << "r,";
|
|
// The final operand is for the sparsity metadata.
|
|
// The sparsity selector appears as direct literal.
|
|
ss << "r";
|
|
ss.flush();
|
|
return str;
|
|
}
|
|
|
|
/// Returns the string for the `mma.sp.sync` instruction that corresponds to
|
|
/// the given parameters. Note that this function doesn't do any validation,
|
|
/// it's expected that the provided parameters correspond to a valid
|
|
/// instruction.
|
|
static std::string buildMmaSparseAsmString(
|
|
const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
|
|
unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
|
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
|
|
std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
|
|
auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
|
|
return NVVM::stringifyMMATypes(ptxType);
|
|
};
|
|
|
|
std::string asmStr;
|
|
llvm::raw_string_ostream ss(asmStr);
|
|
ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
|
|
<< shape[2] << ".row.col.";
|
|
|
|
if (overflow)
|
|
ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
|
|
|
|
ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
|
|
<< ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
|
|
unsigned asmArgIdx = 0;
|
|
|
|
// The operand string is structured into sections `{matC elements...},
|
|
// {matA elements...}, {matB elements...}, {matC elements}`.
|
|
for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
|
|
ss << "{";
|
|
for (unsigned i = 0; i < arrSize; i++)
|
|
ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
|
|
ss << "},";
|
|
}
|
|
ss << "$" << asmArgIdx++ << ",";
|
|
assert(metaDataSelector <= 1);
|
|
ss << "0x" << metaDataSelector << ";";
|
|
ss.flush();
|
|
return asmStr;
|
|
}
|
|
|
|
/// Builds an inline assembly operation corresponding to the specified MMA
|
|
/// sparse sync operation.
|
|
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
|
|
ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
|
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
|
|
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
|
|
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
|
|
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
|
|
Type intrinsicResultType) {
|
|
auto asmDialectAttr =
|
|
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
|
|
|
|
const unsigned matASize = unpackedAData.size();
|
|
const unsigned matBSize = unpackedB.size();
|
|
const unsigned matCSize = unpackedC.size();
|
|
|
|
std::string asmStr = buildMmaSparseAsmString(
|
|
shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
|
|
ptxTypeD, overflow, metadataSelector);
|
|
std::string constraintStr =
|
|
buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
|
|
|
|
SmallVector<Value> asmVals;
|
|
asmVals.reserve(matASize + matBSize + matCSize + 1);
|
|
for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
|
|
llvm::append_range(asmVals, args);
|
|
asmVals.push_back(indexData);
|
|
|
|
return b.create<LLVM::InlineAsmOp>(
|
|
/*resultTypes=*/intrinsicResultType,
|
|
/*operands=*/asmVals,
|
|
/*asm_string=*/asmStr,
|
|
/*constraints=*/constraintStr,
|
|
/*has_side_effects=*/true,
|
|
/*is_align_stack=*/false,
|
|
/*asm_dialect=*/asmDialectAttr,
|
|
/*operand_attrs=*/ArrayAttr());
|
|
}
|
|
|
|
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
|
|
struct NVGPUMmaSparseSyncLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
// Get the shapes of the MMAMatrix type being used. The shapes will
|
|
// choose which intrinsic this op will be lowered to.
|
|
VectorType aType = op.getMatrixA().getType();
|
|
VectorType bType = op.getMatrixB().getType();
|
|
VectorType cType = op.getMatrixC().getType();
|
|
|
|
FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
|
|
if (failed(ptxTypeA))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
|
|
if (failed(ptxTypeB))
|
|
return op->emitOpError("failed to deduce operand PTX types");
|
|
std::optional<NVVM::MMATypes> ptxTypeC =
|
|
NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
|
|
/*isAccumulator=*/true);
|
|
if (!ptxTypeC)
|
|
return op->emitError(
|
|
"could not infer the PTX type for the accumulator/result");
|
|
|
|
// Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
|
|
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
|
|
if (aType.getElementType().isF32() && !tf32Enabled)
|
|
return failure();
|
|
|
|
// TODO: add an attribute to the op to customize this behavior.
|
|
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
|
|
if (isa<IntegerType>(aType.getElementType()))
|
|
overflow = NVVM::MMAIntOverflow::satfinite;
|
|
|
|
SmallVector<Value> matA =
|
|
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
|
SmallVector<Value> matB =
|
|
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
|
SmallVector<Value> matC =
|
|
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
|
|
|
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
|
Type intrinsicResTy = inferIntrinsicResultType(
|
|
typeConverter->convertType(op->getResultTypes()[0]));
|
|
|
|
// Bitcast the sparse metadata from vector<2xf16> to an i32.
|
|
Value sparseMetadata = adaptor.getSparseMetadata();
|
|
if (sparseMetadata.getType() !=
|
|
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
|
|
return op->emitOpError() << "Expected metadata type to be LLVM "
|
|
"VectorType of 2 i16 elements";
|
|
sparseMetadata =
|
|
b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
|
|
|
|
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
|
|
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
|
|
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
|
|
intrinsicResTy);
|
|
if (failed(intrinsicResult))
|
|
return failure();
|
|
|
|
assert((*intrinsicResult).getNumResults() == 1 &&
|
|
"expected inline asm op returns a single LLVM struct type");
|
|
rewriter.replaceOp(
|
|
op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
|
|
(*intrinsicResult)->getResult(0), rewriter));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncCopyLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
Location loc = op.getLoc();
|
|
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
|
|
Value dstPtr =
|
|
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
|
|
adaptor.getDstIndices(), rewriter);
|
|
FailureOr<unsigned> dstAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
|
|
if (failed(dstAddressSpace))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "destination memref address space not convertible to integer");
|
|
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
|
|
FailureOr<unsigned> srcAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
|
|
if (failed(srcAddressSpace))
|
|
return rewriter.notifyMatchFailure(
|
|
loc, "source memref address space not convertible to integer");
|
|
|
|
Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
|
|
adaptor.getSrcIndices(), rewriter);
|
|
// Intrinsics takes a global pointer so we need an address space cast.
|
|
auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
|
|
op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
|
scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
|
|
int64_t dstElements = adaptor.getDstElements().getZExtValue();
|
|
int64_t sizeInBytes =
|
|
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
|
|
// When the optional SrcElements argument is *not* present, the regular
|
|
// CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
|
|
// memory) to fill DstElements number of elements in the destination
|
|
// (shared memory).
|
|
Value srcBytes = adaptor.getSrcElements();
|
|
if (srcBytes) {
|
|
// When the optional SrcElements argument is present, the source (global
|
|
// memory) of CpAsyncOp is read only for SrcElements number of elements.
|
|
// The rest of the DstElements in the destination (shared memory) are
|
|
// filled with zeros.
|
|
Value c3I32 =
|
|
b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
|
|
Value bitwidth = b.create<LLVM::ConstantOp>(
|
|
b.getI32Type(),
|
|
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
|
|
Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
|
|
srcBytes = b.create<LLVM::LShrOp>(
|
|
b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
|
|
}
|
|
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
|
|
// 16 dst bytes.
|
|
NVVM::LoadCacheModifierKind cacheModifier =
|
|
(op.getBypassL1().value_or(false) && sizeInBytes == 16)
|
|
? NVVM::LoadCacheModifierKind::CG
|
|
: NVVM::LoadCacheModifierKind::CA;
|
|
|
|
b.create<NVVM::CpAsyncOp>(
|
|
dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
|
|
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
|
|
srcBytes);
|
|
|
|
// Drop the result token.
|
|
Value zero = b.create<LLVM::ConstantOp>(
|
|
IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
|
|
rewriter.replaceOp(op, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncCreateGroupLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
|
|
// Drop the result token.
|
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
|
op->getLoc(), IntegerType::get(op.getContext(), 32),
|
|
rewriter.getI32IntegerAttr(0));
|
|
rewriter.replaceOp(op, zero);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUAsyncWaitLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// If numGroup is not present pick 0 as a conservative correct value.
|
|
int32_t numGroups = adaptor.getNumGroups().value_or(0);
|
|
rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Creates mbarrier object in shared memory
|
|
struct NVGPUMBarrierCreateLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
|
|
|
|
template <typename moduleT>
|
|
memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
|
|
Operation *funcOp, moduleT moduleOp,
|
|
MemRefType barrierType) const {
|
|
SymbolTable symbolTable(moduleOp);
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(&moduleOp.front());
|
|
auto global = rewriter.create<memref::GlobalOp>(
|
|
funcOp->getLoc(), "__mbarrier",
|
|
/*sym_visibility=*/rewriter.getStringAttr("private"),
|
|
/*type=*/barrierType,
|
|
/*initial_value=*/ElementsAttr(),
|
|
/*constant=*/false,
|
|
/*alignment=*/rewriter.getI64IntegerAttr(8));
|
|
symbolTable.insert(global);
|
|
return global;
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Operation *funcOp = op->getParentOp();
|
|
MemRefType barrierType = nvgpu::getMBarrierMemrefType(
|
|
rewriter.getContext(), op.getBarriers().getType());
|
|
|
|
memref::GlobalOp global;
|
|
if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
|
|
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
|
|
else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
|
|
global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
|
|
|
|
rewriter.setInsertionPoint(op);
|
|
rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
|
|
global.getName());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Base class for lowering mbarrier operations to nvvm intrinsics.
|
|
template <typename SourceOp>
|
|
struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
|
/// Returns the base pointer of the mbarrier object.
|
|
Value getMbarrierPtr(ImplicitLocOpBuilder &b,
|
|
nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
|
|
Value mbarId,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
MemRefType mbarrierMemrefType =
|
|
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
|
|
return ConvertToLLVMPattern::getStridedElementPtr(
|
|
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
|
|
struct NVGPUMBarrierInitLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
|
|
rewriter.setInsertionPoint(op);
|
|
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value count = truncToI32(b, adaptor.getCount());
|
|
if (isMbarrierShared(mbarrierType)) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
|
|
op, barrier, count, adaptor.getPredicate());
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
|
|
adaptor.getPredicate());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
|
|
struct NVGPUMBarrierArriveLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type tokenType = getTypeConverter()->convertType(
|
|
nvgpu::MBarrierTokenType::get(op->getContext()));
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
|
|
barrier);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
|
|
barrier);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
|
|
/// `nvvm.mbarrier.arrive.nocomplete`
|
|
struct NVGPUMBarrierArriveNoCompleteLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type tokenType = getTypeConverter()->convertType(
|
|
nvgpu::MBarrierTokenType::get(op->getContext()));
|
|
Value count = truncToI32(b, adaptor.getCount());
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
|
|
op, tokenType, barrier, count);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
|
|
op, tokenType, barrier, count);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
|
|
struct NVGPUMBarrierTestWaitLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
|
|
using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Type retType = rewriter.getI1Type();
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
|
|
op, retType, barrier, adaptor.getToken());
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
|
|
op, retType, barrier, adaptor.getToken());
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUMBarrierArriveExpectTxLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value txcount = truncToI32(b, adaptor.getTxcount());
|
|
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
|
|
op, barrier, txcount, adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
|
|
op, barrier, txcount, adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUMBarrierTryWaitParityLowering
|
|
: public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
|
|
using MBarrierBasePattern<
|
|
nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
Value ticks = truncToI32(b, adaptor.getTicks());
|
|
Value phase = truncToI32(b, adaptor.getPhase());
|
|
|
|
if (isMbarrierShared(op.getBarriers().getType())) {
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
|
|
op, barrier, phase, ticks);
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
|
|
phase, ticks);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaAsyncLoadOpLowering
|
|
: public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
|
|
using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
|
|
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
|
|
adaptor.getDst(), {}, rewriter);
|
|
Value barrier =
|
|
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
|
adaptor.getMbarId(), rewriter);
|
|
|
|
SmallVector<Value> coords = adaptor.getCoordinates();
|
|
for (auto [index, value] : llvm::enumerate(coords)) {
|
|
coords[index] = truncToI32(b, value);
|
|
}
|
|
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
|
|
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
|
|
ValueRange{}, adaptor.getMulticastMask(), Value{},
|
|
adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaAsyncStoreOpLowering
|
|
: public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
|
|
using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
|
|
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
|
|
adaptor.getSrc(), {}, rewriter);
|
|
SmallVector<Value> coords = adaptor.getCoordinates();
|
|
for (auto [index, value] : llvm::enumerate(coords)) {
|
|
coords[index] = truncToI32(b, value);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
|
|
op, adaptor.getTensorMapDescriptor(), dest, coords,
|
|
adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUGenerateWarpgroupDescriptorLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
|
|
nvgpu::TensorMapSwizzleKind swizzleKind =
|
|
op.getTensorMap().getType().getSwizzle();
|
|
|
|
unsigned layout =
|
|
(swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
|
|
: 1;
|
|
unsigned swizzle =
|
|
(swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
|
|
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
|
|
: 0;
|
|
|
|
auto ti64 = b.getIntegerType(64);
|
|
auto makeConst = [&](uint64_t index) -> Value {
|
|
return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
|
|
};
|
|
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
|
|
return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
|
|
};
|
|
auto shiftRight = [&](Value value, unsigned shift) -> Value {
|
|
return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
|
|
};
|
|
auto insertBit = [&](Value desc, Value val, int startBit) {
|
|
return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
|
|
};
|
|
|
|
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
|
|
uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
|
|
uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
|
|
uint64_t offsetVal = 0;
|
|
|
|
Value strideDim = makeConst(strideDimVal);
|
|
Value leadDim = makeConst(leadDimVal);
|
|
|
|
Value baseAddr = getStridedElementPtr(
|
|
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
|
|
adaptor.getTensor(), {}, rewriter);
|
|
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
|
|
// Just use 14 bits for base address
|
|
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
|
|
|
|
int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
|
|
startLeadBit = 16, startBaseAddrBit = 0;
|
|
Value dsc = makeConst(0);
|
|
// // [62,64) swizzle type
|
|
dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
|
|
// // [49,52) base_offset
|
|
dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
|
|
// // [32,46) stride
|
|
dsc = insertBit(dsc, strideDim, startStrideBit);
|
|
// // [16,30) leading dimension
|
|
dsc = insertBit(dsc, leadDim, startLeadBit);
|
|
// // [0,14) start_address
|
|
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
|
|
|
|
LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
|
|
<< "leading_off:" << leadDimVal << "\t"
|
|
<< "stride_off :" << strideDimVal << "\t"
|
|
<< "base_offset:" << offsetVal << "\t"
|
|
<< "layout_type:" << swizzle << " ("
|
|
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
|
|
<< ")\n start_addr : " << baseAddr << "\n");
|
|
|
|
rewriter.replaceOp(op, dsc);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
|
|
return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
|
|
b.getI32IntegerAttr(index));
|
|
}
|
|
|
|
/// Returns a Value that holds data type enum that is expected by CUDA driver.
|
|
static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
|
|
// Enum is from CUDA driver API
|
|
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
|
|
enum CUtensorMapDataTypeEnum {
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_INT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_INT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
|
|
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
|
CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
|
|
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
|
|
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
|
|
};
|
|
|
|
if (type.isUnsignedInteger(8))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
|
if (type.isUnsignedInteger(16))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
|
|
if (type.isUnsignedInteger(32))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
|
|
if (type.isUnsignedInteger(64))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
|
|
if (type.isSignlessInteger(32))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
|
|
if (type.isSignlessInteger(64))
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
|
|
if (type.isF16())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
|
|
if (type.isF32())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
|
|
if (type.isF64())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
|
|
if (type.isBF16())
|
|
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
|
|
|
|
llvm_unreachable("Not supported data type");
|
|
}
|
|
|
|
struct NVGPUTmaCreateDescriptorOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
|
|
Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
|
|
|
|
Value tensorElementType =
|
|
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
|
|
auto promotedOperands = getTypeConverter()->promoteOperands(
|
|
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
|
|
|
|
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
|
|
makeI64Const(b, 5));
|
|
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
|
|
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
|
|
boxArrayPtr, makeI64Const(b, index));
|
|
b.create<LLVM::StoreOp>(value, gep);
|
|
}
|
|
|
|
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
|
|
// Set Arguments for the function call
|
|
SmallVector<Value> arguments;
|
|
arguments.push_back(promotedOperands[0]); // rank
|
|
arguments.push_back(promotedOperands[1]); // descriptor
|
|
arguments.push_back(tensorElementType); // data type
|
|
arguments.push_back(
|
|
makeI64Const(b, (int)desc.getInterleave())); // interleave
|
|
arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
|
|
arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
|
|
arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
|
|
arguments.push_back(boxArrayPtr); // box dimensions
|
|
|
|
// Set data types of the arguments
|
|
SmallVector<Type> argTypes = {
|
|
llvmInt64Type, /* int64_t tensorRank */
|
|
llvmPointerType, /* ptr */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmInt64Type, /* int64_t */
|
|
llvmPointerType /* ptr */
|
|
};
|
|
FunctionCallBuilder hostRegisterCallBuilder = {
|
|
"mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
|
|
Value tensorMap =
|
|
hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
|
|
|
|
rewriter.replaceOp(op, tensorMap);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// This is a helper class to generate required NVVM Ops for warp-group level
|
|
/// matrix multiplication.
|
|
/// When the given GEMM shape is larger than the shape of
|
|
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
|
|
/// Op(s), group and execute them asynchronously. The class also handles
|
|
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
|
|
/// create descriptors for each instruction.
|
|
///
|
|
/// For example this is the case when the shape of GEMM is 128x128x128
|
|
///
|
|
/// nvvm.wgmma.fence.aligned
|
|
///
|
|
/// nvvm.wgmma.mma.async descA, descB
|
|
/// iterate(descA, descB)
|
|
/// nvvm.wgmma.mma.async descA, descB
|
|
/// [6x times more]
|
|
///
|
|
/// nvvm.wgmma.group.sync.aligned
|
|
/// nvvm.wgmma.wait.group.sync [groupId]
|
|
///
|
|
class WarpgroupGemm {
|
|
nvgpu::WarpgroupMmaOp op;
|
|
ImplicitLocOpBuilder b;
|
|
OpAdaptor adaptor;
|
|
|
|
// Entire shape of the given Op
|
|
int64_t totalM, totalN, totalK;
|
|
|
|
// Shape of one wgmma instruction
|
|
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
|
|
|
|
// Iteration counts for GEMM
|
|
int iterationM = 0, iterationN = 0, iterationK = 0;
|
|
|
|
/// The function returns the shape of wgmma instruction that is defined in
|
|
/// PTX programming guide.
|
|
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
|
|
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
|
|
wgmmaM = 64;
|
|
wgmmaN = sizeN;
|
|
if (inputElemType.isTF32()) {
|
|
wgmmaK = 8;
|
|
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
|
|
wgmmaK = 16;
|
|
} else if (inputElemType.isFloat8E4M3FN() ||
|
|
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
|
|
wgmmaK = 32;
|
|
} else if (inputElemType.isInteger(1)) {
|
|
wgmmaK = 256;
|
|
} else {
|
|
llvm_unreachable("msg: not supported K shape");
|
|
}
|
|
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
|
|
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
|
|
}
|
|
|
|
/// Generates WGMMATypesAttr from MLIR Type
|
|
NVVM::WGMMATypesAttr generateWgmmaType(Type type,
|
|
bool useF32 = false) const {
|
|
auto getWgmmaType = [=](Type elemType) {
|
|
if (elemType.isF32() || elemType.isTF32())
|
|
return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
|
|
if (elemType.isF16())
|
|
return NVVM::WGMMATypes::f16;
|
|
if (elemType.isBF16())
|
|
return NVVM::WGMMATypes::bf16;
|
|
if (elemType.isFloat8E4M3FN())
|
|
return NVVM::WGMMATypes::e4m3;
|
|
if (elemType.isFloat8E5M2())
|
|
return NVVM::WGMMATypes::e5m2;
|
|
if (elemType.isInteger(1))
|
|
return NVVM::WGMMATypes::b1;
|
|
if (elemType.isInteger(8))
|
|
return NVVM::WGMMATypes::s8;
|
|
if (elemType.isUnsignedInteger(8))
|
|
return NVVM::WGMMATypes::u8;
|
|
if (elemType.isInteger(32))
|
|
return NVVM::WGMMATypes::s32;
|
|
llvm_unreachable("unsupported type");
|
|
};
|
|
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
|
|
}
|
|
|
|
/// Generates layout attribute for the input matrix for wgmma instruction
|
|
NVVM::MMALayoutAttr
|
|
generateWgmmaLayout(std::optional<bool> transpose) const {
|
|
if (transpose.value_or(false))
|
|
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
|
|
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
|
|
}
|
|
|
|
/// Generates shape attribute for wgmma instruction
|
|
NVVM::MMAShapeAttr generateWgmmaShape() const {
|
|
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
|
|
}
|
|
|
|
/// Generates scale attributes of output matrix for wgmma instruction
|
|
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
|
|
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
|
|
NVVM::WGMMAScaleOut::one);
|
|
}
|
|
/// Generates scale attributes of input matrix for wgmma instruction
|
|
NVVM::WGMMAScaleInAttr generateScaleIn() const {
|
|
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
|
|
NVVM::WGMMAScaleIn::one);
|
|
}
|
|
|
|
/// Basic function to generate Add
|
|
Value makeAdd(Value lhs, Value rhs) {
|
|
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
|
|
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
|
|
/// Currently, it only handles row-major.
|
|
///
|
|
/// It moves the pointer like below for [128][64] size:
|
|
/// +2 +4 +6
|
|
/// ↓ ↓ ↓
|
|
/// descA ---> +--+--+--+--+
|
|
/// |->|->|->|->|
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// descA+512---> +-----------+
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// | | | | |
|
|
/// +-----------+
|
|
///
|
|
Value iterateDescriptorA(Value desc, int i, int j, int k) {
|
|
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
|
|
Type elemA = matrixTypeA.getElementType();
|
|
int byte = elemA.getIntOrFloatBitWidth() / 8;
|
|
int tileShapeA = matrixTypeA.getDimSize(1);
|
|
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
|
|
incrementVal = incrementVal >> exclude4LSB;
|
|
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
|
|
<< "] [wgmma descriptors] Descriptor A + "
|
|
<< incrementVal << " | \t ");
|
|
if (!incrementVal)
|
|
return desc;
|
|
return makeAdd(desc, makeI64Const(b, incrementVal));
|
|
}
|
|
|
|
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
|
|
/// Currently, it only handles column-major.
|
|
///
|
|
/// It moves the pointer like below for [128][64] size:
|
|
/// descB ---> +--+--+--+--+--+--+--+--+
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// |↓ | | | | | | | |
|
|
/// +--+--+--+--+--+--+--+--+
|
|
///
|
|
Value iterateDescriptorB(Value desc, int i, int j, int k) {
|
|
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
|
|
Type elemB = matrixTypeB.getElementType();
|
|
int byte = elemB.getIntOrFloatBitWidth() / 8;
|
|
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
|
|
incrementVal = incrementVal >> exclude4LSB;
|
|
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
|
|
if (!incrementVal)
|
|
return desc;
|
|
return makeAdd(desc, makeI64Const(b, incrementVal));
|
|
}
|
|
|
|
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
|
|
/// descriptors and arranges them based on induction variables: i, j, and k.
|
|
Value generateWgmma(int i, int j, int k, Value matrixC) {
|
|
LLVM_DEBUG(DBGS() << "\t wgmma."
|
|
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
|
|
<< "(A[" << (iterationM * wgmmaM) << ":"
|
|
<< (iterationM * wgmmaM) + wgmmaM << "]["
|
|
<< (iterationK * wgmmaK) << ":"
|
|
<< (iterationK * wgmmaK + wgmmaK) << "] * "
|
|
<< " B[" << (iterationK * wgmmaK) << ":"
|
|
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
|
|
<< wgmmaN << "])\n");
|
|
|
|
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
|
|
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
|
|
|
|
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
|
|
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
|
|
|
|
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
|
|
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
|
|
|
|
Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
|
|
NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
|
|
|
|
NVVM::MMAShapeAttr shape = generateWgmmaShape();
|
|
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
|
|
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
|
|
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
|
|
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
|
|
|
|
auto overflow = NVVM::MMAIntOverflowAttr::get(
|
|
op->getContext(), NVVM::MMAIntOverflow::wrapped);
|
|
|
|
return b.create<NVVM::WgmmaMmaAsyncOp>(
|
|
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
|
|
itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
|
|
overflow);
|
|
}
|
|
|
|
/// Generates multiple wgmma instructions to complete the given GEMM shape
|
|
Value generateWgmmaGroup() {
|
|
Value wgmmaResult =
|
|
b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
|
|
|
|
// Perform GEMM
|
|
SmallVector<Value> wgmmaResults;
|
|
for (int i = 0; i < iterationM; ++i) {
|
|
Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
|
|
for (int j = 0; j < iterationN; ++j)
|
|
for (int k = 0; k < iterationK; ++k)
|
|
matrixC = generateWgmma(i, j, k, matrixC);
|
|
wgmmaResults.push_back(matrixC);
|
|
}
|
|
for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
|
|
wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
|
|
wgmmaResult, matrix, idx);
|
|
}
|
|
return wgmmaResult;
|
|
}
|
|
|
|
public:
|
|
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
|
|
OpAdaptor adaptor)
|
|
: op(op), b(b), adaptor(adaptor) {
|
|
// Find the entire GEMM Shape
|
|
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
|
|
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
|
|
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
|
|
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
|
|
<< "] += A[" << totalM << "][" << totalK << "] * B["
|
|
<< totalK << "][" << totalN << "] ---===\n");
|
|
|
|
// Find the shape for one wgmma instruction
|
|
findWgmmaShape(
|
|
totalM, totalN,
|
|
op.getDescriptorA().getType().getTensor().getElementType());
|
|
|
|
// Iterations counts to complete the given shape with wgmma shape
|
|
iterationM = totalM / wgmmaM;
|
|
iterationN = totalN / wgmmaN;
|
|
iterationK = totalK / wgmmaK;
|
|
}
|
|
|
|
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
|
|
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
|
|
/// instructions and group synchronization, as well as waiting
|
|
/// (WgmmaGroupSyncAlignedOp) for group synchronization
|
|
/// (WgmmaWaitGroupSyncOp) after the instructions.
|
|
Value generateWarpgroupMma() {
|
|
b.create<NVVM::WgmmaFenceAlignedOp>();
|
|
Value wgmmaResult = generateWgmmaGroup();
|
|
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
|
|
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
|
|
return wgmmaResult;
|
|
}
|
|
};
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
|
|
// Step 1. Build a helper class
|
|
WarpgroupGemm warpgroupGemm(op, b, adaptor);
|
|
|
|
// Step 2. Get the entire GEMM Shape
|
|
Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
|
|
|
|
// Step 3. Replace fragmented result struct with the op results
|
|
rewriter.replaceOp(op, wgmmaResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaStoreOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// This function stores a fragmented register matrix owned by a warp group
|
|
/// (128 threads) into a memref. Each thread has 64 registers, each the size
|
|
/// of a struct.
|
|
/// Here is what each threads (T) holds, each `d` is struct value with a
|
|
/// number.
|
|
///
|
|
/// Threads in warp-group (128 threads) and what they owns in the matrixD:
|
|
/// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
|
|
/// 32-63 Warp-1 -> MatrixD[16:31][0:N]
|
|
/// 64-95 Warp-2 -> MatrixD[32:47][0:N]
|
|
/// 96-127 Warp-3 -> MatrixD[48:64][0:N]
|
|
///
|
|
/// Matrix-D:
|
|
/// +______________________________________________________________________+
|
|
/// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
|
|
/// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
|
|
/// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
|
|
/// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
|
|
/// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
|
|
/// ..| .........|.........|.........|.........|........|...........|........|
|
|
/// +______________________________________________________________________+
|
|
///
|
|
/// \param rewriter: The pattern rewriter.
|
|
/// \param matrixD: Result of the warp-group MMA operation (fragmented
|
|
/// matrix). It is holded by a thread and a struct with 64 elements.
|
|
/// \param dstMemref: The memref where the registers will be stored.
|
|
/// \param offset: the offset within the memref where the registers will be
|
|
/// stored.
|
|
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
|
|
TypedValue<MemRefType> dstMemref,
|
|
int offset) const {
|
|
Type i32 = b.getI32Type();
|
|
|
|
auto makeConst = [&](int32_t index) -> Value {
|
|
return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
|
|
};
|
|
Value c1 = makeConst(1);
|
|
Value c2 = makeConst(2);
|
|
Value c4 = makeConst(4);
|
|
Value c8 = makeConst(8);
|
|
Value c16 = makeConst(16);
|
|
Value warpSize = makeConst(kWarpSize);
|
|
|
|
auto makeMul = [&](Value lhs, Value rhs) -> Value {
|
|
return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
|
|
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
|
};
|
|
|
|
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
|
|
TypedValue<::mlir::MemRefType> memref) {
|
|
Type it = b.getIndexType();
|
|
Value idx = b.create<arith::IndexCastOp>(it, x);
|
|
Value idy0 = b.create<arith::IndexCastOp>(it, y);
|
|
Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
|
|
Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
|
|
Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
|
|
b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
|
|
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
|
|
};
|
|
|
|
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
|
|
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
|
|
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
|
|
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
|
|
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
|
|
|
|
Value tj = makeMul(lane4modId, c2);
|
|
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
|
|
if (offset)
|
|
ti = makeAdd(ti, makeConst(offset));
|
|
|
|
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
|
|
|
|
// Number of 32-bit registers owns per thread
|
|
constexpr unsigned numAdjacentRegisters = 2;
|
|
// Number of 8x8 matrices one below another per warp
|
|
constexpr unsigned numStackedMatrices = 2;
|
|
|
|
size_t storeCount = (structType.getBody().size() /
|
|
(numStackedMatrices * numAdjacentRegisters));
|
|
|
|
for (size_t i = 0; i < numStackedMatrices; ++i) {
|
|
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
|
|
for (size_t j = 0; j < storeCount; ++j) {
|
|
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
|
|
size_t structIndex = (i * numAdjacentRegisters) +
|
|
(j * (numStackedMatrices * numAdjacentRegisters));
|
|
makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
int offset = 0;
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
Value matriDValue = adaptor.getMatrixD();
|
|
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
|
|
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
|
|
auto structType = matrixD.cast<LLVM::LLVMStructType>();
|
|
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
|
|
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
|
|
offset += structType.getBody().size();
|
|
}
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
|
LLVM::LLVMStructType packStructType =
|
|
getTypeConverter()
|
|
->convertType(op.getMatrixC().getType())
|
|
.cast<LLVM::LLVMStructType>();
|
|
Type elemType = packStructType.getBody()
|
|
.front()
|
|
.cast<LLVM::LLVMStructType>()
|
|
.getBody()
|
|
.front();
|
|
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
|
|
Value packStruct = b.create<LLVM::UndefOp>(packStructType);
|
|
SmallVector<Value> innerStructs;
|
|
// Unpack the structs and set all values to zero
|
|
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
|
|
auto structType = s.cast<LLVM::LLVMStructType>();
|
|
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
|
|
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
|
structValue = b.create<LLVM::InsertValueOp>(
|
|
structType, structValue, zero, ArrayRef<int64_t>({i}));
|
|
}
|
|
innerStructs.push_back(structValue);
|
|
}
|
|
// Pack the inner structs into a single struct
|
|
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
|
|
packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
|
|
packStruct, matrix, idx);
|
|
}
|
|
rewriter.replaceOp(op, packStruct);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NVGPUTmaPrefetchOpLowering
|
|
: public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
|
|
using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
|
|
LogicalResult
|
|
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
|
|
op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<
|
|
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
|
|
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
|
|
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
|
|
NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
|
|
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
|
|
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
|
|
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
|
|
NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
|
|
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
|
|
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
|
|
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
|
|
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
|
|
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
|
|
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
|
|
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
|
|
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
|
|
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
|
|
NVGPUMmaSparseSyncLowering>(converter);
|
|
}
|