//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { /// A pattern that converts the region arguments in a single-region OpenMP /// operation to the LLVM dialect. The body of the region is not modified and is /// expected to either be processed by the conversion infrastructure or already /// contain ops compatible with LLVM dialect types. template struct RegionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newOp = rewriter.create( curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), newOp.getRegion().end()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *this->getTypeConverter()))) return failure(); rewriter.eraseOp(curOp); return success(); } }; template struct RegionLessOpWithVarOperandsConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); SmallVector convertedOperands; assert(curOp.getNumVariableOperands() == curOp.getOperation()->getNumOperands() && "unexpected non-variable operands"); for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } convertedOperands.emplace_back(adaptor.getOperands()[idx]); } rewriter.replaceOpWithNewOp(curOp, resTypes, convertedOperands, curOp->getAttrs()); return success(); } }; template struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); SmallVector convertedOperands; assert(curOp.getNumVariableOperands() == curOp.getOperation()->getNumOperands() && "unexpected non-variable operands"); for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } convertedOperands.emplace_back(adaptor.getOperands()[idx]); } auto newOp = rewriter.create(curOp.getLoc(), resTypes, convertedOperands, curOp->getAttrs()); rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), newOp.getRegion().end()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *this->getTypeConverter()))) return failure(); rewriter.eraseOp(curOp); return success(); } }; template struct RegionLessOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); rewriter.replaceOpWithNewOp(curOp, resTypes, adaptor.getOperands(), curOp->getAttrs()); return success(); } }; struct AtomicReadOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); Type curElementType = curOp.getElementType(); auto newOp = rewriter.create( curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType)); newOp.setElementTypeAttr(typeAttr); rewriter.eraseOp(curOp); return success(); } }; struct MapInfoOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); SmallVector resTypes; if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); // Copy attributes of the curOp except for the typeAttr which should // be converted SmallVector newAttrs; for (NamedAttribute attr : curOp->getAttrs()) { if (auto typeAttr = dyn_cast(attr.getValue())) { Type newAttr = converter->convertType(typeAttr.getValue()); newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); } else { newAttrs.push_back(attr); } } rewriter.replaceOpWithNewOp( curOp, resTypes, adaptor.getOperands(), newAttrs); return success(); } }; struct ReductionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (isa(curOp.getAccumulator().getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } rewriter.replaceOpWithNewOp( curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs()); return success(); } }; struct ReductionDeclareOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(omp::ReductionDeclareOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newOp = rewriter.create( curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(), TypeAttr::get(this->getTypeConverter()->convertType( curOp.getTypeAttr().getValue()))); for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) { rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx), newOp.getRegion(idx).end()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx), *this->getTypeConverter()))) return failure(); } rewriter.eraseOp(curOp); return success(); } }; } // namespace void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp< mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::TargetOp, mlir::omp::DataOp, mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionOp, mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskGroupOp, mlir::omp::TaskOp>([&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)) && typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); target.addDynamicallyLegalOp< mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp, mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, mlir::omp::EnterDataOp, mlir::omp::ExitDataOp, mlir::omp::UpdateDataOp, mlir::omp::DataBoundsOp, mlir::omp::MapInfoOp>([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()); }); target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)) && typeConverter.isLegal(&op->getRegion(1)) && typeConverter.isLegal(&op->getRegion(2)) && typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // This type is allowed when converting OpenMP to LLVM Dialect, it carries // bounds information for map clauses and the operation and type are // discarded on lowering to LLVM-IR from the OpenMP dialect. converter.addConversion( [&](omp::DataBoundsType type) -> Type { return type; }); patterns.add< AtomicReadOpConversion, MapInfoOpConversion, ReductionOpConversion, ReductionDeclareOpConversion, RegionOpConversion, RegionOpConversion, ReductionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpWithVarOperandsConversion>(converter); } namespace { struct ConvertOpenMPToLLVMPass : public impl::ConvertOpenMPToLLVMPassBase { using Base::Base; void runOnOperation() override; }; } // namespace void ConvertOpenMPToLLVMPass::runOnOperation() { auto module = getOperation(); // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); arith::populateArithToLLVMConversionPatterns(converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalOp(); configureOpenMPToLLVMConversionLegality(target, converter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); }