785 lines
28 KiB
C++
785 lines
28 KiB
C++
//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Defines the memory operations in the SPIR-V dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
|
|
#include "SPIRVOpUtils.h"
|
|
#include "SPIRVParsingUtils.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/Casting.h"
|
|
|
|
using namespace mlir::spirv::AttrNames;
|
|
|
|
namespace mlir::spirv {
|
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
|
// not detailed enough to understand what is happening.
|
|
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
|
|
OperationState &state) {
|
|
// Parse an optional list of attributes staring with '['
|
|
if (parser.parseOptionalLSquare()) {
|
|
// Nothing to do
|
|
return success();
|
|
}
|
|
|
|
spirv::MemoryAccess memoryAccessAttr;
|
|
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
|
|
memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
|
|
return failure();
|
|
|
|
if (spirv::bitEnumContainsAll(memoryAccessAttr,
|
|
spirv::MemoryAccess::Aligned)) {
|
|
// Parse integer attribute for alignment.
|
|
Attribute alignmentAttr;
|
|
Type i32Type = parser.getBuilder().getIntegerType(32);
|
|
if (parser.parseComma() ||
|
|
parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
|
|
state.attributes)) {
|
|
return failure();
|
|
}
|
|
}
|
|
return parser.parseRSquare();
|
|
}
|
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
|
// not detailed enough to understand what is happening.
|
|
template <typename MemoryOpTy>
|
|
static void printSourceMemoryAccessAttribute(
|
|
MemoryOpTy memoryOp, OpAsmPrinter &printer,
|
|
SmallVectorImpl<StringRef> &elidedAttrs,
|
|
std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
|
|
std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
|
|
|
|
printer << ", ";
|
|
|
|
// Print optional memory access attribute.
|
|
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
|
|
: memoryOp.getMemoryAccess())) {
|
|
elidedAttrs.push_back(kSourceMemoryAccessAttrName);
|
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
|
|
|
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
|
|
// Print integer alignment attribute.
|
|
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
|
: memoryOp.getAlignment())) {
|
|
elidedAttrs.push_back(kSourceAlignmentAttrName);
|
|
printer << ", " << *alignment;
|
|
}
|
|
}
|
|
printer << "]";
|
|
}
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
|
}
|
|
|
|
template <typename MemoryOpTy>
|
|
static void printMemoryAccessAttribute(
|
|
MemoryOpTy memoryOp, OpAsmPrinter &printer,
|
|
SmallVectorImpl<StringRef> &elidedAttrs,
|
|
std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
|
|
std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
|
|
// Print optional memory access attribute.
|
|
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
|
|
: memoryOp.getMemoryAccess())) {
|
|
elidedAttrs.push_back(kMemoryAccessAttrName);
|
|
|
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
|
|
|
|
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
|
|
// Print integer alignment attribute.
|
|
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
|
|
: memoryOp.getAlignment())) {
|
|
elidedAttrs.push_back(kAlignmentAttrName);
|
|
printer << ", " << *alignment;
|
|
}
|
|
}
|
|
printer << "]";
|
|
}
|
|
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
|
|
}
|
|
|
|
template <typename LoadStoreOpTy>
|
|
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
|
|
Value val) {
|
|
// ODS already checks ptr is spirv::PointerType. Just check that the pointee
|
|
// type of the pointer and the type of the value are the same
|
|
//
|
|
// TODO: Check that the value type satisfies restrictions of
|
|
// SPIR-V OpLoad/OpStore operations
|
|
if (val.getType() !=
|
|
llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
|
|
return op.emitOpError("mismatch in result type and pointer type");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename MemoryOpTy>
|
|
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
|
// ODS checks for attributes values. Just need to verify that if the
|
|
// memory-access attribute is Aligned, then the alignment attribute must be
|
|
// present.
|
|
auto *op = memoryOp.getOperation();
|
|
auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
|
|
if (!memAccessAttr) {
|
|
// Alignment attribute shouldn't be present if memory access attribute is
|
|
// not present.
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
|
return memoryOp.emitOpError(
|
|
"invalid alignment specification without aligned memory access "
|
|
"specification");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
|
|
|
|
if (!memAccess) {
|
|
return memoryOp.emitOpError("invalid memory access specifier: ")
|
|
<< memAccessAttr;
|
|
}
|
|
|
|
if (spirv::bitEnumContainsAll(memAccess.getValue(),
|
|
spirv::MemoryAccess::Aligned)) {
|
|
if (!op->getAttr(kAlignmentAttrName)) {
|
|
return memoryOp.emitOpError("missing alignment value");
|
|
}
|
|
} else {
|
|
if (op->getAttr(kAlignmentAttrName)) {
|
|
return memoryOp.emitOpError(
|
|
"invalid alignment specification with non-aligned memory access "
|
|
"specification");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// TODO Make sure to merge this and the previous function into one template
|
|
// parameterized by memory access attribute name and alignment. Doing so now
|
|
// results in VS2017 in producing an internal error (at the call site) that's
|
|
// not detailed enough to understand what is happening.
|
|
template <typename MemoryOpTy>
|
|
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
|
|
// ODS checks for attributes values. Just need to verify that if the
|
|
// memory-access attribute is Aligned, then the alignment attribute must be
|
|
// present.
|
|
auto *op = memoryOp.getOperation();
|
|
auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
|
|
if (!memAccessAttr) {
|
|
// Alignment attribute shouldn't be present if memory access attribute is
|
|
// not present.
|
|
if (op->getAttr(kSourceAlignmentAttrName)) {
|
|
return memoryOp.emitOpError(
|
|
"invalid alignment specification without aligned memory access "
|
|
"specification");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
|
|
|
|
if (!memAccess) {
|
|
return memoryOp.emitOpError("invalid memory access specifier: ")
|
|
<< memAccess;
|
|
}
|
|
|
|
if (spirv::bitEnumContainsAll(memAccess.getValue(),
|
|
spirv::MemoryAccess::Aligned)) {
|
|
if (!op->getAttr(kSourceAlignmentAttrName)) {
|
|
return memoryOp.emitOpError("missing alignment value");
|
|
}
|
|
} else {
|
|
if (op->getAttr(kSourceAlignmentAttrName)) {
|
|
return memoryOp.emitOpError(
|
|
"invalid alignment specification with non-aligned memory access "
|
|
"specification");
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.AccessChainOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
|
|
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
|
|
if (!ptrType) {
|
|
emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
|
|
"to composite type, but provided ")
|
|
<< type;
|
|
return nullptr;
|
|
}
|
|
|
|
auto resultType = ptrType.getPointeeType();
|
|
auto resultStorageClass = ptrType.getStorageClass();
|
|
int32_t index = 0;
|
|
|
|
for (auto indexSSA : indices) {
|
|
auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
|
|
if (!cType) {
|
|
emitError(
|
|
baseLoc,
|
|
"'spirv.AccessChain' op cannot extract from non-composite type ")
|
|
<< resultType << " with index " << index;
|
|
return nullptr;
|
|
}
|
|
index = 0;
|
|
if (llvm::isa<spirv::StructType>(resultType)) {
|
|
Operation *op = indexSSA.getDefiningOp();
|
|
if (!op) {
|
|
emitError(baseLoc, "'spirv.AccessChain' op index must be an "
|
|
"integer spirv.Constant to access "
|
|
"element of spirv.struct");
|
|
return nullptr;
|
|
}
|
|
|
|
// TODO: this should be relaxed to allow
|
|
// integer literals of other bitwidths.
|
|
if (failed(spirv::extractValueFromConstOp(op, index))) {
|
|
emitError(
|
|
baseLoc,
|
|
"'spirv.AccessChain' index must be an integer spirv.Constant to "
|
|
"access element of spirv.struct, but provided ")
|
|
<< op->getName();
|
|
return nullptr;
|
|
}
|
|
if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
|
|
emitError(baseLoc, "'spirv.AccessChain' op index ")
|
|
<< index << " out of bounds for " << resultType;
|
|
return nullptr;
|
|
}
|
|
}
|
|
resultType = cType.getElementType(index);
|
|
}
|
|
return spirv::PointerType::get(resultType, resultStorageClass);
|
|
}
|
|
|
|
void AccessChainOp::build(OpBuilder &builder, OperationState &state,
|
|
Value basePtr, ValueRange indices) {
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
build(builder, state, type, basePtr, indices);
|
|
}
|
|
|
|
ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
|
|
Type type;
|
|
auto loc = parser.getCurrentLocation();
|
|
SmallVector<Type, 4> indicesTypes;
|
|
|
|
if (parser.parseOperand(ptrInfo) ||
|
|
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
|
parser.parseColonType(type) ||
|
|
parser.resolveOperand(ptrInfo, type, result.operands)) {
|
|
return failure();
|
|
}
|
|
|
|
// Check that the provided indices list is not empty before parsing their
|
|
// type list.
|
|
if (indicesInfo.empty()) {
|
|
return mlir::emitError(result.location,
|
|
"'spirv.AccessChain' op expected at "
|
|
"least one index ");
|
|
}
|
|
|
|
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
|
|
return failure();
|
|
|
|
// Check that the indices types list is not empty and that it has a one-to-one
|
|
// mapping to the provided indices.
|
|
if (indicesTypes.size() != indicesInfo.size()) {
|
|
return mlir::emitError(
|
|
result.location, "'spirv.AccessChain' op indices types' count must be "
|
|
"equal to indices info count");
|
|
}
|
|
|
|
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
|
|
return failure();
|
|
|
|
auto resultType = getElementPtrType(
|
|
type, llvm::ArrayRef(result.operands).drop_front(), result.location);
|
|
if (!resultType) {
|
|
return failure();
|
|
}
|
|
|
|
result.addTypes(resultType);
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
|
|
printer << ' ' << op.getBasePtr() << '[' << indices
|
|
<< "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
|
|
}
|
|
|
|
void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
|
|
printAccessChain(*this, getIndices(), printer);
|
|
}
|
|
|
|
template <typename Op>
|
|
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
|
|
auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
|
|
indices, accessChainOp.getLoc());
|
|
if (!resultType)
|
|
return failure();
|
|
|
|
auto providedResultType =
|
|
llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
|
|
if (!providedResultType)
|
|
return accessChainOp.emitOpError(
|
|
"result type must be a pointer, but provided")
|
|
<< providedResultType;
|
|
|
|
if (resultType != providedResultType)
|
|
return accessChainOp.emitOpError("invalid result type: expected ")
|
|
<< resultType << ", but provided " << providedResultType;
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult AccessChainOp::verify() {
|
|
return verifyAccessChain(*this, getIndices());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.LoadOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
|
|
MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
|
|
auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
|
|
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
|
|
alignment);
|
|
}
|
|
|
|
ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse the storage class specification
|
|
spirv::StorageClass storageClass;
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
|
Type elementType;
|
|
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
|
|
parseMemoryAccessAttributes(parser, result) ||
|
|
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
|
|
parser.parseType(elementType)) {
|
|
return failure();
|
|
}
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
|
if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
|
|
return failure();
|
|
}
|
|
|
|
result.addTypes(elementType);
|
|
return success();
|
|
}
|
|
|
|
void LoadOp::print(OpAsmPrinter &printer) {
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
|
StringRef sc = stringifyStorageClass(
|
|
llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
|
|
printer << " \"" << sc << "\" " << getPtr();
|
|
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
|
printer << " : " << getType();
|
|
}
|
|
|
|
LogicalResult LoadOp::verify() {
|
|
// SPIR-V spec : "Result Type is the type of the loaded object. It must be a
|
|
// type with fixed size; i.e., it cannot be, nor include, any
|
|
// OpTypeRuntimeArray types."
|
|
if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
|
|
return failure();
|
|
}
|
|
return verifyMemoryAccessAttribute(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.StoreOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse the storage class specification
|
|
spirv::StorageClass storageClass;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
|
auto loc = parser.getCurrentLocation();
|
|
Type elementType;
|
|
if (parseEnumStrAttr(storageClass, parser) ||
|
|
parser.parseOperandList(operandInfo, 2) ||
|
|
parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
|
|
parser.parseType(elementType)) {
|
|
return failure();
|
|
}
|
|
|
|
auto ptrType = spirv::PointerType::get(elementType, storageClass);
|
|
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
|
|
result.operands)) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void StoreOp::print(OpAsmPrinter &printer) {
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
|
StringRef sc = stringifyStorageClass(
|
|
llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
|
|
printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
|
|
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
|
|
|
printer << " : " << getValue().getType();
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
|
}
|
|
|
|
LogicalResult StoreOp::verify() {
|
|
// SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
|
|
// OpTypePointer whose Type operand is the same as the type of Object."
|
|
if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
|
|
return failure();
|
|
return verifyMemoryAccessAttribute(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.CopyMemory
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void CopyMemoryOp::print(OpAsmPrinter &printer) {
|
|
printer << ' ';
|
|
|
|
StringRef targetStorageClass = stringifyStorageClass(
|
|
llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
|
|
printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
|
|
|
|
StringRef sourceStorageClass = stringifyStorageClass(
|
|
llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
|
|
printer << " \"" << sourceStorageClass << "\" " << getSource();
|
|
|
|
SmallVector<StringRef, 4> elidedAttrs;
|
|
printMemoryAccessAttribute(*this, printer, elidedAttrs);
|
|
printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
|
|
getSourceMemoryAccess(),
|
|
getSourceAlignment());
|
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
|
|
|
|
Type pointeeType =
|
|
llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
|
|
printer << " : " << pointeeType;
|
|
}
|
|
|
|
ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
spirv::StorageClass targetStorageClass;
|
|
OpAsmParser::UnresolvedOperand targetPtrInfo;
|
|
|
|
spirv::StorageClass sourceStorageClass;
|
|
OpAsmParser::UnresolvedOperand sourcePtrInfo;
|
|
|
|
Type elementType;
|
|
|
|
if (parseEnumStrAttr(targetStorageClass, parser) ||
|
|
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
|
|
parseEnumStrAttr(sourceStorageClass, parser) ||
|
|
parser.parseOperand(sourcePtrInfo) ||
|
|
parseMemoryAccessAttributes(parser, result)) {
|
|
return failure();
|
|
}
|
|
|
|
if (!parser.parseOptionalComma()) {
|
|
// Parse 2nd memory access attributes.
|
|
if (parseSourceMemoryAccessAttributes(parser, result)) {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
if (parser.parseColon() || parser.parseType(elementType))
|
|
return failure();
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
|
|
auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
|
|
|
|
if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
|
|
parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CopyMemoryOp::verify() {
|
|
Type targetType =
|
|
llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
|
|
|
|
Type sourceType =
|
|
llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
|
|
|
|
if (targetType != sourceType)
|
|
return emitOpError("both operands must be pointers to the same type");
|
|
|
|
if (failed(verifyMemoryAccessAttribute(*this)))
|
|
return failure();
|
|
|
|
// TODO - According to the spec:
|
|
//
|
|
// If two masks are present, the first applies to Target and cannot include
|
|
// MakePointerVisible, and the second applies to Source and cannot include
|
|
// MakePointerAvailable.
|
|
//
|
|
// Add such verification here.
|
|
|
|
return verifySourceMemoryAccessAttribute(*this);
|
|
}
|
|
|
|
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
|
|
OpAsmParser &parser,
|
|
OperationState &state) {
|
|
OpAsmParser::UnresolvedOperand ptrInfo;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
|
|
Type type;
|
|
auto loc = parser.getCurrentLocation();
|
|
SmallVector<Type, 4> indicesTypes;
|
|
|
|
if (parser.parseOperand(ptrInfo) ||
|
|
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
|
parser.parseColonType(type) ||
|
|
parser.resolveOperand(ptrInfo, type, state.operands))
|
|
return failure();
|
|
|
|
// Check that the provided indices list is not empty before parsing their
|
|
// type list.
|
|
if (indicesInfo.empty())
|
|
return emitError(state.location) << opName << " expected element";
|
|
|
|
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
|
|
return failure();
|
|
|
|
// Check that the indices types list is not empty and that it has a one-to-one
|
|
// mapping to the provided indices.
|
|
if (indicesTypes.size() != indicesInfo.size())
|
|
return emitError(state.location)
|
|
<< opName
|
|
<< " indices types' count must be equal to indices info count";
|
|
|
|
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
|
|
return failure();
|
|
|
|
auto resultType = getElementPtrType(
|
|
type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
|
|
if (!resultType)
|
|
return failure();
|
|
|
|
state.addTypes(resultType);
|
|
return success();
|
|
}
|
|
|
|
template <typename Op>
|
|
static auto concatElemAndIndices(Op op) {
|
|
SmallVector<Value> ret(op.getIndices().size() + 1);
|
|
ret[0] = op.getElement();
|
|
llvm::copy(op.getIndices(), ret.begin() + 1);
|
|
return ret;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.InBoundsPtrAccessChainOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
|
|
Value basePtr, Value element,
|
|
ValueRange indices) {
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
build(builder, state, type, basePtr, element, indices);
|
|
}
|
|
|
|
ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return parsePtrAccessChainOpImpl(
|
|
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
|
|
}
|
|
|
|
void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
|
|
printAccessChain(*this, concatElemAndIndices(*this), printer);
|
|
}
|
|
|
|
LogicalResult InBoundsPtrAccessChainOp::verify() {
|
|
return verifyAccessChain(*this, getIndices());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.PtrAccessChainOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
|
|
Value basePtr, Value element, ValueRange indices) {
|
|
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
|
assert(type && "Unable to deduce return type based on basePtr and indices");
|
|
build(builder, state, type, basePtr, element, indices);
|
|
}
|
|
|
|
ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
|
|
parser, result);
|
|
}
|
|
|
|
void PtrAccessChainOp::print(OpAsmPrinter &printer) {
|
|
printAccessChain(*this, concatElemAndIndices(*this), printer);
|
|
}
|
|
|
|
LogicalResult PtrAccessChainOp::verify() {
|
|
return verifyAccessChain(*this, getIndices());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// spirv.Variable
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Parse optional initializer
|
|
std::optional<OpAsmParser::UnresolvedOperand> initInfo;
|
|
if (succeeded(parser.parseOptionalKeyword("init"))) {
|
|
initInfo = OpAsmParser::UnresolvedOperand();
|
|
if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
|
|
parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parseVariableDecorations(parser, result)) {
|
|
return failure();
|
|
}
|
|
|
|
// Parse result pointer type
|
|
Type type;
|
|
if (parser.parseColon())
|
|
return failure();
|
|
auto loc = parser.getCurrentLocation();
|
|
if (parser.parseType(type))
|
|
return failure();
|
|
|
|
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
|
|
if (!ptrType)
|
|
return parser.emitError(loc, "expected spirv.ptr type");
|
|
result.addTypes(ptrType);
|
|
|
|
// Resolve the initializer operand
|
|
if (initInfo) {
|
|
if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
|
|
result.operands))
|
|
return failure();
|
|
}
|
|
|
|
auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
|
|
ptrType.getStorageClass());
|
|
result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
|
|
|
|
return success();
|
|
}
|
|
|
|
void VariableOp::print(OpAsmPrinter &printer) {
|
|
SmallVector<StringRef, 4> elidedAttrs{
|
|
spirv::attributeName<spirv::StorageClass>()};
|
|
// Print optional initializer
|
|
if (getNumOperands() != 0)
|
|
printer << " init(" << getInitializer() << ")";
|
|
|
|
printVariableDecorations(*this, printer, elidedAttrs);
|
|
printer << " : " << getType();
|
|
}
|
|
|
|
LogicalResult VariableOp::verify() {
|
|
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
|
// object. It cannot be Generic. It must be the same as the Storage Class
|
|
// operand of the Result Type."
|
|
if (getStorageClass() != spirv::StorageClass::Function) {
|
|
return emitOpError(
|
|
"can only be used to model function-level variables. Use "
|
|
"spirv.GlobalVariable for module-level variables.");
|
|
}
|
|
|
|
auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
|
if (getStorageClass() != pointerType.getStorageClass())
|
|
return emitOpError(
|
|
"storage class must match result pointer's storage class");
|
|
|
|
if (getNumOperands() != 0) {
|
|
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
|
// a global (module scope) OpVariable instruction".
|
|
auto *initOp = getOperand(0).getDefiningOp();
|
|
if (!initOp || !isa<spirv::ConstantOp, // for normal constant
|
|
spirv::ReferenceOfOp, // for spec constant
|
|
spirv::AddressOfOp>(initOp))
|
|
return emitOpError("initializer must be the result of a "
|
|
"constant or spirv.GlobalVariable op");
|
|
}
|
|
|
|
auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
|
|
return op->getAttr(
|
|
llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
|
|
};
|
|
|
|
// TODO: generate these strings using ODS.
|
|
for (auto decoration :
|
|
{spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
|
|
spirv::Decoration::BuiltIn}) {
|
|
if (auto attr = getDecorationAttr(decoration))
|
|
return emitOpError("cannot have '")
|
|
<< llvm::convertToSnakeFromCamelCase(
|
|
stringifyDecoration(decoration))
|
|
<< "' attribute (only allowed in spirv.GlobalVariable)";
|
|
}
|
|
|
|
// From SPV_KHR_physical_storage_buffer:
|
|
// > If an OpVariable's pointee type is a pointer (or array of pointers) in
|
|
// > PhysicalStorageBuffer storage class, then the variable must be decorated
|
|
// > with exactly one of AliasedPointer or RestrictPointer.
|
|
auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
|
|
if (!pointeePtrType) {
|
|
if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
|
|
pointeePtrType =
|
|
dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
|
|
}
|
|
}
|
|
|
|
if (pointeePtrType && pointeePtrType.getStorageClass() ==
|
|
spirv::StorageClass::PhysicalStorageBuffer) {
|
|
bool hasAliasedPtr =
|
|
getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
|
|
bool hasRestrictPtr =
|
|
getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
|
|
|
|
if (!hasAliasedPtr && !hasRestrictPtr)
|
|
return emitOpError() << " with physical buffer pointer must be decorated "
|
|
"either 'AliasedPointer' or 'RestrictPointer'";
|
|
|
|
if (hasAliasedPtr && hasRestrictPtr)
|
|
return emitOpError()
|
|
<< " with physical buffer pointer must have exactly one "
|
|
"aliasing decoration";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace mlir::spirv
|