//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the operations in the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "SPIRVOpUtils.h" #include "SPIRVParsingUtils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include #include #include #include using namespace mlir; using namespace mlir::spirv::AttrNames; //===----------------------------------------------------------------------===// // Common utility functions //===----------------------------------------------------------------------===// LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { auto constOp = dyn_cast_or_null(op); if (!constOp) { return failure(); } auto valueAttr = constOp.getValue(); auto integerValueAttr = llvm::dyn_cast(valueAttr); if (!integerValueAttr) { return failure(); } if (integerValueAttr.getType().isSignlessInteger()) value = integerValueAttr.getInt(); else value = integerValueAttr.getSInt(); return success(); } LogicalResult spirv::verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) { // According to the SPIR-V specification: // "Despite being a mask and allowing multiple bits to be combined, it is // invalid for more than one of these four bits to be set: Acquire, Release, // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and // Release semantics is done by setting the AcquireRelease bit, not by setting // two bits." auto atMostOneInSet = spirv::MemorySemantics::Acquire | spirv::MemorySemantics::Release | spirv::MemorySemantics::AcquireRelease | spirv::MemorySemantics::SequentiallyConsistent; auto bitCount = llvm::popcount(static_cast(memorySemantics & atMostOneInSet)); if (bitCount > 1) { return op->emitError( "expected at most one of these four memory constraints " "to be set: `Acquire`, `Release`," "`AcquireRelease` or `SequentiallyConsistent`"); } return success(); } void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs) { // Print optional descriptor binding auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); auto bindingName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::Binding)); auto descriptorSet = op->getAttrOfType(descriptorSetName); auto binding = op->getAttrOfType(bindingName); if (descriptorSet && binding) { elidedAttrs.push_back(descriptorSetName); elidedAttrs.push_back(bindingName); printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() << ")"; } // Print BuiltIn attribute if present auto builtInName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::BuiltIn)); if (auto builtin = op->getAttrOfType(builtInName)) { printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; elidedAttrs.push_back(builtInName); } printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; // If the operand list is in-between parentheses, then we have a generic form. // (see the fallback in `printOneResultOp`). SMLoc loc = parser.getCurrentLocation(); if (!parser.parseOptionalLParen()) { if (parser.parseOperandList(ops) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(type)) return failure(); auto fnType = llvm::dyn_cast(type); if (!fnType) { parser.emitError(loc, "expected function type"); return failure(); } if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) return failure(); result.addTypes(fnType.getResults()); return success(); } return failure(parser.parseOperandList(ops) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands) || parser.addTypeToList(type, result.types)); } static void printOneResultOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumResults() == 1 && "op should have one result"); // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. auto resultType = op->getResult(0).getType(); if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; })) { p.printGenericOp(op, /*printOpName=*/false); return; } p << ' '; p.printOperands(op->getOperands()); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. p << " : " << resultType; } template static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands) { if (!attr) { if (operands.empty()) return success(); return imageOp.emitError("the Image Operands should encode what operands " "follow, as per Image Operands"); } // TODO: Add the validation rules for the following Image Operands. spirv::ImageOperands noSupportOperands = spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | spirv::ImageOperands::MakeTexelAvailable | spirv::ImageOperands::MakeTexelVisible | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) llvm_unreachable("unimplemented operands of Image Operands"); return success(); } template static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { auto valType = val.getType(); if (auto valVecTy = llvm::dyn_cast(valType)) valType = valVecTy.getElementType(); if (valType != llvm::cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); } /// Walks the given type hierarchy with the given indices, potentially down /// to component granularity, to select an element type. Returns null type and /// emits errors with the given loc on failure. static Type getElementType(Type type, ArrayRef indices, function_ref emitErrorFn) { if (indices.empty()) { emitErrorFn("expected at least one index for spirv.CompositeExtract"); return nullptr; } for (auto index : indices) { if (auto cType = llvm::dyn_cast(type)) { if (cType.hasCompileTimeKnownNumElements() && (index < 0 || static_cast(index) >= cType.getNumElements())) { emitErrorFn("index ") << index << " out of bounds for " << type; return nullptr; } type = cType.getElementType(index); } else { emitErrorFn("cannot extract from non-composite type ") << type << " with index " << index; return nullptr; } } return type; } static Type getElementType(Type type, Attribute indices, function_ref emitErrorFn) { auto indicesArrayAttr = llvm::dyn_cast(indices); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); return nullptr; } if (indicesArrayAttr.empty()) { emitErrorFn("expected at least one index for spirv.CompositeExtract"); return nullptr; } SmallVector indexVals; for (auto indexAttr : indicesArrayAttr) { auto indexIntAttr = llvm::dyn_cast(indexAttr); if (!indexIntAttr) { emitErrorFn("expected an 32-bit integer for index, but found '") << indexAttr << "'"; return nullptr; } indexVals.push_back(indexIntAttr.getInt()); } return getElementType(type, indexVals, emitErrorFn); } static Type getElementType(Type type, Attribute indices, Location loc) { auto errorFn = [&](StringRef err) -> InFlightDiagnostic { return ::mlir::emitError(loc, err); }; return getElementType(type, indices, errorFn); } static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, SMLoc loc) { auto errorFn = [&](StringRef err) -> InFlightDiagnostic { return parser.emitError(loc, err); }; return getElementType(type, indices, errorFn); } template static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { auto resultType = llvm::cast(op.getType()); if (resultType.getNumElements() != 2) return op.emitOpError("expected result struct type containing two members"); if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), resultType.getElementType(0), resultType.getElementType(1)})) return op.emitOpError( "expected all operand types and struct member types are the same"); return success(); } static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, OperationState &result) { SmallVector operands; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseOperandList(operands) || parser.parseColon()) return failure(); Type resultType; SMLoc loc = parser.getCurrentLocation(); if (parser.parseType(resultType)) return failure(); auto structType = llvm::dyn_cast(resultType); if (!structType || structType.getNumElements() != 2) return parser.emitError(loc, "expected spirv.struct type with two members"); SmallVector operandTypes(2, structType.getElementType(0)); if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) return failure(); result.addTypes(resultType); return success(); } static void printArithmeticExtendedBinaryOp(Operation *op, OpAsmPrinter &printer) { printer << ' '; printer.printOptionalAttrDict(op->getAttrs()); printer.printOperands(op->getOperands()); printer << " : " << op->getResultTypes().front(); } static LogicalResult verifyShiftOp(Operation *op) { if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " "result, but provided ") << op->getOperand(0).getType() << " and " << op->getResult(0).getType(); } return success(); } //===----------------------------------------------------------------------===// // spirv.mlir.addressof //===----------------------------------------------------------------------===// void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, spirv::GlobalVariableOp var) { build(builder, state, var.getType(), SymbolRefAttr::get(var)); } LogicalResult spirv::AddressOfOp::verify() { auto varOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), getVariableAttr())); if (!varOp) { return emitOpError("expected spirv.GlobalVariable symbol"); } if (getPointer().getType() != varOp.getType()) { return emitOpError( "result type mismatch with the referenced global variable's type"); } return success(); } //===----------------------------------------------------------------------===// // spirv.CompositeConstruct //===----------------------------------------------------------------------===// LogicalResult spirv::CompositeConstructOp::verify() { operand_range constituents = this->getConstituents(); // There are 4 cases with varying verification rules: // 1. Cooperative Matrices (1 constituent) // 2. Structs (1 constituent for each member) // 3. Arrays (1 constituent for each array element) // 4. Vectors (1 constituent (sub-)element for each vector element) auto coopElementType = llvm::TypeSwitch(getType()) .Case( [](auto coopType) { return coopType.getElementType(); }) .Default([](Type) { return nullptr; }); // Case 1. -- matrices. if (coopElementType) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); if (coopElementType != constituents.front().getType()) return emitOpError("operand type mismatch: expected operand type ") << coopElementType << ", but provided " << constituents.front().getType(); return success(); } // Case 2./3./4. -- number of constituents matches the number of elements. auto cType = llvm::cast(getType()); if (constituents.size() == cType.getNumElements()) { for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { return emitOpError("operand type mismatch: expected operand type ") << cType.getElementType(index) << ", but provided " << constituents[index].getType(); } } return success(); } // Case 4. -- check that all constituents add up tp the expected vector type. auto resultType = llvm::dyn_cast(cType); if (!resultType) return emitOpError( "expected to return a vector or cooperative matrix when the number of " "constituents is less than what the result needs"); SmallVector sizes; for (Value component : constituents) { if (!llvm::isa(component.getType()) && !component.getType().isIntOrFloat()) return emitOpError("operand type mismatch: expected operand to have " "a scalar or vector type, but provided ") << component.getType(); Type elementType = component.getType(); if (auto vectorType = llvm::dyn_cast(component.getType())) { sizes.push_back(vectorType.getNumElements()); elementType = vectorType.getElementType(); } else { sizes.push_back(1); } if (elementType != resultType.getElementType()) return emitOpError("operand element type mismatch: expected to be ") << resultType.getElementType() << ", but provided " << elementType; } unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); if (totalCount != cType.getNumElements()) return emitOpError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << totalCount; return success(); } //===----------------------------------------------------------------------===// // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, Value composite, ArrayRef indices) { auto indexAttr = builder.getI32ArrayAttr(indices); auto elementType = getElementType(composite.getType(), indexAttr, state.location); if (!elementType) { return; } build(builder, state, elementType, composite, indexAttr); } ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand compositeInfo; Attribute indicesAttr; Type compositeType; SMLoc attrLocation; if (parser.parseOperand(compositeInfo) || parser.getCurrentLocation(&attrLocation) || parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || parser.parseColonType(compositeType) || parser.resolveOperand(compositeInfo, compositeType, result.operands)) { return failure(); } Type resultType = getElementType(compositeType, indicesAttr, parser, attrLocation); if (!resultType) { return failure(); } result.addTypes(resultType); return success(); } void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { printer << ' ' << getComposite() << getIndices() << " : " << getComposite().getType(); } LogicalResult spirv::CompositeExtractOp::verify() { auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto resultType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!resultType) return failure(); if (resultType != getType()) { return emitOpError("invalid result type: expected ") << resultType << " but provided " << getType(); } return success(); } //===----------------------------------------------------------------------===// // spirv.CompositeInsert //===----------------------------------------------------------------------===// void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, Value object, Value composite, ArrayRef indices) { auto indexAttr = builder.getI32ArrayAttr(indices); build(builder, state, composite.getType(), object, composite, indexAttr); } ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operands; Type objectType, compositeType; Attribute indicesAttr; auto loc = parser.getCurrentLocation(); return failure( parser.parseOperandList(operands, 2) || parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || parser.parseColonType(objectType) || parser.parseKeywordType("into", compositeType) || parser.resolveOperands(operands, {objectType, compositeType}, loc, result.operands) || parser.addTypesToList(compositeType, result.types)); } LogicalResult spirv::CompositeInsertOp::verify() { auto indicesArrayAttr = llvm::dyn_cast(getIndices()); auto objectType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!objectType) return failure(); if (objectType != getObject().getType()) { return emitOpError("object operand type should be ") << objectType << ", but found " << getObject().getType(); } if (getComposite().getType() != getType()) { return emitOpError("result type should be the same as " "the composite type, but found ") << getComposite().getType() << " vs " << getType(); } return success(); } void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { printer << " " << getObject() << ", " << getComposite() << getIndices() << " : " << getObject().getType() << " into " << getComposite().getType(); } //===----------------------------------------------------------------------===// // spirv.Constant //===----------------------------------------------------------------------===// ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, OperationState &result) { Attribute value; if (parser.parseAttribute(value, kValueAttrName, result.attributes)) return failure(); Type type = NoneType::get(parser.getContext()); if (auto typedAttr = llvm::dyn_cast(value)) type = typedAttr.getType(); if (llvm::isa(type)) { if (parser.parseColonType(type)) return failure(); } return parser.addTypeToList(type, result.types); } void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << getValue(); if (llvm::isa(getType())) printer << " : " << getType(); } static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { if (llvm::isa(value)) { auto valueType = llvm::cast(value).getType(); if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } if (llvm::isa(value)) { auto valueType = llvm::cast(value).getType(); if (valueType == opType) return success(); auto arrayType = llvm::dyn_cast(opType); auto shapedType = llvm::dyn_cast(valueType); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType << "), must be the same or spirv.array"; int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); while (auto t = llvm::dyn_cast(opElemType)) { numElements *= t.getNumElements(); opElemType = t.getElementType(); } if (!opElemType.isIntOrFloat()) return op.emitOpError("only support nested array result type"); auto valueElemType = shapedType.getElementType(); if (valueElemType != opElemType) { return op.emitOpError("result element type (") << opElemType << ") does not match value element type (" << valueElemType << ")"; } if (numElements != shapedType.getNumElements()) { return op.emitOpError("result number of elements (") << numElements << ") does not match value number of elements (" << shapedType.getNumElements() << ")"; } return success(); } if (auto arrayAttr = llvm::dyn_cast(value)) { auto arrayType = llvm::dyn_cast(opType); if (!arrayType) return op.emitOpError( "must have spirv.array result type for array value"); Type elemType = arrayType.getElementType(); for (Attribute element : arrayAttr.getValue()) { // Verify array elements recursively. if (failed(verifyConstantType(op, element, elemType))) return failure(); } return success(); } return op.emitOpError("cannot have attribute: ") << value; } LogicalResult spirv::ConstantOp::verify() { // ODS already generates checks to make sure the result type is valid. We just // need to additionally check that the value's attribute type is consistent // with the result type. return verifyConstantType(*this, getValueAttr(), getType()); } bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. if (!llvm::isa(type)) return false; if (isa(type.getDialect())) { // TODO: support constant struct return llvm::isa(type); } return true; } spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, OpBuilder &builder) { if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, builder.getBoolAttr(false)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 0.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 0.0).getValue())); } } llvm_unreachable("unimplemented types for ConstantOp::getZero()"); } spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, OpBuilder &builder) { if (auto intType = llvm::dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, builder.getBoolAttr(true)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } if (auto floatType = llvm::dyn_cast(type)) { return builder.create( loc, type, builder.getFloatAttr(floatType, 1.0)); } if (auto vectorType = llvm::dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (llvm::isa(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } if (llvm::isa(elemType)) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 1.0).getValue())); } } llvm_unreachable("unimplemented types for ConstantOp::getOne()"); } void mlir::spirv::ConstantOp::getAsmResultNames( llvm::function_ref setNameFn) { Type type = getType(); SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "cst"; IntegerType intTy = llvm::dyn_cast(type); if (IntegerAttr intCst = llvm::dyn_cast(getValue())) { if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } if (intTy.isSignless()) { specialName << intCst.getInt(); } else if (intTy.isUnsigned()) { specialName << intCst.getUInt(); } else { specialName << intCst.getSInt(); } } if (intTy || llvm::isa(type)) { specialName << '_' << type; } if (auto vecType = llvm::dyn_cast(type)) { specialName << "_vec_"; specialName << vecType.getDimSize(0); Type elementType = vecType.getElementType(); if (llvm::isa(elementType) || llvm::isa(elementType)) { specialName << "x" << elementType; } } setNameFn(getResult(), specialName.str()); } void mlir::spirv::AddressOfOp::getAsmResultNames( llvm::function_ref setNameFn) { SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << getVariable() << "_addr"; setNameFn(getResult(), specialName.str()); } //===----------------------------------------------------------------------===// // spirv.ControlBarrierOp //===----------------------------------------------------------------------===// LogicalResult spirv::ControlBarrierOp::verify() { return verifyMemorySemantics(getOperation(), getMemorySemantics()); } //===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, spirv::ExecutionModel executionModel, spirv::FuncOp function, ArrayRef interfaceVars) { build(builder, state, spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); } ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, OperationState &result) { spirv::ExecutionModel execModel; SmallVector identifiers; SmallVector idTypes; SmallVector interfaceVars; FlatSymbolRefAttr fn; if (parseEnumStrAttr(execModel, parser, result) || parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { return failure(); } if (!parser.parseOptionalComma()) { // Parse the interface variables if (parser.parseCommaSeparatedList([&]() -> ParseResult { // The name of the interface variable attribute isnt important FlatSymbolRefAttr var; NamedAttrList attrs; if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) return failure(); interfaceVars.push_back(var); return success(); })) return failure(); } result.addAttribute(kInterfaceAttrName, parser.getBuilder().getArrayAttr(interfaceVars)); return success(); } void spirv::EntryPointOp::print(OpAsmPrinter &printer) { printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; printer.printSymbolName(getFn()); auto interfaceVars = getInterface().getValue(); if (!interfaceVars.empty()) { printer << ", "; llvm::interleaveComma(interfaceVars, printer); } } LogicalResult spirv::EntryPointOp::verify() { // Checks for fn and interface symbol reference are done in spirv::ModuleOp // verification. return success(); } //===----------------------------------------------------------------------===// // spirv.ExecutionMode //===----------------------------------------------------------------------===// void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, spirv::FuncOp function, spirv::ExecutionMode executionMode, ArrayRef params) { build(builder, state, SymbolRefAttr::get(function), spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), builder.getI32ArrayAttr(params)); } ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, OperationState &result) { spirv::ExecutionMode execMode; Attribute fn; if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) || parseEnumStrAttr(execMode, parser, result)) { return failure(); } SmallVector values; Type i32Type = parser.getBuilder().getIntegerType(32); while (!parser.parseOptionalComma()) { NamedAttrList attr; Attribute value; if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); } values.push_back(llvm::cast(value).getInt()); } result.addAttribute(kValuesAttrName, parser.getBuilder().getI32ArrayAttr(values)); return success(); } void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getFn()); printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; auto values = this->getValues(); if (values.empty()) return; printer << ", "; llvm::interleaveComma(values, printer, [&](Attribute a) { printer << llvm::cast(a).getInt(); }); } //===----------------------------------------------------------------------===// // spirv.func //===----------------------------------------------------------------------===// ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; auto &builder = parser.getBuilder(); // Parse the name as a symbol. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); // Parse the function signature. bool isVariadic = false; if (function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); SmallVector argTypes; for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); // Parse the optional function control keyword. spirv::FunctionControl fnControl; if (parseEnumStrAttr(fnControl, parser, result)) return failure(); // If additional attributes are present, parse them. if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); function_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); // Parse the optional function body. auto *body = result.addRegion(); OptionalParseResult parseResult = parser.parseOptionalRegion(*body, entryArgs); return failure(parseResult.has_value() && failed(*parseResult)); } void spirv::FuncOp::print(OpAsmPrinter &printer) { // Print function name, signature, and control. printer << " "; printer.printSymbolName(getSymName()); auto fnType = getFunctionType(); function_interface_impl::printFunctionSignature( printer, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( printer, *this, {spirv::attributeName(), getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); if (!body.empty()) { printer << ' '; printer.printRegion(body, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } } LogicalResult spirv::FuncOp::verifyType() { FunctionType fnType = getFunctionType(); if (fnType.getNumResults() > 1) return emitOpError("cannot have more than one result"); auto hasDecorationAttr = [&](spirv::Decoration decoration, unsigned argIndex) { auto func = llvm::cast(getOperation()); for (auto argAttr : cast(func).getArgAttrs(argIndex)) { if (argAttr.getName() != spirv::DecorationAttr::name) continue; if (auto decAttr = dyn_cast(argAttr.getValue())) return decAttr.getValue() == decoration; } return false; }; for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) { Type param = fnType.getInputs()[i]; auto inputPtrType = dyn_cast(param); if (!inputPtrType) continue; auto pointeePtrType = dyn_cast(inputPtrType.getPointeeType()); if (pointeePtrType) { // SPIR-V spec, from SPV_KHR_physical_storage_buffer: // > If an OpFunctionParameter is a pointer (or contains a pointer) // > and the type it points to is a pointer in the PhysicalStorageBuffer // > storage class, the function parameter must be decorated with exactly // > one of AliasedPointer or RestrictPointer. if (pointeePtrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) continue; bool hasAliasedPtr = hasDecorationAttr(spirv::Decoration::AliasedPointer, i); bool hasRestrictPtr = hasDecorationAttr(spirv::Decoration::RestrictPointer, i); if (!hasAliasedPtr && !hasRestrictPtr) return emitOpError() << "with a pointer points to a physical buffer pointer must " "be decorated either 'AliasedPointer' or 'RestrictPointer'"; continue; } // SPIR-V spec, from SPV_KHR_physical_storage_buffer: // > If an OpFunctionParameter is a pointer (or contains a pointer) in // > the PhysicalStorageBuffer storage class, the function parameter must // > be decorated with exactly one of Aliased or Restrict. if (auto pointeeArrayType = dyn_cast(inputPtrType.getPointeeType())) { pointeePtrType = dyn_cast(pointeeArrayType.getElementType()); } else { pointeePtrType = inputPtrType; } if (!pointeePtrType || pointeePtrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) continue; bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i); bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i); if (!hasAliased && !hasRestrict) return emitOpError() << "with physical buffer pointer must be decorated " "either 'Aliased' or 'Restrict'"; } return success(); } LogicalResult spirv::FuncOp::verifyBody() { FunctionType fnType = getFunctionType(); auto walkResult = walk([fnType](Operation *op) -> WalkResult { if (auto retOp = dyn_cast(op)) { if (fnType.getNumResults() != 0) return retOp.emitOpError("cannot be used in functions returning value"); } else if (auto retOp = dyn_cast(op)) { if (fnType.getNumResults() != 1) return retOp.emitOpError( "returns 1 value but enclosing function requires ") << fnType.getNumResults() << " results"; auto retOperandType = retOp.getValue().getType(); auto fnResultType = fnType.getResult(0); if (retOperandType != fnResultType) return retOp.emitOpError(" return value's type (") << retOperandType << ") mismatch with function's result type (" << fnResultType << ")"; } return WalkResult::advance(); }); // TODO: verify other bits like linkage type. return failure(walkResult.wasInterrupted()); } void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, spirv::FunctionControl control, ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); } //===----------------------------------------------------------------------===// // spirv.GLFClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spirv.GLUClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spirv.GLSClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spirv.GLFmaOp //===----------------------------------------------------------------------===// ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spirv.GlobalVariable //===----------------------------------------------------------------------===// void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, unsigned descriptorSet, unsigned binding) { build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), builder.getI32IntegerAttr(descriptorSet)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), builder.getI32IntegerAttr(binding)); } void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, spirv::BuiltIn builtin) { build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); } ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, OperationState &result) { // Parse variable name. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes)) { return failure(); } // Parse optional initializer if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { FlatSymbolRefAttr initSymbol; if (parser.parseLParen() || parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, result.attributes) || parser.parseRParen()) return failure(); } if (parseVariableDecorations(parser, result)) { return failure(); } Type type; auto loc = parser.getCurrentLocation(); if (parser.parseColonType(type)) { return failure(); } if (!llvm::isa(type)) { return parser.emitError(loc, "expected spirv.ptr type"); } result.addAttribute(kTypeAttrName, TypeAttr::get(type)); return success(); } void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print variable name. printer << ' '; printer.printSymbolName(getSymName()); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); // Print optional initializer if (auto initializer = this->getInitializer()) { printer << " " << kInitializerAttrName << '('; printer.printSymbolName(*initializer); printer << ')'; elidedAttrs.push_back(kInitializerAttrName); } elidedAttrs.push_back(kTypeAttrName); spirv::printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << getType(); } LogicalResult spirv::GlobalVariableOp::verify() { if (!llvm::isa(getType())) return emitOpError("result must be of a !spv.ptr type"); // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." // Also, Function storage class is reserved by spirv.Variable. auto storageClass = this->storageClass(); if (storageClass == spirv::StorageClass::Generic || storageClass == spirv::StorageClass::Function) { return emitOpError("storage class cannot be '") << stringifyStorageClass(storageClass) << "'"; } if (auto init = (*this)->getAttrOfType(kInitializerAttrName)) { Operation *initOp = SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), init.getAttr()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. if (!initOp || !isa(initOp)) { return emitOpError("initializer must be result of a " "spirv.SpecConstant or spirv.GlobalVariable or " "spirv.SpecConstantCompositeOp op"); } } return success(); } //===----------------------------------------------------------------------===// // spirv.INTEL.SubgroupBlockRead //===----------------------------------------------------------------------===// ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::UnresolvedOperand ptrInfo; Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { return failure(); } result.addTypes(elementType); return success(); } void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { printer << " " << getPtr() << " : " << getType(); } LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // spirv.INTEL.SubgroupBlockWrite //===----------------------------------------------------------------------===// ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; auto loc = parser.getCurrentLocation(); Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperandList(operandInfo, 2) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (auto valVecTy = llvm::dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, result.operands)) { return failure(); } return success(); } void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { printer << " " << getPtr() << ", " << getValue() << " : " << getValue().getType(); } LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // spirv.IAddCarryOp //===----------------------------------------------------------------------===// LogicalResult spirv::IAddCarryOp::verify() { return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseArithmeticExtendedBinaryOp(parser, result); } void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// // spirv.ISubBorrowOp //===----------------------------------------------------------------------===// LogicalResult spirv::ISubBorrowOp::verify() { return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseArithmeticExtendedBinaryOp(parser, result); } void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// // spirv.SMulExtended //===----------------------------------------------------------------------===// LogicalResult spirv::SMulExtendedOp::verify() { return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseArithmeticExtendedBinaryOp(parser, result); } void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// // spirv.UMulExtended //===----------------------------------------------------------------------===// LogicalResult spirv::UMulExtendedOp::verify() { return ::verifyArithmeticExtendedBinaryOp(*this); } ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseArithmeticExtendedBinaryOp(parser, result); } void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { ::printArithmeticExtendedBinaryOp(*this, printer); } //===----------------------------------------------------------------------===// // spirv.MemoryBarrierOp //===----------------------------------------------------------------------===// LogicalResult spirv::MemoryBarrierOp::verify() { return verifyMemorySemantics(getOperation(), getMemorySemantics()); } //===----------------------------------------------------------------------===// // spirv.module //===----------------------------------------------------------------------===// void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, std::optional name) { OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); if (name) { state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); } } void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, spirv::AddressingModel addressingModel, spirv::MemoryModel memoryModel, std::optional vceTriple, std::optional name) { state.addAttribute( "addressing_model", builder.getAttr(addressingModel)); state.addAttribute("memory_model", builder.getAttr(memoryModel)); OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); if (vceTriple) state.addAttribute(getVCETripleAttrName(), *vceTriple); if (name) state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); } ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &result) { Region *body = result.addRegion(); // If the name is present, parse it. StringAttr nameAttr; (void)parser.parseOptionalSymbolName( nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes); // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; if (spirv::parseEnumKeywordAttr(addrModel, parser, result) || spirv::parseEnumKeywordAttr(memoryModel, parser, result)) return failure(); if (succeeded(parser.parseOptionalKeyword("requires"))) { spirv::VerCapExtAttr vceTriple; if (parser.parseAttribute(vceTriple, spirv::ModuleOp::getVCETripleAttrName(), result.attributes)) return failure(); } if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || parser.parseRegion(*body, /*arguments=*/{})) return failure(); // Make sure we have at least one block. if (body->empty()) body->push_back(new Block()); return success(); } void spirv::ModuleOp::print(OpAsmPrinter &printer) { if (std::optional name = getName()) { printer << ' '; printer.printSymbolName(*name); } SmallVector elidedAttrs; printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " << spirv::stringifyMemoryModel(getMemoryModel()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, mlir::SymbolTable::getSymbolAttrName()}); if (std::optional triple = getVceTriple()) { printer << " requires " << *triple; elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); } printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); printer << ' '; printer.printRegion(getRegion()); } LogicalResult spirv::ModuleOp::verifyRegions() { Dialect *dialect = (*this)->getDialect(); DenseMap, spirv::EntryPointOp> entryPoints; mlir::SymbolTable table(*this); for (auto &op : *getBody()) { if (op.getDialect() != dialect) return op.emitError("'spirv.module' can only contain spirv.* ops"); // For EntryPoint op, check that the function and execution model is not // duplicated in EntryPointOps. Also verify that the interface specified // comes from globalVariables here to make this check cheaper. if (auto entryPointOp = dyn_cast(op)) { auto funcOp = table.lookup(entryPointOp.getFn()); if (!funcOp) { return entryPointOp.emitError("function '") << entryPointOp.getFn() << "' not found in 'spirv.module'"; } if (auto interface = entryPointOp.getInterface()) { for (Attribute varRef : interface) { auto varSymRef = llvm::dyn_cast(varRef); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " "specification instead of '") << varRef; } auto variableOp = table.lookup(varSymRef.getValue()); if (!variableOp) { return entryPointOp.emitError("expected spirv.GlobalVariable " "symbol reference instead of'") << varSymRef << "'"; } } } auto key = std::pair( funcOp, entryPointOp.getExecutionModel()); auto entryPtIt = entryPoints.find(key); if (entryPtIt != entryPoints.end()) { return entryPointOp.emitError("duplicate of a previous EntryPointOp"); } entryPoints[key] = entryPointOp; } else if (auto funcOp = dyn_cast(op)) { // If the function is external and does not have 'Import' // linkage_attributes(LinkageAttributes), throw an error. 'Import' // LinkageAttributes is used to import external functions. auto linkageAttr = funcOp.getLinkageAttributes(); auto hasImportLinkage = linkageAttr && (linkageAttr.value().getLinkageType().getValue() == spirv::LinkageType::Import); if (funcOp.isExternal() && !hasImportLinkage) return op.emitError( "'spirv.module' cannot contain external functions " "without 'Import' linkage_attributes (LinkageAttributes)"); // TODO: move this check to spirv.func. for (auto &block : funcOp) for (auto &op : block) { if (op.getDialect() != dialect) return op.emitError( "functions in 'spirv.module' can only contain spirv.* ops"); } } } return success(); } //===----------------------------------------------------------------------===// // spirv.mlir.referenceof //===----------------------------------------------------------------------===// LogicalResult spirv::ReferenceOfOp::verify() { auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), getSpecConstAttr()); Type constType; auto specConstOp = dyn_cast_or_null(specConstSym); if (specConstOp) constType = specConstOp.getDefaultValue().getType(); auto specConstCompositeOp = dyn_cast_or_null(specConstSym); if (specConstCompositeOp) constType = specConstCompositeOp.getType(); if (!specConstOp && !specConstCompositeOp) return emitOpError( "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol"); if (getReference().getType() != constType) return emitOpError("result type mismatch with the referenced " "specialization constant's type"); return success(); } //===----------------------------------------------------------------------===// // spirv.SpecConstant //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr nameAttr; Attribute valueAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); // Parse optional spec_id. if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { IntegerAttr specIdAttr; if (parser.parseLParen() || parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) || parser.parseRParen()) return failure(); } if (parser.parseEqual() || parser.parseAttribute(valueAttr, kDefaultValueAttrName, result.attributes)) return failure(); return success(); } void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { printer << ' '; printer.printSymbolName(getSymName()); if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; printer << " = " << getDefaultValue(); } LogicalResult spirv::SpecConstantOp::verify() { if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) if (specID.getValue().isNegative()) return emitOpError("SpecId cannot be negative"); auto value = getDefaultValue(); if (llvm::isa(value)) { // Make sure bitwidth is allowed. if (!llvm::isa(value.getType())) return emitOpError("default value bitwidth disallowed"); return success(); } return emitOpError( "default value can only be a bool, integer, or float scalar"); } //===----------------------------------------------------------------------===// // spirv.VectorShuffle //===----------------------------------------------------------------------===// LogicalResult spirv::VectorShuffleOp::verify() { VectorType resultType = llvm::cast(getType()); size_t numResultElements = resultType.getNumElements(); if (numResultElements != getComponents().size()) return emitOpError("result type element count (") << numResultElements << ") mismatch with the number of component selectors (" << getComponents().size() << ")"; size_t totalSrcElements = llvm::cast(getVector1().getType()).getNumElements() + llvm::cast(getVector2().getType()).getNumElements(); for (const auto &selector : getComponents().getAsValueRange()) { uint32_t index = selector.getZExtValue(); if (index >= totalSrcElements && index != std::numeric_limits().max()) return emitOpError("component selector ") << index << " out of range: expected to be in [0, " << totalSrcElements << ") or 0xffffffff"; } return success(); } //===----------------------------------------------------------------------===// // spirv.MatrixTimesScalar //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { Type elementType = llvm::TypeSwitch(getMatrix().getType()) .Case( [](auto matrixType) { return matrixType.getElementType(); }) .Default([](Type) { return nullptr; }); assert(elementType && "Unhandled type"); // Check that the scalar type is the same as the matrix element type. if (getScalar().getType() != elementType) return emitOpError("input matrix components' type and scaling value must " "have the same type"); return success(); } //===----------------------------------------------------------------------===// // spirv.Transpose //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { auto inputMatrix = llvm::cast(getMatrix().getType()); auto resultMatrix = llvm::cast(getResult().getType()); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) return emitError("input matrix rows count must be equal to " "output matrix columns count"); if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) return emitError("input matrix columns count must be equal to " "output matrix rows count"); // Verify that the input and output matrices have the same component type if (inputMatrix.getElementType() != resultMatrix.getElementType()) return emitError("input and output matrices must have the same " "component type"); return success(); } //===----------------------------------------------------------------------===// // spirv.MatrixTimesMatrix //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { auto leftMatrix = llvm::cast(getLeftmatrix().getType()); auto rightMatrix = llvm::cast(getRightmatrix().getType()); auto resultMatrix = llvm::cast(getResult().getType()); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) return emitError("left matrix columns' count must be equal to " "the right matrix rows' count"); // right and result matrices columns' count must be the same if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) return emitError( "right and result matrices must have equal columns' count"); // right and result matrices component type must be the same if (rightMatrix.getElementType() != resultMatrix.getElementType()) return emitError("right and result matrices' component type must" " be the same"); // left and result matrices component type must be the same if (leftMatrix.getElementType() != resultMatrix.getElementType()) return emitError("left and result matrices' component type" " must be the same"); // left and result matrices rows count must be the same if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) return emitError("left and result matrices must have equal rows' count"); return success(); } //===----------------------------------------------------------------------===// // spirv.SpecConstantComposite //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr compositeName; if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); if (parser.parseLParen()) return failure(); SmallVector constituents; do { // The name of the constituent attribute isn't important const char *attrName = "spec_const"; FlatSymbolRefAttr specConstRef; NamedAttrList attrs; if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) return failure(); constituents.push_back(specConstRef); } while (!parser.parseOptionalComma()); if (parser.parseRParen()) return failure(); result.addAttribute(kCompositeSpecConstituentsName, parser.getBuilder().getArrayAttr(constituents)); Type type; if (parser.parseColonType(type)) return failure(); result.addAttribute(kTypeAttrName, TypeAttr::get(type)); return success(); } void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getSymName()); printer << " ("; auto constituents = this->getConstituents().getValue(); if (!constituents.empty()) llvm::interleaveComma(constituents, printer); printer << ") : " << getType(); } LogicalResult spirv::SpecConstantCompositeOp::verify() { auto cType = llvm::dyn_cast(getType()); auto constituents = this->getConstituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") << getType(); if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; if (llvm::isa(cType)) return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { auto constituent = llvm::cast(constituents[index]); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), constituent.getAttr())); if (constituentSpecConstOp.getDefaultValue().getType() != cType.getElementType(index)) return emitError("has incorrect types of operands: expected ") << cType.getElementType(index) << ", but provided " << constituentSpecConstOp.getDefaultValue().getType(); } return success(); } //===----------------------------------------------------------------------===// // spirv.SpecConstantOperation //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, OperationState &result) { Region *body = result.addRegion(); if (parser.parseKeyword("wraps")) return failure(); body->push_back(new Block); Block &block = body->back(); Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); if (!wrappedOp) return failure(); OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); result.location = wrappedOp->getLoc(); result.addTypes(wrappedOp->getResult(0).getType()); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { printer << " wraps "; printer.printGenericOp(&getBody().front().front()); } LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { Block &block = getRegion().getBlocks().front(); if (block.getOperations().size() != 2) return emitOpError("expected exactly 2 nested ops"); Operation &enclosedOp = block.getOperations().front(); if (!enclosedOp.hasTrait()) return emitOpError("invalid enclosed op"); for (auto operand : enclosedOp.getOperands()) if (!isa(operand.getDefiningOp())) return emitOpError( "invalid operand, must be defined by a constant operation"); return success(); } //===----------------------------------------------------------------------===// // spirv.GL.FrexpStruct //===----------------------------------------------------------------------===// LogicalResult spirv::GLFrexpStructOp::verify() { spirv::StructType structTy = llvm::dyn_cast(getResult().getType()); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); VectorType exponentVecTy = llvm::dyn_cast(exponentTy); IntegerType exponentIntTy = llvm::dyn_cast(exponentTy); Type operandTy = getOperand().getType(); VectorType operandVecTy = llvm::dyn_cast(operandTy); FloatType operandFTy = llvm::dyn_cast(operandTy); if (significandTy != operandTy) return emitError("member zero of the resulting struct type must be the " "same type as the operand"); if (exponentVecTy) { IntegerType componentIntTy = llvm::dyn_cast(exponentVecTy.getElementType()); if (!componentIntTy || componentIntTy.getWidth() != 32) return emitError("member one of the resulting struct type must" "be a scalar or vector of 32 bit integer type"); } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { return emitError("member one of the resulting struct type " "must be a scalar or vector of 32 bit integer type"); } // Check that the two member types have the same number of components if (operandVecTy && exponentVecTy && (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) return success(); if (operandFTy && exponentIntTy) return success(); return emitError("member one of the resulting struct type must have the same " "number of components as the operand type"); } //===----------------------------------------------------------------------===// // spirv.GL.Ldexp //===----------------------------------------------------------------------===// LogicalResult spirv::GLLdexpOp::verify() { Type significandType = getX().getType(); Type exponentType = getExp().getType(); if (llvm::isa(significandType) != llvm::isa(exponentType)) return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { if (auto vectorType = llvm::dyn_cast(type)) return vectorType.getNumElements(); return 1; }; if (getNumElements(significandType) != getNumElements(exponentType)) return emitOpError("operands must have the same number of elements"); return success(); } //===----------------------------------------------------------------------===// // spirv.ImageDrefGather //===----------------------------------------------------------------------===// LogicalResult spirv::ImageDrefGatherOp::verify() { VectorType resultType = llvm::cast(getResult().getType()); auto sampledImageType = llvm::cast(getSampledimage().getType()); auto imageType = llvm::cast(sampledImageType.getImageType()); if (resultType.getNumElements() != 4) return emitOpError("result type must be a vector of four components"); Type elementType = resultType.getElementType(); Type sampledElementType = imageType.getElementType(); if (!llvm::isa(sampledElementType) && elementType != sampledElementType) return emitOpError( "the component type of result must be the same as sampled type of the " "underlying image type"); spirv::Dim imageDim = imageType.getDim(); spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && imageDim != spirv::Dim::Rect) return emitOpError( "the Dim operand of the underlying image type must be 2D, Cube, or " "Rect"); if (imageMS != spirv::ImageSamplingInfo::SingleSampled) return emitOpError("the MS operand of the underlying image type must be 0"); spirv::ImageOperandsAttr attr = getImageoperandsAttr(); auto operandArguments = getOperandArguments(); return verifyImageOperands(*this, attr, operandArguments); } //===----------------------------------------------------------------------===// // spirv.ShiftLeftLogicalOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftLeftLogicalOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spirv.ShiftRightArithmeticOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftRightArithmeticOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spirv.ShiftRightLogicalOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftRightLogicalOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spirv.ImageQuerySize //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { spirv::ImageType imageType = llvm::cast(getImage().getType()); Type resultType = getResult().getType(); spirv::Dim dim = imageType.getDim(); spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); switch (dim) { case spirv::Dim::Dim1D: case spirv::Dim::Dim2D: case spirv::Dim::Dim3D: case spirv::Dim::Cube: if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) return emitError( "if Dim is 1D, 2D, 3D, or Cube, " "it must also have either an MS of 1 or a Sampled of 0 or 2"); break; case spirv::Dim::Buffer: case spirv::Dim::Rect: break; default: return emitError("the Dim operand of the image type must " "be 1D, 2D, 3D, Buffer, Cube, or Rect"); } unsigned componentNumber = 0; switch (dim) { case spirv::Dim::Dim1D: case spirv::Dim::Buffer: componentNumber = 1; break; case spirv::Dim::Dim2D: case spirv::Dim::Cube: case spirv::Dim::Rect: componentNumber = 2; break; case spirv::Dim::Dim3D: componentNumber = 3; break; default: break; } if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) componentNumber += 1; unsigned resultComponentNumber = 1; if (auto resultVectorType = llvm::dyn_cast(resultType)) resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) return emitError("expected the result to have ") << componentNumber << " component(s), but found " << resultComponentNumber << " component(s)"; return success(); } //===----------------------------------------------------------------------===// // spirv.VectorTimesScalarOp //===----------------------------------------------------------------------===// LogicalResult spirv::VectorTimesScalarOp::verify() { if (getVector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); auto scalarType = llvm::cast(getType()).getElementType(); if (getScalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); }