bolt/deps/llvm-18.1.8/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
2025-02-14 19:21:04 +01:00

2111 lines
78 KiB
C++

//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <numeric>
#include <optional>
#include <type_traits>
using namespace mlir;
using namespace mlir::spirv::AttrNames;
//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//
LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
if (!constOp) {
return failure();
}
auto valueAttr = constOp.getValue();
auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
if (!integerValueAttr) {
return failure();
}
if (integerValueAttr.getType().isSignlessInteger())
value = integerValueAttr.getInt();
else
value = integerValueAttr.getSInt();
return success();
}
LogicalResult
spirv::verifyMemorySemantics(Operation *op,
spirv::MemorySemantics memorySemantics) {
// According to the SPIR-V specification:
// "Despite being a mask and allowing multiple bits to be combined, it is
// invalid for more than one of these four bits to be set: Acquire, Release,
// AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
// Release semantics is done by setting the AcquireRelease bit, not by setting
// two bits."
auto atMostOneInSet = spirv::MemorySemantics::Acquire |
spirv::MemorySemantics::Release |
spirv::MemorySemantics::AcquireRelease |
spirv::MemorySemantics::SequentiallyConsistent;
auto bitCount =
llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
if (bitCount > 1) {
return op->emitError(
"expected at most one of these four memory constraints "
"to be set: `Acquire`, `Release`,"
"`AcquireRelease` or `SequentiallyConsistent`");
}
return success();
}
void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
SmallVectorImpl<StringRef> &elidedAttrs) {
// Print optional descriptor binding
auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName = llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::Binding));
auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
if (descriptorSet && binding) {
elidedAttrs.push_back(descriptorSetName);
elidedAttrs.push_back(bindingName);
printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
<< ")";
}
// Print BuiltIn attribute if present
auto builtInName = llvm::convertToSnakeFromCamelCase(
stringifyDecoration(spirv::Decoration::BuiltIn));
if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
elidedAttrs.push_back(builtInName);
}
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
Type type;
// If the operand list is in-between parentheses, then we have a generic form.
// (see the fallback in `printOneResultOp`).
SMLoc loc = parser.getCurrentLocation();
if (!parser.parseOptionalLParen()) {
if (parser.parseOperandList(ops) || parser.parseRParen() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(type))
return failure();
auto fnType = llvm::dyn_cast<FunctionType>(type);
if (!fnType) {
parser.emitError(loc, "expected function type");
return failure();
}
if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
return failure();
result.addTypes(fnType.getResults());
return success();
}
return failure(parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands) ||
parser.addTypeToList(type, result.types));
}
static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
assert(op->getNumResults() == 1 && "op should have one result");
// If not all the operand and result types are the same, just use the
// generic assembly form to avoid omitting information in printing.
auto resultType = op->getResult(0).getType();
if (llvm::any_of(op->getOperandTypes(),
[&](Type type) { return type != resultType; })) {
p.printGenericOp(op, /*printOpName=*/false);
return;
}
p << ' ';
p.printOperands(op->getOperands());
p.printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
p << " : " << resultType;
}
template <typename Op>
static LogicalResult verifyImageOperands(Op imageOp,
spirv::ImageOperandsAttr attr,
Operation::operand_range operands) {
if (!attr) {
if (operands.empty())
return success();
return imageOp.emitError("the Image Operands should encode what operands "
"follow, as per Image Operands");
}
// TODO: Add the validation rules for the following Image Operands.
spirv::ImageOperands noSupportOperands =
spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
spirv::ImageOperands::MakeTexelAvailable |
spirv::ImageOperands::MakeTexelVisible |
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
llvm_unreachable("unimplemented operands of Image Operands");
return success();
}
template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
Value ptr, Value val) {
auto valType = val.getType();
if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
valType = valVecTy.getElementType();
if (valType !=
llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
}
/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
static Type
getElementType(Type type, ArrayRef<int32_t> indices,
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
if (indices.empty()) {
emitErrorFn("expected at least one index for spirv.CompositeExtract");
return nullptr;
}
for (auto index : indices) {
if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
if (cType.hasCompileTimeKnownNumElements() &&
(index < 0 ||
static_cast<uint64_t>(index) >= cType.getNumElements())) {
emitErrorFn("index ") << index << " out of bounds for " << type;
return nullptr;
}
type = cType.getElementType(index);
} else {
emitErrorFn("cannot extract from non-composite type ")
<< type << " with index " << index;
return nullptr;
}
}
return type;
}
static Type
getElementType(Type type, Attribute indices,
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
if (!indicesArrayAttr) {
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
return nullptr;
}
if (indicesArrayAttr.empty()) {
emitErrorFn("expected at least one index for spirv.CompositeExtract");
return nullptr;
}
SmallVector<int32_t, 2> indexVals;
for (auto indexAttr : indicesArrayAttr) {
auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
if (!indexIntAttr) {
emitErrorFn("expected an 32-bit integer for index, but found '")
<< indexAttr << "'";
return nullptr;
}
indexVals.push_back(indexIntAttr.getInt());
}
return getElementType(type, indexVals, emitErrorFn);
}
static Type getElementType(Type type, Attribute indices, Location loc) {
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
return ::mlir::emitError(loc, err);
};
return getElementType(type, indices, errorFn);
}
static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
SMLoc loc) {
auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
return parser.emitError(loc, err);
};
return getElementType(type, indices, errorFn);
}
template <typename ExtendedBinaryOp>
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
auto resultType = llvm::cast<spirv::StructType>(op.getType());
if (resultType.getNumElements() != 2)
return op.emitOpError("expected result struct type containing two members");
if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
return op.emitOpError(
"expected all operand types and struct member types are the same");
return success();
}
static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseOperandList(operands) || parser.parseColon())
return failure();
Type resultType;
SMLoc loc = parser.getCurrentLocation();
if (parser.parseType(resultType))
return failure();
auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
if (!structType || structType.getNumElements() != 2)
return parser.emitError(loc, "expected spirv.struct type with two members");
SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
return failure();
result.addTypes(resultType);
return success();
}
static void printArithmeticExtendedBinaryOp(Operation *op,
OpAsmPrinter &printer) {
printer << ' ';
printer.printOptionalAttrDict(op->getAttrs());
printer.printOperands(op->getOperands());
printer << " : " << op->getResultTypes().front();
}
static LogicalResult verifyShiftOp(Operation *op) {
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
return op->emitError("expected the same type for the first operand and "
"result, but provided ")
<< op->getOperand(0).getType() << " and "
<< op->getResult(0).getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.mlir.addressof
//===----------------------------------------------------------------------===//
void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
spirv::GlobalVariableOp var) {
build(builder, state, var.getType(), SymbolRefAttr::get(var));
}
LogicalResult spirv::AddressOfOp::verify() {
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
getVariableAttr()));
if (!varOp) {
return emitOpError("expected spirv.GlobalVariable symbol");
}
if (getPointer().getType() != varOp.getType()) {
return emitOpError(
"result type mismatch with the referenced global variable's type");
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.CompositeConstruct
//===----------------------------------------------------------------------===//
LogicalResult spirv::CompositeConstructOp::verify() {
operand_range constituents = this->getConstituents();
// There are 4 cases with varying verification rules:
// 1. Cooperative Matrices (1 constituent)
// 2. Structs (1 constituent for each member)
// 3. Arrays (1 constituent for each array element)
// 4. Vectors (1 constituent (sub-)element for each vector element)
auto coopElementType =
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
[](auto coopType) { return coopType.getElementType(); })
.Default([](Type) { return nullptr; });
// Case 1. -- matrices.
if (coopElementType) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
if (coopElementType != constituents.front().getType())
return emitOpError("operand type mismatch: expected operand type ")
<< coopElementType << ", but provided "
<< constituents.front().getType();
return success();
}
// Case 2./3./4. -- number of constituents matches the number of elements.
auto cType = llvm::cast<spirv::CompositeType>(getType());
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
return emitOpError("operand type mismatch: expected operand type ")
<< cType.getElementType(index) << ", but provided "
<< constituents[index].getType();
}
}
return success();
}
// Case 4. -- check that all constituents add up tp the expected vector type.
auto resultType = llvm::dyn_cast<VectorType>(cType);
if (!resultType)
return emitOpError(
"expected to return a vector or cooperative matrix when the number of "
"constituents is less than what the result needs");
SmallVector<unsigned> sizes;
for (Value component : constituents) {
if (!llvm::isa<VectorType>(component.getType()) &&
!component.getType().isIntOrFloat())
return emitOpError("operand type mismatch: expected operand to have "
"a scalar or vector type, but provided ")
<< component.getType();
Type elementType = component.getType();
if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
sizes.push_back(vectorType.getNumElements());
elementType = vectorType.getElementType();
} else {
sizes.push_back(1);
}
if (elementType != resultType.getElementType())
return emitOpError("operand element type mismatch: expected to be ")
<< resultType.getElementType() << ", but provided " << elementType;
}
unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
if (totalCount != cType.getNumElements())
return emitOpError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided " << totalCount;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//
void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
Value composite,
ArrayRef<int32_t> indices) {
auto indexAttr = builder.getI32ArrayAttr(indices);
auto elementType =
getElementType(composite.getType(), indexAttr, state.location);
if (!elementType) {
return;
}
build(builder, state, elementType, composite, indexAttr);
}
ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
Type compositeType;
SMLoc attrLocation;
if (parser.parseOperand(compositeInfo) ||
parser.getCurrentLocation(&attrLocation) ||
parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
parser.parseColonType(compositeType) ||
parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
return failure();
}
Type resultType =
getElementType(compositeType, indicesAttr, parser, attrLocation);
if (!resultType) {
return failure();
}
result.addTypes(resultType);
return success();
}
void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
printer << ' ' << getComposite() << getIndices() << " : "
<< getComposite().getType();
}
LogicalResult spirv::CompositeExtractOp::verify() {
auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
auto resultType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!resultType)
return failure();
if (resultType != getType()) {
return emitOpError("invalid result type: expected ")
<< resultType << " but provided " << getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.CompositeInsert
//===----------------------------------------------------------------------===//
void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
Value object, Value composite,
ArrayRef<int32_t> indices) {
auto indexAttr = builder.getI32ArrayAttr(indices);
build(builder, state, composite.getType(), object, composite, indexAttr);
}
ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
auto loc = parser.getCurrentLocation();
return failure(
parser.parseOperandList(operands, 2) ||
parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
parser.parseColonType(objectType) ||
parser.parseKeywordType("into", compositeType) ||
parser.resolveOperands(operands, {objectType, compositeType}, loc,
result.operands) ||
parser.addTypesToList(compositeType, result.types));
}
LogicalResult spirv::CompositeInsertOp::verify() {
auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
auto objectType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!objectType)
return failure();
if (objectType != getObject().getType()) {
return emitOpError("object operand type should be ")
<< objectType << ", but found " << getObject().getType();
}
if (getComposite().getType() != getType()) {
return emitOpError("result type should be the same as "
"the composite type, but found ")
<< getComposite().getType() << " vs " << getType();
}
return success();
}
void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
printer << " " << getObject() << ", " << getComposite() << getIndices()
<< " : " << getObject().getType() << " into "
<< getComposite().getType();
}
//===----------------------------------------------------------------------===//
// spirv.Constant
//===----------------------------------------------------------------------===//
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
if (parser.parseAttribute(value, kValueAttrName, result.attributes))
return failure();
Type type = NoneType::get(parser.getContext());
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
type = typedAttr.getType();
if (llvm::isa<NoneType, TensorType>(type)) {
if (parser.parseColonType(type))
return failure();
}
return parser.addTypeToList(type, result.types);
}
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
printer << ' ' << getValue();
if (llvm::isa<spirv::ArrayType>(getType()))
printer << " : " << getType();
}
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
Type opType) {
if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
auto valueType = llvm::cast<TypedAttr>(value).getType();
if (valueType != opType)
return op.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
auto valueType = llvm::cast<TypedAttr>(value).getType();
if (valueType == opType)
return success();
auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
if (!arrayType)
return op.emitOpError("result or element type (")
<< opType << ") does not match value type (" << valueType
<< "), must be the same or spirv.array";
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
if (!opElemType.isIntOrFloat())
return op.emitOpError("only support nested array result type");
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
return op.emitOpError("result element type (")
<< opElemType << ") does not match value element type ("
<< valueElemType << ")";
}
if (numElements != shapedType.getNumElements()) {
return op.emitOpError("result number of elements (")
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
return success();
}
if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
if (!arrayType)
return op.emitOpError(
"must have spirv.array result type for array value");
Type elemType = arrayType.getElementType();
for (Attribute element : arrayAttr.getValue()) {
// Verify array elements recursively.
if (failed(verifyConstantType(op, element, elemType)))
return failure();
}
return success();
}
return op.emitOpError("cannot have attribute: ") << value;
}
LogicalResult spirv::ConstantOp::verify() {
// ODS already generates checks to make sure the result type is valid. We just
// need to additionally check that the value's attribute type is consistent
// with the result type.
return verifyConstantType(*this, getValueAttr(), getType());
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
// Must be valid SPIR-V type first.
if (!llvm::isa<spirv::SPIRVType>(type))
return false;
if (isa<SPIRVDialect>(type.getDialect())) {
// TODO: support constant struct
return llvm::isa<spirv::ArrayType>(type);
}
return true;
}
spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
OpBuilder &builder) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return builder.create<spirv::ConstantOp>(loc, type,
builder.getBoolAttr(false));
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return builder.create<spirv::ConstantOp>(
loc, type, builder.getFloatAttr(floatType, 0.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 0).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 0.0).getValue()));
}
}
llvm_unreachable("unimplemented types for ConstantOp::getZero()");
}
spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
OpBuilder &builder) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return builder.create<spirv::ConstantOp>(loc, type,
builder.getBoolAttr(true));
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return builder.create<spirv::ConstantOp>(
loc, type, builder.getFloatAttr(floatType, 1.0));
}
if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
if (llvm::isa<IntegerType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 1).getValue()));
}
if (llvm::isa<FloatType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseFPElementsAttr::get(vectorType,
FloatAttr::get(elemType, 1.0).getValue()));
}
}
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
}
void mlir::spirv::ConstantOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
Type type = getType();
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "cst";
IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
if (intTy && intTy.getWidth() == 1) {
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
}
if (intTy.isSignless()) {
specialName << intCst.getInt();
} else if (intTy.isUnsigned()) {
specialName << intCst.getUInt();
} else {
specialName << intCst.getSInt();
}
}
if (intTy || llvm::isa<FloatType>(type)) {
specialName << '_' << type;
}
if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
specialName << "_vec_";
specialName << vecType.getDimSize(0);
Type elementType = vecType.getElementType();
if (llvm::isa<IntegerType>(elementType) ||
llvm::isa<FloatType>(elementType)) {
specialName << "x" << elementType;
}
}
setNameFn(getResult(), specialName.str());
}
void mlir::spirv::AddressOfOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << getVariable() << "_addr";
setNameFn(getResult(), specialName.str());
}
//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::ControlBarrierOp::verify() {
return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
spirv::ExecutionModel executionModel,
spirv::FuncOp function,
ArrayRef<Attribute> interfaceVars) {
build(builder, state,
spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
}
ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
OperationState &result) {
spirv::ExecutionModel execModel;
SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
SmallVector<Type, 0> idTypes;
SmallVector<Attribute, 4> interfaceVars;
FlatSymbolRefAttr fn;
if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
return failure();
}
if (!parser.parseOptionalComma()) {
// Parse the interface variables
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
// The name of the interface variable attribute isnt important
FlatSymbolRefAttr var;
NamedAttrList attrs;
if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
return failure();
interfaceVars.push_back(var);
return success();
}))
return failure();
}
result.addAttribute(kInterfaceAttrName,
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
printer.printSymbolName(getFn());
auto interfaceVars = getInterface().getValue();
if (!interfaceVars.empty()) {
printer << ", ";
llvm::interleaveComma(interfaceVars, printer);
}
}
LogicalResult spirv::EntryPointOp::verify() {
// Checks for fn and interface symbol reference are done in spirv::ModuleOp
// verification.
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ExecutionMode
//===----------------------------------------------------------------------===//
void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
spirv::FuncOp function,
spirv::ExecutionMode executionMode,
ArrayRef<int32_t> params) {
build(builder, state, SymbolRefAttr::get(function),
spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
builder.getI32ArrayAttr(params));
}
ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
OperationState &result) {
spirv::ExecutionMode execMode;
Attribute fn;
if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
return failure();
}
SmallVector<int32_t, 4> values;
Type i32Type = parser.getBuilder().getIntegerType(32);
while (!parser.parseOptionalComma()) {
NamedAttrList attr;
Attribute value;
if (parser.parseAttribute(value, i32Type, "value", attr)) {
return failure();
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
result.addAttribute(kValuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
}
void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printSymbolName(getFn());
printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
auto values = this->getValues();
if (values.empty())
return;
printer << ", ";
llvm::interleaveComma(values, printer, [&](Attribute a) {
printer << llvm::cast<IntegerAttr>(a).getInt();
});
}
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
auto &builder = parser.getBuilder();
// Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
// Parse the function signature.
bool isVariadic = false;
if (function_interface_impl::parseFunctionSignature(
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
resultAttrs))
return failure();
SmallVector<Type> argTypes;
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(fnType));
// Parse the optional function control keyword.
spirv::FunctionControl fnControl;
if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
return failure();
// If additional attributes are present, parse them.
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
function_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
// Parse the optional function body.
auto *body = result.addRegion();
OptionalParseResult parseResult =
parser.parseOptionalRegion(*body, entryArgs);
return failure(parseResult.has_value() && failed(*parseResult));
}
void spirv::FuncOp::print(OpAsmPrinter &printer) {
// Print function name, signature, and control.
printer << " ";
printer.printSymbolName(getSymName());
auto fnType = getFunctionType();
function_interface_impl::printFunctionSignature(
printer, *this, fnType.getInputs(),
/*isVariadic=*/false, fnType.getResults());
printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
<< "\"";
function_interface_impl::printFunctionAttributes(
printer, *this,
{spirv::attributeName<spirv::FunctionControl>(),
getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getFunctionControlAttrName()});
// Print the body if this is not an external function.
Region &body = this->getBody();
if (!body.empty()) {
printer << ' ';
printer.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
}
LogicalResult spirv::FuncOp::verifyType() {
FunctionType fnType = getFunctionType();
if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
auto hasDecorationAttr = [&](spirv::Decoration decoration,
unsigned argIndex) {
auto func = llvm::cast<FunctionOpInterface>(getOperation());
for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
if (argAttr.getName() != spirv::DecorationAttr::name)
continue;
if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
return decAttr.getValue() == decoration;
}
return false;
};
for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
Type param = fnType.getInputs()[i];
auto inputPtrType = dyn_cast<spirv::PointerType>(param);
if (!inputPtrType)
continue;
auto pointeePtrType =
dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
if (pointeePtrType) {
// SPIR-V spec, from SPV_KHR_physical_storage_buffer:
// > If an OpFunctionParameter is a pointer (or contains a pointer)
// > and the type it points to is a pointer in the PhysicalStorageBuffer
// > storage class, the function parameter must be decorated with exactly
// > one of AliasedPointer or RestrictPointer.
if (pointeePtrType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer)
continue;
bool hasAliasedPtr =
hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
bool hasRestrictPtr =
hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
if (!hasAliasedPtr && !hasRestrictPtr)
return emitOpError()
<< "with a pointer points to a physical buffer pointer must "
"be decorated either 'AliasedPointer' or 'RestrictPointer'";
continue;
}
// SPIR-V spec, from SPV_KHR_physical_storage_buffer:
// > If an OpFunctionParameter is a pointer (or contains a pointer) in
// > the PhysicalStorageBuffer storage class, the function parameter must
// > be decorated with exactly one of Aliased or Restrict.
if (auto pointeeArrayType =
dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
pointeePtrType =
dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
} else {
pointeePtrType = inputPtrType;
}
if (!pointeePtrType || pointeePtrType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer)
continue;
bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
if (!hasAliased && !hasRestrict)
return emitOpError() << "with physical buffer pointer must be decorated "
"either 'Aliased' or 'Restrict'";
}
return success();
}
LogicalResult spirv::FuncOp::verifyBody() {
FunctionType fnType = getFunctionType();
auto walkResult = walk([fnType](Operation *op) -> WalkResult {
if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
if (fnType.getNumResults() != 0)
return retOp.emitOpError("cannot be used in functions returning value");
} else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
if (fnType.getNumResults() != 1)
return retOp.emitOpError(
"returns 1 value but enclosing function requires ")
<< fnType.getNumResults() << " results";
auto retOperandType = retOp.getValue().getType();
auto fnResultType = fnType.getResult(0);
if (retOperandType != fnResultType)
return retOp.emitOpError(" return value's type (")
<< retOperandType << ") mismatch with function's result type ("
<< fnResultType << ")";
}
return WalkResult::advance();
});
// TODO: verify other bits like linkage type.
return failure(walkResult.wasInterrupted());
}
void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
StringRef name, FunctionType type,
spirv::FunctionControl control,
ArrayRef<NamedAttribute> attrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
}
//===----------------------------------------------------------------------===//
// spirv.GLFClampOp
//===----------------------------------------------------------------------===//
ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.GLUClampOp
//===----------------------------------------------------------------------===//
ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.GLSClampOp
//===----------------------------------------------------------------------===//
ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.GLFmaOp
//===----------------------------------------------------------------------===//
ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.GlobalVariable
//===----------------------------------------------------------------------===//
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
Type type, StringRef name,
unsigned descriptorSet, unsigned binding) {
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
state.addAttribute(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
builder.getI32IntegerAttr(descriptorSet));
state.addAttribute(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
builder.getI32IntegerAttr(binding));
}
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
Type type, StringRef name,
spirv::BuiltIn builtin) {
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
state.addAttribute(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
}
ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
}
// Parse optional initializer
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
result.attributes) ||
parser.parseRParen())
return failure();
}
if (parseVariableDecorations(parser, result)) {
return failure();
}
Type type;
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
}
if (!llvm::isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
result.addAttribute(kTypeAttrName, TypeAttr::get(type));
return success();
}
void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
// Print variable name.
printer << ' ';
printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
// Print optional initializer
if (auto initializer = this->getInitializer()) {
printer << " " << kInitializerAttrName << '(';
printer.printSymbolName(*initializer);
printer << ')';
elidedAttrs.push_back(kInitializerAttrName);
}
elidedAttrs.push_back(kTypeAttrName);
spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
LogicalResult spirv::GlobalVariableOp::verify() {
if (!llvm::isa<spirv::PointerType>(getType()))
return emitOpError("result must be of a !spv.ptr type");
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
// Also, Function storage class is reserved by spirv.Variable.
auto storageClass = this->storageClass();
if (storageClass == spirv::StorageClass::Generic ||
storageClass == spirv::StorageClass::Function) {
return emitOpError("storage class cannot be '")
<< stringifyStorageClass(storageClass) << "'";
}
if (auto init =
(*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
"spirv.SpecConstant or spirv.GlobalVariable or "
"spirv.SpecConstantCompositeOp op");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.INTEL.SubgroupBlockRead
//===----------------------------------------------------------------------===//
ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
OpAsmParser::UnresolvedOperand ptrInfo;
Type elementType;
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
parser.parseColon() || parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
return failure();
}
result.addTypes(elementType);
return success();
}
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << getPtr() << " : " << getType();
}
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// spirv.INTEL.SubgroupBlockWrite
//===----------------------------------------------------------------------===//
ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
auto loc = parser.getCurrentLocation();
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
result.operands)) {
return failure();
}
return success();
}
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << getPtr() << ", " << getValue() << " : "
<< getValue().getType();
}
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// spirv.IAddCarryOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::IAddCarryOp::verify() {
return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseArithmeticExtendedBinaryOp(parser, result);
}
void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
// spirv.ISubBorrowOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::ISubBorrowOp::verify() {
return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseArithmeticExtendedBinaryOp(parser, result);
}
void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
// spirv.SMulExtended
//===----------------------------------------------------------------------===//
LogicalResult spirv::SMulExtendedOp::verify() {
return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseArithmeticExtendedBinaryOp(parser, result);
}
void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
// spirv.UMulExtended
//===----------------------------------------------------------------------===//
LogicalResult spirv::UMulExtendedOp::verify() {
return ::verifyArithmeticExtendedBinaryOp(*this);
}
ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseArithmeticExtendedBinaryOp(parser, result);
}
void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
// spirv.MemoryBarrierOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::MemoryBarrierOp::verify() {
return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
//===----------------------------------------------------------------------===//
// spirv.module
//===----------------------------------------------------------------------===//
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
std::optional<StringRef> name) {
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
}
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
spirv::AddressingModel addressingModel,
spirv::MemoryModel memoryModel,
std::optional<VerCapExtAttr> vceTriple,
std::optional<StringRef> name) {
state.addAttribute(
"addressing_model",
builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
state.addAttribute("memory_model",
builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (vceTriple)
state.addAttribute(getVCETripleAttrName(), *vceTriple);
if (name)
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
OperationState &result) {
Region *body = result.addRegion();
// If the name is present, parse it.
StringAttr nameAttr;
(void)parser.parseOptionalSymbolName(
nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
result) ||
spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
result))
return failure();
if (succeeded(parser.parseOptionalKeyword("requires"))) {
spirv::VerCapExtAttr vceTriple;
if (parser.parseAttribute(vceTriple,
spirv::ModuleOp::getVCETripleAttrName(),
result.attributes))
return failure();
}
if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
parser.parseRegion(*body, /*arguments=*/{}))
return failure();
// Make sure we have at least one block.
if (body->empty())
body->push_back(new Block());
return success();
}
void spirv::ModuleOp::print(OpAsmPrinter &printer) {
if (std::optional<StringRef> name = getName()) {
printer << ' ';
printer.printSymbolName(*name);
}
SmallVector<StringRef, 2> elidedAttrs;
printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
<< spirv::stringifyMemoryModel(getMemoryModel());
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
mlir::SymbolTable::getSymbolAttrName()});
if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
printer << " requires " << *triple;
elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
}
printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
printer << ' ';
printer.printRegion(getRegion());
}
LogicalResult spirv::ModuleOp::verifyRegions() {
Dialect *dialect = (*this)->getDialect();
DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
entryPoints;
mlir::SymbolTable table(*this);
for (auto &op : *getBody()) {
if (op.getDialect() != dialect)
return op.emitError("'spirv.module' can only contain spirv.* ops");
// For EntryPoint op, check that the function and execution model is not
// duplicated in EntryPointOps. Also verify that the interface specified
// comes from globalVariables here to make this check cheaper.
if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
if (!funcOp) {
return entryPointOp.emitError("function '")
<< entryPointOp.getFn() << "' not found in 'spirv.module'";
}
if (auto interface = entryPointOp.getInterface()) {
for (Attribute varRef : interface) {
auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
"specification instead of '")
<< varRef;
}
auto variableOp =
table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
if (!variableOp) {
return entryPointOp.emitError("expected spirv.GlobalVariable "
"symbol reference instead of'")
<< varSymRef << "'";
}
}
}
auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
funcOp, entryPointOp.getExecutionModel());
auto entryPtIt = entryPoints.find(key);
if (entryPtIt != entryPoints.end()) {
return entryPointOp.emitError("duplicate of a previous EntryPointOp");
}
entryPoints[key] = entryPointOp;
} else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
// If the function is external and does not have 'Import'
// linkage_attributes(LinkageAttributes), throw an error. 'Import'
// LinkageAttributes is used to import external functions.
auto linkageAttr = funcOp.getLinkageAttributes();
auto hasImportLinkage =
linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
spirv::LinkageType::Import);
if (funcOp.isExternal() && !hasImportLinkage)
return op.emitError(
"'spirv.module' cannot contain external functions "
"without 'Import' linkage_attributes (LinkageAttributes)");
// TODO: move this check to spirv.func.
for (auto &block : funcOp)
for (auto &op : block) {
if (op.getDialect() != dialect)
return op.emitError(
"functions in 'spirv.module' can only contain spirv.* ops");
}
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.mlir.referenceof
//===----------------------------------------------------------------------===//
LogicalResult spirv::ReferenceOfOp::verify() {
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), getSpecConstAttr());
Type constType;
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
if (specConstOp)
constType = specConstOp.getDefaultValue().getType();
auto specConstCompositeOp =
dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
if (specConstCompositeOp)
constType = specConstCompositeOp.getType();
if (!specConstOp && !specConstCompositeOp)
return emitOpError(
"expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
if (getReference().getType() != constType)
return emitOpError("result type mismatch with the referenced "
"specialization constant's type");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.SpecConstant
//===----------------------------------------------------------------------===//
ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
// Parse optional spec_id.
if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
IntegerAttr specIdAttr;
if (parser.parseLParen() ||
parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
parser.parseRParen())
return failure();
}
if (parser.parseEqual() ||
parser.parseAttribute(valueAttr, kDefaultValueAttrName,
result.attributes))
return failure();
return success();
}
void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
printer << ' ';
printer.printSymbolName(getSymName());
if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
printer << " = " << getDefaultValue();
}
LogicalResult spirv::SpecConstantOp::verify() {
if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
if (specID.getValue().isNegative())
return emitOpError("SpecId cannot be negative");
auto value = getDefaultValue();
if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
// Make sure bitwidth is allowed.
if (!llvm::isa<spirv::SPIRVType>(value.getType()))
return emitOpError("default value bitwidth disallowed");
return success();
}
return emitOpError(
"default value can only be a bool, integer, or float scalar");
}
//===----------------------------------------------------------------------===//
// spirv.VectorShuffle
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorShuffleOp::verify() {
VectorType resultType = llvm::cast<VectorType>(getType());
size_t numResultElements = resultType.getNumElements();
if (numResultElements != getComponents().size())
return emitOpError("result type element count (")
<< numResultElements
<< ") mismatch with the number of component selectors ("
<< getComponents().size() << ")";
size_t totalSrcElements =
llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
llvm::cast<VectorType>(getVector2().getType()).getNumElements();
for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
uint32_t index = selector.getZExtValue();
if (index >= totalSrcElements &&
index != std::numeric_limits<uint32_t>().max())
return emitOpError("component selector ")
<< index << " out of range: expected to be in [0, "
<< totalSrcElements << ") or 0xffffffff";
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesScalar
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesScalarOp::verify() {
Type elementType =
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
.Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
.Default([](Type) { return nullptr; });
assert(elementType && "Unhandled type");
// Check that the scalar type is the same as the matrix element type.
if (getScalar().getType() != elementType)
return emitOpError("input matrix components' type and scaling value must "
"have the same type");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.Transpose
//===----------------------------------------------------------------------===//
LogicalResult spirv::TransposeOp::verify() {
auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
// Verify that the input and output matrices have correct shapes.
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
return emitError("input matrix rows count must be equal to "
"output matrix columns count");
if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
return emitError("input matrix columns count must be equal to "
"output matrix rows count");
// Verify that the input and output matrices have the same component type
if (inputMatrix.getElementType() != resultMatrix.getElementType())
return emitError("input and output matrices must have the same "
"component type");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesMatrixOp::verify() {
auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
// left matrix columns' count and right matrix rows' count must be equal
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
return emitError("left matrix columns' count must be equal to "
"the right matrix rows' count");
// right and result matrices columns' count must be the same
if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
return emitError(
"right and result matrices must have equal columns' count");
// right and result matrices component type must be the same
if (rightMatrix.getElementType() != resultMatrix.getElementType())
return emitError("right and result matrices' component type must"
" be the same");
// left and result matrices component type must be the same
if (leftMatrix.getElementType() != resultMatrix.getElementType())
return emitError("left and result matrices' component type"
" must be the same");
// left and result matrices rows count must be the same
if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
return emitError("left and result matrices must have equal rows' count");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.SpecConstantComposite
//===----------------------------------------------------------------------===//
ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr compositeName;
if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseLParen())
return failure();
SmallVector<Attribute, 4> constituents;
do {
// The name of the constituent attribute isn't important
const char *attrName = "spec_const";
FlatSymbolRefAttr specConstRef;
NamedAttrList attrs;
if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
return failure();
constituents.push_back(specConstRef);
} while (!parser.parseOptionalComma());
if (parser.parseRParen())
return failure();
result.addAttribute(kCompositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
Type type;
if (parser.parseColonType(type))
return failure();
result.addAttribute(kTypeAttrName, TypeAttr::get(type));
return success();
}
void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printSymbolName(getSymName());
printer << " (";
auto constituents = this->getConstituents().getValue();
if (!constituents.empty())
llvm::interleaveComma(constituents, printer);
printer << ") : " << getType();
}
LogicalResult spirv::SpecConstantCompositeOp::verify() {
auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
auto constituents = this->getConstituents().getValue();
if (!cType)
return emitError("result type must be a composite type, but provided ")
<< getType();
if (llvm::isa<spirv::CooperativeMatrixType>(cType))
return emitError("unsupported composite type ") << cType;
if (llvm::isa<spirv::JointMatrixINTELType>(cType))
return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
<< cType.getNumElements() << ", but provided "
<< constituents.size();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
auto constituentSpecConstOp =
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), constituent.getAttr()));
if (constituentSpecConstOp.getDefaultValue().getType() !=
cType.getElementType(index))
return emitError("has incorrect types of operands: expected ")
<< cType.getElementType(index) << ", but provided "
<< constituentSpecConstOp.getDefaultValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//
ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
OperationState &result) {
Region *body = result.addRegion();
if (parser.parseKeyword("wraps"))
return failure();
body->push_back(new Block);
Block &block = body->back();
Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
if (!wrappedOp)
return failure();
OpBuilder builder(parser.getContext());
builder.setInsertionPointToEnd(&block);
builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
result.location = wrappedOp->getLoc();
result.addTypes(wrappedOp->getResult(0).getType());
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
printer << " wraps ";
printer.printGenericOp(&getBody().front().front());
}
LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
Block &block = getRegion().getBlocks().front();
if (block.getOperations().size() != 2)
return emitOpError("expected exactly 2 nested ops");
Operation &enclosedOp = block.getOperations().front();
if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
return emitOpError("invalid enclosed op");
for (auto operand : enclosedOp.getOperands())
if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
return emitOpError(
"invalid operand, must be defined by a constant operation");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GL.FrexpStruct
//===----------------------------------------------------------------------===//
LogicalResult spirv::GLFrexpStructOp::verify() {
spirv::StructType structTy =
llvm::dyn_cast<spirv::StructType>(getResult().getType());
if (structTy.getNumElements() != 2)
return emitError("result type must be a struct type with two memebers");
Type significandTy = structTy.getElementType(0);
Type exponentTy = structTy.getElementType(1);
VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
Type operandTy = getOperand().getType();
VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
if (significandTy != operandTy)
return emitError("member zero of the resulting struct type must be the "
"same type as the operand");
if (exponentVecTy) {
IntegerType componentIntTy =
llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
if (!componentIntTy || componentIntTy.getWidth() != 32)
return emitError("member one of the resulting struct type must"
"be a scalar or vector of 32 bit integer type");
} else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
return emitError("member one of the resulting struct type "
"must be a scalar or vector of 32 bit integer type");
}
// Check that the two member types have the same number of components
if (operandVecTy && exponentVecTy &&
(exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
return success();
if (operandFTy && exponentIntTy)
return success();
return emitError("member one of the resulting struct type must have the same "
"number of components as the operand type");
}
//===----------------------------------------------------------------------===//
// spirv.GL.Ldexp
//===----------------------------------------------------------------------===//
LogicalResult spirv::GLLdexpOp::verify() {
Type significandType = getX().getType();
Type exponentType = getExp().getType();
if (llvm::isa<FloatType>(significandType) !=
llvm::isa<IntegerType>(exponentType))
return emitOpError("operands must both be scalars or vectors");
auto getNumElements = [](Type type) -> unsigned {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return vectorType.getNumElements();
return 1;
};
if (getNumElements(significandType) != getNumElements(exponentType))
return emitOpError("operands must have the same number of elements");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ImageDrefGather
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageDrefGatherOp::verify() {
VectorType resultType = llvm::cast<VectorType>(getResult().getType());
auto sampledImageType =
llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
auto imageType =
llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
if (resultType.getNumElements() != 4)
return emitOpError("result type must be a vector of four components");
Type elementType = resultType.getElementType();
Type sampledElementType = imageType.getElementType();
if (!llvm::isa<NoneType>(sampledElementType) &&
elementType != sampledElementType)
return emitOpError(
"the component type of result must be the same as sampled type of the "
"underlying image type");
spirv::Dim imageDim = imageType.getDim();
spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
imageDim != spirv::Dim::Rect)
return emitOpError(
"the Dim operand of the underlying image type must be 2D, Cube, or "
"Rect");
if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
return emitOpError("the MS operand of the underlying image type must be 0");
spirv::ImageOperandsAttr attr = getImageoperandsAttr();
auto operandArguments = getOperandArguments();
return verifyImageOperands(*this, attr, operandArguments);
}
//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogicalOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::ShiftLeftLogicalOp::verify() {
return verifyShiftOp(*this);
}
//===----------------------------------------------------------------------===//
// spirv.ShiftRightArithmeticOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::ShiftRightArithmeticOp::verify() {
return verifyShiftOp(*this);
}
//===----------------------------------------------------------------------===//
// spirv.ShiftRightLogicalOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::ShiftRightLogicalOp::verify() {
return verifyShiftOp(*this);
}
//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageQuerySizeOp::verify() {
spirv::ImageType imageType =
llvm::cast<spirv::ImageType>(getImage().getType());
Type resultType = getResult().getType();
spirv::Dim dim = imageType.getDim();
spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
switch (dim) {
case spirv::Dim::Dim1D:
case spirv::Dim::Dim2D:
case spirv::Dim::Dim3D:
case spirv::Dim::Cube:
if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
return emitError(
"if Dim is 1D, 2D, 3D, or Cube, "
"it must also have either an MS of 1 or a Sampled of 0 or 2");
break;
case spirv::Dim::Buffer:
case spirv::Dim::Rect:
break;
default:
return emitError("the Dim operand of the image type must "
"be 1D, 2D, 3D, Buffer, Cube, or Rect");
}
unsigned componentNumber = 0;
switch (dim) {
case spirv::Dim::Dim1D:
case spirv::Dim::Buffer:
componentNumber = 1;
break;
case spirv::Dim::Dim2D:
case spirv::Dim::Cube:
case spirv::Dim::Rect:
componentNumber = 2;
break;
case spirv::Dim::Dim3D:
componentNumber = 3;
break;
default:
break;
}
if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
componentNumber += 1;
unsigned resultComponentNumber = 1;
if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
resultComponentNumber = resultVectorType.getNumElements();
if (componentNumber != resultComponentNumber)
return emitError("expected the result to have ")
<< componentNumber << " component(s), but found "
<< resultComponentNumber << " component(s)";
return success();
}
//===----------------------------------------------------------------------===//
// spirv.VectorTimesScalarOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorTimesScalarOp::verify() {
if (getVector().getType() != getType())
return emitOpError("vector operand and result type mismatch");
auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
if (getScalar().getType() != scalarType)
return emitOpError("scalar operand and result element type match");
return success();
}