//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the Deserializer methods for SPIR-V binary instructions. // //===----------------------------------------------------------------------===// #include "Deserializer.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include using namespace mlir; #define DEBUG_TYPE "spirv-deserialization" //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// /// Extracts the opcode from the given first word of a SPIR-V instruction. static inline spirv::Opcode extractOpcode(uint32_t word) { return static_cast(word & 0xffff); } //===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// Value spirv::Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spirv.Constant` op at every use site. return opBuilder.create(unknownLoc, constInfo->second, constInfo->first); } if (auto varOp = getGlobalVariable(id)) { auto addressOfOp = opBuilder.create( unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation())); return addressOfOp.getPointer(); } if (auto constOp = getSpecConstant(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constOp.getDefaultValue().getType(), SymbolRefAttr::get(constOp.getOperation())); return referenceOfOp.getReference(); } if (auto constCompositeOp = getSpecConstantComposite(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constCompositeOp.getType(), SymbolRefAttr::get(constCompositeOp.getOperation())); return referenceOfOp.getReference(); } if (auto specConstOperationInfo = getSpecConstantOperation(id)) { return materializeSpecConstantOperation( id, specConstOperationInfo->enclodesOpcode, specConstOperationInfo->resultTypeID, specConstOperationInfo->enclosedOpOperands); } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } return valueMap.lookup(id); } LogicalResult spirv::Deserializer::sliceInstruction( spirv::Opcode &opcode, ArrayRef &operands, std::optional expectedOpcode) { auto binarySize = binary.size(); if (curOffset >= binarySize) { return emitError(unknownLoc, "expected ") << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) : "more") << " instruction"; } // For each instruction, get its word count from the first word to slice it // from the stream properly, and then dispatch to the instruction handler. uint32_t wordCount = binary[curOffset] >> 16; if (wordCount == 0) return emitError(unknownLoc, "word count cannot be zero"); uint32_t nextOffset = curOffset + wordCount; if (nextOffset > binarySize) return emitError(unknownLoc, "insufficient words for the last instruction"); opcode = extractOpcode(binary[curOffset]); operands = binary.slice(curOffset + 1, wordCount - 1); curOffset = nextOffset; return success(); } LogicalResult spirv::Deserializer::processInstruction( spirv::Opcode opcode, ArrayRef operands, bool deferInstructions) { LLVM_DEBUG(logger.startLine() << "[inst] processing instruction " << spirv::stringifyOpcode(opcode) << "\n"); // First dispatch all the instructions whose opcode does not correspond to // those that have a direct mirror in the SPIR-V dialect switch (opcode) { case spirv::Opcode::OpCapability: return processCapability(operands); case spirv::Opcode::OpExtension: return processExtension(operands); case spirv::Opcode::OpExtInst: return processExtInst(operands); case spirv::Opcode::OpExtInstImport: return processExtInstImport(operands); case spirv::Opcode::OpMemberName: return processMemberName(operands); case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); case spirv::Opcode::OpEntryPoint: case spirv::Opcode::OpExecutionMode: if (deferInstructions) { deferredInstructions.emplace_back(opcode, operands); return success(); } break; case spirv::Opcode::OpVariable: if (isa(opBuilder.getBlock()->getParentOp())) { return processGlobalVariable(operands); } break; case spirv::Opcode::OpLine: return processDebugLine(operands); case spirv::Opcode::OpNoLine: clearDebugLine(); return success(); case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpString: return processDebugString(operands); case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpSource: case spirv::Opcode::OpSourceContinued: case spirv::Opcode::OpSourceExtension: // TODO: This is debug information embedded in the binary which should be // translated into the spirv.module. return success(); case spirv::Opcode::OpTypeVoid: case spirv::Opcode::OpTypeBool: case spirv::Opcode::OpTypeInt: case spirv::Opcode::OpTypeFloat: case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeMatrix: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeImage: case spirv::Opcode::OpTypeSampledImage: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: case spirv::Opcode::OpTypeCooperativeMatrixKHR: return processType(opcode, operands); case spirv::Opcode::OpTypeForwardPointer: return processTypeForwardPointer(operands); case spirv::Opcode::OpTypeJointMatrixINTEL: return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstant: return processConstant(operands, /*isSpec=*/true); case spirv::Opcode::OpConstantComposite: return processConstantComposite(operands); case spirv::Opcode::OpSpecConstantComposite: return processSpecConstantComposite(operands); case spirv::Opcode::OpSpecConstantOp: return processSpecConstantOperation(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); case spirv::Opcode::OpMemberDecorate: return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); case spirv::Opcode::OpLabel: return processLabel(operands); case spirv::Opcode::OpBranch: return processBranch(operands); case spirv::Opcode::OpBranchConditional: return processBranchConditional(operands); case spirv::Opcode::OpSelectionMerge: return processSelectionMerge(operands); case spirv::Opcode::OpLoopMerge: return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: break; } return dispatchToAutogenDeserialization(opcode, operands); } LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( ArrayRef words, StringRef opName, bool hasResult, unsigned numOperands) { SmallVector resultTypes; uint32_t valueID = 0; size_t wordIndex = 0; if (hasResult) { if (wordIndex >= words.size()) return emitError(unknownLoc, "expected result type while deserializing for ") << opName; // Decode the type auto type = getType(words[wordIndex]); if (!type) return emitError(unknownLoc, "unknown type result : ") << words[wordIndex]; resultTypes.push_back(type); ++wordIndex; // Decode the result if (wordIndex >= words.size()) return emitError(unknownLoc, "expected result while deserializing for ") << opName; valueID = words[wordIndex]; ++wordIndex; } SmallVector operands; SmallVector attributes; // Decode operands size_t operandIndex = 0; for (; operandIndex < numOperands && wordIndex < words.size(); ++operandIndex, ++wordIndex) { auto arg = getValue(words[wordIndex]); if (!arg) return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; operands.push_back(arg); } if (operandIndex != numOperands) { return emitError( unknownLoc, "found less operands than expected when deserializing for ") << opName << "; only " << operandIndex << " of " << numOperands << " processed"; } if (wordIndex != words.size()) { return emitError( unknownLoc, "found more operands than expected when deserializing for ") << opName << "; only " << wordIndex << " of " << words.size() << " processed"; } // Attach attributes from decorations if (decorations.count(valueID)) { auto attrs = decorations[valueID].getAttrs(); attributes.append(attrs.begin(), attrs.end()); } // Create the op and update bookkeeping maps Location loc = createFileLineColLoc(opBuilder); OperationState opState(loc, opName); opState.addOperands(operands); if (hasResult) opState.addTypes(resultTypes); opState.addAttributes(attributes); Operation *op = opBuilder.create(opState); if (hasResult) valueMap[valueID] = op->getResult(0); if (op->hasTrait()) clearDebugLine(); return success(); } LogicalResult spirv::Deserializer::processUndef(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpUndef instruction must have two operands"); } auto type = getType(operands[0]); if (!type) { return emitError(unknownLoc, "unknown type with OpUndef instruction"); } undefMap[operands[1]] = type; return success(); } LogicalResult spirv::Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, "OpExtInst must have at least 4 operands, result type " ", result , set and instruction opcode"); } if (!extendedInstSets.count(operands[2])) { return emitError(unknownLoc, "undefined set in OpExtInst"); } SmallVector slicedOperands; slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); slicedOperands.append(std::next(operands.begin(), 4), operands.end()); return dispatchToExtensionSetAutogenDeserialization( extendedInstSets[operands[2]], operands[3], slicedOperands); } namespace mlir { namespace spirv { template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } auto execModel = spirv::ExecutionModelAttr::get( context, static_cast(words[wordIndex++])); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing in OpEntryPoint"); } // Get the function auto fnID = words[wordIndex++]; // Get the function name auto fnName = decodeStringLiteral(words, wordIndex); // Verify that the function matches the fnName auto parsedFunc = getFunction(fnID); if (!parsedFunc) { return emitError(unknownLoc, "no function matching ") << fnID; } if (parsedFunc.getName() != fnName) { // The deserializer uses "spirv_fn_" as the function name if the input // SPIR-V blob does not contain a name for it. We should use a more clear // indication for such case rather than relying on naming details. if (!parsedFunc.getName().starts_with("spirv_fn_")) return emitError(unknownLoc, "function name mismatch between OpEntryPoint " "and OpFunction with ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); parsedFunc.setName(fnName); } SmallVector interface; while (wordIndex < words.size()) { auto arg = getGlobalVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result ") << words[wordIndex] << " while decoding OpEntryPoint"; } interface.push_back(SymbolRefAttr::get(arg.getOperation())); wordIndex++; } opBuilder.create( unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), opBuilder.getArrayAttr(interface)); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing function result in OpExecutionMode"); } // Get the function to get the name of the function auto fnID = words[wordIndex++]; auto fn = getFunction(fnID); if (!fn) { return emitError(unknownLoc, "no function matching ") << fnID; } // Get the Execution mode if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); } auto execMode = spirv::ExecutionModeAttr::get( context, static_cast(words[wordIndex++])); // Get the values SmallVector attrListElems; while (wordIndex < words.size()) { attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); } auto values = opBuilder.getArrayAttr(attrListElems); opBuilder.create( unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode, values); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() < 3) { return emitError(unknownLoc, "OpFunctionCall must have at least 3 operands"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } // Use null type to mean no result type. if (isVoidType(resultType)) resultType = nullptr; auto resultID = operands[1]; auto functionID = operands[2]; auto functionName = getFunctionSymbol(functionID); SmallVector arguments; for (auto operand : llvm::drop_begin(operands, 3)) { auto value = getValue(operand); if (!value) { return emitError(unknownLoc, "unknown ") << operand << " used by OpFunctionCall"; } arguments.push_back(value); } auto opFunctionCall = opBuilder.create( unknownLoc, resultType, SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); if (resultType) valueMap[resultID] = opFunctionCall.getResult(0); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef words) { SmallVector resultTypes; size_t wordIndex = 0; SmallVector operands; SmallVector attributes; if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; } operands.push_back(arg); wordIndex++; } if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; } operands.push_back(arg); wordIndex++; } bool isAlignedAttr = false; if (wordIndex < words.size()) { auto attrValue = words[wordIndex++]; auto attr = opBuilder.getAttr( static_cast(attrValue)); attributes.push_back(opBuilder.getNamedAttr("memory_access", attr)); isAlignedAttr = (attrValue == 2); } if (isAlignedAttr && wordIndex < words.size()) { attributes.push_back(opBuilder.getNamedAttr( "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); } if (wordIndex < words.size()) { auto attrValue = words[wordIndex++]; auto attr = opBuilder.getAttr( static_cast(attrValue)); attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr)); } if (wordIndex < words.size()) { attributes.push_back(opBuilder.getNamedAttr( "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); } if (wordIndex != words.size()) { return emitError(unknownLoc, "found more operands than expected when deserializing " "spirv::CopyMemoryOp, only ") << wordIndex << " of " << words.size() << " processed"; } Location loc = createFileLineColLoc(opBuilder); opBuilder.create(loc, resultTypes, operands, attributes); return success(); } template <> LogicalResult Deserializer::processOp( ArrayRef words) { if (words.size() != 4) { return emitError(unknownLoc, "expected 4 words in GenericCastToPtrExplicitOp" " but got : ") << words.size(); } SmallVector resultTypes; SmallVector operands; uint32_t valueID = 0; auto type = getType(words[0]); if (!type) return emitError(unknownLoc, "unknown type result : ") << words[0]; resultTypes.push_back(type); valueID = words[1]; auto arg = getValue(words[2]); if (!arg) return emitError(unknownLoc, "unknown result : ") << words[2]; operands.push_back(arg); Location loc = createFileLineColLoc(opBuilder); Operation *op = opBuilder.create( loc, resultTypes, operands); valueMap[valueID] = op->getResult(0); return success(); } // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" } // namespace spirv } // namespace mlir