//===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains corner case tests for the SPIR-V serializer that are not // covered by normal serialization and deserialization roundtripping. // //===----------------------------------------------------------------------===// #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "gmock/gmock.h" using namespace mlir; //===----------------------------------------------------------------------===// // Test Fixture //===----------------------------------------------------------------------===// class SerializationTest : public ::testing::Test { protected: SerializationTest() { context.getOrLoadDialect(); initModuleOp(); } /// Initializes an empty SPIR-V module op. void initModuleOp() { OpBuilder builder(&context); OperationState state(UnknownLoc::get(&context), spirv::ModuleOp::getOperationName()); state.addAttribute("addressing_model", builder.getAttr( spirv::AddressingModel::Logical)); state.addAttribute("memory_model", builder.getAttr( spirv::MemoryModel::GLSL450)); state.addAttribute("vce_triple", spirv::VerCapExtAttr::get( spirv::Version::V_1_0, ArrayRef(), ArrayRef(), &context)); spirv::ModuleOp::build(builder, state); module = cast(Operation::create(state)); } /// Gets the `struct { float }` type. spirv::StructType getFloatStructType() { OpBuilder builder(module->getRegion()); llvm::SmallVector elementTypes{builder.getF32Type()}; llvm::SmallVector offsetInfo{0}; return spirv::StructType::get(elementTypes, offsetInfo); } /// Inserts a global variable of the given `type` and `name`. spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { OpBuilder builder(module->getRegion()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); return builder.create( UnknownLoc::get(&context), TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr); } // Inserts an Integer or a Vector of Integers constant of value 'val'. spirv::ConstantOp AddConstInt(Type type, APInt val) { OpBuilder builder(module->getRegion()); auto loc = UnknownLoc::get(&context); if (auto intType = dyn_cast(type)) { return builder.create( loc, type, builder.getIntegerAttr(type, val)); } if (auto vectorType = dyn_cast(type)) { Type elemType = vectorType.getElementType(); if (auto intType = dyn_cast(elemType)) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, val).getValue())); } } llvm_unreachable("unimplemented types for AddConstInt()"); } /// Handles a SPIR-V instruction with the given `opcode` and `operand`. /// Returns true to interrupt. using HandleFn = llvm::function_ref operands)>; /// Returns true if we can find a matching instruction in the SPIR-V blob. bool scanInstruction(HandleFn handleFn) { auto binarySize = binary.size(); auto *begin = binary.begin(); auto currOffset = spirv::kHeaderWordCount; while (currOffset < binarySize) { auto wordCount = binary[currOffset] >> 16; if (!wordCount || (currOffset + wordCount > binarySize)) return false; spirv::Opcode opcode = static_cast(binary[currOffset] & 0xffff); llvm::ArrayRef operands(begin + currOffset + 1, begin + currOffset + wordCount); if (handleFn(opcode, operands)) return true; currOffset += wordCount; } return false; } protected: MLIRContext context; OwningOpRef module; SmallVector binary; }; //===----------------------------------------------------------------------===// // Block decoration //===----------------------------------------------------------------------===// TEST_F(SerializationTest, ContainsBlockDecoration) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0"); ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); auto hasBlockDecoration = [](spirv::Opcode opcode, ArrayRef operands) { return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && operands[1] == static_cast(spirv::Decoration::Block); }; EXPECT_TRUE(scanInstruction(hasBlockDecoration)); } TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { auto structType = getFloatStructType(); // Two global variables using the same type should not decorate the type with // duplicated `Block` decorations. addGlobalVar(structType, "var0"); addGlobalVar(structType, "var1"); ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); unsigned count = 0; auto countBlockDecoration = [&count](spirv::Opcode opcode, ArrayRef operands) { if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && operands[1] == static_cast(spirv::Decoration::Block)) ++count; return false; }; ASSERT_FALSE(scanInstruction(countBlockDecoration)); EXPECT_EQ(count, 1u); } TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) { auto signlessInt16Type = IntegerType::get(&context, 16, IntegerType::Signless); auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed); // Check the bit extension of same value under different signedness semantics. APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1, signlessInt16Type.getSignedness()); APInt signedIntConstVal(signedInt16Type.getWidth(), -1, signedInt16Type.getSignedness()); AddConstInt(signlessInt16Type, signlessIntConstVal); AddConstInt(signedInt16Type, signedIntConstVal); ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef operands) { return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && operands[2] == 65535; }; EXPECT_TRUE(scanInstruction(hasSignlessVal)); auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef operands) { return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && operands[2] == 4294967295; }; EXPECT_TRUE(scanInstruction(hasSignedVal)); } TEST_F(SerializationTest, ContainsSymbolName) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0"); spirv::SerializationOptions options; options.emitSymbolName = true; ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); auto hasVarName = [](spirv::Opcode opcode, ArrayRef operands) { unsigned index = 1; // Skip the result return opcode == spirv::Opcode::OpName && spirv::decodeStringLiteral(operands, index) == "var0"; }; EXPECT_TRUE(scanInstruction(hasVarName)); } TEST_F(SerializationTest, DoesNotContainSymbolName) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0"); spirv::SerializationOptions options; options.emitSymbolName = false; ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); auto hasVarName = [](spirv::Opcode opcode, ArrayRef operands) { unsigned index = 1; // Skip the result return opcode == spirv::Opcode::OpName && spirv::decodeStringLiteral(operands, index) == "var0"; }; EXPECT_FALSE(scanInstruction(hasVarName)); }