//===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is a command line utility that executes an MLIR file on the Vulkan by // translating MLIR GPU module to SPIR-V and host part to LLVM IR before // JIT-compiling and executing the latter. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; namespace { struct VulkanRunnerOptions { llvm::cl::OptionCategory category{"mlir-vulkan-runner options"}; llvm::cl::opt spirvWebGPUPrepare{ "vulkan-runner-spirv-webgpu-prepare", llvm::cl::desc("Run MLIR transforms used when targetting WebGPU"), llvm::cl::cat(category)}; }; } // namespace static LogicalResult runMLIRPasses(Operation *op, VulkanRunnerOptions &options) { auto module = dyn_cast(op); if (!module) return op->emitOpError("expected a 'builtin.module' op"); PassManager passManager(module.getContext()); if (failed(applyPassManagerCLOptions(passManager))) return failure(); passManager.addPass(createGpuKernelOutliningPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass()); modulePM.addPass(spirv::createSPIRVUpdateVCEPass()); if (options.spirvWebGPUPrepare) modulePM.addPass(spirv::createSPIRVWebGPUPreparePass()); passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); passManager.addPass(createFinalizeMemRefToLLVMConversionPass()); passManager.addPass(createConvertVectorToLLVMPass()); passManager.nest().addPass(LLVM::createRequestCWrappersPass()); ConvertFuncToLLVMPassOptions funcToLLVMOptions{}; funcToLLVMOptions.indexBitwidth = DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext())); passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions)); passManager.addPass(createReconcileUnrealizedCastsPass()); passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); return passManager.run(module); } int main(int argc, char **argv) { llvm::llvm_shutdown_obj x; registerPassManagerCLOptions(); llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); // Initialize runner-specific CLI options. These will be parsed and // initialzied in `JitRunnerMain`. VulkanRunnerOptions options; auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) { return runMLIRPasses(op, options); }; mlir::JitRunnerConfig jitRunnerConfig; jitRunnerConfig.mlirTransformer = runPassesWithOptions; mlir::DialectRegistry registry; registry.insert(); mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); }