//===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::amdgpu; static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type llvmI32 = rewriter.getI32Type(); return rewriter.create(loc, llvmI32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); return rewriter.createOrFold(loc, llvmI1, value); } namespace { /// Define lowering patterns for raw buffer ops template struct RawBufferOpLowering : public ConvertOpToLLVMPattern { RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; static constexpr uint32_t maxVectorOpWidth = 128; LogicalResult matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); MemRefType memrefType = cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("raw buffer ops require GCN or higher"); Value storeData = adaptor.getODSOperands(0)[0]; if (storeData == memref) // no write component to this op storeData = Value(); Type wantedDataType; if (storeData) wantedDataType = storeData.getType(); else wantedDataType = gpuOp.getODSResults(0)[0].getType(); Value atomicCmpData = Value(); // Operand index 1 of a load is the indices, trying to read them can crash. if (storeData) { Value maybeCmpData = adaptor.getODSOperands(1)[0]; if (maybeCmpData != memref) atomicCmpData = maybeCmpData; } Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType); Type i32 = rewriter.getI32Type(); Type llvmI32 = this->typeConverter->convertType(i32); Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type()); int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8; Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); // If we want to load a vector with total size <= 32 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, // so bitcast any floats to integers. On top of all this, cast bfloat // (vectors) to i16 since the backend doesn't currently support bfloat on // these operations. Type llvmBufferValType = llvmWantedDataType; if (wantedDataType.isBF16()) llvmBufferValType = rewriter.getI16Type(); if (auto wantedVecType = dyn_cast(wantedDataType)) if (wantedVecType.getElementType().isBF16()) llvmBufferValType = wantedVecType.clone(rewriter.getI16Type()); if (atomicCmpData) { if (isa(wantedDataType)) return gpuOp.emitOpError("vector compare-and-swap does not exist"); if (auto floatType = dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } if (auto dataVector = dyn_cast(wantedDataType)) { uint32_t elemBits = dataVector.getElementTypeBitWidth(); uint32_t totalBits = elemBits * dataVector.getNumElements(); if (totalBits > maxVectorOpWidth) return gpuOp.emitOpError( "Total width of loads or stores must be no more than " + Twine(maxVectorOpWidth) + " bits, but we call for " + Twine(totalBits) + " bits. This should've been caught in validation"); if (elemBits < 32) { if (totalBits > 32) { if (totalBits % 32 != 0) return gpuOp.emitOpError("Load or store of more than 32-bits that " "doesn't fit into words. Can't happen\n"); llvmBufferValType = this->typeConverter->convertType( VectorType::get(totalBits / 32, i32)); } else { llvmBufferValType = this->typeConverter->convertType( rewriter.getIntegerType(totalBits)); } } } SmallVector args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { Value castForStore = rewriter.create(loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); } } if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { Value castForCmp = rewriter.create( loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); } } // Construct buffer descriptor from memref, attributes int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); MemRefDescriptor memrefDescriptor(memref); Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); // The stride value is always 0 for raw buffers. This also disables // swizling. Value stride = rewriter.createOrFold( loc, llvmI16, rewriter.getI16IntegerAttr(0)); Value numRecords; if (memrefType.hasStaticShape()) { numRecords = createI32Constant( rewriter, loc, static_cast(memrefType.getNumElements() * elementByteWidth)); } else { Value maxIndex; for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); stride = rewriter.create(loc, stride, byteWidthConst); Value maxThisDim = rewriter.create(loc, size, stride); maxIndex = maxIndex ? rewriter.create(loc, maxIndex, maxThisDim) : maxThisDim; } numRecords = rewriter.create(loc, llvmI32, maxIndex); } // Flag word: // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) // bits 15-18: data format (ignored, must be nonzero, 4=32bit) // bit 19: In nested heap (0 here) // bit 20: Behavior on unmap (0 means "return 0 / ignore") // bits 21-22: Index stride for swizzles (N/A) // bit 23: Add thread ID (0) // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) // bits 25-26: Reserved (0) // bit 27: Buffer is non-volatile (CDNA only) // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 = // none, 3 = either swizzles or testing against offset field) RDNA only // bits 30-31: Type (must be 0) uint32_t flags = (7 << 12) | (4 << 15); if (chipset.majorVersion >= 10) { flags |= (1 << 24); uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2; flags |= (oob << 28); } Value flagsConst = createI32Constant(rewriter, loc, flags); Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); Value resource = rewriter.createOrFold( loc, rsrcType, ptr, stride, numRecords, flagsConst); args.push_back(resource); // Indexing (voffset) Value voffset = createI32Constant(rewriter, loc, 0); for (auto pair : llvm::enumerate(adaptor.getIndices())) { size_t i = pair.index(); Value index = pair.value(); Value strideOp; if (ShapedType::isDynamic(strides[i])) { strideOp = rewriter.create( loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst); } else { strideOp = createI32Constant(rewriter, loc, strides[i] * elementByteWidth); } index = rewriter.create(loc, index, strideOp); voffset = rewriter.create(loc, voffset, index); } if (adaptor.getIndexOffset()) { int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth; Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset); voffset = voffset ? rewriter.create(loc, voffset, extraOffsetConst) : extraOffsetConst; } args.push_back(voffset); Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); if (ShapedType::isDynamic(offset)) sgprOffset = rewriter.create( loc, memrefDescriptor.offset(rewriter, loc), sgprOffset); else if (offset > 0) sgprOffset = rewriter.create( loc, sgprOffset, createI32Constant(rewriter, loc, offset)); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) // bits 1-2: SLC, DLC = 0 (similarly) // bit 3: swizzled (0 for raw) args.push_back(createI32Constant(rewriter, loc, 0)); llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); Operation *lowered = rewriter.create(loc, resultTypes, args, ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { replacement = rewriter.create(loc, llvmWantedDataType, replacement); } rewriter.replaceOp(gpuOp, replacement); } else { rewriter.eraseOp(gpuOp); } return success(); } }; struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); const char *asmStr = "s_waitcnt lgkmcnt(0)\ns_barrier"; const char *constraints = ""; rewriter.replaceOpWithNewOp( op, /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(), /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); return success(); } }; } // namespace /// If `input` is a vector of bytes, concatentate those bytes in little-endian /// order to form a single integer of size 8 * [vector length]. This works /// around a wart in the AMDGPU intrinsics where operations that logically take /// vectors of bytes instead integers. Since we do not want to expose this /// implementation detail to MLIR, we correct for it here. /// /// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU /// MFMA intrinsics pre-date the bfloat type. static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { if (vectorType.getElementType().isBF16()) return rewriter.create( loc, vectorType.clone(rewriter.getI16Type()), input); if (!vectorType.getElementType().isInteger(8)) return input; int64_t numBytes = vectorType.getNumElements(); Type destType = rewriter.getIntegerType(numBytes * 8); Value result = rewriter.create( loc, destType, rewriter.getIntegerAttr(destType, 0)); for (int64_t i = 0; i < numBytes; ++i) { Value idxConst = createI32Constant(rewriter, loc, i); Value element = rewriter.create(loc, input, idxConst); Value extended = rewriter.create(loc, destType, element); Value shiftConst = rewriter.create( loc, destType, rewriter.getIntegerAttr(destType, i * 8)); Value shifted = rewriter.create(loc, extended, shiftConst); result = rewriter.create(loc, result, shifted); } return result; } return input; } /// Push an input operand. If it is a float type, nothing to do. If it is /// an integer type, then we need to also push its signdness (1 for signed, 0 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 /// vector. We also need to convert bfloat inputs to i16 to account for the lack /// of bfloat support in the WMMA intrinsics themselves. static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, SmallVector &operands) { Type inputType = llvmInput.getType(); auto vectorType = inputType.dyn_cast(); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) llvmInput = rewriter.create( loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (!elemType.isInteger(8)) { operands.push_back(llvmInput); return; } int64_t numBytes = vectorType.getNumElements(); Type i32 = rewriter.getI32Type(); VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); Value result = rewriter.createOrFold( loc, llvmVectorType32bits, llvmInput); // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag bool localIsUnsigned = isUnsigned; if (elemType.isUnsignedInteger(8)) { localIsUnsigned = true; } else if (elemType.isSignedInteger(8)) { localIsUnsigned = false; } Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); operands.push_back(sign); operands.push_back(result); } /// Push the output operand. For many cases this is only pushing the output in /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, /// since the same numbers of VGPRs is used, we need to decide if to store the /// result in the upper 16 bits of the VGPRs or in the lower part. To store the /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will /// be stored it in the upper part static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector &operands) { Type inputType = output.getType(); auto vectorType = inputType.dyn_cast(); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) output = rewriter.create( loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); } else if (elemType.isInteger(32)) { operands.push_back(createI1Constant(rewriter, loc, clamp)); } } /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) { uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = mfma.getSourceA().getType(); if (auto sourceType = dyn_cast(sourceElem)) sourceElem = sourceType.getElementType(); Type destElem = mfma.getDestC().getType(); if (auto destType = dyn_cast(destElem)) destElem = destType.getElementType(); if (sourceElem.isF32() && destElem.isF32()) { if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) { if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); } if (m == 32 && n == 32 && k == 1 && b == 2) return ROCDL::mfma_f32_32x32x1f32::getOperationName(); if (m == 16 && n == 16 && k == 1 && b == 4) return ROCDL::mfma_f32_16x16x1f32::getOperationName(); if (m == 4 && n == 4 && k == 1 && b == 16) return ROCDL::mfma_f32_4x4x1f32::getOperationName(); if (m == 32 && n == 32 && k == 2 && b == 1) return ROCDL::mfma_f32_32x32x2f32::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f32_16x16x4f32::getOperationName(); } if (sourceElem.isF16() && destElem.isF32()) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4f16::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4f16::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4f16::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8f16::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16f16::getOperationName(); } if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); } if (sourceElem.isBF16() && destElem.isF32()) { if (m == 32 && n == 32 && k == 2 && b == 2) return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); if (m == 16 && n == 16 && k == 2 && b == 4) return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); if (m == 4 && n == 4 && k == 2 && b == 16) return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); if (m == 32 && n == 32 && k == 4 && b == 1) return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); if (m == 16 && n == 16 && k == 8 && b == 1) return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } if (isa(sourceElem) && destElem.isInteger(32)) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) return ROCDL::mfma_i32_16x16x4i8::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 16) return ROCDL::mfma_i32_4x4x4i8::getOperationName(); if (m == 32 && n == 32 && k == 8 && b == 1) return ROCDL::mfma_i32_32x32x8i8::getOperationName(); if (m == 16 && n == 16 && k == 16 && b == 1) return ROCDL::mfma_i32_16x16x16i8::getOperationName(); if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40) return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40) return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); } if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) { if (m == 16 && n == 16 && k == 4 && b == 1) return ROCDL::mfma_f64_16x16x4f64::getOperationName(); if (m == 4 && n == 4 && k == 4 && b == 4) return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset.minorVersion >= 0x40) { // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); if (sourceBElem.isFloat8E4M3FNUZ()) return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); if (sourceBElem.isFloat8E4M3FNUZ()) return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); } } if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset.minorVersion >= 0x40) { Type sourceBElem = cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); if (sourceBElem.isFloat8E4M3FNUZ()) return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); } if (m == 32 && n == 32 && k == 16 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); if (sourceBElem.isFloat8E4M3FNUZ()) return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); } } return std::nullopt; } /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { auto sourceVectorType = wmma.getSourceA().getType().dyn_cast(); auto destVectorType = wmma.getDestC().getType().dyn_cast(); auto elemSourceType = sourceVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); if (elemSourceType.isF16() && elemDestType.isF32()) { return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); } if (elemSourceType.isBF16() && elemDestType.isF32()) { return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); } else if (elemSourceType.isF16() && elemDestType.isF16()) { return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); } return std::nullopt; } namespace { struct MFMAOpLowering : public ConvertOpToLLVMPattern { MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); Type intrinsicOutType = outType; if (auto outVecType = dyn_cast(outType)) if (outVecType.getElementType().isBF16()) intrinsicOutType = outVecType.clone(rewriter.getI16Type()); if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) return op->emitOpError("MFMA only supported on gfx908+"); uint32_t getBlgpField = static_cast(op.getBlgp()); if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { if (chipset.minorVersion < 0x40) return op.emitOpError("negation unsupported on older than gfx840"); getBlgpField |= op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); } std::optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching MFMA size on given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands( {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()), mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()), adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), createI32Constant(rewriter, loc, op.getAbid()), createI32Constant(rewriter, loc, getBlgpField)}); Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) lowered = rewriter.create(loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } }; struct WMMAOpLowering : public ConvertOpToLLVMPattern { WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); if (chipset.majorVersion != 11) return op->emitOpError("WMMA only supported on gfx11"); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching WMMA on the given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes(outType); SmallVector operands; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), adaptor.getSourceA(), operands); wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), adaptor.getSourceB(), operands); wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), op.getSubwordOffset(), op.getClamp(), operands); loweredOp.addOperands(operands); Operation *lowered = rewriter.create(loweredOp); rewriter.replaceOp(op, lowered->getResults()); return success(); } }; namespace { struct ExtPackedFp8OpLowering final : public ConvertOpToLLVMPattern { ExtPackedFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern { PackedTrunc2xFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct PackedStochRoundFp8OpLowering final : public ConvertOpToLLVMPattern { PackedStochRoundFp8OpLowering(LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} Chipset chipset; LogicalResult matchAndRewrite(PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // end namespace LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type v4i8 = getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type())); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type f32 = getTypeConverter()->convertType(op.getResult().getType()); Value source = adaptor.getSource(); auto sourceVecType = op.getSource().getType().dyn_cast(); Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { Value longVec = rewriter.create(loc, v4i8); if (!sourceVecType) { longVec = rewriter.create( loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); Value elem = rewriter.create(loc, source, idx); longVec = rewriter.create(loc, longVec, elem, idx); } } source = longVec; } Value i32Source = rewriter.create(loc, i32, source); Value wordSel = createI32Constant(rewriter, loc, op.getIndex()); if (sourceElemType.isFloat8E5M2FNUZ()) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); } else if (sourceElemType.isFloat8E4M3FNUZ()) { rewriter.replaceOpWithNewOp(op, f32, i32Source, wordSel); } return success(); } LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) sourceB = rewriter.create(loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) existing = rewriter.create(loc, i32, existing); else existing = rewriter.create(loc, i32); Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); Value result; if (resultElemType.isFloat8E5M2FNUZ()) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); else if (resultElemType.isFloat8E4M3FNUZ()) result = rewriter.create(loc, i32, sourceA, sourceB, existing, wordSel); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); if (chipset.majorVersion != 9 || chipset.minorVersion < 0x40) return rewriter.notifyMatchFailure( loc, "Fp8 conversion instructions are not available on target " "architecture and their emulation is not implemented"); Type i32 = getTypeConverter()->convertType(rewriter.getI32Type()); Type resultType = op.getResult().getType(); Type resultElemType = getElementTypeOrSelf(resultType); Value source = adaptor.getSource(); Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) existing = rewriter.create(loc, i32, existing); else existing = rewriter.create(loc, i32); Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); Value result; if (resultElemType.isFloat8E5M2FNUZ()) result = rewriter.create(loc, i32, source, stoch, existing, byteSel); else if (resultElemType.isFloat8E4M3FNUZ()) result = rewriter.create(loc, i32, source, stoch, existing, byteSel); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); return success(); } struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLBase { ConvertAMDGPUToROCDLPass() = default; void runOnOperation() override { MLIRContext *ctx = &getContext(); FailureOr maybeChipset = Chipset::parse(chipset); if (failed(maybeChipset)) { emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset); return signalPassFailure(); } RewritePatternSet patterns(ctx); LLVMTypeConverter converter(ctx); populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { converter.addConversion([](BFloat16Type t) -> Type { return IntegerType::get(t.getContext(), 16); }); converter.addConversion([&converter](VectorType t) -> std::optional { if (!t.getElementType().isBF16()) return std::nullopt; return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16))); }); patterns.add(converter); patterns .add, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, RawBufferOpLowering, MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { return std::make_unique(); }