//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements passes to convert `gpu.launch_func` op into a sequence // of LLVM calls that emulate the host and device sides. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" namespace mlir { #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; static constexpr const char kSPIRVModule[] = "__spv__"; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns the string name of the `DescriptorSet` decoration. static std::string descriptorSetName() { return llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); } /// Returns the string name of the `Binding` decoration. static std::string bindingName() { return llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::Binding)); } /// Calculates the index of the kernel's operand that is represented by the /// given global variable with the `bind` attribute. We assume that the index of /// each kernel's operand is mapped to (descriptorSet, binding) by the map: /// i -> (0, i) /// which is implemented under `LowerABIAttributesPass`. static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { IntegerAttr binding = op->getAttrOfType(bindingName()); return binding.getInt(); } /// Copies the given number of bytes from src to dst pointers. static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder) { builder.create(loc, dst, src, size, /*isVolatile=*/false); } /// Encodes the binding and descriptor set numbers into a new symbolic name. /// The name is specified by /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and /// binding numbers. static std::string createGlobalVariableWithBindName(spirv::GlobalVariableOp op, StringRef kernelModuleName) { IntegerAttr descriptorSet = op->getAttrOfType(descriptorSetName()); IntegerAttr binding = op->getAttrOfType(bindingName()); return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", kernelModuleName.str(), op.getSymName().str(), std::to_string(descriptorSet.getInt()), std::to_string(binding.getInt())); } /// Returns true if the given global variable has both a descriptor set number /// and a binding number. static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { IntegerAttr descriptorSet = op->getAttrOfType(descriptorSetName()); IntegerAttr binding = op->getAttrOfType(bindingName()); return descriptorSet && binding; } /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel /// arguments from the given SPIR-V module. We assume that the module contains a /// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind /// attribute are kernel arguments. static LogicalResult getKernelGlobalVariables( spirv::ModuleOp module, DenseMap &globalVariableMap) { auto entryPoints = module.getOps(); if (!llvm::hasSingleElement(entryPoints)) { return module.emitError( "The module must contain exactly one entry point function"); } auto globalVariables = module.getOps(); for (auto globalOp : globalVariables) { if (hasDescriptorSetAndBinding(globalOp)) globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; } return success(); } /// Encodes the SPIR-V module's symbolic name into the name of the entry point /// function. static LogicalResult encodeKernelName(spirv::ModuleOp module) { StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule); // We already know that the module contains exactly one entry point function // based on `getKernelGlobalVariables()` call. Update this function's name // to: // {spv_module_name}_{function_name} auto entryPoints = module.getOps(); if (!llvm::hasSingleElement(entryPoints)) { return module.emitError( "The module must contain exactly one entry point function"); } spirv::EntryPointOp entryPoint = *entryPoints.begin(); StringRef funcName = entryPoint.getFn(); auto funcOp = module.lookupSymbol(entryPoint.getFnAttr()); StringAttr newFuncName = StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) return failure(); SymbolTable::setSymbolName(funcOp, newFuncName); return success(); } //===----------------------------------------------------------------------===// // Conversion patterns //===----------------------------------------------------------------------===// namespace { /// Structure to group information about the variables being copied. struct CopyInfo { Value dst; Value src; Value size; }; /// This pattern emulates a call to the kernel in LLVM dialect. For that, we /// copy the data to the global variable (emulating device side), call the /// kernel as a normal void LLVM function, and copy the data back (emulating the /// host side). class GPULaunchLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *op = launchOp.getOperation(); MLIRContext *context = rewriter.getContext(); auto module = launchOp->getParentOfType(); // Get the SPIR-V module that represents the gpu kernel module. The module // is named: // __spv__{kernel_module_name} // based on GPU to SPIR-V conversion. StringRef kernelModuleName = launchOp.getKernelModuleName().getValue(); std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); auto spvModule = module.lookupSymbol( StringAttr::get(context, spvModuleName)); if (!spvModule) { return launchOp.emitOpError("SPIR-V kernel module '") << spvModuleName << "' is not found"; } // Declare kernel function in the main module so that it later can be linked // with its definition from the kernel module. We know that the kernel // function would have no arguments and the data is passed via global // variables. The name of the kernel will be // {spv_module_name}_{kernel_function_name} // to avoid symbolic name conflicts. StringRef kernelFuncName = launchOp.getKernelName().getValue(); std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); auto kernelFunc = module.lookupSymbol( StringAttr::get(context, newKernelFuncName)); if (!kernelFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); kernelFunc = rewriter.create( rewriter.getUnknownLoc(), newKernelFuncName, LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), ArrayRef())); rewriter.setInsertionPoint(launchOp); } // Get all global variables associated with the kernel operands. DenseMap globalVariableMap; if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) return failure(); // Traverse kernel operands that were converted to MemRefDescriptors. For // each operand, create a global variable and copy data from operand to it. Location loc = launchOp.getLoc(); SmallVector copyInfo; auto numKernelOperands = launchOp.getNumKernelOperands(); auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); for (const auto &operand : llvm::enumerate(kernelOperands)) { // Check if the kernel's operand is a ranked memref. auto memRefType = dyn_cast( launchOp.getKernelOperand(operand.index()).getType()); if (!memRefType) return failure(); // Calculate the size of the memref and get the pointer to the allocated // buffer. SmallVector sizes; SmallVector strides; Value sizeBytes; getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides, sizeBytes); MemRefDescriptor descriptor(operand.value()); Value src = descriptor.allocatedPtr(rewriter, loc); // Get the global variable in the SPIR-V module that is associated with // the kernel operand. Construct its new name and create a corresponding // LLVM dialect global variable. spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = cast(spirvGlobal.getType()).getPointeeType(); auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); std::string name = createGlobalVariableWithBindName(spirvGlobal, spvModuleName); // Check if this variable has already been created. auto dstGlobal = module.lookupSymbol(name); if (!dstGlobal) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); dstGlobal = rewriter.create( loc, dstGlobalType, /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), /*alignment=*/0); rewriter.setInsertionPoint(launchOp); } // Copy the data from src operand pointer to dst global variable. Save // src, dst and size so that we can copy data back after emulating the // kernel call. Value dst = rewriter.create( loc, typeConverter->convertType(spirvGlobal.getType()), dstGlobal.getSymName()); copy(loc, dst, src, sizeBytes, rewriter); CopyInfo info; info.dst = dst; info.src = src; info.size = sizeBytes; copyInfo.push_back(info); } // Create a call to the kernel and copy the data back. rewriter.replaceOpWithNewOp(op, kernelFunc, ArrayRef()); for (CopyInfo info : copyInfo) copy(loc, info.src, info.dst, info.size, rewriter); return success(); } }; class LowerHostCodeToLLVM : public impl::LowerHostCodeToLLVMPassBase { public: using Base::Base; void runOnOperation() override { ModuleOp module = getOperation(); // Erase the GPU module. for (auto gpuModule : llvm::make_early_inc_range(module.getOps())) gpuModule.erase(); // Request C wrapper emission. for (auto func : module.getOps()) { func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), UnitAttr::get(&getContext())); } // Specify options to lower to LLVM and pull in the conversion patterns. LowerToLLVMOptions options(module.getContext()); auto *context = module.getContext(); RewritePatternSet patterns(context); LLVMTypeConverter typeConverter(context, options); mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); patterns.add(typeConverter); // Pull in SPIR-V type conversion patterns to convert SPIR-V global // variable's type to LLVM dialect type. populateSPIRVToLLVMTypeConversion(typeConverter); ConversionTarget target(*context); target.addLegalDialect(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); // Finally, modify the kernel function in SPIR-V modules to avoid symbolic // conflicts. for (auto spvModule : module.getOps()) { if (failed(encodeKernelName(spvModule))) { signalPassFailure(); return; } } } }; } // namespace