//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // // ============================================================================= #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace acc; #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc" #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc" namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel { Type getElementType(Type pointer) const { return llvm::cast(pointer).getElementType(); } }; struct LLVMPointerPointerLikeModel : public PointerLikeType::ExternalModel { Type getElementType(Type pointer) const { return Type(); } }; } // namespace //===----------------------------------------------------------------------===// // OpenACC operations //===----------------------------------------------------------------------===// void OpenACCDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" >(); // By attaching interfaces here, we make the OpenACC dialect dependent on // the other dialects. This is probably better than having dialects like LLVM // and memref be dependent on OpenACC. MemRefType::attachInterface(*getContext()); LLVM::LLVMPointerType::attachInterface( *getContext()); } //===----------------------------------------------------------------------===// // device_type support helpers //===----------------------------------------------------------------------===// static bool hasDeviceTypeValues(std::optional arrayAttr) { if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) return true; return false; } static bool hasDeviceType(std::optional arrayAttr, mlir::acc::DeviceType deviceType) { if (!hasDeviceTypeValues(arrayAttr)) return false; for (auto attr : *arrayAttr) { auto deviceTypeAttr = mlir::dyn_cast(attr); if (deviceTypeAttr.getValue() == deviceType) return true; } return false; } static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional deviceTypes) { if (!hasDeviceTypeValues(deviceTypes)) return; p << "["; llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) { p << attr; }); p << "]"; } //===----------------------------------------------------------------------===// // DataBoundsOp //===----------------------------------------------------------------------===// LogicalResult acc::DataBoundsOp::verify() { auto extent = getExtent(); auto upperbound = getUpperbound(); if (!extent && !upperbound) return emitError("expected extent or upperbound."); return success(); } //===----------------------------------------------------------------------===// // PrivateOp //===----------------------------------------------------------------------===// LogicalResult acc::PrivateOp::verify() { if (getDataClause() != acc::DataClause::acc_private) return emitError( "data clause associated with private operation must match its intent"); return success(); } //===----------------------------------------------------------------------===// // FirstprivateOp //===----------------------------------------------------------------------===// LogicalResult acc::FirstprivateOp::verify() { if (getDataClause() != acc::DataClause::acc_firstprivate) return emitError("data clause associated with firstprivate operation must " "match its intent"); return success(); } //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// LogicalResult acc::ReductionOp::verify() { if (getDataClause() != acc::DataClause::acc_reduction) return emitError("data clause associated with reduction operation must " "match its intent"); return success(); } //===----------------------------------------------------------------------===// // DevicePtrOp //===----------------------------------------------------------------------===// LogicalResult acc::DevicePtrOp::verify() { if (getDataClause() != acc::DataClause::acc_deviceptr) return emitError("data clause associated with deviceptr operation must " "match its intent"); return success(); } //===----------------------------------------------------------------------===// // PresentOp //===----------------------------------------------------------------------===// LogicalResult acc::PresentOp::verify() { if (getDataClause() != acc::DataClause::acc_present) return emitError( "data clause associated with present operation must match its intent"); return success(); } //===----------------------------------------------------------------------===// // CopyinOp //===----------------------------------------------------------------------===// LogicalResult acc::CopyinOp::verify() { // Test for all clauses this operation can be decomposed from: if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin && getDataClause() != acc::DataClause::acc_copyin_readonly && getDataClause() != acc::DataClause::acc_copy && getDataClause() != acc::DataClause::acc_reduction) return emitError( "data clause associated with copyin operation must match its intent" " or specify original clause this operation was decomposed from"); return success(); } bool acc::CopyinOp::isCopyinReadonly() { return getDataClause() == acc::DataClause::acc_copyin_readonly; } //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// LogicalResult acc::CreateOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_create && getDataClause() != acc::DataClause::acc_create_zero && getDataClause() != acc::DataClause::acc_copyout && getDataClause() != acc::DataClause::acc_copyout_zero) return emitError( "data clause associated with create operation must match its intent" " or specify original clause this operation was decomposed from"); return success(); } bool acc::CreateOp::isCreateZero() { // The zero modifier is encoded in the data clause. return getDataClause() == acc::DataClause::acc_create_zero || getDataClause() == acc::DataClause::acc_copyout_zero; } //===----------------------------------------------------------------------===// // NoCreateOp //===----------------------------------------------------------------------===// LogicalResult acc::NoCreateOp::verify() { if (getDataClause() != acc::DataClause::acc_no_create) return emitError("data clause associated with no_create operation must " "match its intent"); return success(); } //===----------------------------------------------------------------------===// // AttachOp //===----------------------------------------------------------------------===// LogicalResult acc::AttachOp::verify() { if (getDataClause() != acc::DataClause::acc_attach) return emitError( "data clause associated with attach operation must match its intent"); return success(); } //===----------------------------------------------------------------------===// // DeclareDeviceResidentOp //===----------------------------------------------------------------------===// LogicalResult acc::DeclareDeviceResidentOp::verify() { if (getDataClause() != acc::DataClause::acc_declare_device_resident) return emitError("data clause associated with device_resident operation " "must match its intent"); return success(); } //===----------------------------------------------------------------------===// // DeclareLinkOp //===----------------------------------------------------------------------===// LogicalResult acc::DeclareLinkOp::verify() { if (getDataClause() != acc::DataClause::acc_declare_link) return emitError( "data clause associated with link operation must match its intent"); return success(); } //===----------------------------------------------------------------------===// // CopyoutOp //===----------------------------------------------------------------------===// LogicalResult acc::CopyoutOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_copyout && getDataClause() != acc::DataClause::acc_copyout_zero && getDataClause() != acc::DataClause::acc_copy && getDataClause() != acc::DataClause::acc_reduction) return emitError( "data clause associated with copyout operation must match its intent" " or specify original clause this operation was decomposed from"); if (!getVarPtr() || !getAccPtr()) return emitError("must have both host and device pointers"); return success(); } bool acc::CopyoutOp::isCopyoutZero() { return getDataClause() == acc::DataClause::acc_copyout_zero; } //===----------------------------------------------------------------------===// // DeleteOp //===----------------------------------------------------------------------===// LogicalResult acc::DeleteOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_delete && getDataClause() != acc::DataClause::acc_create && getDataClause() != acc::DataClause::acc_create_zero && getDataClause() != acc::DataClause::acc_copyin && getDataClause() != acc::DataClause::acc_copyin_readonly && getDataClause() != acc::DataClause::acc_present && getDataClause() != acc::DataClause::acc_declare_device_resident && getDataClause() != acc::DataClause::acc_declare_link) return emitError( "data clause associated with delete operation must match its intent" " or specify original clause this operation was decomposed from"); if (!getAccPtr()) return emitError("must have device pointer"); return success(); } //===----------------------------------------------------------------------===// // DetachOp //===----------------------------------------------------------------------===// LogicalResult acc::DetachOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_detach && getDataClause() != acc::DataClause::acc_attach) return emitError( "data clause associated with detach operation must match its intent" " or specify original clause this operation was decomposed from"); if (!getAccPtr()) return emitError("must have device pointer"); return success(); } //===----------------------------------------------------------------------===// // HostOp //===----------------------------------------------------------------------===// LogicalResult acc::UpdateHostOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_update_host && getDataClause() != acc::DataClause::acc_update_self) return emitError( "data clause associated with host operation must match its intent" " or specify original clause this operation was decomposed from"); if (!getVarPtr() || !getAccPtr()) return emitError("must have both host and device pointers"); return success(); } //===----------------------------------------------------------------------===// // DeviceOp //===----------------------------------------------------------------------===// LogicalResult acc::UpdateDeviceOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_update_device) return emitError( "data clause associated with device operation must match its intent" " or specify original clause this operation was decomposed from"); return success(); } //===----------------------------------------------------------------------===// // UseDeviceOp //===----------------------------------------------------------------------===// LogicalResult acc::UseDeviceOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_use_device) return emitError( "data clause associated with use_device operation must match its intent" " or specify original clause this operation was decomposed from"); return success(); } //===----------------------------------------------------------------------===// // CacheOp //===----------------------------------------------------------------------===// LogicalResult acc::CacheOp::verify() { // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_cache && getDataClause() != acc::DataClause::acc_cache_readonly) return emitError( "data clause associated with cache operation must match its intent" " or specify original clause this operation was decomposed from"); return success(); } template static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions = 1) { SmallVector regions; for (unsigned i = 0; i < nRegions; ++i) regions.push_back(state.addRegion()); for (Region *region : regions) if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) return failure(); return success(); } static bool isComputeOperation(Operation *op) { return isa(op) || isa(op); } namespace { /// Pattern to remove operation without region that have constant false `ifCond` /// and remove the condition from the operation if the `ifCond` is a true /// constant. template struct RemoveConstantIfCondition : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Early return if there is no condition. Value ifCond = op.getIfCond(); if (!ifCond) return failure(); IntegerAttr constAttr; if (!matchPattern(ifCond, m_Constant(&constAttr))) return failure(); if (constAttr.getInt()) rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else rewriter.eraseOp(op); return success(); } }; /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs = {}) { assert(llvm::hasSingleElement(region) && "expected single-region block"); Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); rewriter.inlineBlockBefore(block, op, blockArgs); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } /// Pattern to remove operation with region that have constant false `ifCond` /// and remove the condition from the operation if the `ifCond` is constant /// true. template struct RemoveConstantIfConditionWithRegion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Early return if there is no condition. Value ifCond = op.getIfCond(); if (!ifCond) return failure(); IntegerAttr constAttr; if (!matchPattern(ifCond, m_Constant(&constAttr))) return failure(); if (constAttr.getInt()) rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // PrivateRecipeOp //===----------------------------------------------------------------------===// static LogicalResult verifyInitLikeSingleArgRegion( Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional = false) { if (optional && region.empty()) return success(); if (region.empty()) return op->emitOpError() << "expects non-empty " << regionName << " region"; Block &firstBlock = region.front(); if (firstBlock.getNumArguments() < 1 || firstBlock.getArgument(0).getType() != type) return op->emitOpError() << "expects " << regionName << " region first " "argument of the " << regionType << " type"; if (verifyYield) { for (YieldOp yieldOp : region.getOps()) { if (yieldOp.getOperands().size() != 1 || yieldOp.getOperands().getTypes()[0] != type) return op->emitOpError() << "expects " << regionName << " region to " "yield a value of the " << regionType << " type"; } } return success(); } LogicalResult acc::PrivateRecipeOp::verifyRegions() { if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "privatization", "init", getType(), /*verifyYield=*/false))) return failure(); if (failed(verifyInitLikeSingleArgRegion( *this, getDestroyRegion(), "privatization", "destroy", getType(), /*verifyYield=*/false, /*optional=*/true))) return failure(); return success(); } //===----------------------------------------------------------------------===// // FirstprivateRecipeOp //===----------------------------------------------------------------------===// LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "privatization", "init", getType(), /*verifyYield=*/false))) return failure(); if (getCopyRegion().empty()) return emitOpError() << "expects non-empty copy region"; Block &firstBlock = getCopyRegion().front(); if (firstBlock.getNumArguments() < 2 || firstBlock.getArgument(0).getType() != getType()) return emitOpError() << "expects copy region with two arguments of the " "privatization type"; if (getDestroyRegion().empty()) return success(); if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(), "privatization", "destroy", getType(), /*verifyYield=*/false))) return failure(); return success(); } //===----------------------------------------------------------------------===// // ReductionRecipeOp //===----------------------------------------------------------------------===// LogicalResult acc::ReductionRecipeOp::verifyRegions() { if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction", "init", getType(), /*verifyYield=*/false))) return failure(); if (getCombinerRegion().empty()) return emitOpError() << "expects non-empty combiner region"; Block &reductionBlock = getCombinerRegion().front(); if (reductionBlock.getNumArguments() < 2 || reductionBlock.getArgument(0).getType() != getType() || reductionBlock.getArgument(1).getType() != getType()) return emitOpError() << "expects combiner region with the first two " << "arguments of the reduction type"; for (YieldOp yieldOp : getCombinerRegion().getOps()) { if (yieldOp.getOperands().size() != 1 || yieldOp.getOperands().getTypes()[0] != getType()) return emitOpError() << "expects combiner region to yield a value " "of the reduction type"; } return success(); } //===----------------------------------------------------------------------===// // Custom parser and printer verifier for private clause //===----------------------------------------------------------------------===// static ParseResult parseSymOperandList( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &symbols) { llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); return success(); }))) return failure(); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); symbols = ArrayAttr::get(parser.getContext(), arrayAttr); return success(); } static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional attributes) { llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " << std::get<1>(it).getType(); }); } //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// /// Check dataOperands for acc.parallel, acc.serial and acc.kernels. template static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands) { for (mlir::Value operand : operands) if (!mlir::isa( operand.getDefiningOp())) return op.emitError( "expect data entry/exit operation or acc.getdeviceptr " "as defining op"); return success(); } template static LogicalResult checkSymOperandList(Operation *op, std::optional attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType = true) { if (!operands.empty()) { if (!attributes || attributes->size() != operands.size()) return op->emitOpError() << "expected as many " << symbolName << " symbol reference as " << operandName << " operands"; } else { if (attributes) return op->emitOpError() << "unexpected " << symbolName << " symbol reference"; return success(); } llvm::DenseSet set; for (auto args : llvm::zip(operands, *attributes)) { mlir::Value operand = std::get<0>(args); if (!set.insert(operand).second) return op->emitOpError() << operandName << " operand appears more than once"; mlir::Type varType = operand.getType(); auto symbolRef = llvm::cast(std::get<1>(args)); auto decl = SymbolTable::lookupNearestSymbolFrom(op, symbolRef); if (!decl) return op->emitOpError() << "expected symbol reference " << symbolRef << " to point to a " << operandName << " declaration"; if (checkOperandType && decl.getType() && decl.getType() != varType) return op->emitOpError() << "expected " << operandName << " (" << varType << ") to be the same type as " << operandName << " declaration (" << decl.getType() << ")"; } return success(); } unsigned ParallelOp::getNumDataOperands() { return getReductionOperands().size() + getGangPrivateOperands().size() + getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); } Value ParallelOp::getDataOperand(unsigned i) { unsigned numOptional = getAsync().size(); numOptional += getNumGangs().size(); numOptional += getNumWorkers().size(); numOptional += getVectorLength().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } template static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword) { if (!operands.empty() && deviceTypes.getValue().size() != operands.size()) return op.emitOpError() << keyword << " operands count must match " << keyword << " device_type count"; return success(); } template static LogicalResult verifyDeviceTypeAndSegmentCountMatch( Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { std::size_t numOperandsInSegments = 0; if (!segments) return success(); for (auto segCount : segments.asArrayRef()) { if (maxInSegment != 0 && segCount > maxInSegment) return op.emitOpError() << keyword << " expects a maximum of " << maxInSegment << " values per segment"; numOperandsInSegments += segCount; } if (numOperandsInSegments != operands.size()) return op.emitOpError() << keyword << " operand count does not match count in segments"; if (deviceTypes.getValue().size() != (size_t)segments.size()) return op.emitOpError() << keyword << " segment count does not match device_type count"; return success(); } LogicalResult acc::ParallelOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", "privatizations", /*checkOperandType=*/false))) return failure(); if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getNumGangs(), getNumGangsSegmentsAttr(), getNumGangsDeviceTypeAttr(), "num_gangs", 3))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), getWaitOperandsDeviceTypeAttr(), "wait"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), getNumWorkersDeviceTypeAttr(), "num_workers"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), getVectorLengthDeviceTypeAttr(), "vector_length"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), getAsyncDeviceTypeAttr(), "async"))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); } static std::optional findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType) { unsigned segmentIdx = 0; for (auto attr : segments) { auto deviceTypeAttr = mlir::dyn_cast(attr); if (deviceTypeAttr.getValue() == deviceType) return std::make_optional(segmentIdx); ++segmentIdx; } return std::nullopt; } static mlir::Value getValueInDeviceTypeSegment(std::optional arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType) { if (!arrayAttr) return {}; if (auto pos = findSegment(*arrayAttr, deviceType)) return range[*pos]; return {}; } bool acc::ParallelOp::hasAsyncOnly() { return hasAsyncOnly(mlir::acc::DeviceType::None); } bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::ParallelOp::getAsyncValue() { return getAsyncValue(mlir::acc::DeviceType::None); } mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), deviceType); } mlir::Value acc::ParallelOp::getNumWorkersValue() { return getNumWorkersValue(mlir::acc::DeviceType::None); } mlir::Value acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), deviceType); } mlir::Value acc::ParallelOp::getVectorLengthValue() { return getVectorLengthValue(mlir::acc::DeviceType::None); } mlir::Value acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), getVectorLength(), deviceType); } mlir::Operation::operand_range ParallelOp::getNumGangsValues() { return getNumGangsValues(mlir::acc::DeviceType::None); } static mlir::Operation::operand_range getValuesFromSegments(std::optional arrayAttr, mlir::Operation::operand_range range, std::optional> segments, mlir::acc::DeviceType deviceType) { if (!arrayAttr) return range.take_front(0); if (auto pos = findSegment(*arrayAttr, deviceType)) { int32_t nbOperandsBefore = 0; for (unsigned i = 0; i < *pos; ++i) nbOperandsBefore += (*segments)[i]; return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]); } return range.take_front(0); } mlir::Operation::operand_range ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(), deviceType); } bool acc::ParallelOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range ParallelOp::getWaitValues() { return getWaitValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), deviceType); } static ParseResult parseNumGangs( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments) { llvm::SmallVector attributes; llvm::SmallVector seg; do { if (failed(parser.parseLBrace())) return failure(); int32_t crtOperandsSize = operands.size(); if (failed(parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::None, [&]() { if (parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); return success(); }))) return failure(); seg.push_back(operands.size() - crtOperandsSize); if (failed(parser.parseRBrace())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseRSquare()) return failure(); } else { attributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } } while (succeeded(parser.parseOptionalComma())); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); segments = DenseI32ArrayAttr::get(parser.getContext(), seg); return success(); } static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) { auto deviceTypeAttr = mlir::dyn_cast(attr); if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) p << " [" << attr << "]"; } static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes, std::optional segments) { unsigned opIdx = 0; llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { p << "{"; llvm::interleaveComma( llvm::seq(0, (*segments)[it.index()]), p, [&](auto it) { p << operands[opIdx] << " : " << operands[opIdx].getType(); ++opIdx; }); p << "}"; printSingleDeviceType(p, it.value()); }); } static ParseResult parseDeviceTypeOperandsWithSegment( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments) { llvm::SmallVector attributes; llvm::SmallVector seg; do { if (failed(parser.parseLBrace())) return failure(); int32_t crtOperandsSize = operands.size(); if (failed(parser.parseCommaSeparatedList( mlir::AsmParser::Delimiter::None, [&]() { if (parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); return success(); }))) return failure(); seg.push_back(operands.size() - crtOperandsSize); if (failed(parser.parseRBrace())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseRSquare()) return failure(); } else { attributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } } while (succeeded(parser.parseOptionalComma())); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); segments = DenseI32ArrayAttr::get(parser.getContext(), seg); return success(); } static void printDeviceTypeOperandsWithSegment( mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes, std::optional segments) { unsigned opIdx = 0; llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { p << "{"; llvm::interleaveComma( llvm::seq(0, (*segments)[it.index()]), p, [&](auto it) { p << operands[opIdx] << " : " << operands[opIdx].getType(); ++opIdx; }); p << "}"; printSingleDeviceType(p, it.value()); }); } static ParseResult parseDeviceTypeOperands( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes) { llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseRSquare()) return failure(); } else { attributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } return success(); }))) return failure(); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); return success(); } static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes) { llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) { p << std::get<1>(it) << " : " << std::get<1>(it).getType(); printSingleDeviceType(p, std::get<0>(it)); }); } static ParseResult parseDeviceTypeOperandsWithKeywordOnly( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType) { llvm::SmallVector keywordOnlyDeviceTypeAttributes; bool needCommaBeforeOperands = false; if (failed(parser.parseOptionalLParen())) return failure(); // Parse keyword only attributes if (succeeded(parser.parseOptionalLSquare())) { // Parse keyword only attributes if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute( keywordOnlyDeviceTypeAttributes.emplace_back())) return failure(); return success(); }))) return failure(); if (parser.parseRSquare()) return failure(); needCommaBeforeOperands = true; } else if (succeeded(parser.parseOptionalRParen())) { // Keyword only keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); keywordOnlyDeviceType = ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes); return success(); } if (needCommaBeforeOperands && failed(parser.parseComma())) return failure(); llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(attributes.emplace_back()) || parser.parseRSquare()) return failure(); } else { attributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } return success(); }))) return failure(); if (failed(parser.parseRParen())) return failure(); llvm::SmallVector arrayAttr(attributes.begin(), attributes.end()); deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); return success(); } static void printDeviceTypeOperandsWithKeywordOnly( mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional deviceTypes, std::optional keywordOnlyDeviceTypes) { p << "("; if (operands.begin() == operands.end() && keywordOnlyDeviceTypes && keywordOnlyDeviceTypes->size() == 1) { auto deviceTypeAttr = mlir::dyn_cast((*keywordOnlyDeviceTypes)[0]); if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) { p << ")"; return; } } printDeviceTypes(p, keywordOnlyDeviceTypes); if (hasDeviceTypeValues(keywordOnlyDeviceTypes) && hasDeviceTypeValues(deviceTypes)) p << ", "; printDeviceTypeOperands(p, op, operands, types, deviceTypes); p << ")"; } //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// unsigned SerialOp::getNumDataOperands() { return getReductionOperands().size() + getGangPrivateOperands().size() + getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); } Value SerialOp::getDataOperand(unsigned i) { unsigned numOptional = getAsync().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } bool acc::SerialOp::hasAsyncOnly() { return hasAsyncOnly(mlir::acc::DeviceType::None); } bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::SerialOp::getAsyncValue() { return getAsyncValue(mlir::acc::DeviceType::None); } mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), deviceType); } bool acc::SerialOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range SerialOp::getWaitValues() { return getWaitValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), deviceType); } LogicalResult acc::SerialOp::verify() { if (failed(checkSymOperandList( *this, getPrivatizations(), getGangPrivateOperands(), "private", "privatizations", /*checkOperandType=*/false))) return failure(); if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), getWaitOperandsDeviceTypeAttr(), "wait"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), getAsyncDeviceTypeAttr(), "async"))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); } //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// unsigned KernelsOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value KernelsOp::getDataOperand(unsigned i) { unsigned numOptional = getAsync().size(); numOptional += getWaitOperands().size(); numOptional += getNumGangs().size(); numOptional += getNumWorkers().size(); numOptional += getVectorLength().size(); numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; return getOperand(numOptional + i); } bool acc::KernelsOp::hasAsyncOnly() { return hasAsyncOnly(mlir::acc::DeviceType::None); } bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value acc::KernelsOp::getAsyncValue() { return getAsyncValue(mlir::acc::DeviceType::None); } mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), deviceType); } mlir::Value acc::KernelsOp::getNumWorkersValue() { return getNumWorkersValue(mlir::acc::DeviceType::None); } mlir::Value acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), deviceType); } mlir::Value acc::KernelsOp::getVectorLengthValue() { return getVectorLengthValue(mlir::acc::DeviceType::None); } mlir::Value acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), getVectorLength(), deviceType); } mlir::Operation::operand_range KernelsOp::getNumGangsValues() { return getNumGangsValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(), deviceType); } bool acc::KernelsOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range KernelsOp::getWaitValues() { return getWaitValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), deviceType); } LogicalResult acc::KernelsOp::verify() { if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getNumGangs(), getNumGangsSegmentsAttr(), getNumGangsDeviceTypeAttr(), "num_gangs", 3))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), getWaitOperandsDeviceTypeAttr(), "wait"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), getNumWorkersDeviceTypeAttr(), "num_workers"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), getVectorLengthDeviceTypeAttr(), "vector_length"))) return failure(); if (failed(verifyDeviceTypeCountMatch(*this, getAsync(), getAsyncDeviceTypeAttr(), "async"))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); } //===----------------------------------------------------------------------===// // HostDataOp //===----------------------------------------------------------------------===// LogicalResult acc::HostDataOp::verify() { if (getDataClauseOperands().empty()) return emitError("at least one operand must appear on the host_data " "operation"); for (mlir::Value operand : getDataClauseOperands()) if (!mlir::isa(operand.getDefiningOp())) return emitError("expect data entry operation as defining op"); return success(); } void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// static ParseResult parseGangValue( OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl &operands, llvm::SmallVectorImpl &types, llvm::SmallVector &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue) { if (succeeded(parser.parseOptionalKeyword(keyword))) { if (parser.parseEqual()) return failure(); if (parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); attributes.push_back(gangArgType); needCommaBetweenValues = true; newValue = true; } return success(); } static ParseResult parseGangClause( OpAsmParser &parser, llvm::SmallVectorImpl &gangOperands, llvm::SmallVectorImpl &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType) { llvm::SmallVector gangArgTypeAttributes; llvm::SmallVector deviceTypeAttributes; llvm::SmallVector gangOnlyDeviceTypeAttributes; llvm::SmallVector seg; bool needCommaBetweenValues = false; bool needCommaBeforeOperands = false; if (failed(parser.parseOptionalLParen())) return failure(); // Parse gang only attributes if (succeeded(parser.parseOptionalLSquare())) { // Parse gang only attributes if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute( gangOnlyDeviceTypeAttributes.emplace_back())) return failure(); return success(); }))) return failure(); if (parser.parseRSquare()) return failure(); needCommaBeforeOperands = true; } else if (succeeded(parser.parseOptionalRParen())) { // Gang only keyword gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes); return success(); } auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(), mlir::acc::GangArgType::Num); auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(), mlir::acc::GangArgType::Dim); auto argStatic = mlir::acc::GangArgTypeAttr::get( parser.getContext(), mlir::acc::GangArgType::Static); do { if (needCommaBeforeOperands) { needCommaBeforeOperands = false; continue; } if (failed(parser.parseLBrace())) return failure(); int32_t crtOperandsSize = gangOperands.size(); while (true) { bool newValue = false; bool needValue = false; if (needCommaBetweenValues) { if (succeeded(parser.parseOptionalComma())) needValue = true; // expect a new value after comma. else break; } if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), gangOperands, gangOperandsType, gangArgTypeAttributes, argNum, needCommaBetweenValues, newValue))) return failure(); if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), gangOperands, gangOperandsType, gangArgTypeAttributes, argDim, needCommaBetweenValues, newValue))) return failure(); if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(), gangOperands, gangOperandsType, gangArgTypeAttributes, argStatic, needCommaBetweenValues, newValue))) return failure(); if (!newValue && needValue) { parser.emitError(parser.getCurrentLocation(), "new value expected after comma"); return failure(); } if (!newValue) break; } if (gangOperands.empty()) return parser.emitError( parser.getCurrentLocation(), "expect at least one of num, dim or static values"); if (failed(parser.parseRBrace())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) || parser.parseRSquare()) return failure(); } else { deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } seg.push_back(gangOperands.size() - crtOperandsSize); } while (succeeded(parser.parseOptionalComma())); if (failed(parser.parseRParen())) return failure(); llvm::SmallVector arrayAttr(gangArgTypeAttributes.begin(), gangArgTypeAttributes.end()); gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr); deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes); llvm::SmallVector gangOnlyAttr( gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end()); gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr); segments = DenseI32ArrayAttr::get(parser.getContext(), seg); return success(); } void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional gangArgTypes, std::optional deviceTypes, std::optional segments, std::optional gangOnlyDeviceTypes) { p << "("; if (operands.begin() == operands.end() && hasDeviceTypeValues(gangOnlyDeviceTypes) && gangOnlyDeviceTypes->size() == 1) { auto deviceTypeAttr = mlir::dyn_cast((*gangOnlyDeviceTypes)[0]); if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) { p << ")"; return; } } printDeviceTypes(p, gangOnlyDeviceTypes); if (hasDeviceTypeValues(gangOnlyDeviceTypes) && hasDeviceTypeValues(deviceTypes)) p << ", "; if (hasDeviceTypeValues(deviceTypes)) { unsigned opIdx = 0; llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { p << "{"; llvm::interleaveComma( llvm::seq(0, (*segments)[it.index()]), p, [&](auto it) { auto gangArgTypeAttr = mlir::dyn_cast( (*gangArgTypes)[opIdx]); if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num) p << LoopOp::getGangNumKeyword(); else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim) p << LoopOp::getGangDimKeyword(); else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static) p << LoopOp::getGangStaticKeyword(); p << "=" << operands[opIdx] << " : " << operands[opIdx].getType(); ++opIdx; }); p << "}"; printSingleDeviceType(p, it.value()); }); } p << ")"; } bool hasDuplicateDeviceTypes( std::optional segments, llvm::SmallSet &deviceTypes) { if (!segments) return false; for (auto attr : *segments) { auto deviceTypeAttr = mlir::dyn_cast(attr); if (deviceTypes.contains(deviceTypeAttr.getValue())) return true; deviceTypes.insert(deviceTypeAttr.getValue()); } return false; } /// Check for duplicates in the DeviceType array attribute. LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { llvm::SmallSet crtDeviceTypes; if (!deviceTypes) return success(); for (auto attr : deviceTypes) { auto deviceTypeAttr = mlir::dyn_cast_or_null(attr); if (!deviceTypeAttr) return failure(); if (crtDeviceTypes.contains(deviceTypeAttr.getValue())) return failure(); crtDeviceTypes.insert(deviceTypeAttr.getValue()); } return success(); } LogicalResult acc::LoopOp::verify() { if (!getUpperbound().empty() && getInclusiveUpperbound() && (getUpperbound().size() != getInclusiveUpperbound()->size())) return emitError() << "inclusiveUpperbound size is expected to be the same" << " as upperbound size"; // Check collapse if (getCollapseAttr() && !getCollapseDeviceTypeAttr()) return emitOpError() << "collapse device_type attr must be define when" << " collapse attr is present"; if (getCollapseAttr() && getCollapseDeviceTypeAttr() && getCollapseAttr().getValue().size() != getCollapseDeviceTypeAttr().getValue().size()) return emitOpError() << "collapse attribute count must match collapse" << " device_type count"; if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) return emitOpError() << "duplicate device_type found in collapseDeviceType attribute"; // Check gang if (!getGangOperands().empty()) { if (!getGangOperandsArgType()) return emitOpError() << "gangOperandsArgType attribute must be defined" << " when gang operands are present"; if (getGangOperands().size() != getGangOperandsArgTypeAttr().getValue().size()) return emitOpError() << "gangOperandsArgType attribute count must match" << " gangOperands count"; } if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) return emitOpError() << "duplicate device_type found in gang attribute"; if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getGangOperands(), getGangOperandsSegmentsAttr(), getGangOperandsDeviceTypeAttr(), "gang"))) return failure(); // Check worker if (failed(checkDeviceTypes(getWorkerAttr()))) return emitOpError() << "duplicate device_type found in worker attribute"; if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) return emitOpError() << "duplicate device_type found in " "workerNumOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), getWorkerNumOperandsDeviceTypeAttr(), "worker"))) return failure(); // Check vector if (failed(checkDeviceTypes(getVectorAttr()))) return emitOpError() << "duplicate device_type found in vector attribute"; if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) return emitOpError() << "duplicate device_type found in " "vectorOperandsDeviceType attribute"; if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), getVectorOperandsDeviceTypeAttr(), "vector"))) return failure(); if (failed(verifyDeviceTypeAndSegmentCountMatch( *this, getTileOperands(), getTileOperandsSegmentsAttr(), getTileOperandsDeviceTypeAttr(), "tile"))) return failure(); // auto, independent and seq attribute are mutually exclusive. llvm::SmallSet deviceTypes; if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) || hasDuplicateDeviceTypes(getIndependent(), deviceTypes) || hasDuplicateDeviceTypes(getSeq(), deviceTypes)) { return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName() << "\", " << getIndependentAttrName() << ", " << getSeqAttrName() << " can be present at the same time"; } // Gang, worker and vector are incompatible with seq. if (getSeqAttr()) { for (auto attr : getSeqAttr()) { auto deviceTypeAttr = mlir::dyn_cast(attr); if (hasVector(deviceTypeAttr.getValue()) || getVectorValue(deviceTypeAttr.getValue()) || hasWorker(deviceTypeAttr.getValue()) || getWorkerValue(deviceTypeAttr.getValue()) || hasGang(deviceTypeAttr.getValue()) || getGangValue(mlir::acc::GangArgType::Num, deviceTypeAttr.getValue()) || getGangValue(mlir::acc::GangArgType::Dim, deviceTypeAttr.getValue()) || getGangValue(mlir::acc::GangArgType::Static, deviceTypeAttr.getValue())) return emitError() << "gang, worker or vector cannot appear with the seq attr"; } } if (failed(checkSymOperandList( *this, getPrivatizations(), getPrivateOperands(), "private", "privatizations", false))) return failure(); if (failed(checkSymOperandList( *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); // Check non-empty body(). if (getRegion().empty()) return emitError("expected non-empty body."); return success(); } unsigned LoopOp::getNumDataOperands() { return getReductionOperands().size() + getPrivateOperands().size(); } Value LoopOp::getDataOperand(unsigned i) { unsigned numOptional = getLowerbound().size() + getUpperbound().size() + getStep().size(); numOptional += getGangOperands().size(); numOptional += getVectorOperands().size(); numOptional += getWorkerNumOperands().size(); numOptional += getTileOperands().size(); numOptional += getCacheOperands().size(); return getOperand(numOptional + i); } bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); } bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) { return hasDeviceType(getAuto_(), deviceType); } bool LoopOp::hasIndependent() { return hasIndependent(mlir::acc::DeviceType::None); } bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) { return hasDeviceType(getIndependent(), deviceType); } bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) { return hasDeviceType(getSeq(), deviceType); } mlir::Value LoopOp::getVectorValue() { return getVectorValue(mlir::acc::DeviceType::None); } mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(), getVectorOperands(), deviceType); } bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) { return hasDeviceType(getVector(), deviceType); } mlir::Value LoopOp::getWorkerValue() { return getWorkerValue(mlir::acc::DeviceType::None); } mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(), getWorkerNumOperands(), deviceType); } bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWorker(), deviceType); } mlir::Operation::operand_range LoopOp::getTileValues() { return getTileValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range LoopOp::getTileValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(), getTileOperandsSegments(), deviceType); } std::optional LoopOp::getCollapseValue() { return getCollapseValue(mlir::acc::DeviceType::None); } std::optional LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) { if (!getCollapseAttr()) return std::nullopt; if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) { auto intAttr = mlir::dyn_cast(getCollapseAttr().getValue()[*pos]); return intAttr.getValue().getZExtValue(); } return std::nullopt; } mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) { return getGangValue(gangArgType, mlir::acc::DeviceType::None); } mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType, mlir::acc::DeviceType deviceType) { if (getGangOperands().empty()) return {}; if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) { int32_t nbOperandsBefore = 0; for (unsigned i = 0; i < *pos; ++i) nbOperandsBefore += (*getGangOperandsSegments())[i]; mlir::Operation::operand_range values = getGangOperands() .drop_front(nbOperandsBefore) .take_front((*getGangOperandsSegments())[*pos]); int32_t argTypeIdx = nbOperandsBefore; for (auto value : values) { auto gangArgTypeAttr = mlir::dyn_cast( (*getGangOperandsArgType())[argTypeIdx]); if (gangArgTypeAttr.getValue() == gangArgType) return value; ++argTypeIdx; } } return {}; } bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) { return hasDeviceType(getGang(), deviceType); } llvm::SmallVector acc::LoopOp::getLoopRegions() { return {&getRegion()}; } /// loop-control ::= `(` ssa-id-and-type-list `)` `=` `(` ssa-id-and-type-list /// `)` `to` `(` ssa-id-and-type-list `)` `step` `(` ssa-id-and-type-list `)` ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl &lowerbound, SmallVectorImpl &lowerboundType, SmallVectorImpl &upperbound, SmallVectorImpl &upperboundType, SmallVectorImpl &step, SmallVectorImpl &stepType) { SmallVector inductionVars; if (succeeded(parser.parseOptionalLParen())) { if (parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None, /*allowType=*/true) || parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || parser.parseOperandList(lowerbound, inductionVars.size(), OpAsmParser::Delimiter::None) || parser.parseColonTypeList(lowerboundType) || parser.parseRParen() || parser.parseKeyword("to") || parser.parseLParen() || parser.parseOperandList(upperbound, inductionVars.size(), OpAsmParser::Delimiter::None) || parser.parseColonTypeList(upperboundType) || parser.parseRParen() || parser.parseKeyword("step") || parser.parseLParen() || parser.parseOperandList(step, inductionVars.size(), OpAsmParser::Delimiter::None) || parser.parseColonTypeList(stepType) || parser.parseRParen()) return failure(); } return parser.parseRegion(region, inductionVars); } void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType) { ValueRange regionArgs = region.front().getArguments(); if (!regionArgs.empty()) { p << "("; llvm::interleaveComma(regionArgs, p, [&p](Value v) { p << v << " : " << v.getType(); }); p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" << upperbound << " : " << upperboundType << ") " << " step (" << steps << " : " << stepType << ") "; } p.printRegion(region, /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // DataOp //===----------------------------------------------------------------------===// LogicalResult acc::DataOp::verify() { // 2.6.5. Data Construct restriction // At least one copy, copyin, copyout, create, no_create, present, deviceptr, // attach, or default clause must appear on a data construct. if (getOperands().empty() && !getDefaultAttr()) return emitError("at least one operand or the default attribute " "must appear on the data operation"); for (mlir::Value operand : getDataClauseOperands()) if (!mlir::isa( operand.getDefiningOp())) return emitError("expect data entry/exit operation or acc.getdeviceptr " "as defining op"); return success(); } unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value DataOp::getDataOperand(unsigned i) { unsigned numOptional = getIfCond() ? 1 : 0; numOptional += getAsync().size() ? 1 : 0; numOptional += getWaitOperands().size(); return getOperand(numOptional + i); } bool acc::DataOp::hasAsyncOnly() { return hasAsyncOnly(mlir::acc::DeviceType::None); } bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getAsyncOnly(), deviceType); } mlir::Value DataOp::getAsyncValue() { return getAsyncValue(mlir::acc::DeviceType::None); } mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) { return getValueInDeviceTypeSegment(getAsyncDeviceType(), getAsync(), deviceType); } bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWaitOnly(), deviceType); } mlir::Operation::operand_range DataOp::getWaitValues() { return getWaitValues(mlir::acc::DeviceType::None); } mlir::Operation::operand_range DataOp::getWaitValues(mlir::acc::DeviceType deviceType) { return getValuesFromSegments(getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), deviceType); } //===----------------------------------------------------------------------===// // ExitDataOp //===----------------------------------------------------------------------===// LogicalResult acc::ExitDataOp::verify() { // 2.6.6. Data Exit Directive restriction // At least one copyout, delete, or detach clause must appear on an exit data // directive. if (getDataClauseOperands().empty()) return emitError("at least one operand must be present in dataOperands on " "the exit data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. if (getAsyncOperand() && getAsync()) return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. if (!getWaitOperands().empty() && getWait()) return emitError("wait attribute cannot appear with waitOperands"); if (getWaitDevnum() && getWaitOperands().empty()) return emitError("wait_devnum cannot appear without waitOperands"); return success(); } unsigned ExitDataOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value ExitDataOp::getDataOperand(unsigned i) { unsigned numOptional = getIfCond() ? 1 : 0; numOptional += getAsyncOperand() ? 1 : 0; numOptional += getWaitDevnum() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // EnterDataOp //===----------------------------------------------------------------------===// LogicalResult acc::EnterDataOp::verify() { // 2.6.6. Data Enter Directive restriction // At least one copyin, create, or attach clause must appear on an enter data // directive. if (getDataClauseOperands().empty()) return emitError("at least one operand must be present in dataOperands on " "the enter data operation"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. if (getAsyncOperand() && getAsync()) return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. if (!getWaitOperands().empty() && getWait()) return emitError("wait attribute cannot appear with waitOperands"); if (getWaitDevnum() && getWaitOperands().empty()) return emitError("wait_devnum cannot appear without waitOperands"); for (mlir::Value operand : getDataClauseOperands()) if (!mlir::isa( operand.getDefiningOp())) return emitError("expect data entry operation as defining op"); return success(); } unsigned EnterDataOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value EnterDataOp::getDataOperand(unsigned i) { unsigned numOptional = getIfCond() ? 1 : 0; numOptional += getAsyncOperand() ? 1 : 0; numOptional += getWaitDevnum() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // AtomicReadOp //===----------------------------------------------------------------------===// LogicalResult AtomicReadOp::verify() { return verifyCommon(); } //===----------------------------------------------------------------------===// // AtomicWriteOp //===----------------------------------------------------------------------===// LogicalResult AtomicWriteOp::verify() { return verifyCommon(); } //===----------------------------------------------------------------------===// // AtomicUpdateOp //===----------------------------------------------------------------------===// LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, PatternRewriter &rewriter) { if (op.isNoOp()) { rewriter.eraseOp(op); return success(); } if (Value writeVal = op.getWriteOpVal()) { rewriter.replaceOpWithNewOp(op, op.getX(), writeVal); return success(); } return failure(); } LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); } LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); } //===----------------------------------------------------------------------===// // AtomicCaptureOp //===----------------------------------------------------------------------===// AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { if (auto op = dyn_cast(getFirstOp())) return op; return dyn_cast(getSecondOp()); } AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { if (auto op = dyn_cast(getFirstOp())) return op; return dyn_cast(getSecondOp()); } AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { if (auto op = dyn_cast(getFirstOp())) return op; return dyn_cast(getSecondOp()); } LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); } //===----------------------------------------------------------------------===// // DeclareEnterOp //===----------------------------------------------------------------------===// template static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand = true) { if (operands.empty() && requireAtLeastOneOperand) return emitError( op->getLoc(), "at least one operand must appear on the declare operation"); for (mlir::Value operand : operands) { if (!mlir::isa( operand.getDefiningOp())) return op.emitError( "expect valid declare data entry operation or acc.getdeviceptr " "as defining op"); mlir::Value varPtr{getVarPtr(operand.getDefiningOp())}; assert(varPtr && "declare operands can only be data entry operations which " "must have varPtr"); std::optional dataClauseOptional{ getDataClause(operand.getDefiningOp())}; assert(dataClauseOptional.has_value() && "declare operands can only be data entry operations which must have " "dataClause"); // If varPtr has no defining op - there is nothing to check further. if (!varPtr.getDefiningOp()) continue; // Check that the varPtr has a declare attribute. auto declareAttribute{ varPtr.getDefiningOp()->getAttr(mlir::acc::getDeclareAttrName())}; if (!declareAttribute) return op.emitError( "expect declare attribute on variable in declare operation"); auto declAttr = mlir::cast(declareAttribute); if (declAttr.getDataClause().getValue() != dataClauseOptional.value()) return op.emitError( "expect matching declare attribute on variable in declare operation"); // If the variable is marked with implicit attribute, the matching declare // data action must also be marked implicit. The reverse is not checked // since implicit data action may be inserted to do actions like updating // device copy, in which case the variable is not necessarily implicitly // declare'd. if (declAttr.getImplicit() && declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp())) return op.emitError( "implicitness must match between declare op and flag on variable"); } return success(); } LogicalResult acc::DeclareEnterOp::verify() { return checkDeclareOperands(*this, this->getDataClauseOperands()); } //===----------------------------------------------------------------------===// // DeclareExitOp //===----------------------------------------------------------------------===// LogicalResult acc::DeclareExitOp::verify() { if (getToken()) return checkDeclareOperands(*this, this->getDataClauseOperands(), /*requireAtLeastOneOperand=*/false); return checkDeclareOperands(*this, this->getDataClauseOperands()); } //===----------------------------------------------------------------------===// // DeclareOp //===----------------------------------------------------------------------===// LogicalResult acc::DeclareOp::verify() { return checkDeclareOperands(*this, this->getDataClauseOperands()); } //===----------------------------------------------------------------------===// // RoutineOp //===----------------------------------------------------------------------===// static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype) { unsigned parallelism = 0; parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0; parallelism += op.hasWorker(dtype) ? 1 : 0; parallelism += op.hasVector(dtype) ? 1 : 0; parallelism += op.hasSeq(dtype) ? 1 : 0; return parallelism; } LogicalResult acc::RoutineOp::verify() { unsigned baseParallelism = getParallelismForDeviceType(*this, acc::DeviceType::None); if (baseParallelism > 1) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " "be present at the same time"; for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); ++dtypeInt) { auto dtype = static_cast(dtypeInt); if (dtype == acc::DeviceType::None) continue; unsigned parallelism = getParallelismForDeviceType(*this, dtype); if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " "be present at the same time"; } return success(); } static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes) { llvm::SmallVector bindNameAttrs; llvm::SmallVector deviceTypeAttrs; if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute(bindNameAttrs.emplace_back())) return failure(); if (failed(parser.parseOptionalLSquare())) { deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } else { if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) || parser.parseRSquare()) return failure(); } return success(); }))) return failure(); bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs); deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); return success(); } static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional bindName, std::optional deviceTypes) { llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p, [&](const auto &pair) { p << std::get<0>(pair); printSingleDeviceType(p, std::get<1>(pair)); }); } static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes) { llvm::SmallVector gangAttrs, gangDimAttrs, gangDimDeviceTypeAttrs; bool needCommaBeforeOperands = false; // Gang keyword only if (failed(parser.parseOptionalLParen())) { gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); gang = ArrayAttr::get(parser.getContext(), gangAttrs); return success(); } // Parse keyword only attributes if (succeeded(parser.parseOptionalLSquare())) { if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute(gangAttrs.emplace_back())) return failure(); return success(); }))) return failure(); if (parser.parseRSquare()) return failure(); needCommaBeforeOperands = true; } if (needCommaBeforeOperands && failed(parser.parseComma())) return failure(); if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) || parser.parseColon() || parser.parseAttribute(gangDimAttrs.emplace_back())) return failure(); if (succeeded(parser.parseOptionalLSquare())) { if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) || parser.parseRSquare()) return failure(); } else { gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); } return success(); }))) return failure(); if (failed(parser.parseRParen())) return failure(); gang = ArrayAttr::get(parser.getContext(), gangAttrs); gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs); gangDimDeviceTypes = ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs); return success(); } void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional gang, std::optional gangDim, std::optional gangDimDeviceTypes) { if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) && gang->size() == 1) { auto deviceTypeAttr = mlir::dyn_cast((*gang)[0]); if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) return; } p << "("; printDeviceTypes(p, gang); if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes)) p << ", "; if (hasDeviceTypeValues(gangDimDeviceTypes)) llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p, [&](const auto &pair) { p << acc::RoutineOp::getGangDimKeyword() << ": "; p << std::get<0>(pair); printSingleDeviceType(p, std::get<1>(pair)); }); p << ")"; } static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes) { llvm::SmallVector attributes; // Keyword only if (failed(parser.parseOptionalLParen())) { attributes.push_back(mlir::acc::DeviceTypeAttr::get( parser.getContext(), mlir::acc::DeviceType::None)); deviceTypes = ArrayAttr::get(parser.getContext(), attributes); return success(); } // Parse device type attributes if (succeeded(parser.parseOptionalLSquare())) { if (failed(parser.parseCommaSeparatedList([&]() { if (parser.parseAttribute(attributes.emplace_back())) return failure(); return success(); }))) return failure(); if (parser.parseRSquare() || parser.parseRParen()) return failure(); } deviceTypes = ArrayAttr::get(parser.getContext(), attributes); return success(); } static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional deviceTypes) { if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) { auto deviceTypeAttr = mlir::dyn_cast((*deviceTypes)[0]); if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) return; } if (!hasDeviceTypeValues(deviceTypes)) return; p << "(["; llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) { auto dTypeAttr = mlir::dyn_cast(attr); p << dTypeAttr; }); p << "])"; } bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) { return hasDeviceType(getWorker(), deviceType); } bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) { return hasDeviceType(getVector(), deviceType); } bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) { return hasDeviceType(getSeq(), deviceType); } std::optional RoutineOp::getBindNameValue() { return getBindNameValue(mlir::acc::DeviceType::None); } std::optional RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) { if (!hasDeviceTypeValues(getBindNameDeviceType())) return std::nullopt; if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) { auto attr = (*getBindName())[*pos]; auto stringAttr = dyn_cast(attr); return stringAttr.getValue(); } return std::nullopt; } bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) { return hasDeviceType(getGang(), deviceType); } std::optional RoutineOp::getGangDimValue() { return getGangDimValue(mlir::acc::DeviceType::None); } std::optional RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { if (!hasDeviceTypeValues(getGangDimDeviceType())) return std::nullopt; if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) { auto intAttr = mlir::dyn_cast((*getGangDim())[*pos]); return intAttr.getInt(); } return std::nullopt; } //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// LogicalResult acc::InitOp::verify() { Operation *currOp = *this; while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) return emitOpError("cannot be nested in a compute operation"); return success(); } //===----------------------------------------------------------------------===// // ShutdownOp //===----------------------------------------------------------------------===// LogicalResult acc::ShutdownOp::verify() { Operation *currOp = *this; while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) return emitOpError("cannot be nested in a compute operation"); return success(); } //===----------------------------------------------------------------------===// // SetOp //===----------------------------------------------------------------------===// LogicalResult acc::SetOp::verify() { Operation *currOp = *this; while ((currOp = currOp->getParentOp())) if (isComputeOperation(currOp)) return emitOpError("cannot be nested in a compute operation"); if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum()) return emitOpError("at least one default_async, device_num, or device_type " "operand must appear"); return success(); } //===----------------------------------------------------------------------===// // UpdateOp //===----------------------------------------------------------------------===// LogicalResult acc::UpdateOp::verify() { // At least one of host or device should have a value. if (getDataClauseOperands().empty()) return emitError("at least one value must be present in dataOperands"); // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. if (getAsyncOperand() && getAsync()) return emitError("async attribute cannot appear with asyncOperand"); // The wait attribute represent the wait clause without values. Therefore the // attribute and operands cannot appear at the same time. if (!getWaitOperands().empty() && getWait()) return emitError("wait attribute cannot appear with waitOperands"); if (getWaitDevnum() && getWaitOperands().empty()) return emitError("wait_devnum cannot appear without waitOperands"); for (mlir::Value operand : getDataClauseOperands()) if (!mlir::isa( operand.getDefiningOp())) return emitError("expect data entry/exit operation or acc.getdeviceptr " "as defining op"); return success(); } unsigned UpdateOp::getNumDataOperands() { return getDataClauseOperands().size(); } Value UpdateOp::getDataOperand(unsigned i) { unsigned numOptional = getAsyncOperand() ? 1 : 0; numOptional += getWaitDevnum() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; return getOperand(getWaitOperands().size() + numOptional + i); } void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add>(context); } //===----------------------------------------------------------------------===// // WaitOp //===----------------------------------------------------------------------===// LogicalResult acc::WaitOp::verify() { // The async attribute represent the async clause without value. Therefore the // attribute and operand cannot appear at the same time. if (getAsyncOperand() && getAsync()) return emitError("async attribute cannot appear with asyncOperand"); if (getWaitDevnum() && getWaitOperands().empty()) return emitError("wait_devnum cannot appear without waitOperands"); return success(); } #define GET_OP_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" //===----------------------------------------------------------------------===// // acc dialect utilities //===----------------------------------------------------------------------===// mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { auto varPtr{llvm::TypeSwitch(accDataClauseOp) .Case( [&](auto entry) { return entry.getVarPtr(); }) .Case( [&](auto exit) { return exit.getVarPtr(); }) .Default([&](mlir::Operation *) { return mlir::Value(); })}; return varPtr; } mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { auto accPtr{llvm::TypeSwitch(accDataClauseOp) .Case( [&](auto dataClause) { return dataClause.getAccPtr(); }) .Default([&](mlir::Operation *) { return mlir::Value(); })}; return accPtr; } mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) { auto varPtrPtr{ llvm::TypeSwitch(accDataClauseOp) .Case( [&](auto dataClause) { return dataClause.getVarPtrPtr(); }) .Default([&](mlir::Operation *) { return mlir::Value(); })}; return varPtrPtr; } mlir::SmallVector mlir::acc::getBounds(mlir::Operation *accDataClauseOp) { mlir::SmallVector bounds{ llvm::TypeSwitch>( accDataClauseOp) .Case([&](auto dataClause) { return mlir::SmallVector( dataClause.getBounds().begin(), dataClause.getBounds().end()); }) .Default([&](mlir::Operation *) { return mlir::SmallVector(); })}; return bounds; } std::optional mlir::acc::getVarName(mlir::Operation *accOp) { auto name{ llvm::TypeSwitch>(accOp) .Case([&](auto entry) { return entry.getName(); }) .Default([&](mlir::Operation *) -> std::optional { return {}; })}; return name; } std::optional mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) { auto dataClause{ llvm::TypeSwitch>( accDataEntryOp) .Case( [&](auto entry) { return entry.getDataClause(); }) .Default([&](mlir::Operation *) { return std::nullopt; })}; return dataClause; } bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) { auto implicit{llvm::TypeSwitch(accDataEntryOp) .Case( [&](auto entry) { return entry.getImplicit(); }) .Default([&](mlir::Operation *) { return false; })}; return implicit; } mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) { auto dataOperands{ llvm::TypeSwitch(accOp) .Case( [&](auto entry) { return entry.getDataClauseOperands(); }) .Default([&](mlir::Operation *) { return mlir::ValueRange(); })}; return dataOperands; } mlir::MutableOperandRange mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { auto dataOperands{ llvm::TypeSwitch(accOp) .Case( [&](auto entry) { return entry.getDataClauseOperandsMutable(); }) .Default([&](mlir::Operation *) { return nullptr; })}; return dataOperands; }