//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTOPENACCTOSCF #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; //===----------------------------------------------------------------------===// // Conversion patterns //===----------------------------------------------------------------------===// namespace { /// Pattern to transform the `getIfCond` on operation without region into a /// scf.if and move the operation into the `then` region. template class ExpandIfCondition : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Early exit if there is no condition. if (!op.getIfCond()) return failure(); IntegerAttr constAttr; if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.getIfCond(), false); rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener()); thenBodyBuilder.clone(*op.getOperation()); rewriter.eraseOp(op); } else { if (constAttr.getInt()) rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else rewriter.eraseOp(op); } return success(); } }; } // namespace void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) { patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); } namespace { struct ConvertOpenACCToSCFPass : public impl::ConvertOpenACCToSCFBase { void runOnOperation() override; }; } // namespace void ConvertOpenACCToSCFPass::runOnOperation() { auto op = getOperation(); auto *context = op.getContext(); RewritePatternSet patterns(context); ConversionTarget target(*context); populateOpenACCToSCFConversionPatterns(patterns); target.addLegalDialect(); target.addLegalDialect(); target.addDynamicallyLegalOp( [](acc::EnterDataOp op) { return !op.getIfCond(); }); target.addDynamicallyLegalOp( [](acc::ExitDataOp op) { return !op.getIfCond(); }); target.addDynamicallyLegalOp( [](acc::UpdateOp op) { return !op.getIfCond(); }); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertOpenACCToSCFPass() { return std::make_unique(); }