//===- Async.cpp - MLIR Async Operations ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::async; #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; void AsyncDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" >(); } //===----------------------------------------------------------------------===// /// ExecuteOp //===----------------------------------------------------------------------===// constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getBodyRegion() && "invalid region index"); return getBodyOperands(); } bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { const auto getValueOrTokenType = [](Type type) { if (auto value = llvm::dyn_cast(type)) return value.getValueType(); return type; }; return getValueOrTokenType(lhs) == getValueOrTokenType(rhs); } void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. if (point == getBodyRegion()) { regions.push_back(RegionSuccessor(getBodyResults())); return; } // Otherwise the successor is the body region. regions.push_back( RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); } void ExecuteOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange dependencies, ValueRange operands, BodyBuilderFn bodyBuilder) { result.addOperands(dependencies); result.addOperands(operands); // Add derived `operandSegmentSizes` attribute based on parsed operands. int32_t numDependencies = dependencies.size(); int32_t numOperands = operands.size(); auto operandSegmentSizes = builder.getDenseI32ArrayAttr({numDependencies, numOperands}); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); // First result is always a token, and then `resultTypes` wrapped into // `async.value`. result.addTypes({TokenType::get(result.getContext())}); for (Type type : resultTypes) result.addTypes(ValueType::get(type)); // Add a body region with block arguments as unwrapped async value operands. Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); for (Value operand : operands) { auto valueType = llvm::dyn_cast(operand.getType()); bodyBlock.addArgument(valueType ? valueType.getValueType() : operand.getType(), operand.getLoc()); } // Create the default terminator if the builder is not provided and if the // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); builder.create(result.location, ValueRange()); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); bodyBuilder(builder, result.location, bodyBlock.getArguments()); } } void ExecuteOp::print(OpAsmPrinter &p) { // [%tokens,...] if (!getDependencies().empty()) p << " [" << getDependencies() << "]"; // (%value as %unwrapped: !async.value, ...) if (!getBodyOperands().empty()) { p << " ("; Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front(); llvm::interleaveComma( getBodyOperands(), p, [&, n = 0](Value operand) mutable { Value argument = entry ? entry->getArgument(n++) : Value(); p << operand << " as " << argument << ": " << operand.getType(); }); p << ")"; } // -> (!async.value, ...) p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes())); p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {kOperandSegmentSizesAttr}); p << ' '; p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false); } ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = result.getContext(); // Sizes of parsed variadic operands, will be updated below after parsing. int32_t numDependencies = 0; auto tokenTy = TokenType::get(ctx); // Parse dependency tokens. if (succeeded(parser.parseOptionalLSquare())) { SmallVector tokenArgs; if (parser.parseOperandList(tokenArgs) || parser.resolveOperands(tokenArgs, tokenTy, result.operands) || parser.parseRSquare()) return failure(); numDependencies = tokenArgs.size(); } // Parse async value operands (%value as %unwrapped : !async.value). SmallVector valueArgs; SmallVector unwrappedArgs; SmallVector valueTypes; // Parse a single instance of `%value as %unwrapped : !async.value`. auto parseAsyncValueArg = [&]() -> ParseResult { if (parser.parseOperand(valueArgs.emplace_back()) || parser.parseKeyword("as") || parser.parseArgument(unwrappedArgs.emplace_back()) || parser.parseColonType(valueTypes.emplace_back())) return failure(); auto valueTy = llvm::dyn_cast(valueTypes.back()); unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type(); return success(); }; auto argsLoc = parser.getCurrentLocation(); if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen, parseAsyncValueArg) || parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) return failure(); int32_t numOperands = valueArgs.size(); // Add derived `operandSegmentSizes` attribute based on parsed operands. auto operandSegmentSizes = parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands}); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); // Parse the types of results returned from the async execute op. SmallVector resultTypes; NamedAttrList attrs; if (parser.parseOptionalArrowTypeList(resultTypes) || // Async execute first result is always a completion token. parser.addTypeToList(tokenTy, result.types) || parser.addTypesToList(resultTypes, result.types) || // Parse operation attributes. parser.parseOptionalAttrDictWithKeyword(attrs)) return failure(); result.addAttributes(attrs); // Parse asynchronous region. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/unwrappedArgs); } LogicalResult ExecuteOp::verifyRegions() { // Unwrap async.execute value operands types. auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) { return llvm::cast(operand.getType()).getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. if (getBodyRegion().getArgumentTypes() != unwrappedTypes) return emitOpError("async body region argument types do not match the " "execute operation arguments types"); return success(); } //===----------------------------------------------------------------------===// /// CreateGroupOp //===----------------------------------------------------------------------===// LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, PatternRewriter &rewriter) { // Find all `await_all` users of the group. llvm::SmallVector awaitAllUsers; auto isAwaitAll = [&](Operation *op) -> bool { if (AwaitAllOp awaitAll = dyn_cast(op)) { awaitAllUsers.push_back(awaitAll); return true; } return false; }; // Check if all users of the group are `await_all` operations. if (!llvm::all_of(op->getUsers(), isAwaitAll)) return failure(); // If group is only awaited without adding anything to it, we can safely erase // the create operation and all users. for (AwaitAllOp awaitAll : awaitAllUsers) rewriter.eraseOp(awaitAll); rewriter.eraseOp(op); return success(); } //===----------------------------------------------------------------------===// /// AwaitOp //===----------------------------------------------------------------------===// void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, ArrayRef attrs) { result.addOperands({operand}); result.attributes.append(attrs.begin(), attrs.end()); // Add unwrapped async.value type to the returned values types. if (auto valueType = llvm::dyn_cast(operand.getType())) result.addTypes(valueType.getValueType()); } static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType) { if (parser.parseType(operandType)) return failure(); // Add unwrapped async.value type to the returned values types. if (auto valueType = llvm::dyn_cast(operandType)) resultType = valueType.getValueType(); return success(); } static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType) { p << operandType; } LogicalResult AwaitOp::verify() { Type argType = getOperand().getType(); // Awaiting on a token does not have any results. if (llvm::isa(argType) && !getResultTypes().empty()) return emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. if (auto value = llvm::dyn_cast(argType)) { if (*getResultType() != value.getValueType()) return emitOpError() << "result type " << *getResultType() << " does not match async value type " << value.getValueType(); } return success(); } //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); function_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } /// Check that the result type of async.func is not void and must be /// some async token or async values. LogicalResult FuncOp::verify() { auto resultTypes = getResultTypes(); if (resultTypes.empty()) return emitOpError() << "result is expected to be at least of size 1, but got " << resultTypes.size(); for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) { auto type = resultTypes[i]; if (!llvm::isa(type) && !llvm::isa(type)) return emitOpError() << "result type must be async value type or async " "token type, but got " << type; // We only allow AsyncToken appear as the first return value if (llvm::isa(type) && i != 0) { return emitOpError() << " results' (optional) async token type is expected " "to appear as the 1st return value, but got " << i + 1; } } return success(); } //===----------------------------------------------------------------------===// /// CallOp //===----------------------------------------------------------------------===// LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid async function"; // Verify that the operand and result types match the callee. auto fnType = fn.getFunctionType(); if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (getOperand(i).getType() != fnType.getInput(i)) return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (getResult(i).getType() != fnType.getResult(i)) { auto diag = emitOpError("result type mismatch at index ") << i; diag.attachNote() << " op result types: " << getResultTypes(); diag.attachNote() << "function result types: " << fnType.getResults(); return diag; } return success(); } FunctionType CallOp::getCalleeType() { return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } //===----------------------------------------------------------------------===// /// ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { auto funcOp = (*this)->getParentOfType(); ArrayRef resultTypes = funcOp.isStateful() ? funcOp.getResultTypes().drop_front() : funcOp.getResultTypes(); // Get the underlying value types from async types returned from the // parent `async.func` operation. auto types = llvm::map_range(resultTypes, [](const Type &result) { return llvm::cast(result).getValueType(); }); if (getOperandTypes() != types) return emitOpError("operand types do not match the types returned from " "the parent FuncOp"); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" //===----------------------------------------------------------------------===// // TableGen'd type method definitions //===----------------------------------------------------------------------===// #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" void ValueType::print(AsmPrinter &printer) const { printer << "<"; printer.printType(getValueType()); printer << '>'; } Type ValueType::parse(mlir::AsmParser &parser) { Type ty; if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { parser.emitError(parser.getNameLoc(), "failed to parse async value type"); return Type(); } return ValueType::get(ty); }