bolt/deps/llvm-18.1.8/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

1164 lines
44 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
#define DEBUG_TYPE "convert-async-to-llvm"
using namespace mlir;
using namespace mlir::async;
//===----------------------------------------------------------------------===//
// Async Runtime C API declaration.
//===----------------------------------------------------------------------===//
static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
static constexpr const char *kGetValueStorage =
"mlirAsyncRuntimeGetValueStorage";
static constexpr const char *kAddTokenToGroup =
"mlirAsyncRuntimeAddTokenToGroup";
static constexpr const char *kAwaitTokenAndExecute =
"mlirAsyncRuntimeAwaitTokenAndExecute";
static constexpr const char *kAwaitValueAndExecute =
"mlirAsyncRuntimeAwaitValueAndExecute";
static constexpr const char *kAwaitAllAndExecute =
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
static constexpr const char *kGetNumWorkerThreads =
"mlirAsyncRuntimGetNumWorkerThreads";
namespace {
/// Async Runtime API function types.
///
/// Because we can't create API function signature for type parametrized
/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
/// lowering all async data types become opaque pointers at runtime.
struct AsyncAPI {
// All async types are lowered to opaque LLVM pointers at runtime.
static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
return LLVM::LLVMPointerType::get(ctx);
}
static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
return LLVM::LLVMTokenType::get(ctx);
}
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
auto ref = opaquePointerType(ctx);
auto count = IntegerType::get(ctx, 64);
return FunctionType::get(ctx, {ref, count}, {});
}
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
}
static FunctionType createValueFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(ctx, 64);
auto value = opaquePointerType(ctx);
return FunctionType::get(ctx, {i64}, {value});
}
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(ctx, 64);
return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
}
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
return FunctionType::get(ctx, {ptrType}, {ptrType});
}
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
auto value = opaquePointerType(ctx);
return FunctionType::get(ctx, {value}, {});
}
static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
auto value = opaquePointerType(ctx);
return FunctionType::get(ctx, {value}, {});
}
static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
auto i1 = IntegerType::get(ctx, 1);
return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
}
static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
auto value = opaquePointerType(ctx);
auto i1 = IntegerType::get(ctx, 1);
return FunctionType::get(ctx, {value}, {i1});
}
static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
auto i1 = IntegerType::get(ctx, 1);
return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
}
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
auto value = opaquePointerType(ctx);
return FunctionType::get(ctx, {value}, {});
}
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
}
static FunctionType executeFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
return FunctionType::get(ctx, {ptrType, ptrType}, {});
}
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(ctx, 64);
return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
{i64});
}
static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
}
static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
}
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
auto ptrType = opaquePointerType(ctx);
return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
}
static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
}
// Auxiliary coroutine resume intrinsic wrapper.
static Type resumeFunctionType(MLIRContext *ctx) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
auto ptrType = opaquePointerType(ctx);
return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
}
};
} // namespace
/// Adds Async Runtime C API declarations to the module.
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto builder =
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
auto addFuncDecl = [&](StringRef name, FunctionType type) {
if (module.lookupSymbol(name))
return;
builder.create<func::FuncOp>(name, type).setPrivate();
};
MLIRContext *ctx = module.getContext();
addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
addFuncDecl(kAwaitTokenAndExecute,
AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
addFuncDecl(kAwaitValueAndExecute,
AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
addFuncDecl(kAwaitAllAndExecute,
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
}
//===----------------------------------------------------------------------===//
// Coroutine resume function wrapper.
//===----------------------------------------------------------------------===//
static constexpr const char *kResume = "__resume";
/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
/// intrinsics. We need this function to be able to pass it to the async
/// runtime execute API.
static void addResumeFunction(ModuleOp module) {
if (module.lookupSymbol(kResume))
return;
MLIRContext *ctx = module.getContext();
auto loc = module.getLoc();
auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
auto voidTy = LLVM::LLVMVoidType::get(ctx);
Type ptrType = AsyncAPI::opaquePointerType(ctx);
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock();
auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
blockBuilder.create<LLVM::ReturnOp>(ValueRange());
}
//===----------------------------------------------------------------------===//
// Convert Async dialect types to LLVM types.
//===----------------------------------------------------------------------===//
namespace {
/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
/// their runtime type (opaque pointers) and does not convert any other types.
class AsyncRuntimeTypeConverter : public TypeConverter {
public:
AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
addConversion([](Type type) { return type; });
addConversion([](Type type) { return convertAsyncTypes(type); });
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
addSourceMaterialization(addUnrealizedCast);
addTargetMaterialization(addUnrealizedCast);
}
static std::optional<Type> convertAsyncTypes(Type type) {
if (isa<TokenType, GroupType, ValueType>(type))
return AsyncAPI::opaquePointerType(type.getContext());
if (isa<CoroIdType, CoroStateType>(type))
return AsyncAPI::tokenType(type.getContext());
if (isa<CoroHandleType>(type))
return AsyncAPI::opaquePointerType(type.getContext());
return std::nullopt;
}
};
/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
/// as type converter. Allows access to it via the 'getTypeConverter'
/// convenience method.
template <typename SourceOp>
class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
using Base = OpConversionPattern<SourceOp>;
public:
AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
MLIRContext *context)
: Base(typeConverter, context) {}
/// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
const AsyncRuntimeTypeConverter *getTypeConverter() const {
return static_cast<const AsyncRuntimeTypeConverter *>(
Base::getTypeConverter());
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.id to @llvm.coro.id intrinsic.
//===----------------------------------------------------------------------===//
namespace {
class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
public:
using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto token = AsyncAPI::tokenType(op->getContext());
auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
auto loc = op->getLoc();
// Constants for initializing coroutine frame.
auto constZero =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
// Get coroutine id: @llvm.coro.id.
rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.begin to @llvm.coro.begin intrinsic.
//===----------------------------------------------------------------------===//
namespace {
class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
public:
using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
auto loc = op->getLoc();
// Get coroutine frame size: @llvm.coro.size.i64.
Value coroSize =
rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
// Get coroutine frame alignment: @llvm.coro.align.i64.
Value coroAlign =
rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
// Round up the size to be multiple of the alignment. Since aligned_alloc
// requires the size parameter be an integral multiple of the alignment
// parameter.
auto makeConstant = [&](uint64_t c) {
return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
rewriter.getI64Type(), c);
};
coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
coroSize =
rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
Value negCoroAlign =
rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
coroSize =
rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
auto coroAlloc = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.free to @llvm.coro.free intrinsic.
//===----------------------------------------------------------------------===//
namespace {
class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
public:
using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
auto loc = op->getLoc();
// Get a pointer to the coroutine frame memory: @llvm.coro.free.
auto coroMem =
rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
ValueRange(coroMem.getResult()));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.end to @llvm.coro.end intrinsic.
//===----------------------------------------------------------------------===//
namespace {
class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We are not in the block that is part of the unwind sequence.
auto constFalse = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
// Mark the end of a coroutine: @llvm.coro.end.
auto coroHdl = adaptor.getHandle();
rewriter.create<LLVM::CoroEndOp>(
op->getLoc(), rewriter.getI1Type(),
ValueRange({coroHdl, constFalse, noneToken}));
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.save to @llvm.coro.save intrinsic.
//===----------------------------------------------------------------------===//
namespace {
class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Save the coroutine state: @llvm.coro.save
rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
//===----------------------------------------------------------------------===//
namespace {
/// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
/// branch to the appropriate block based on the return code.
///
/// Before:
///
/// ^suspended:
/// "opBefore"(...)
/// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
/// ^resume:
/// "op"(...)
/// ^cleanup: ...
/// ^suspend: ...
///
/// After:
///
/// ^suspended:
/// "opBefore"(...)
/// %suspend = llmv.intr.coro.suspend ...
/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
/// ^resume:
/// "op"(...)
/// ^cleanup: ...
/// ^suspend: ...
///
class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto i8 = rewriter.getIntegerType(8);
auto i32 = rewriter.getI32Type();
auto loc = op->getLoc();
// This is not a final suspension point.
auto constFalse = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
// Suspend a coroutine: @llvm.coro.suspend
auto coroState = adaptor.getState();
auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
loc, i8, ValueRange({coroState, constFalse}));
// Cast return code to i32.
// After a suspension point decide if we should branch into resume, cleanup
// or suspend block of the coroutine (see @llvm.coro.suspend return code
// documentation).
llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
op.getCleanupDest()};
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
/*defaultDestination=*/op.getSuspendDest(),
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
/*caseDestinations=*/caseDest,
/*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
/*branchWeights=*/ArrayRef<int32_t>());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.create to the corresponding runtime API call.
//
// To allocate storage for the async values we use getelementptr trick:
// http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
//===----------------------------------------------------------------------===//
namespace {
class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
// Tokens creation maps to a simple function call.
if (isa<TokenType>(resultType)) {
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kCreateToken, converter->convertType(resultType));
return success();
}
// To create a value we need to compute the storage requirement.
if (auto value = dyn_cast<ValueType>(resultType)) {
// Returns the size requirements for the async value storage.
auto sizeOf = [&](ValueType valueType) -> Value {
auto loc = op->getLoc();
auto i64 = rewriter.getI64Type();
auto storedType = converter->convertType(valueType.getValueType());
auto storagePtrType =
AsyncAPI::opaquePointerType(rewriter.getContext());
// %Size = getelementptr %T* null, int 1
// %SizeI = ptrtoint %T* %Size to i64
auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
auto gep =
rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
nullPtr, ArrayRef<LLVM::GEPArg>{1});
return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
};
rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
sizeOf(value));
return success();
}
return rewriter.notifyMatchFailure(op, "unsupported async type");
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.create_group to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeCreateGroupOpLowering
: public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const TypeConverter *converter = getTypeConverter();
Type resultType = op.getResult().getType();
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kCreateGroup, converter->convertType(resultType),
adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.set_available to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeSetAvailableOpLowering
: public OpConversionPattern<RuntimeSetAvailableOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kEmplaceToken; })
.Case<ValueType>([](Type) { return kEmplaceValue; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.set_error to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeSetErrorOpLowering
: public OpConversionPattern<RuntimeSetErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kSetTokenError; })
.Case<ValueType>([](Type) { return kSetValueError; });
rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.is_error to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kIsTokenError; })
.Case<GroupType>([](Type) { return kIsGroupError; })
.Case<ValueType>([](Type) { return kIsValueError; });
rewriter.replaceOpWithNewOp<func::CallOp>(
op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.await to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kAwaitToken; })
.Case<ValueType>([](Type) { return kAwaitValue; })
.Case<GroupType>([](Type) { return kAwaitGroup; });
rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
adaptor.getOperands());
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.await_and_resume to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeAwaitAndResumeOpLowering
: public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
public:
using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef apiFuncName =
TypeSwitch<Type, StringRef>(op.getOperand().getType())
.Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
.Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
.Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
Value operand = adaptor.getOperand();
Value handle = adaptor.getHandle();
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
rewriter.create<func::CallOp>(
op->getLoc(), apiFuncName, TypeRange(),
ValueRange({operand, handle, resumePtr.getRes()}));
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.resume to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeResumeOpLowering
: public AsyncOpConversionPattern<RuntimeResumeOp> {
public:
using AsyncOpConversionPattern::AsyncOpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// A pointer to coroutine resume intrinsic wrapper.
addResumeFunction(op->getParentOfType<ModuleOp>());
auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
kResume);
// Call async runtime API to execute a coroutine in the managed thread.
auto coroHdl = adaptor.getHandle();
rewriter.replaceOpWithNewOp<func::CallOp>(
op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.store to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
auto storagePtr = rewriter.create<func::CallOp>(
loc, kGetValueStorage, TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getValue().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
op, "failed to convert stored value type to LLVM type");
Value castedStoragePtr = storagePtr.getResult(0);
// Store the yielded value into the async value storage.
auto value = adaptor.getValue();
rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
// Erase the original runtime store operation.
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.load to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
public:
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
// Get a pointer to the async value storage from the runtime.
auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
auto storage = adaptor.getStorage();
auto storagePtr = rewriter.create<func::CallOp>(
loc, kGetValueStorage, TypeRange(ptrType), storage);
// Cast from i8* to the LLVM pointer type.
auto valueType = op.getResult().getType();
auto llvmValueType = getTypeConverter()->convertType(valueType);
if (!llvmValueType)
return rewriter.notifyMatchFailure(
op, "failed to convert loaded value type to LLVM type");
Value castedStoragePtr = storagePtr.getResult(0);
// Load from the casted pointer.
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
castedStoragePtr);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.add_to_group to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeAddToGroupOpLowering
: public OpConversionPattern<RuntimeAddToGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently we can only add tokens to the group.
if (!isa<TokenType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "only token type is supported");
// Replace with a runtime API function call.
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert async.runtime.num_worker_threads to the corresponding runtime API
// call.
//===----------------------------------------------------------------------===//
namespace {
class RuntimeNumWorkerThreadsOpLowering
: public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace with a runtime API function call.
rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
rewriter.getIndexType());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Async reference counting ops lowering (`async.runtime.add_ref` and
// `async.runtime.drop_ref` to the corresponding API calls).
//===----------------------------------------------------------------------===//
namespace {
template <typename RefCountingOp>
class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
public:
explicit RefCountingOpLowering(const TypeConverter &converter,
MLIRContext *ctx, StringRef apiFunctionName)
: OpConversionPattern<RefCountingOp>(converter, ctx),
apiFunctionName(apiFunctionName) {}
LogicalResult
matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto count = rewriter.create<arith::ConstantOp>(
op->getLoc(), rewriter.getI64Type(),
rewriter.getI64IntegerAttr(op.getCount()));
auto operand = adaptor.getOperand();
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
ValueRange({operand, count}));
return success();
}
private:
StringRef apiFunctionName;
};
class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
public:
explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
MLIRContext *ctx)
: RefCountingOpLowering(converter, ctx, kAddRef) {}
};
class RuntimeDropRefOpLowering
: public RefCountingOpLowering<RuntimeDropRefOp> {
public:
explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
MLIRContext *ctx)
: RefCountingOpLowering(converter, ctx, kDropRef) {}
};
} // namespace
//===----------------------------------------------------------------------===//
// Convert return operations that return async values from async regions.
//===----------------------------------------------------------------------===//
namespace {
class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
namespace {
struct ConvertAsyncToLLVMPass
: public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
void ConvertAsyncToLLVMPass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext *ctx = module->getContext();
LowerToLLVMOptions options(ctx);
// Add declarations for most functions required by the coroutines lowering.
// We delay adding the resume function until it's needed because it currently
// fails to compile unless '-O0' is specified.
addAsyncRuntimeApiDeclarations(module);
// Lower async.runtime and async.coro operations to Async Runtime API and
// LLVM coroutine intrinsics.
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter(options);
RewritePatternSet patterns(ctx);
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
LLVMTypeConverter llvmConverter(ctx, options);
llvmConverter.addConversion([&](Type type) {
return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
});
// Convert async types in function signatures and function calls.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
// Convert return operations inside async.execute regions.
patterns.add<ReturnOpOpConversion>(converter, ctx);
// Lower async.runtime operations to the async runtime API calls.
patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
ctx);
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
// Lower async coroutine operations to LLVM coroutine intrinsics.
patterns
.add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
converter, ctx);
ConversionTarget target(*ctx);
target.addLegalOp<arith::ConstantOp, func::ConstantOp,
UnrealizedConversionCastOp>();
target.addLegalDialect<LLVM::LLVMDialect>();
// All operations from Async dialect must be lowered to the runtime API and
// LLVM intrinsics calls.
target.addIllegalDialect<AsyncDialect>();
// Add dynamic legality constraints to apply conversions defined above.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
//===----------------------------------------------------------------------===//
// Patterns for structural type conversions for the Async dialect operations.
//===----------------------------------------------------------------------===//
namespace {
class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ExecuteOp newOp =
cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
// Set operands and update block argument and result types.
newOp->setOperands(adaptor.getOperands());
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
return failure();
for (auto result : newOp.getResults())
result.setType(typeConverter->convertType(result.getType()));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
// Dummy pattern to trigger the appropriate type conversion / materialization.
class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
return success();
}
};
// Dummy pattern to trigger the appropriate type conversion / materialization.
class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
return success();
}
};
} // namespace
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
typeConverter.addConversion([&](TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
Type converted = typeConverter.convertType(type.getValueType());
return converted ? ValueType::get(converted) : converted;
});
patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
}