//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" namespace fir { #define GEN_PASS_DEF_ABSTRACTRESULTONFUNCOPT #define GEN_PASS_DEF_ABSTRACTRESULTONGLOBALOPT #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir #define DEBUG_TYPE "flang-abstract-result-opt" using namespace mlir; namespace fir { namespace { static mlir::Type getResultArgumentType(mlir::Type resultType, bool shouldBoxResult) { return llvm::TypeSwitch(resultType) .Case( [&](mlir::Type type) -> mlir::Type { if (shouldBoxResult) return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) .Case([](mlir::Type type) -> mlir::Type { return fir::ReferenceType::get(type); }) .Default([](mlir::Type) -> mlir::Type { llvm_unreachable("bad abstract result type"); }); } static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, bool shouldBoxResult) { auto resultType = funcTy.getResult(0); auto argTy = getResultArgumentType(resultType, shouldBoxResult); llvm::SmallVector newInputTypes = {argTy}; newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, /*resultTypes=*/{}); } /// This is for function result types that are of type C_PTR from ISO_C_BINDING. /// Follow the ABI for interoperability with C. static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { auto resultType = funcTy.getResult(0); assert(fir::isa_builtin_cptr_type(resultType)); llvm::SmallVector outputTypes; auto recTy = resultType.dyn_cast(); outputTypes.emplace_back(recTy.getTypeList()[0].second); return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), outputTypes); } static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { return resultType.isa() && shouldBoxResult; } template class CallConversion : public mlir::OpRewritePattern { public: using mlir::OpRewritePattern::OpRewritePattern; CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) : OpRewritePattern(context, 1), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto result = op->getResult(0); if (!result.hasOneUse()) { mlir::emitError(loc, "calls with abstract result must have exactly one user"); return mlir::failure(); } auto saveResult = mlir::dyn_cast(result.use_begin().getUser()); if (!saveResult) { mlir::emitError( loc, "calls with abstract result must be used in fir.save_result"); return mlir::failure(); } auto argType = getResultArgumentType(result.getType(), shouldBoxResult); auto buffer = saveResult.getMemref(); mlir::Value arg = buffer; if (mustEmboxResult(result.getType(), shouldBoxResult)) arg = rewriter.create( loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, saveResult.getTypeparams()); llvm::SmallVector newResultTypes; // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); Op newOp; if (isResultBuiltinCPtr) { auto recTy = result.getType().template dyn_cast(); newResultTypes.emplace_back(recTy.getTypeList()[0].second); } // fir::CallOp specific handling. if constexpr (std::is_same_v) { if (op.getCallee()) { llvm::SmallVector newOperands; if (!isResultBuiltinCPtr) newOperands.emplace_back(arg); newOperands.append(op.getOperands().begin(), op.getOperands().end()); newOp = rewriter.create(loc, *op.getCallee(), newResultTypes, newOperands); } else { // Indirect calls. llvm::SmallVector newInputTypes; if (!isResultBuiltinCPtr) newInputTypes.emplace_back(argType); for (auto operand : op.getOperands().drop_front()) newInputTypes.push_back(operand.getType()); auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, newResultTypes); llvm::SmallVector newOperands; newOperands.push_back( rewriter.create(loc, newFuncTy, op.getOperand(0))); if (!isResultBuiltinCPtr) newOperands.push_back(arg); newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); newOp = rewriter.create(loc, mlir::SymbolRefAttr{}, newResultTypes, newOperands); } } // fir::DispatchOp specific handling. if constexpr (std::is_same_v) { llvm::SmallVector newOperands; if (!isResultBuiltinCPtr) newOperands.emplace_back(arg); unsigned passArgShift = newOperands.size(); newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); fir::DispatchOp newDispatchOp; if (op.getPassArgPos()) newOp = rewriter.create( loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), op.getOperands()[0], newOperands, rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); else newOp = rewriter.create( loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), op.getOperands()[0], newOperands, nullptr); } if (isResultBuiltinCPtr) { mlir::Value save = saveResult.getMemref(); auto module = op->template getParentOfType(); FirOpBuilder builder(rewriter, module); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); rewriter.create(loc, newOp->getResult(0), saveAddr); } op->dropAllReferences(); rewriter.eraseOp(op); return mlir::success(); } private: bool shouldBoxResult; }; class SaveResultOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; SaveResultOpConversion(mlir::MLIRContext *context) : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite(fir::SaveResultOp op, mlir::PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return mlir::success(); } }; class ReturnOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) : OpRewritePattern(context), newArg{newArg} {} mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { auto loc = ret.getLoc(); rewriter.setInsertionPoint(ret); auto returnedValue = ret.getOperand(0); bool replacedStorage = false; if (auto *op = returnedValue.getDefiningOp()) if (auto load = mlir::dyn_cast(op)) { auto resultStorage = load.getMemref(); // The result alloca may be behind a fir.declare, if any. if (auto declare = mlir::dyn_cast_or_null( resultStorage.getDefiningOp())) resultStorage = declare.getMemref(); // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. if (fir::isa_builtin_cptr_type(returnedValue.getType())) { rewriter.eraseOp(load); auto module = ret->getParentOfType(); FirOpBuilder builder(rewriter, module); mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, resultStorage, returnedValue.getType()); mlir::Value retValue = rewriter.create( loc, fir::unwrapRefType(retAddr.getType()), retAddr); rewriter.replaceOpWithNewOp( ret, mlir::ValueRange{retValue}); return mlir::success(); } resultStorage.replaceAllUsesWith(newArg); replacedStorage = true; if (auto *alloc = resultStorage.getDefiningOp()) if (alloc->use_empty()) rewriter.eraseOp(alloc); } // The result storage may have been optimized out by a memory to // register pass, this is possible for fir.box results, or fir.record // with no length parameters. Simply store the result in the result storage. // at the return point. if (!replacedStorage) rewriter.create(loc, returnedValue, newArg); rewriter.replaceOpWithNewOp(ret); return mlir::success(); } private: mlir::Value newArg; }; class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = addrOf.getType().cast(); mlir::FunctionType newFuncTy; // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. if (oldFuncTy.getNumResults() != 0 && fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) newFuncTy = getCPtrFunctionType(oldFuncTy); else newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, addrOf.getSymbol()); // Rather than converting all op a function pointer might transit through // (e.g calls, stores, loads, converts...), cast new type to the abstract // type. A conversion will be added when calling indirect calls of abstract // types. rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); return mlir::success(); } private: bool shouldBoxResult; }; /// @brief Base CRTP class for AbstractResult pass family. /// Contains common logic for abstract result conversion in a reusable fashion. /// @tparam Pass target class that implements operation-specific logic. /// @tparam PassBase base class template for the pass generated by TableGen. /// The `Pass` class must define runOnSpecificOperation(OpTy, bool, /// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. /// This function should implement operation-specific functionality. template class PassBase> class AbstractResultOptTemplate : public PassBase { public: void runOnOperation() override { auto *context = &this->getContext(); auto op = this->getOperation(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; const bool shouldBoxResult = this->passResultAsBox.getValue(); auto &self = static_cast(*this); self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::CallOp call) { return !hasAbstractResult(call.getFunctionType()); }); target.addDynamicallyLegalOp([](fir::AddrOfOp addrOf) { if (auto funTy = addrOf.getType().dyn_cast()) return !hasAbstractResult(funTy); return true; }); target.addDynamicallyLegalOp([](fir::DispatchOp dispatch) { return !hasAbstractResult(dispatch.getFunctionType()); }); patterns.insert>(context, shouldBoxResult); patterns.insert>(context, shouldBoxResult); patterns.insert(context); patterns.insert(context, shouldBoxResult); if (mlir::failed( mlir::applyPartialConversion(op, target, std::move(patterns)))) { mlir::emitError(op.getLoc(), "error in converting abstract results\n"); this->signalPassFailure(); } } }; class AbstractResultOnFuncOpt : public AbstractResultOptTemplate { public: void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { auto loc = func.getLoc(); auto *context = &getContext(); // Convert function type itself if it has an abstract result. auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { func.setType(getCPtrFunctionType(funcTy)); patterns.insert(context, mlir::Value{}); target.addDynamicallyLegalOp( [](mlir::func::ReturnOp ret) { mlir::Type retTy = ret.getOperand(0).getType(); return !fir::isa_builtin_cptr_type(retTy); }); return; } if (!func.empty()) { // Insert new argument. mlir::OpBuilder rewriter(context); auto resultType = funcTy.getResult(0); auto argTy = getResultArgumentType(resultType, shouldBoxResult); func.insertArgument(0u, argTy, {}, loc); func.eraseResult(0u); mlir::Value newArg = func.getArgument(0u); if (mustEmboxResult(resultType, shouldBoxResult)) { auto bufferType = fir::ReferenceType::get(resultType); rewriter.setInsertionPointToStart(&func.front()); newArg = rewriter.create(loc, bufferType, newArg); } patterns.insert(context, newArg); target.addDynamicallyLegalOp( [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); assert(func.getFunctionType() == getNewFunctionType(funcTy, shouldBoxResult)); } else { llvm::SmallVector allArgs; func.getAllArgAttrs(allArgs); allArgs.insert(allArgs.begin(), mlir::DictionaryAttr::get(func->getContext())); func.setType(getNewFunctionType(funcTy, shouldBoxResult)); func.setAllArgAttrs(allArgs); } } } }; inline static bool containsFunctionTypeWithAbstractResult(mlir::Type type) { return mlir::TypeSwitch(type) .Case([](fir::BoxProcType boxProc) { return fir::hasAbstractResult( boxProc.getEleTy().cast()); }) .Case([](fir::PointerType pointer) { return fir::hasAbstractResult( pointer.getEleTy().cast()); }) .Default([](auto &&) { return false; }); } class AbstractResultOnGlobalOpt : public AbstractResultOptTemplate< AbstractResultOnGlobalOpt, fir::impl::AbstractResultOnGlobalOptBase> { public: void runOnSpecificOperation(fir::GlobalOp global, bool, mlir::RewritePatternSet &, mlir::ConversionTarget &) { if (containsFunctionTypeWithAbstractResult(global.getType())) { TODO(global->getLoc(), "support for procedure pointers"); } } }; } // end anonymous namespace } // namespace fir std::unique_ptr fir::createAbstractResultOnFuncOptPass() { return std::make_unique(); } std::unique_ptr fir::createAbstractResultOnGlobalOptPass() { return std::make_unique(); }