//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // The purpose of this file is to provide negative deserialization tests. // For positive deserialization tests, please use serialization and // deserialization for roundtripping. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "gmock/gmock.h" #include using namespace mlir; using ::testing::StrEq; //===----------------------------------------------------------------------===// // Test Fixture //===----------------------------------------------------------------------===// /// A deserialization test fixture providing minimal SPIR-V building and /// diagnostic checking utilities. class DeserializationTest : public ::testing::Test { protected: DeserializationTest() { context.getOrLoadDialect(); // Register a diagnostic handler to capture the diagnostic so that we can // check it later. context.getDiagEngine().registerHandler([&](Diagnostic &diag) { diagnostic.reset(new Diagnostic(std::move(diag))); }); } /// Performs deserialization and returns the constructed spv.module op. spirv::OwningSPIRVModuleRef deserialize() { return spirv::deserialize(binary, &context); } /// Checks there is a diagnostic generated with the given `errorMessage`. void expectDiagnostic(StringRef errorMessage) { ASSERT_NE(nullptr, diagnostic.get()); // TODO: check error location too. EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); } //===--------------------------------------------------------------------===// // SPIR-V builder methods //===--------------------------------------------------------------------===// /// Adds the SPIR-V module header to `binary`. void addHeader() { spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); } /// Adds the SPIR-V instruction into `binary`. void addInstruction(spirv::Opcode op, ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); binary.append(operands.begin(), operands.end()); } uint32_t addVoidType() { auto id = nextID++; addInstruction(spirv::Opcode::OpTypeVoid, {id}); return id; } uint32_t addIntType(uint32_t bitwidth) { auto id = nextID++; addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); return id; } uint32_t addStructType(ArrayRef memberTypes) { auto id = nextID++; SmallVector words; words.push_back(id); words.append(memberTypes.begin(), memberTypes.end()); addInstruction(spirv::Opcode::OpTypeStruct, words); return id; } uint32_t addFunctionType(uint32_t retType, ArrayRef paramTypes) { auto id = nextID++; SmallVector operands; operands.push_back(id); operands.push_back(retType); operands.append(paramTypes.begin(), paramTypes.end()); addInstruction(spirv::Opcode::OpTypeFunction, operands); return id; } uint32_t addFunction(uint32_t retType, uint32_t fnType) { auto id = nextID++; addInstruction(spirv::Opcode::OpFunction, {retType, id, static_cast(spirv::FunctionControl::None), fnType}); return id; } void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } protected: SmallVector binary; uint32_t nextID = 1; MLIRContext context; std::unique_ptr diagnostic; }; //===----------------------------------------------------------------------===// // Basics //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, EmptyModuleFailure) { ASSERT_FALSE(deserialize()); expectDiagnostic("SPIR-V binary module must have a 5-word header"); } TEST_F(DeserializationTest, WrongMagicNumberFailure) { addHeader(); binary.front() = 0xdeadbeef; // Change to a wrong magic number ASSERT_FALSE(deserialize()); expectDiagnostic("incorrect magic number"); } TEST_F(DeserializationTest, OnlyHeaderSuccess) { addHeader(); EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, ZeroWordCountFailure) { addHeader(); binary.push_back(0); // OpNop with zero word count ASSERT_FALSE(deserialize()); expectDiagnostic("word count cannot be zero"); } TEST_F(DeserializationTest, InsufficientWordFailure) { addHeader(); binary.push_back((2u << 16) | static_cast(spirv::Opcode::OpTypeVoid)); // Missing word for type . ASSERT_FALSE(deserialize()); expectDiagnostic("insufficient words for the last instruction"); } //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { addHeader(); addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); ASSERT_FALSE(deserialize()); expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); } //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, OpMemberNameSuccess) { addHeader(); SmallVector typeDecl; std::swap(typeDecl, binary); auto int32Type = addIntType(32); auto structType = addStructType({int32Type, int32Type}); std::swap(typeDecl, binary); SmallVector operands1 = {structType, 0}; spirv::encodeStringLiteralInto(operands1, "i1"); addInstruction(spirv::Opcode::OpMemberName, operands1); SmallVector operands2 = {structType, 1}; spirv::encodeStringLiteralInto(operands2, "i2"); addInstruction(spirv::Opcode::OpMemberName, operands2); binary.append(typeDecl.begin(), typeDecl.end()); EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, OpMemberNameMissingOperands) { addHeader(); SmallVector typeDecl; std::swap(typeDecl, binary); auto int32Type = addIntType(32); auto int64Type = addIntType(64); auto structType = addStructType({int32Type, int64Type}); std::swap(typeDecl, binary); SmallVector operands1 = {structType}; addInstruction(spirv::Opcode::OpMemberName, operands1); binary.append(typeDecl.begin(), typeDecl.end()); ASSERT_FALSE(deserialize()); expectDiagnostic("OpMemberName must have at least 3 operands"); } TEST_F(DeserializationTest, OpMemberNameExcessOperands) { addHeader(); SmallVector typeDecl; std::swap(typeDecl, binary); auto int32Type = addIntType(32); auto structType = addStructType({int32Type}); std::swap(typeDecl, binary); SmallVector operands = {structType, 0}; spirv::encodeStringLiteralInto(operands, "int32"); operands.push_back(42); addInstruction(spirv::Opcode::OpMemberName, operands); binary.append(typeDecl.begin(), typeDecl.end()); ASSERT_FALSE(deserialize()); expectDiagnostic("unexpected trailing words in OpMemberName instruction"); } //===----------------------------------------------------------------------===// // Functions //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, FunctionMissingEndFailure) { addHeader(); auto voidType = addVoidType(); auto fnType = addFunctionType(voidType, {}); addFunction(voidType, fnType); // Missing OpFunctionEnd. ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionEnd instruction"); } TEST_F(DeserializationTest, FunctionMissingParameterFailure) { addHeader(); auto voidType = addVoidType(); auto i32Type = addIntType(32); auto fnType = addFunctionType(voidType, {i32Type}); addFunction(voidType, fnType); // Missing OpFunctionParameter. ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionParameter instruction"); } TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { addHeader(); auto voidType = addVoidType(); auto fnType = addFunctionType(voidType, {}); addFunction(voidType, fnType); // Missing OpLabel. addReturn(); addFunctionEnd(); ASSERT_FALSE(deserialize()); expectDiagnostic("a basic block must start with OpLabel"); } TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { addHeader(); auto voidType = addVoidType(); auto fnType = addFunctionType(voidType, {}); addFunction(voidType, fnType); addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel addReturn(); addFunctionEnd(); ASSERT_FALSE(deserialize()); expectDiagnostic("OpLabel should only have result "); }