//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to // SPIRV Cooperative Matrix ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include namespace mlir { //===----------------------------------------------------------------------===// // Patterns and helpers. //===----------------------------------------------------------------------===// /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. /// Returns false if cannot. /// /// See SPV_KHR_cooperative_matrix for supported elementwise ops. static bool createElementwiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, Type coopType, ValueRange operands) { assert((isa(coopType))); switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::ADDI: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::SUBF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::SUBI: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVS: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::DIVU: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::NEGATEF: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); return true; case gpu::MMAElementwiseOp::EXTF: builder.replaceOpWithNewOp(op, coopType, operands); return true; default: break; } return false; } bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { assert(!operands.empty()); if (!llvm::all_equal( llvm::map_range(operands, [](Value v) { return v.getType(); }))) return false; return isa(operands.front().getType()); } namespace { /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative /// matrix ops. struct WmmaConstantOpToSPIRVLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Value cst = adaptor.getOperands().front(); auto coopType = getTypeConverter()->convertType(op.getType()); if (!coopType) return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.replaceOpWithNewOp(op, coopType, cst); return success(); } }; /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// the default case. struct WmmaElementwiseOpToSPIRVDefaultLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { return rewriter.notifyMatchFailure(op, "not all operands are coop matrices"); } auto coopType = getTypeConverter()->convertType(op.getType()); if (!coopType) return rewriter.notifyMatchFailure(op, "type conversion failed"); return success( createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); } }; /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// matrix times scalar case. struct WmmaElementwiseOpToSPIRVScalarMulLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (adaptor.getOperands().size() != 2) return failure(); // All operands should be of cooperative matrix types. if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { return rewriter.notifyMatchFailure(op, "not all operands are coop matrices"); } if (op.getOpType() != gpu::MMAElementwiseOp::MULF) return failure(); // Use the original operands to check whether one of the operands is a splat // scalar value. Value lhs = op.getOperands().front(); Value rhs = op.getOperands().back(); Value splat = nullptr; Value matrix = nullptr; if (lhs.getDefiningOp()) { splat = adaptor.getOperands().front(); matrix = adaptor.getOperands().back(); } else if (rhs.getDefiningOp()) { matrix = adaptor.getOperands().front(); splat = adaptor.getOperands().back(); } if (!splat || !matrix) return rewriter.notifyMatchFailure(op, "no splat operand"); // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. Value scalar; auto cc = splat.getDefiningOp(); if (!cc) { return rewriter.notifyMatchFailure(op, "splat is not a composite construct"); } assert(cc.getConstituents().size() == 1); scalar = cc.getConstituents().front(); auto coopType = getTypeConverter()->convertType(op.getType()); if (!coopType) return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.replaceOpWithNewOp( op, coopType, ValueRange{matrix, scalar}); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // SPV_KHR_cooperative_matrix //===----------------------------------------------------------------------===// namespace khr { namespace { /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV /// dialect. struct WmmaLoadOpToSPIRVLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto &typeConverter = *getTypeConverter(); Location loc = op->getLoc(); auto retType = cast(op.getRes().getType()); MemRefType memrefType = op.getSrcMemref().getType(); Value bufferPtr = spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); auto coopType = typeConverter.convertType(retType); if (!coopType) return rewriter.notifyMatchFailure(op, "type conversion failed"); int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor : spirv::CooperativeMatrixLayoutKHR::RowMajor; rewriter.replaceOpWithNewOp( op, coopType, bufferPtr, strideValue, layout); return success(); } }; /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV /// dialect. struct WmmaStoreOpToSPIRVLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto &typeConverter = *getTypeConverter(); Location loc = op->getLoc(); auto memrefType = cast(op.getDstMemref().getType()); Value bufferPtr = spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); int64_t stride = op.getLeadDimension().getSExtValue(); IntegerType i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); bool isColMajor = op.getTranspose().value_or(false); auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor : spirv::CooperativeMatrixLayoutKHR::RowMajor; rewriter.replaceOpWithNewOp( op, bufferPtr, adaptor.getSrc(), strideValue, layout); return success(); } }; /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV /// dialect. struct WmmaMmaOpToSPIRVLowering final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), adaptor.getOpC()); return success(); } }; } // namespace } // namespace khr } // namespace mlir void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { using namespace mlir; MLIRContext *context = patterns.getContext(); patterns.add(converter, context); // Give the following patterns higher benefit to prevail over the default one. patterns.add(converter, context, /*benefit=*/2); } void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( mlir::SPIRVTypeConverter &typeConverter) { typeConverter.addConversion([](gpu::MMAMatrixType type) { ArrayRef retTypeShape = type.getShape(); Type elementType = type.getElementType(); auto use = llvm::StringSwitch(type.getOperand()) .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA) .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB) .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], retTypeShape[1], spirv::Scope::Subgroup, use); }); }