//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Defines the memory operations in the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "SPIRVOpUtils.h" #include "SPIRVParsingUtils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Casting.h" using namespace mlir::spirv::AttrNames; namespace mlir::spirv { // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' if (parser.parseOptionalLSquare()) { // Nothing to do return success(); } spirv::MemoryAccess memoryAccessAttr; if (spirv::parseEnumStrAttr( memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName)) return failure(); if (spirv::bitEnumContainsAll(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseComma() || parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, state.attributes)) { return failure(); } } return parser.parseRSquare(); } // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. template static void printSourceMemoryAccessAttribute( MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs, std::optional memoryAccessAtrrValue = std::nullopt, std::optional alignmentAttrValue = std::nullopt) { printer << ", "; // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue : memoryOp.getMemoryAccess())) { elidedAttrs.push_back(kSourceMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.getAlignment())) { elidedAttrs.push_back(kSourceAlignmentAttrName); printer << ", " << *alignment; } } printer << "]"; } elidedAttrs.push_back(spirv::attributeName()); } template static void printMemoryAccessAttribute( MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs, std::optional memoryAccessAtrrValue = std::nullopt, std::optional alignmentAttrValue = std::nullopt) { // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue : memoryOp.getMemoryAccess())) { elidedAttrs.push_back(kMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.getAlignment())) { elidedAttrs.push_back(kAlignmentAttrName); printer << ", " << *alignment; } } printer << "]"; } elidedAttrs.push_back(spirv::attributeName()); } template static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val) { // ODS already checks ptr is spirv::PointerType. Just check that the pointee // type of the pointer and the type of the value are the same // // TODO: Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations if (val.getType() != llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); } template static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); } return success(); } auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") << memAccessAttr; } if (spirv::bitEnumContainsAll(memAccess.getValue(), spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); } } return success(); } // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. template static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. if (op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); } return success(); } auto memAccess = llvm::cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") << memAccess; } if (spirv::bitEnumContainsAll(memAccess.getValue(), spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { if (op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); } } return success(); } //===----------------------------------------------------------------------===// // spirv.AccessChainOp //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { auto ptrType = llvm::dyn_cast(type); if (!ptrType) { emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " "to composite type, but provided ") << type; return nullptr; } auto resultType = ptrType.getPointeeType(); auto resultStorageClass = ptrType.getStorageClass(); int32_t index = 0; for (auto indexSSA : indices) { auto cType = llvm::dyn_cast(resultType); if (!cType) { emitError( baseLoc, "'spirv.AccessChain' op cannot extract from non-composite type ") << resultType << " with index " << index; return nullptr; } index = 0; if (llvm::isa(resultType)) { Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spirv.AccessChain' op index must be an " "integer spirv.Constant to access " "element of spirv.struct"); return nullptr; } // TODO: this should be relaxed to allow // integer literals of other bitwidths. if (failed(spirv::extractValueFromConstOp(op, index))) { emitError( baseLoc, "'spirv.AccessChain' index must be an integer spirv.Constant to " "access element of spirv.struct, but provided ") << op->getName(); return nullptr; } if (index < 0 || static_cast(index) >= cType.getNumElements()) { emitError(baseLoc, "'spirv.AccessChain' op index ") << index << " out of bounds for " << resultType; return nullptr; } } resultType = cType.getElementType(index); } return spirv::PointerType::get(resultType, resultStorageClass); } void AccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); } ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand ptrInfo; SmallVector indicesInfo; Type type; auto loc = parser.getCurrentLocation(); SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || parser.resolveOperand(ptrInfo, type, result.operands)) { return failure(); } // Check that the provided indices list is not empty before parsing their // type list. if (indicesInfo.empty()) { return mlir::emitError(result.location, "'spirv.AccessChain' op expected at " "least one index "); } if (parser.parseComma() || parser.parseTypeList(indicesTypes)) return failure(); // Check that the indices types list is not empty and that it has a one-to-one // mapping to the provided indices. if (indicesTypes.size() != indicesInfo.size()) { return mlir::emitError( result.location, "'spirv.AccessChain' op indices types' count must be " "equal to indices info count"); } if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands)) return failure(); auto resultType = getElementPtrType( type, llvm::ArrayRef(result.operands).drop_front(), result.location); if (!resultType) { return failure(); } result.addTypes(resultType); return success(); } template static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { printer << ' ' << op.getBasePtr() << '[' << indices << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); } void spirv::AccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, getIndices(), printer); } template static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), indices, accessChainOp.getLoc()); if (!resultType) return failure(); auto providedResultType = llvm::dyn_cast(accessChainOp.getType()); if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") << providedResultType; if (resultType != providedResultType) return accessChainOp.emitOpError("invalid result type: expected ") << resultType << ", but provided " << providedResultType; return success(); } LogicalResult AccessChainOp::verify() { return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// // spirv.LoadOp //===----------------------------------------------------------------------===// void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { auto ptrType = llvm::cast(basePtr.getType()); build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::UnresolvedOperand ptrInfo; Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || parseMemoryAccessAttributes(parser, result) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { return failure(); } result.addTypes(elementType); return success(); } void LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << " : " << getType(); } LogicalResult LoadOp::verify() { // SPIR-V spec : "Result Type is the type of the loaded object. It must be a // type with fixed size; i.e., it cannot be, nor include, any // OpTypeRuntimeArray types." if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { return failure(); } return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spirv.StoreOp //===----------------------------------------------------------------------===// ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; auto loc = parser.getCurrentLocation(); Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperandList(operandInfo, 2) || parseMemoryAccessAttributes(parser, result) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, result.operands)) { return failure(); } return success(); } void StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( llvm::cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); printMemoryAccessAttribute(*this, printer, elidedAttrs); printer << " : " << getValue().getType(); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); } LogicalResult StoreOp::verify() { // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an // OpTypePointer whose Type operand is the same as the type of Object." if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spirv.CopyMemory //===----------------------------------------------------------------------===// void CopyMemoryOp::print(OpAsmPrinter &printer) { printer << ' '; StringRef targetStorageClass = stringifyStorageClass( llvm::cast(getTarget().getType()).getStorageClass()); printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; StringRef sourceStorageClass = stringifyStorageClass( llvm::cast(getSource().getType()).getStorageClass()); printer << " \"" << sourceStorageClass << "\" " << getSource(); SmallVector elidedAttrs; printMemoryAccessAttribute(*this, printer, elidedAttrs); printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, getSourceMemoryAccess(), getSourceAlignment()); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = llvm::cast(getTarget().getType()).getPointeeType(); printer << " : " << pointeeType; } ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) { spirv::StorageClass targetStorageClass; OpAsmParser::UnresolvedOperand targetPtrInfo; spirv::StorageClass sourceStorageClass; OpAsmParser::UnresolvedOperand sourcePtrInfo; Type elementType; if (parseEnumStrAttr(targetStorageClass, parser) || parser.parseOperand(targetPtrInfo) || parser.parseComma() || parseEnumStrAttr(sourceStorageClass, parser) || parser.parseOperand(sourcePtrInfo) || parseMemoryAccessAttributes(parser, result)) { return failure(); } if (!parser.parseOptionalComma()) { // Parse 2nd memory access attributes. if (parseSourceMemoryAccessAttributes(parser, result)) { return failure(); } } if (parser.parseColon() || parser.parseType(elementType)) return failure(); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { return failure(); } return success(); } LogicalResult CopyMemoryOp::verify() { Type targetType = llvm::cast(getTarget().getType()).getPointeeType(); Type sourceType = llvm::cast(getSource().getType()).getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); if (failed(verifyMemoryAccessAttribute(*this))) return failure(); // TODO - According to the spec: // // If two masks are present, the first applies to Target and cannot include // MakePointerVisible, and the second applies to Source and cannot include // MakePointerAvailable. // // Add such verification here. return verifySourceMemoryAccessAttribute(*this); } static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state) { OpAsmParser::UnresolvedOperand ptrInfo; SmallVector indicesInfo; Type type; auto loc = parser.getCurrentLocation(); SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || parser.resolveOperand(ptrInfo, type, state.operands)) return failure(); // Check that the provided indices list is not empty before parsing their // type list. if (indicesInfo.empty()) return emitError(state.location) << opName << " expected element"; if (parser.parseComma() || parser.parseTypeList(indicesTypes)) return failure(); // Check that the indices types list is not empty and that it has a one-to-one // mapping to the provided indices. if (indicesTypes.size() != indicesInfo.size()) return emitError(state.location) << opName << " indices types' count must be equal to indices info count"; if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) return failure(); auto resultType = getElementPtrType( type, llvm::ArrayRef(state.operands).drop_front(2), state.location); if (!resultType) return failure(); state.addTypes(resultType); return success(); } template static auto concatElemAndIndices(Op op) { SmallVector ret(op.getIndices().size() + 1); ret[0] = op.getElement(); llvm::copy(op.getIndices(), ret.begin() + 1); return ret; } //===----------------------------------------------------------------------===// // spirv.InBoundsPtrAccessChainOp //===----------------------------------------------------------------------===// void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, Value element, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, element, indices); } ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, OperationState &result) { return parsePtrAccessChainOpImpl( spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result); } void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult InBoundsPtrAccessChainOp::verify() { return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// // spirv.PtrAccessChainOp //===----------------------------------------------------------------------===// void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, Value element, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, element, indices); } ParseResult PtrAccessChainOp::parse(OpAsmParser &parser, OperationState &result) { return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), parser, result); } void PtrAccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult PtrAccessChainOp::verify() { return verifyAccessChain(*this, getIndices()); } //===----------------------------------------------------------------------===// // spirv.Variable //===----------------------------------------------------------------------===// ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) { // Parse optional initializer std::optional initInfo; if (succeeded(parser.parseOptionalKeyword("init"))) { initInfo = OpAsmParser::UnresolvedOperand(); if (parser.parseLParen() || parser.parseOperand(*initInfo) || parser.parseRParen()) return failure(); } if (parseVariableDecorations(parser, result)) { return failure(); } // Parse result pointer type Type type; if (parser.parseColon()) return failure(); auto loc = parser.getCurrentLocation(); if (parser.parseType(type)) return failure(); auto ptrType = llvm::dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected spirv.ptr type"); result.addTypes(ptrType); // Resolve the initializer operand if (initInfo) { if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), result.operands)) return failure(); } auto attr = parser.getBuilder().getAttr( ptrType.getStorageClass()); result.addAttribute(spirv::attributeName(), attr); return success(); } void VariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print optional initializer if (getNumOperands() != 0) printer << " init(" << getInitializer() << ")"; printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << getType(); } LogicalResult VariableOp::verify() { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." if (getStorageClass() != spirv::StorageClass::Function) { return emitOpError( "can only be used to model function-level variables. Use " "spirv.GlobalVariable for module-level variables."); } auto pointerType = llvm::cast(getPointer().getType()); if (getStorageClass() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); if (getNumOperands() != 0) { // SPIR-V spec: "Initializer must be an from a constant instruction or // a global (module scope) OpVariable instruction". auto *initOp = getOperand(0).getDefiningOp(); if (!initOp || !isa(initOp)) return emitOpError("initializer must be the result of a " "constant or spirv.GlobalVariable op"); } auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) { return op->getAttr( llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration))); }; // TODO: generate these strings using ODS. for (auto decoration : {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding, spirv::Decoration::BuiltIn}) { if (auto attr = getDecorationAttr(decoration)) return emitOpError("cannot have '") << llvm::convertToSnakeFromCamelCase( stringifyDecoration(decoration)) << "' attribute (only allowed in spirv.GlobalVariable)"; } // From SPV_KHR_physical_storage_buffer: // > If an OpVariable's pointee type is a pointer (or array of pointers) in // > PhysicalStorageBuffer storage class, then the variable must be decorated // > with exactly one of AliasedPointer or RestrictPointer. auto pointeePtrType = dyn_cast(getPointeeType()); if (!pointeePtrType) { if (auto pointeeArrayType = dyn_cast(getPointeeType())) { pointeePtrType = dyn_cast(pointeeArrayType.getElementType()); } } if (pointeePtrType && pointeePtrType.getStorageClass() == spirv::StorageClass::PhysicalStorageBuffer) { bool hasAliasedPtr = getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr; bool hasRestrictPtr = getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr; if (!hasAliasedPtr && !hasRestrictPtr) return emitOpError() << " with physical buffer pointer must be decorated " "either 'AliasedPointer' or 'RestrictPointer'"; if (hasAliasedPtr && hasRestrictPtr) return emitOpError() << " with physical buffer pointer must have exactly one " "aliasing decoration"; } return success(); } } // namespace mlir::spirv