//===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a translation NVVM ops which is not supported in LLVM // core. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace NVVM; namespace { struct PtxLowering : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern< BasicPtxBuilderInterface>::OpInterfaceRewritePattern; PtxLowering(MLIRContext *context, PatternBenefit benefit = 2) : OpInterfaceRewritePattern(context, benefit) {} LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); return failure(); } SmallVector> asmValues; LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); generator.insertValue(asmValue, modifier); } generator.buildAndReplaceOp(); return success(); } }; struct ConvertNVVMToLLVMPass : public impl::ConvertNVVMToLLVMPassBase { using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { ConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); RewritePatternSet pattern(&getContext()); mlir::populateNVVMToLLVMConversionPatterns(pattern); if (failed( applyPartialConversion(getOperation(), target, std::move(pattern)))) signalPassFailure(); } }; /// Implement the interface to convert NVVM to LLVM. struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateNVVMToLLVMConversionPatterns(patterns); } }; } // namespace void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) { dialect->addInterfaces(); }); }