• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2015 The Khronos Group Inc.
6  * Copyright (c) 2023 ARM Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief Test cases for VK_KHR_shader_expect_assume.
23  *        Ensure being working the OpAssumeTrueKHR/OpExpectKHR OpCode.
24  *//*--------------------------------------------------------------------*/
25 
26 #include "vktShaderExpectAssumeTests.hpp"
27 #include "vktShaderExecutor.hpp"
28 #include "vktTestGroupUtil.hpp"
29 
30 #include "tcuStringTemplate.hpp"
31 
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkMemUtil.hpp"
35 #include "vkObjUtil.hpp"
36 #include "vkQueryUtil.hpp"
37 #include "vkRefUtil.hpp"
38 #include "vkTypeUtil.hpp"
39 
40 #include "tcuResultCollector.hpp"
41 
42 #include "deArrayUtil.hpp"
43 #include "deSharedPtr.hpp"
44 #include "deStringUtil.hpp"
45 
46 #include <cassert>
47 #include <string>
48 
49 namespace vkt
50 {
51 namespace shaderexecutor
52 {
53 
54 namespace
55 {
56 
57 using namespace vk;
58 constexpr uint32_t kNumElements           = 32;
59 constexpr VkFormat kColorAttachmentFormat = VK_FORMAT_R32G32_UINT;
60 
61 enum class OpType
62 {
63     Expect = 0,
64     Assume
65 };
66 
67 enum class DataClass
68 {
69     Constant = 0,
70     SpecializationConstant,
71     PushConstant,
72     StorageBuffer,
73 };
74 
75 enum class DataType
76 {
77     Bool = 0,
78     Int8,
79     Int16,
80     Int32,
81     Int64
82 };
83 
84 struct TestParam
85 {
86     OpType opType;
87     DataClass dataClass;
88     DataType dataType;
89     uint32_t dataChannelCount;
90     VkShaderStageFlagBits shaderType;
91     bool wrongExpectation;
92     std::string testName;
93 };
94 
95 class ShaderExpectAssumeTestInstance : public TestInstance
96 {
97 public:
ShaderExpectAssumeTestInstance(Context & context,const TestParam & testParam)98     ShaderExpectAssumeTestInstance(Context &context, const TestParam &testParam)
99         : TestInstance(context)
100         , m_testParam(testParam)
101         , m_vk(m_context.getDeviceInterface())
102     {
103         initialize();
104     }
105 
iterate(void)106     virtual tcu::TestStatus iterate(void)
107     {
108         if (m_testParam.shaderType == VK_SHADER_STAGE_COMPUTE_BIT)
109         {
110             dispatch();
111         }
112         else
113         {
114             render();
115         }
116 
117         const uint32_t *outputData = reinterpret_cast<uint32_t *>(m_outputAlloc->getHostPtr());
118         return validateOutput(outputData);
119     }
120 
121 private:
validateOutput(const uint32_t * outputData)122     tcu::TestStatus validateOutput(const uint32_t *outputData)
123     {
124         for (uint32_t i = 0; i < kNumElements; i++)
125         {
126             // (gl_GlobalInvocationID.x, verification result)
127             if (outputData[i * 2] != i || outputData[i * 2 + 1] != 1)
128             {
129                 return tcu::TestStatus::fail("Result comparison failed");
130             }
131         }
132         return tcu::TestStatus::pass("Pass");
133     }
134 
initialize()135     void initialize()
136     {
137         generateCmdBuffer();
138         if (m_testParam.shaderType == VK_SHADER_STAGE_COMPUTE_BIT)
139         {
140             generateStorageBuffers();
141             generateComputePipeline();
142         }
143         else
144         {
145             generateAttachments();
146             generateVertexBuffer();
147             generateStorageBuffers();
148             generateGraphicsPipeline();
149         }
150     }
151 
generateCmdBuffer()152     void generateCmdBuffer()
153     {
154         const VkDevice device = m_context.getDevice();
155 
156         m_cmdPool   = createCommandPool(m_vk, device, VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT,
157                                         m_context.getUniversalQueueFamilyIndex());
158         m_cmdBuffer = allocateCommandBuffer(m_vk, device, *m_cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
159     }
160 
generateVertexBuffer()161     void generateVertexBuffer()
162     {
163         const VkDevice device           = m_context.getDevice();
164         const DeviceInterface &vk       = m_context.getDeviceInterface();
165         const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
166         Allocator &memAlloc             = m_context.getDefaultAllocator();
167         std::vector<tcu::Vec2> vbo;
168         // _____
169         // |  /
170         // | /
171         // |/
172         vbo.emplace_back(tcu::Vec2(-1, -1));
173         vbo.emplace_back(tcu::Vec2(1, 1));
174         vbo.emplace_back(tcu::Vec2(-1, 1));
175         //   /|
176         //  / |
177         // /__|
178         vbo.emplace_back(tcu::Vec2(-1, -1));
179         vbo.emplace_back(tcu::Vec2(1, -1));
180         vbo.emplace_back(tcu::Vec2(1, 1));
181 
182         const size_t dataSize               = vbo.size() * sizeof(tcu::Vec2);
183         const VkBufferCreateInfo bufferInfo = {
184             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // VkStructureType sType;
185             nullptr,                              // const void* pNext;
186             0u,                                   // VkBufferCreateFlags flags;
187             dataSize,                             // VkDeviceSize size;
188             VK_BUFFER_USAGE_VERTEX_BUFFER_BIT,    // VkBufferUsageFlags usage;
189             VK_SHARING_MODE_EXCLUSIVE,            // VkSharingMode sharingMode;
190             1u,                                   // uint32_t queueFamilyCount;
191             &queueFamilyIndex                     // const uint32_t* pQueueFamilyIndices;
192         };
193         m_vertexBuffer = createBuffer(vk, device, &bufferInfo);
194         m_vertexAlloc =
195             memAlloc.allocate(getBufferMemoryRequirements(vk, device, *m_vertexBuffer), MemoryRequirement::HostVisible);
196 
197         void *vertexData = m_vertexAlloc->getHostPtr();
198 
199         VK_CHECK(vk.bindBufferMemory(device, *m_vertexBuffer, m_vertexAlloc->getMemory(), m_vertexAlloc->getOffset()));
200 
201         /* Load vertices into vertex buffer */
202         deMemcpy(vertexData, vbo.data(), dataSize);
203         flushAlloc(vk, device, *m_vertexAlloc);
204     }
205 
generateAttachments()206     void generateAttachments()
207     {
208         const VkDevice device = m_context.getDevice();
209         Allocator &allocator  = m_context.getDefaultAllocator();
210 
211         const VkImageUsageFlags imageUsage = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT;
212 
213         // Color Attachment
214         const VkImageCreateInfo imageInfo = {
215             VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
216             nullptr,                             // const void* pNext;
217             (VkImageCreateFlags)0,               // VkImageCreateFlags flags;
218             VK_IMAGE_TYPE_2D,                    // VkImageType imageType;
219             kColorAttachmentFormat,              // VkFormat format;
220             makeExtent3D(kNumElements, 1, 1),    // VkExtent3D extent;
221             1u,                                  // uint32_t mipLevels;
222             1u,                                  // uint32_t arrayLayers;
223             VK_SAMPLE_COUNT_1_BIT,               // VkSampleCountFlagBits samples;
224             VK_IMAGE_TILING_OPTIMAL,             // VkImageTiling tiling;
225             imageUsage,                          // VkImageUsageFlags usage;
226             VK_SHARING_MODE_EXCLUSIVE,           // VkSharingMode sharingMode;
227             0u,                                  // uint32_t queueFamilyIndexCount;
228             nullptr,                             // const uint32_t* pQueueFamilyIndices;
229             VK_IMAGE_LAYOUT_UNDEFINED,           // VkImageLayout initialLayout;
230         };
231 
232         const VkImageSubresourceRange imageSubresource =
233             makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
234 
235         m_imageColor      = makeImage(m_vk, device, imageInfo);
236         m_imageColorAlloc = bindImage(m_vk, device, allocator, *m_imageColor, MemoryRequirement::Any);
237         m_imageColorView =
238             makeImageView(m_vk, device, *m_imageColor, VK_IMAGE_VIEW_TYPE_2D, kColorAttachmentFormat, imageSubresource);
239     }
240 
generateGraphicsPipeline()241     void generateGraphicsPipeline()
242     {
243         const VkDevice device = m_context.getDevice();
244         std::vector<VkDescriptorSetLayoutBinding> bindings;
245 
246         if (m_testParam.dataClass == DataClass::StorageBuffer)
247         {
248             VkDescriptorSetLayoutCreateFlags layoutCreateFlags = 0;
249 
250             bindings.emplace_back(VkDescriptorSetLayoutBinding{
251                 0,                                                       // binding
252                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,                       // descriptorType
253                 1,                                                       // descriptorCount
254                 static_cast<VkShaderStageFlags>(m_testParam.shaderType), // stageFlags
255                 nullptr,                                                 // pImmutableSamplers
256             });                                                          // input binding
257 
258             // Create a layout and allocate a descriptor set for it.
259             const VkDescriptorSetLayoutCreateInfo setLayoutCreateInfo = {
260                 vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
261                 nullptr,                                                 // pNext
262                 layoutCreateFlags,                                       // flags
263                 static_cast<uint32_t>(bindings.size()),                  // bindingCount
264                 bindings.data()                                          // pBindings
265             };
266 
267             m_descriptorSetLayout = vk::createDescriptorSetLayout(m_vk, device, &setLayoutCreateInfo);
268             m_pipelineLayout      = makePipelineLayout(m_vk, device, 1, &m_descriptorSetLayout.get(), 0, nullptr);
269         }
270         else if (m_testParam.dataClass == DataClass::PushConstant)
271         {
272             VkPushConstantRange pushConstant{static_cast<VkShaderStageFlags>(m_testParam.shaderType), 0,
273                                              sizeof(VkBool32)};
274             m_pipelineLayout = makePipelineLayout(m_vk, device, 0, nullptr, 1, &pushConstant);
275         }
276         else
277         {
278             m_pipelineLayout = makePipelineLayout(m_vk, device, 0, nullptr, 0, nullptr);
279         }
280 
281         Move<VkShaderModule> vertexModule =
282             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("vert"), 0u);
283         Move<VkShaderModule> fragmentModule =
284             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("frag"), 0u);
285 
286         const VkVertexInputBindingDescription vertexInputBindingDescription = {
287             0,                           // uint32_t binding;
288             sizeof(tcu::Vec2),           // uint32_t strideInBytes;
289             VK_VERTEX_INPUT_RATE_VERTEX, // VkVertexInputStepRate stepRate;
290         };
291 
292         const VkVertexInputAttributeDescription vertexInputAttributeDescription = {
293             0u,                      // uint32_t location;
294             0u,                      // uint32_t binding;
295             VK_FORMAT_R32G32_SFLOAT, // VkFormat format;
296             0u,                      // uint32_t offsetInBytes;
297         };
298 
299         const VkPipelineVertexInputStateCreateInfo vertexInputStateParams = {
300             VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, // VkStructureType sType;
301             nullptr,                                                   // const void* pNext;
302             0,                                                         // VkPipelineVertexInputStateCreateFlags flags;
303             1u,                                                        // uint32_t bindingCount;
304             &vertexInputBindingDescription,   // const VkVertexInputBindingDescription* pVertexBindingDescriptions;
305             1u,                               // uint32_t attributeCount;
306             &vertexInputAttributeDescription, // const VkVertexInputAttributeDescription* pVertexAttributeDescriptions;
307         };
308 
309         const VkPipelineInputAssemblyStateCreateInfo pipelineInputAssemblyStateInfo = {
310             VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO, // VkStructureType sType;
311             nullptr,                                                     // const void* pNext;
312             (VkPipelineInputAssemblyStateCreateFlags)0, // VkPipelineInputAssemblyStateCreateFlags flags;
313             VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST,        // VkPrimitiveTopology topology;
314             VK_FALSE,                                   // VkBool32 primitiveRestartEnable;
315         };
316 
317         const VkViewport viewport{0, 0, static_cast<float>(kNumElements), 1, 0, 1};
318         const VkRect2D scissor{{0, 0}, {kNumElements, 1}};
319 
320         const VkPipelineViewportStateCreateInfo pipelineViewportStateInfo = {
321             VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, // VkStructureType sType;
322             nullptr,                                               // const void* pNext;
323             (VkPipelineViewportStateCreateFlags)0,                 // VkPipelineViewportStateCreateFlags flags;
324             1u,                                                    // uint32_t viewportCount;
325             &viewport,                                             // const VkViewport* pViewports;
326             1u,                                                    // uint32_t scissorCount;
327             &scissor,                                              // const VkRect2D* pScissors;
328         };
329 
330         const VkPipelineRasterizationStateCreateInfo pipelineRasterizationStateInfo = {
331             VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO, // VkStructureType sType;
332             nullptr,                                                    // const void* pNext;
333             0u,                              // VkPipelineRasterizationStateCreateFlags flags;
334             VK_FALSE,                        // VkBool32 depthClampEnable;
335             VK_FALSE,                        // VkBool32 rasterizerDiscardEnable;
336             VK_POLYGON_MODE_FILL,            // VkPolygonMode polygonMode;
337             VK_CULL_MODE_NONE,               // VkCullModeFlags cullMode;
338             VK_FRONT_FACE_COUNTER_CLOCKWISE, // VkFrontFace frontFace;
339             VK_FALSE,                        // VkBool32 depthBiasEnable;
340             0.0f,                            // float depthBiasConstantFactor;
341             0.0f,                            // float depthBiasClamp;
342             0.0f,                            // float depthBiasSlopeFactor;
343             1.0f,                            // float lineWidth;
344         };
345 
346         const VkPipelineMultisampleStateCreateInfo pipelineMultisampleStateInfo = {
347             VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO, // VkStructureType sType;
348             nullptr,                                                  // const void* pNext;
349             0u,                                                       // VkPipelineMultisampleStateCreateFlags flags;
350             VK_SAMPLE_COUNT_1_BIT,                                    // VkSampleCountFlagBits rasterizationSamples;
351             VK_FALSE,                                                 // VkBool32 sampleShadingEnable;
352             1.0f,                                                     // float minSampleShading;
353             nullptr,                                                  // const VkSampleMask* pSampleMask;
354             VK_FALSE,                                                 // VkBool32 alphaToCoverageEnable;
355             VK_FALSE                                                  // VkBool32 alphaToOneEnable;
356         };
357 
358         std::vector<VkPipelineColorBlendAttachmentState> colorBlendAttachmentState(
359             1,
360             {
361                 false,                                                // VkBool32 blendEnable;
362                 VK_BLEND_FACTOR_ONE,                                  // VkBlend srcBlendColor;
363                 VK_BLEND_FACTOR_ONE,                                  // VkBlend destBlendColor;
364                 VK_BLEND_OP_ADD,                                      // VkBlendOp blendOpColor;
365                 VK_BLEND_FACTOR_ONE,                                  // VkBlend srcBlendAlpha;
366                 VK_BLEND_FACTOR_ONE,                                  // VkBlend destBlendAlpha;
367                 VK_BLEND_OP_ADD,                                      // VkBlendOp blendOpAlpha;
368                 (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT) // VkChannelFlags channelWriteMask;
369             });
370 
371         const VkPipelineColorBlendStateCreateInfo pipelineColorBlendStateInfo = {
372             VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO, // VkStructureType sType;
373             nullptr,                                                  // const void* pNext;
374             /* always needed */
375             0,                                          // VkPipelineColorBlendStateCreateFlags flags;
376             false,                                      // VkBool32 logicOpEnable;
377             VK_LOGIC_OP_COPY,                           // VkLogicOp logicOp;
378             (uint32_t)colorBlendAttachmentState.size(), // uint32_t attachmentCount;
379             colorBlendAttachmentState.data(),           // const VkPipelineColorBlendAttachmentState* pAttachments;
380             {0.0f, 0.0f, 0.0f, 0.0f},                   // float blendConst[4];
381         };
382 
383         VkStencilOpState stencilOpState = {
384             VK_STENCIL_OP_ZERO,               // VkStencilOp failOp;
385             VK_STENCIL_OP_INCREMENT_AND_WRAP, // VkStencilOp passOp;
386             VK_STENCIL_OP_INCREMENT_AND_WRAP, // VkStencilOp depthFailOp;
387             VK_COMPARE_OP_ALWAYS,             // VkCompareOp compareOp;
388             0xff,                             // uint32_t compareMask;
389             0xff,                             // uint32_t writeMask;
390             0,                                // uint32_t reference;
391         };
392 
393         VkPipelineDepthStencilStateCreateInfo pipelineDepthStencilStateInfo = {
394             VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO,
395             // VkStructureType sType;
396             nullptr, // const void* pNext;
397             0,
398             // VkPipelineDepthStencilStateCreateFlags flags;
399             VK_FALSE,             // VkBool32 depthTestEnable;
400             VK_FALSE,             // VkBool32 depthWriteEnable;
401             VK_COMPARE_OP_ALWAYS, // VkCompareOp depthCompareOp;
402             VK_FALSE,             // VkBool32 depthBoundsTestEnable;
403             VK_FALSE,             // VkBool32 stencilTestEnable;
404             stencilOpState,       // VkStencilOpState front;
405             stencilOpState,       // VkStencilOpState back;
406             0.0f,                 // float minDepthBounds;
407             1.0f,                 // float maxDepthBounds;
408         };
409 
410         const VkPipelineRenderingCreateInfoKHR renderingCreateInfo = {
411             VK_STRUCTURE_TYPE_PIPELINE_RENDERING_CREATE_INFO_KHR, // VkStructureType sType;
412             nullptr,                                              // const void* pNext;
413             0u,                                                   // uint32_t viewMask;
414             1,                                                    // uint32_t colorAttachmentCount;
415             &kColorAttachmentFormat,                              // const VkFormat* pColorAttachmentFormats;
416             VK_FORMAT_UNDEFINED,                                  // VkFormat depthAttachmentFormat;
417             VK_FORMAT_UNDEFINED,                                  // VkFormat stencilAttachmentFormat;
418         };
419 
420         VkSpecializationMapEntry specializationMapEntry = {0, 0, sizeof(VkBool32)};
421         VkBool32 specializationData                     = VK_TRUE;
422         VkSpecializationInfo specializationInfo = {1, &specializationMapEntry, sizeof(VkBool32), &specializationData};
423 
424         const VkPipelineShaderStageCreateInfo pShaderStages[] = {
425             {
426                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
427                 nullptr,                                             // const void*  pNext;
428                 (VkPipelineShaderStageCreateFlags)0,                 // VkPipelineShaderStageCreateFlags flags;
429                 VK_SHADER_STAGE_VERTEX_BIT,                          // VkShaderStageFlagBits stage;
430                 *vertexModule,                                       // VkShaderModule module;
431                 "main",                                              // const char* pName;
432                 (m_testParam.dataClass == DataClass::SpecializationConstant &&
433                  m_testParam.shaderType == VK_SHADER_STAGE_VERTEX_BIT) ?
434                     &specializationInfo :
435                     nullptr, // const VkSpecializationInfo* pSpecializationInfo;
436             },
437             {
438                 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
439                 nullptr,                                             // const void* pNext;
440                 (VkPipelineShaderStageCreateFlags)0,                 // VkPipelineShaderStageCreateFlags flags;
441                 VK_SHADER_STAGE_FRAGMENT_BIT,                        // VkShaderStageFlagBits stage;
442                 *fragmentModule,                                     // VkShaderModule module;
443                 "main",                                              // const char* pName;
444                 (m_testParam.dataClass == DataClass::SpecializationConstant &&
445                  m_testParam.shaderType == VK_SHADER_STAGE_FRAGMENT_BIT) ?
446                     &specializationInfo :
447                     nullptr, // const VkSpecializationInfo* pSpecializationInfo;
448             },
449         };
450 
451         const VkGraphicsPipelineCreateInfo graphicsPipelineInfo = {
452             VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO, // VkStructureType sType;
453             &renderingCreateInfo,                            // const void* pNext;
454             (VkPipelineCreateFlags)0,                        // VkPipelineCreateFlags flags;
455             2u,                                              // uint32_t stageCount;
456             pShaderStages,                                   // const VkPipelineShaderStageCreateInfo* pStages;
457             &vertexInputStateParams,         // const VkPipelineVertexInputStateCreateInfo* pVertexInputState;
458             &pipelineInputAssemblyStateInfo, // const VkPipelineInputAssemblyStateCreateInfo* pInputAssemblyState;
459             nullptr,                         // const VkPipelineTessellationStateCreateInfo* pTessellationState;
460             &pipelineViewportStateInfo,      // const VkPipelineViewportStateCreateInfo* pViewportState;
461             &pipelineRasterizationStateInfo, // const VkPipelineRasterizationStateCreateInfo* pRasterizationState;
462             &pipelineMultisampleStateInfo,   // const VkPipelineMultisampleStateCreateInfo* pMultisampleState;
463             &pipelineDepthStencilStateInfo,  // const VkPipelineDepthStencilStateCreateInfo* pDepthStencilState;
464             &pipelineColorBlendStateInfo,    // const VkPipelineColorBlendStateCreateInfo* pColorBlendState;
465             nullptr,                         // const VkPipelineDynamicStateCreateInfo* pDynamicState;
466             *m_pipelineLayout,               // VkPipelineLayout layout;
467             VK_NULL_HANDLE,                  // VkRenderPass renderPass;
468             0u,                              // uint32_t subpass;
469             VK_NULL_HANDLE,                  // VkPipeline basePipelineHandle;
470             0,                               // int32_t basePipelineIndex;
471         };
472 
473         m_pipeline = createGraphicsPipeline(m_vk, device, VK_NULL_HANDLE, &graphicsPipelineInfo);
474 
475         // DescriptorSet create/update for input storage buffer
476         if (m_testParam.dataClass == DataClass::StorageBuffer)
477         {
478             // DescriptorPool/DescriptorSet create
479             VkDescriptorPoolCreateFlags poolCreateFlags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
480 
481             vk::DescriptorPoolBuilder poolBuilder;
482             for (uint32_t i = 0; i < static_cast<uint32_t>(bindings.size()); ++i)
483             {
484                 poolBuilder.addType(bindings[i].descriptorType, bindings[i].descriptorCount);
485             }
486             m_descriptorPool = poolBuilder.build(m_vk, device, poolCreateFlags, 1);
487 
488             m_descriptorSet = makeDescriptorSet(m_vk, device, *m_descriptorPool, *m_descriptorSetLayout);
489 
490             // DescriptorSet update
491             VkDescriptorBufferInfo inputBufferInfo;
492             std::vector<VkDescriptorBufferInfo> bufferInfos;
493 
494             inputBufferInfo = makeDescriptorBufferInfo(m_inputBuffer.get(), 0, VK_WHOLE_SIZE);
495             bufferInfos.push_back(inputBufferInfo); // binding 1 is input if needed
496 
497             VkWriteDescriptorSet w = {
498                 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,    // sType
499                 nullptr,                                   // pNext
500                 *m_descriptorSet,                          // dstSet
501                 (uint32_t)0,                               // dstBinding
502                 0,                                         // dstArrayEllement
503                 static_cast<uint32_t>(bufferInfos.size()), // descriptorCount
504                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         // descriptorType
505                 nullptr,                                   // pImageInfo
506                 bufferInfos.data(),                        // pBufferInfo
507                 nullptr,                                   // pTexelBufferView
508             };
509 
510             m_vk.updateDescriptorSets(device, 1, &w, 0, nullptr);
511         }
512     }
513 
generateStorageBuffers()514     void generateStorageBuffers()
515     {
516         // Avoid creating zero-sized buffer/memory
517         const size_t inputBufferSize  = kNumElements * sizeof(uint64_t) * 4; // maximum size, 4 vector of 64bit
518         const size_t outputBufferSize = kNumElements * sizeof(uint32_t) * 2;
519 
520         // Upload data to buffer
521         const VkDevice device           = m_context.getDevice();
522         const uint32_t queueFamilyIndex = m_context.getUniversalQueueFamilyIndex();
523         Allocator &memAlloc             = m_context.getDefaultAllocator();
524 
525         const VkBufferCreateInfo inputBufferParams = {
526             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // VkStructureType sType;
527             nullptr,                              // const void* pNext;
528             0u,                                   // VkBufferCreateFlags flags;
529             inputBufferSize,                      // VkDeviceSize size;
530             VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,   // VkBufferUsageFlags usage;
531             VK_SHARING_MODE_EXCLUSIVE,            // VkSharingMode sharingMode;
532             1u,                                   // uint32_t queueFamilyCount;
533             &queueFamilyIndex                     // const uint32_t* pQueueFamilyIndices;
534         };
535 
536         m_inputBuffer   = createBuffer(m_vk, device, &inputBufferParams);
537         m_inputAlloc    = memAlloc.allocate(getBufferMemoryRequirements(m_vk, device, *m_inputBuffer),
538                                             MemoryRequirement::HostVisible);
539         void *inputData = m_inputAlloc->getHostPtr();
540 
541         // element stride of channel count 3 is 4, otherwise same to channel count
542         const uint32_t elementStride = (m_testParam.dataChannelCount != 3) ? m_testParam.dataChannelCount : 4;
543 
544         for (uint32_t i = 0; i < kNumElements; i++)
545         {
546             for (uint32_t channel = 0; channel < m_testParam.dataChannelCount; channel++)
547             {
548                 const uint32_t index = (i * elementStride) + channel;
549                 uint32_t value       = i + channel;
550                 if (m_testParam.wrongExpectation)
551                 {
552                     value += 1; // write wrong value to storage buffer
553                 }
554 
555                 switch (m_testParam.dataType)
556                 {
557                 case DataType::Bool: // std430 layout alignment of machine type(GLfloat)
558                     reinterpret_cast<int32_t *>(inputData)[index] = m_testParam.wrongExpectation ? VK_FALSE : VK_TRUE;
559                     break;
560                 case DataType::Int8:
561                     reinterpret_cast<int8_t *>(inputData)[index] = static_cast<int8_t>(value);
562                     break;
563                 case DataType::Int16:
564                     reinterpret_cast<int16_t *>(inputData)[index] = static_cast<int16_t>(value);
565                     break;
566                 case DataType::Int32:
567                     reinterpret_cast<int32_t *>(inputData)[index] = static_cast<int32_t>(value);
568                     break;
569                 case DataType::Int64:
570                     reinterpret_cast<int64_t *>(inputData)[index] = static_cast<int64_t>(value);
571                     break;
572                 default:
573                     assert(false);
574                 }
575             }
576         }
577 
578         VK_CHECK(m_vk.bindBufferMemory(device, *m_inputBuffer, m_inputAlloc->getMemory(), m_inputAlloc->getOffset()));
579         flushAlloc(m_vk, device, *m_inputAlloc);
580 
581         const VkBufferCreateInfo outputBufferParams = {
582             VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO,                                  // VkStructureType sType;
583             nullptr,                                                               // const void* pNext;
584             0u,                                                                    // VkBufferCreateFlags flags;
585             outputBufferSize,                                                      // VkDeviceSize size;
586             VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT, // VkBufferUsageFlags usage;
587             VK_SHARING_MODE_EXCLUSIVE,                                             // VkSharingMode sharingMode;
588             1u,                                                                    // uint32_t queueFamilyCount;
589             &queueFamilyIndex // const uint32_t* pQueueFamilyIndices;
590         };
591 
592         m_outputBuffer = createBuffer(m_vk, device, &outputBufferParams);
593         m_outputAlloc  = memAlloc.allocate(getBufferMemoryRequirements(m_vk, device, *m_outputBuffer),
594                                            MemoryRequirement::HostVisible);
595 
596         void *outputData = m_outputAlloc->getHostPtr();
597         deMemset(outputData, 0, sizeof(outputBufferSize));
598 
599         VK_CHECK(
600             m_vk.bindBufferMemory(device, *m_outputBuffer, m_outputAlloc->getMemory(), m_outputAlloc->getOffset()));
601         flushAlloc(m_vk, device, *m_outputAlloc);
602     }
603 
generateComputePipeline()604     void generateComputePipeline()
605     {
606         const VkDevice device = m_context.getDevice();
607 
608         const Unique<VkShaderModule> cs(
609             createShaderModule(m_vk, device, m_context.getBinaryCollection().get("comp"), 0));
610 
611         VkDescriptorSetLayoutCreateFlags layoutCreateFlags = 0;
612 
613         std::vector<VkDescriptorSetLayoutBinding> bindings;
614         bindings.emplace_back(VkDescriptorSetLayoutBinding{
615             0,                                 // binding
616             VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, // descriptorType
617             1,                                 // descriptorCount
618             VK_SHADER_STAGE_COMPUTE_BIT,       // stageFlags
619             nullptr,                           // pImmutableSamplers
620         });                                    // output binding
621 
622         if (m_testParam.dataClass == DataClass::StorageBuffer)
623         {
624             bindings.emplace_back(VkDescriptorSetLayoutBinding{
625                 1,                                 // binding
626                 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, // descriptorType
627                 1,                                 // descriptorCount
628                 VK_SHADER_STAGE_COMPUTE_BIT,       // stageFlags
629                 nullptr,                           // pImmutableSamplers
630             });                                    // input binding
631         }
632 
633         // Create a layout and allocate a descriptor set for it.
634         const VkDescriptorSetLayoutCreateInfo setLayoutCreateInfo = {
635             vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, // sType
636             nullptr,                                                 // pNext
637             layoutCreateFlags,                                       // flags
638             static_cast<uint32_t>(bindings.size()),                  // bindingCount
639             bindings.data()                                          // pBindings
640         };
641 
642         m_descriptorSetLayout = vk::createDescriptorSetLayout(m_vk, device, &setLayoutCreateInfo);
643 
644         VkSpecializationMapEntry specializationMapEntry = {0, 0, sizeof(VkBool32)};
645         VkBool32 specializationData                     = VK_TRUE;
646         VkSpecializationInfo specializationInfo = {1, &specializationMapEntry, sizeof(VkBool32), &specializationData};
647         const VkPipelineShaderStageCreateInfo csShaderCreateInfo = {
648             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
649             nullptr,
650             (VkPipelineShaderStageCreateFlags)0,
651             VK_SHADER_STAGE_COMPUTE_BIT, // stage
652             *cs,                         // shader
653             "main",
654             (m_testParam.dataClass == DataClass::SpecializationConstant) ? &specializationInfo :
655                                                                            nullptr, // pSpecializationInfo
656         };
657 
658         VkPushConstantRange pushConstantRange = {VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(VkBool32)};
659         m_pipelineLayout = makePipelineLayout(m_vk, device, 1, &m_descriptorSetLayout.get(), 1, &pushConstantRange);
660 
661         const VkComputePipelineCreateInfo pipelineCreateInfo = {
662             VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
663             nullptr,
664             0u,                 // flags
665             csShaderCreateInfo, // cs
666             *m_pipelineLayout,  // layout
667             (vk::VkPipeline)0,  // basePipelineHandle
668             0u,                 // basePipelineIndex
669         };
670 
671         m_pipeline = createComputePipeline(m_vk, device, VK_NULL_HANDLE, &pipelineCreateInfo, nullptr);
672 
673         // DescriptorSet create for input/output storage buffer
674         VkDescriptorPoolCreateFlags poolCreateFlags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
675 
676         vk::DescriptorPoolBuilder poolBuilder;
677         for (uint32_t i = 0; i < static_cast<uint32_t>(bindings.size()); ++i)
678         {
679             poolBuilder.addType(bindings[i].descriptorType, bindings[i].descriptorCount);
680         }
681         m_descriptorPool = poolBuilder.build(m_vk, device, poolCreateFlags, 1);
682 
683         m_descriptorSet = makeDescriptorSet(m_vk, device, *m_descriptorPool, *m_descriptorSetLayout);
684 
685         // DescriptorSet update
686         VkDescriptorBufferInfo outputBufferInfo;
687         VkDescriptorBufferInfo inputBufferInfo;
688         std::vector<VkDescriptorBufferInfo> bufferInfos;
689 
690         outputBufferInfo = makeDescriptorBufferInfo(m_outputBuffer.get(), 0, VK_WHOLE_SIZE);
691         bufferInfos.push_back(outputBufferInfo); // binding 0 is output
692 
693         if (m_testParam.dataClass == DataClass::StorageBuffer)
694         {
695             inputBufferInfo = makeDescriptorBufferInfo(m_inputBuffer.get(), 0, VK_WHOLE_SIZE);
696             bufferInfos.push_back(inputBufferInfo); // binding 1 is input if needed
697         }
698 
699         VkWriteDescriptorSet w = {
700             VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,    // sType
701             nullptr,                                   // pNext
702             *m_descriptorSet,                          // dstSet
703             (uint32_t)0,                               // dstBinding
704             0,                                         // dstArrayEllement
705             static_cast<uint32_t>(bufferInfos.size()), // descriptorCount
706             VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,         // descriptorType
707             nullptr,                                   // pImageInfo
708             bufferInfos.data(),                        // pBufferInfo
709             nullptr,                                   // pTexelBufferView
710         };
711 
712         m_vk.updateDescriptorSets(device, 1, &w, 0, nullptr);
713     }
714 
dispatch()715     void dispatch()
716     {
717         const VkDevice device = m_context.getDevice();
718         const VkQueue queue   = m_context.getUniversalQueue();
719 
720         beginCommandBuffer(m_vk, *m_cmdBuffer);
721         m_vk.cmdBindPipeline(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *m_pipeline);
722         m_vk.cmdBindDescriptorSets(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, *m_pipelineLayout, 0u, 1,
723                                    &m_descriptorSet.get(), 0u, nullptr);
724 
725         if (m_testParam.dataClass == DataClass::PushConstant)
726         {
727             VkBool32 pcValue = VK_TRUE;
728             m_vk.cmdPushConstants(*m_cmdBuffer, *m_pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(VkBool32),
729                                   &pcValue);
730         }
731         m_vk.cmdDispatch(*m_cmdBuffer, 1, 1, 1);
732         VK_CHECK(m_vk.endCommandBuffer(*m_cmdBuffer));
733         submitCommandsAndWait(m_vk, device, queue, m_cmdBuffer.get());
734         flushMappedMemoryRange(m_vk, device, m_outputAlloc->getMemory(), 0, VK_WHOLE_SIZE);
735     }
736 
render()737     void render()
738     {
739         const VkDevice device = m_context.getDevice();
740         const VkQueue queue   = m_context.getUniversalQueue();
741 
742         beginCommandBuffer(m_vk, *m_cmdBuffer);
743 
744         // begin render pass
745         const VkClearValue clearValue = {}; // { 0, 0, 0, 0 }
746         const VkRect2D renderArea     = {{0, 0}, {kNumElements, 1}};
747 
748         const VkRenderingAttachmentInfoKHR renderingAttInfo = {
749             VK_STRUCTURE_TYPE_RENDERING_ATTACHMENT_INFO_KHR, // VkStructureType sType;
750             nullptr,                                         // const void* pNext;
751             *m_imageColorView,                               // VkImageView imageView;
752             VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,        // VkImageLayout imageLayout;
753             VK_RESOLVE_MODE_NONE,                            // VkResolveModeFlagBits resolveMode;
754             VK_NULL_HANDLE,                                  // VkImageView resolveImageView;
755             VK_IMAGE_LAYOUT_UNDEFINED,                       // VkImageLayout resolveImageLayout;
756             VK_ATTACHMENT_LOAD_OP_CLEAR,                     // VkAttachmentLoadOp loadOp;
757             VK_ATTACHMENT_STORE_OP_STORE,                    // VkAttachmentStoreOp storeOp;
758             clearValue,                                      // VkClearValue clearValue;
759         };
760 
761         const VkRenderingInfoKHR renderingInfo = {
762             VK_STRUCTURE_TYPE_RENDERING_INFO_KHR, // VkStructureType sType;
763             nullptr,                              // const void* pNext;
764             0,                                    // VkRenderingFlagsKHR flags;
765             renderArea,                           // VkRect2D renderArea;
766             1u,                                   // uint32_t layerCount;
767             0u,                                   // uint32_t viewMask;
768             1,                                    // uint32_t colorAttachmentCount;
769             &renderingAttInfo,                    // const VkRenderingAttachmentInfoKHR* pColorAttachments;
770             nullptr,                              // const VkRenderingAttachmentInfoKHR* pDepthAttachment;
771             nullptr                               // const VkRenderingAttachmentInfoKHR* pStencilAttachment;
772         };
773 
774         auto transition2DImage = [](const vk::DeviceInterface &vk, vk::VkCommandBuffer cmdBuffer, vk::VkImage image,
775                                     vk::VkImageAspectFlags aspectMask, vk::VkImageLayout oldLayout,
776                                     vk::VkImageLayout newLayout, vk::VkAccessFlags srcAccessMask,
777                                     vk::VkAccessFlags dstAccessMask, vk::VkPipelineStageFlags srcStageMask,
778                                     vk::VkPipelineStageFlags dstStageMask)
779         {
780             vk::VkImageMemoryBarrier barrier;
781             barrier.sType                           = vk::VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
782             barrier.pNext                           = nullptr;
783             barrier.srcAccessMask                   = srcAccessMask;
784             barrier.dstAccessMask                   = dstAccessMask;
785             barrier.oldLayout                       = oldLayout;
786             barrier.newLayout                       = newLayout;
787             barrier.srcQueueFamilyIndex             = VK_QUEUE_FAMILY_IGNORED;
788             barrier.dstQueueFamilyIndex             = VK_QUEUE_FAMILY_IGNORED;
789             barrier.image                           = image;
790             barrier.subresourceRange.aspectMask     = aspectMask;
791             barrier.subresourceRange.baseMipLevel   = 0;
792             barrier.subresourceRange.levelCount     = 1;
793             barrier.subresourceRange.baseArrayLayer = 0;
794             barrier.subresourceRange.layerCount     = 1;
795 
796             vk.cmdPipelineBarrier(cmdBuffer, srcStageMask, dstStageMask, (vk::VkDependencyFlags)0, 0,
797                                   (const vk::VkMemoryBarrier *)nullptr, 0, (const vk::VkBufferMemoryBarrier *)nullptr,
798                                   1, &barrier);
799         };
800 
801         transition2DImage(m_vk, *m_cmdBuffer, *m_imageColor, VK_IMAGE_ASPECT_COLOR_BIT, VK_IMAGE_LAYOUT_UNDEFINED,
802                           VK_IMAGE_LAYOUT_GENERAL, 0, VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT,
803                           VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT);
804 
805         m_vk.cmdBeginRendering(*m_cmdBuffer, &renderingInfo);
806 
807         // vertex input setup
808         // pipeline setup
809         m_vk.cmdBindPipeline(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *m_pipeline);
810 
811         const uint32_t vertexCount = 6;
812         const VkDeviceSize pOffset = 0;
813         assert(vertexCount <= kNumElements);
814         if (m_testParam.dataClass == DataClass::PushConstant)
815         {
816             const VkBool32 pcValue = VK_TRUE;
817             m_vk.cmdPushConstants(*m_cmdBuffer, *m_pipelineLayout,
818                                   static_cast<VkShaderStageFlags>(m_testParam.shaderType), 0, sizeof(VkBool32),
819                                   &pcValue);
820         }
821         else if (m_testParam.dataClass == DataClass::StorageBuffer)
822         {
823             m_vk.cmdBindDescriptorSets(*m_cmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, *m_pipelineLayout, 0u, 1,
824                                        &m_descriptorSet.get(), 0u, nullptr);
825         }
826         m_vk.cmdBindVertexBuffers(*m_cmdBuffer, 0, 1, &m_vertexBuffer.get(), &pOffset);
827 
828         m_vk.cmdDraw(*m_cmdBuffer, vertexCount, 1, 0, 0u);
829 
830         m_vk.cmdEndRendering(*m_cmdBuffer);
831 
832         VkMemoryBarrier memBarrier = {
833             VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
834             nullptr,                          // pNext
835             0u,                               // srcAccessMask
836             0u,                               // dstAccessMask
837         };
838         memBarrier.srcAccessMask = VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT;
839         memBarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT;
840         m_vk.cmdPipelineBarrier(*m_cmdBuffer, VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT,
841                                 VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &memBarrier, 0, nullptr, 0, nullptr);
842 
843         // copy color image to output buffer
844         const VkImageSubresourceLayers imageSubresource = {VK_IMAGE_ASPECT_COLOR_BIT, 0, 0, 1};
845         const VkOffset3D imageOffset                    = {};
846         const VkExtent3D imageExtent                    = {kNumElements, 1, 1};
847         const VkBufferImageCopy copyRegion              = {0, 0, 0, imageSubresource, imageOffset, imageExtent};
848 
849         m_vk.cmdCopyImageToBuffer(*m_cmdBuffer, *m_imageColor, VK_IMAGE_LAYOUT_GENERAL, *m_outputBuffer, 1,
850                                   &copyRegion);
851 
852         VK_CHECK(m_vk.endCommandBuffer(*m_cmdBuffer));
853 
854         submitCommandsAndWait(m_vk, device, queue, m_cmdBuffer.get());
855         flushMappedMemoryRange(m_vk, device, m_outputAlloc->getMemory(), 0, VK_WHOLE_SIZE);
856     }
857 
858     TestParam m_testParam;
859     const DeviceInterface &m_vk;
860 
861     Move<VkCommandPool> m_cmdPool;
862     Move<VkCommandBuffer> m_cmdBuffer;
863     Move<VkDescriptorPool> m_descriptorPool;
864     Move<VkDescriptorSet> m_descriptorSet;
865     Move<VkDescriptorSetLayout> m_descriptorSetLayout;
866     Move<VkPipelineLayout> m_pipelineLayout;
867     Move<VkPipeline> m_pipeline;
868     Move<VkBuffer> m_inputBuffer;
869     de::MovePtr<Allocation> m_inputAlloc;
870     Move<VkBuffer> m_outputBuffer;
871     de::MovePtr<Allocation> m_outputAlloc;
872     Move<VkBuffer> m_vertexBuffer;
873     de::MovePtr<Allocation> m_vertexAlloc;
874     Move<VkImage> m_imageColor;
875     de::MovePtr<Allocation> m_imageColorAlloc;
876     Move<VkImageView> m_imageColorView;
877 };
878 
879 class ShaderExpectAssumeCase : public TestCase
880 {
881 public:
ShaderExpectAssumeCase(tcu::TestContext & testCtx,TestParam testParam)882     ShaderExpectAssumeCase(tcu::TestContext &testCtx, TestParam testParam)
883         : TestCase(testCtx, testParam.testName)
884         , m_testParam(testParam)
885     {
886     }
887     ShaderExpectAssumeCase(const ShaderExpectAssumeCase &)            = delete;
888     ShaderExpectAssumeCase &operator=(const ShaderExpectAssumeCase &) = delete;
889 
createInstance(Context & ctx) const890     TestInstance *createInstance(Context &ctx) const override
891     {
892         return new ShaderExpectAssumeTestInstance(ctx, m_testParam);
893     }
894 
initPrograms(vk::SourceCollections & programCollection) const895     void initPrograms(vk::SourceCollections &programCollection) const override
896     {
897         std::map<std::string, std::string> params;
898 
899         params["TEST_ELEMENT_COUNT"] = std::to_string(kNumElements);
900         assert(kNumElements < 127); // less than int byte
901 
902         switch (m_testParam.opType)
903         {
904         case OpType::Expect:
905             params["TEST_OPERATOR"] = "expectKHR";
906             break;
907         case OpType::Assume:
908             params["TEST_OPERATOR"] = "assumeTrueKHR";
909             break;
910         default:
911             assert(false);
912         }
913 
914         // default no need additional extension.
915         params["DATATYPE_EXTENSION_ENABLE"] = "";
916 
917         switch (m_testParam.dataType)
918         {
919         case DataType::Bool:
920             if (m_testParam.dataChannelCount == 1)
921             {
922                 params["DATATYPE"] = "bool";
923             }
924             else
925             {
926                 params["DATATYPE"] = "bvec" + std::to_string(m_testParam.dataChannelCount);
927             }
928             break;
929         case DataType::Int8:
930             assert(m_testParam.opType != OpType::Assume);
931             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable";
932             if (m_testParam.dataChannelCount == 1)
933             {
934                 params["DATATYPE"] = "int8_t";
935             }
936             else
937             {
938                 params["DATATYPE"] = "i8vec" + std::to_string(m_testParam.dataChannelCount);
939             }
940             break;
941         case DataType::Int16:
942             assert(m_testParam.opType != OpType::Assume);
943             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int16: enable";
944             if (m_testParam.dataChannelCount == 1)
945             {
946                 params["DATATYPE"] = "int16_t";
947             }
948             else
949             {
950                 params["DATATYPE"] = "i16vec" + std::to_string(m_testParam.dataChannelCount);
951             }
952             break;
953         case DataType::Int32:
954             assert(m_testParam.opType != OpType::Assume);
955             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int32: enable";
956             if (m_testParam.dataChannelCount == 1)
957             {
958                 params["DATATYPE"] = "int32_t";
959             }
960             else
961             {
962                 params["DATATYPE"] = "i32vec" + std::to_string(m_testParam.dataChannelCount);
963             }
964             break;
965         case DataType::Int64:
966             assert(m_testParam.opType != OpType::Assume);
967             params["DATATYPE_EXTENSION_ENABLE"] = "#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable";
968             if (m_testParam.dataChannelCount == 1)
969             {
970                 params["DATATYPE"] = "int64_t";
971             }
972             else
973             {
974                 params["DATATYPE"] = "i64vec" + std::to_string(m_testParam.dataChannelCount);
975             }
976             break;
977         default:
978             assert(false);
979         }
980 
981         switch (m_testParam.dataClass)
982         {
983         case DataClass::Constant:
984             assert(m_testParam.dataChannelCount == 1);
985 
986             params["VARNAME"] = "kThisIsTrue";
987             if (m_testParam.opType == OpType::Expect)
988             {
989                 params["EXPECTEDVALUE"] = "true";
990                 params["WRONGVALUE"]    = "false";
991             }
992             break;
993         case DataClass::SpecializationConstant:
994             assert(m_testParam.dataChannelCount == 1);
995 
996             params["VARNAME"] = "scThisIsTrue";
997             if (m_testParam.opType == OpType::Expect)
998             {
999                 params["EXPECTEDVALUE"] = "true";
1000                 params["WRONGVALUE"]    = "false";
1001             }
1002             break;
1003         case DataClass::StorageBuffer:
1004         {
1005             std::string indexingOffset;
1006             switch (m_testParam.shaderType)
1007             {
1008             case VK_SHADER_STAGE_COMPUTE_BIT:
1009                 indexingOffset = "gl_GlobalInvocationID.x";
1010                 break;
1011             case VK_SHADER_STAGE_VERTEX_BIT:
1012                 indexingOffset = "gl_VertexIndex";
1013                 break;
1014             case VK_SHADER_STAGE_FRAGMENT_BIT:
1015                 indexingOffset = "uint(gl_FragCoord.x)";
1016                 break;
1017             default:
1018                 assert(false);
1019             }
1020 
1021             params["VARNAME"] = "inputBuffer[" + indexingOffset + "]";
1022 
1023             if (m_testParam.opType == OpType::Expect)
1024             {
1025                 if (m_testParam.dataType == DataType::Bool)
1026                 {
1027                     params["EXPECTEDVALUE"] =
1028                         params["DATATYPE"] + "(true)"; // inputBuffer should be same as invocation id
1029                     params["WRONGVALUE"] =
1030                         params["DATATYPE"] + "(false)"; // inputBuffer should be same as invocation id
1031                 }
1032                 else
1033                 {
1034                     // inputBuffer should be same as invocation id + channel
1035                     params["EXPECTEDVALUE"] = params["DATATYPE"] + "(" + indexingOffset;
1036                     for (uint32_t channel = 1; channel < m_testParam.dataChannelCount; channel++) // from channel 1
1037                     {
1038                         params["EXPECTEDVALUE"] += ", " + indexingOffset + " + " + std::to_string(channel);
1039                     }
1040                     params["EXPECTEDVALUE"] += ")";
1041 
1042                     params["WRONGVALUE"] = params["DATATYPE"] + "(" + indexingOffset + "*2 + 3";
1043                     for (uint32_t channel = 1; channel < m_testParam.dataChannelCount; channel++) // from channel 1
1044                     {
1045                         params["WRONGVALUE"] += ", " + indexingOffset + "*2 + 3" + " + " + std::to_string(channel);
1046                     }
1047                     params["WRONGVALUE"] += ")";
1048                 }
1049             }
1050             break;
1051         }
1052         case DataClass::PushConstant:
1053             assert(m_testParam.dataChannelCount == 1);
1054             params["VARNAME"] = "pcThisIsTrue";
1055 
1056             if (m_testParam.opType == OpType::Expect)
1057             {
1058                 params["EXPECTEDVALUE"] = "true";
1059                 params["WRONGVALUE"]    = "false";
1060             }
1061 
1062             break;
1063         default:
1064             assert(false);
1065         }
1066 
1067         assert(!params["VARNAME"].empty());
1068         if (params["EXPECTEDVALUE"].empty())
1069         {
1070             params["TEST_OPERANDS"] = "(" + params["VARNAME"] + ")";
1071         }
1072         else
1073         {
1074             params["TEST_OPERANDS"] = "(" + params["VARNAME"] + ", " + params["EXPECTEDVALUE"] + ")";
1075         }
1076 
1077         switch (m_testParam.shaderType)
1078         {
1079         case VK_SHADER_STAGE_COMPUTE_BIT:
1080             addComputeTestShader(programCollection, params);
1081             break;
1082         case VK_SHADER_STAGE_VERTEX_BIT:
1083             addVertexTestShaders(programCollection, params);
1084             break;
1085         case VK_SHADER_STAGE_FRAGMENT_BIT:
1086             addFragmentTestShaders(programCollection, params);
1087             break;
1088         default:
1089             assert(0);
1090         }
1091     }
1092 
checkSupport(Context & context) const1093     void checkSupport(Context &context) const override
1094     {
1095         context.requireDeviceFunctionality("VK_KHR_shader_expect_assume");
1096 
1097         const auto &features          = context.getDeviceFeatures();
1098         const auto &featuresStorage16 = context.get16BitStorageFeatures();
1099         const auto &featuresF16I8     = context.getShaderFloat16Int8Features();
1100         const auto &featuresStorage8  = context.get8BitStorageFeatures();
1101 
1102         if (m_testParam.dataType == DataType::Int64)
1103         {
1104             if (!features.shaderInt64)
1105                 TCU_THROW(NotSupportedError, "64-bit integers not supported");
1106         }
1107         else if (m_testParam.dataType == DataType::Int16)
1108         {
1109             context.requireDeviceFunctionality("VK_KHR_16bit_storage");
1110 
1111             if (!features.shaderInt16)
1112                 TCU_THROW(NotSupportedError, "16-bit integers not supported");
1113 
1114             if (!featuresStorage16.storageBuffer16BitAccess)
1115                 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
1116         }
1117         else if (m_testParam.dataType == DataType::Int8)
1118         {
1119             context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
1120             context.requireDeviceFunctionality("VK_KHR_8bit_storage");
1121 
1122             if (!featuresF16I8.shaderInt8)
1123                 TCU_THROW(NotSupportedError, "8-bit integers not supported");
1124 
1125             if (!featuresStorage8.storageBuffer8BitAccess)
1126                 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
1127 
1128             if (!featuresStorage8.uniformAndStorageBuffer8BitAccess)
1129                 TCU_THROW(NotSupportedError, "8-bit Uniform storage buffer access not supported");
1130         }
1131     }
1132 
1133 private:
addComputeTestShader(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1134     void addComputeTestShader(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1135     {
1136         std::stringstream compShader;
1137 
1138         // Compute shader copies color to linear layout in buffer memory
1139         compShader << "#version 460 core\n"
1140                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1141                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1142                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1143                    << "void assumeTrueKHR(bool);\n"
1144                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1145                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1146                    << "precision highp float;\n"
1147                    << "precision highp int;\n"
1148                    << "layout(set = 0, binding = 0, std430) buffer Block0 { uvec2 outputBuffer[]; };\n";
1149 
1150         // declare input variable.
1151         if (m_testParam.dataClass == DataClass::Constant)
1152         {
1153             compShader << "bool kThisIsTrue = true;\n";
1154         }
1155         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1156         {
1157             compShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1158         }
1159         else if (m_testParam.dataClass == DataClass::PushConstant)
1160         {
1161             compShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1162         }
1163         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1164         {
1165             compShader << "layout(set = 0, binding = 1, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1166         }
1167 
1168         compShader << "layout(local_size_x = ${TEST_ELEMENT_COUNT}, local_size_y = 1, local_size_z = 1) in;\n"
1169                    << "void main()\n"
1170                    << "{\n";
1171         if (m_testParam.opType == OpType::Assume)
1172         {
1173             compShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1174         }
1175         else if (m_testParam.opType == OpType::Expect)
1176         {
1177             compShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1178                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1179                        << "        control = ${EXPECTEDVALUE};\n"
1180                        << "    } else {\n"
1181                        << "        // set wrong value\n"
1182                        << "        control = ${WRONGVALUE};\n"
1183                        << "    }\n";
1184         }
1185         compShader << "    outputBuffer[gl_GlobalInvocationID.x].x = gl_GlobalInvocationID.x;\n";
1186 
1187         if (params["EXPECTEDVALUE"].empty())
1188         {
1189             compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(${VARNAME});\n";
1190         }
1191         else
1192         {
1193             if (m_testParam.opType == OpType::Assume)
1194             {
1195                 compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1196             }
1197             else if (m_testParam.opType == OpType::Expect)
1198             {
1199                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1200                 if (m_testParam.wrongExpectation)
1201                     compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(control == ${WRONGVALUE});\n";
1202                 else
1203                     compShader << "    outputBuffer[gl_GlobalInvocationID.x].y = uint(control == ${EXPECTEDVALUE});\n";
1204             }
1205         }
1206         compShader << "}\n";
1207 
1208         tcu::StringTemplate computeShaderTpl(compShader.str());
1209         programCollection.glslSources.add("comp") << glu::ComputeSource(computeShaderTpl.specialize(params));
1210     }
1211 
addVertexTestShaders(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1212     void addVertexTestShaders(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1213     {
1214         //vertex shader
1215         std::stringstream vertShader;
1216         vertShader << "#version 460\n"
1217                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1218                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1219                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1220                    << "void assumeTrueKHR(bool);\n"
1221                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1222                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1223                    << "precision highp float;\n"
1224                    << "precision highp int;\n"
1225                    << "layout(location = 0) in vec4 in_position;\n"
1226                    << "layout(location = 0) out flat uint value;\n";
1227 
1228         // declare input variable.
1229         if (m_testParam.dataClass == DataClass::Constant)
1230         {
1231             vertShader << "bool kThisIsTrue = true;\n";
1232         }
1233         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1234         {
1235             vertShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1236         }
1237         else if (m_testParam.dataClass == DataClass::PushConstant)
1238         {
1239             vertShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1240         }
1241         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1242         {
1243             vertShader << "layout(set = 0, binding = 0, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1244         }
1245 
1246         vertShader << "void main() {\n";
1247         if (m_testParam.opType == OpType::Assume)
1248         {
1249             vertShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1250         }
1251         else if (m_testParam.opType == OpType::Expect)
1252         {
1253             vertShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1254                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1255                        << "        control = ${EXPECTEDVALUE};\n"
1256                        << "    } else {\n"
1257                        << "        // set wrong value\n"
1258                        << "        control = ${WRONGVALUE};\n"
1259                        << "    }\n";
1260         }
1261 
1262         vertShader << "    gl_Position  = in_position;\n";
1263 
1264         if (params["EXPECTEDVALUE"].empty())
1265         {
1266             vertShader << "    value = uint(${VARNAME});\n";
1267         }
1268         else
1269         {
1270             if (m_testParam.opType == OpType::Assume)
1271             {
1272                 vertShader << "    value = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1273             }
1274             else if (m_testParam.opType == OpType::Expect)
1275             {
1276                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1277                 if (m_testParam.wrongExpectation)
1278                     vertShader << "    value = uint(control == ${WRONGVALUE});\n";
1279                 else
1280                     vertShader << "    value = uint(control == ${EXPECTEDVALUE});\n";
1281             }
1282         }
1283         vertShader << "}\n";
1284 
1285         tcu::StringTemplate vertexShaderTpl(vertShader.str());
1286         programCollection.glslSources.add("vert") << glu::VertexSource(vertexShaderTpl.specialize(params));
1287 
1288         // fragment shader
1289         std::stringstream fragShader;
1290         fragShader << "#version 460\n"
1291                    << "precision highp float;\n"
1292                    << "precision highp int;\n"
1293                    << "layout(location = 0) in flat uint value;\n"
1294                    << "layout(location = 0) out uvec2 out_color;\n"
1295                    << "void main()\n"
1296                    << "{\n"
1297                    << "    out_color.r = uint(gl_FragCoord.x);\n"
1298                    << "    out_color.g = value;\n"
1299                    << "}\n";
1300 
1301         tcu::StringTemplate fragmentShaderTpl(fragShader.str());
1302         programCollection.glslSources.add("frag") << glu::FragmentSource(fragmentShaderTpl.specialize(params));
1303     }
1304 
addFragmentTestShaders(SourceCollections & programCollection,std::map<std::string,std::string> & params) const1305     void addFragmentTestShaders(SourceCollections &programCollection, std::map<std::string, std::string> &params) const
1306     {
1307         //vertex shader
1308         std::stringstream vertShader;
1309         vertShader << "#version 460\n"
1310                    << "precision highp float;\n"
1311                    << "precision highp int;\n"
1312                    << "layout(location = 0) in vec4 in_position;\n"
1313                    << "void main() {\n"
1314                    << "    gl_Position  = in_position;\n"
1315                    << "}\n";
1316 
1317         tcu::StringTemplate vertexShaderTpl(vertShader.str());
1318         programCollection.glslSources.add("vert") << glu::VertexSource(vertexShaderTpl.specialize(params));
1319 
1320         // fragment shader
1321         std::stringstream fragShader;
1322         fragShader << "#version 460\n"
1323                    << "#extension GL_EXT_spirv_intrinsics: enable\n"
1324                    << "${DATATYPE_EXTENSION_ENABLE}\n"
1325                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5630)\n"
1326                    << "void assumeTrueKHR(bool);\n"
1327                    << "spirv_instruction (extensions = [\"SPV_KHR_expect_assume\"], capabilities = [5629], id = 5631)\n"
1328                    << "${DATATYPE} expectKHR(${DATATYPE}, ${DATATYPE});\n"
1329                    << "precision highp float;\n"
1330                    << "precision highp int;\n"
1331                    << "layout(location = 0) out uvec2 out_color;\n";
1332         if (m_testParam.dataClass == DataClass::Constant)
1333         {
1334             fragShader << "bool kThisIsTrue = true;\n";
1335         }
1336         else if (m_testParam.dataClass == DataClass::SpecializationConstant)
1337         {
1338             fragShader << "layout (constant_id = 0) const bool scThisIsTrue = false;\n";
1339         }
1340         else if (m_testParam.dataClass == DataClass::PushConstant)
1341         {
1342             fragShader << "layout( push_constant, std430 ) uniform pc { layout(offset = 0) bool pcThisIsTrue; };\n";
1343         }
1344         else if (m_testParam.dataClass == DataClass::StorageBuffer)
1345         {
1346             fragShader << "layout(set = 0, binding = 0, std430) buffer Block1 { ${DATATYPE} inputBuffer[]; };\n";
1347         }
1348 
1349         fragShader << "void main()\n"
1350                    << "{\n";
1351 
1352         if (m_testParam.opType == OpType::Assume)
1353         {
1354             fragShader << "    ${TEST_OPERATOR} ${TEST_OPERANDS};\n";
1355         }
1356         else if (m_testParam.opType == OpType::Expect)
1357         {
1358             fragShader << "    ${DATATYPE} control = ${WRONGVALUE};\n"
1359                        << "    if ( ${TEST_OPERATOR}(${VARNAME}, ${EXPECTEDVALUE}) == ${EXPECTEDVALUE} ) {\n"
1360                        << "        control = ${EXPECTEDVALUE};\n"
1361                        << "    } else {\n"
1362                        << "        // set wrong value\n"
1363                        << "        control = ${WRONGVALUE};\n"
1364                        << "    }\n";
1365         }
1366         fragShader << "    out_color.r = int(gl_FragCoord.x);\n";
1367 
1368         if (params["EXPECTEDVALUE"].empty())
1369         {
1370             fragShader << "    out_color.g = uint(${VARNAME});\n";
1371         }
1372         else
1373         {
1374             if (m_testParam.opType == OpType::Assume)
1375             {
1376                 fragShader << "    out_color.g = uint(${VARNAME} == ${EXPECTEDVALUE});\n";
1377             }
1378             else if (m_testParam.opType == OpType::Expect)
1379             {
1380                 // when m_testParam.wrongExpectation == true, the value of ${VARNAME} is set to ${EXPECTEDVALUE} + 1
1381                 if (m_testParam.wrongExpectation)
1382                     fragShader << "    out_color.g = uint(control == ${WRONGVALUE});\n";
1383                 else
1384                     fragShader << "    out_color.g = uint(control == ${EXPECTEDVALUE});\n";
1385             }
1386         }
1387         fragShader << "}\n";
1388 
1389         tcu::StringTemplate fragmentShaderTpl(fragShader.str());
1390         programCollection.glslSources.add("frag") << glu::FragmentSource(fragmentShaderTpl.specialize(params));
1391     }
1392 
1393 private:
1394     TestParam m_testParam;
1395 };
1396 
addShaderExpectAssumeTests(tcu::TestCaseGroup * testGroup)1397 void addShaderExpectAssumeTests(tcu::TestCaseGroup *testGroup)
1398 {
1399     VkShaderStageFlagBits stages[] = {
1400         VK_SHADER_STAGE_VERTEX_BIT,
1401         VK_SHADER_STAGE_FRAGMENT_BIT,
1402         VK_SHADER_STAGE_COMPUTE_BIT,
1403     };
1404 
1405     TestParam testParams[] = {
1406         {OpType::Expect, DataClass::Constant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "constant"},
1407         {OpType::Expect, DataClass::SpecializationConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false,
1408          "specializationconstant"},
1409         {OpType::Expect, DataClass::PushConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "pushconstant"},
1410         {OpType::Expect, DataClass::StorageBuffer, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer_bool"},
1411         {OpType::Expect, DataClass::StorageBuffer, DataType::Int8, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer_int8"},
1412         {OpType::Expect, DataClass::StorageBuffer, DataType::Int16, 0, VK_SHADER_STAGE_ALL, false,
1413          "storagebuffer_int16"},
1414         {OpType::Expect, DataClass::StorageBuffer, DataType::Int32, 0, VK_SHADER_STAGE_ALL, false,
1415          "storagebuffer_int32"},
1416         {OpType::Expect, DataClass::StorageBuffer, DataType::Int64, 0, VK_SHADER_STAGE_ALL, false,
1417          "storagebuffer_int64"},
1418         {OpType::Assume, DataClass::Constant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "constant"},
1419         {OpType::Assume, DataClass::SpecializationConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false,
1420          "specializationconstant"},
1421         {OpType::Assume, DataClass::PushConstant, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "pushconstant"},
1422         {OpType::Assume, DataClass::StorageBuffer, DataType::Bool, 0, VK_SHADER_STAGE_ALL, false, "storagebuffer"},
1423     };
1424 
1425     tcu::TestContext &testCtx = testGroup->getTestContext();
1426 
1427     for (VkShaderStageFlagBits stage : stages)
1428     {
1429         const char *stageName = (stage == VK_SHADER_STAGE_VERTEX_BIT)   ? ("vertex") :
1430                                 (stage == VK_SHADER_STAGE_FRAGMENT_BIT) ? ("fragment") :
1431                                 (stage == VK_SHADER_STAGE_COMPUTE_BIT)  ? ("compute") :
1432                                                                           (nullptr);
1433 
1434         const std::string setName = std::string() + stageName;
1435         de::MovePtr<tcu::TestCaseGroup> stageGroupTest(
1436             new tcu::TestCaseGroup(testCtx, setName.c_str(), "Shader Expect Assume Tests"));
1437 
1438         de::MovePtr<tcu::TestCaseGroup> expectGroupTest(
1439             new tcu::TestCaseGroup(testCtx, "expect", "Shader Expect Tests"));
1440 
1441         de::MovePtr<tcu::TestCaseGroup> assumeGroupTest(
1442             new tcu::TestCaseGroup(testCtx, "assume", "Shader Assume Tests"));
1443 
1444         for (uint32_t expectationState = 0; expectationState < 2; expectationState++)
1445         {
1446             bool wrongExpected = (expectationState == 0) ? false : true;
1447             for (uint32_t channelCount = 1; channelCount <= 4; channelCount++)
1448             {
1449                 for (TestParam testParam : testParams)
1450                 {
1451                     testParam.dataChannelCount = channelCount;
1452                     testParam.wrongExpectation = wrongExpected;
1453                     if (channelCount > 1 || wrongExpected)
1454                     {
1455                         if (testParam.opType != OpType::Expect || testParam.dataClass != DataClass::StorageBuffer)
1456                         {
1457                             continue;
1458                         }
1459 
1460                         if (channelCount > 1)
1461                         {
1462                             testParam.testName = testParam.testName + "_vec" + std::to_string(channelCount);
1463                         }
1464 
1465                         if (wrongExpected)
1466                         {
1467                             testParam.testName = testParam.testName + "_wrong_expected";
1468                         }
1469                     }
1470 
1471                     testParam.shaderType = stage;
1472 
1473                     switch (testParam.opType)
1474                     {
1475                     case OpType::Expect:
1476                         expectGroupTest->addChild(new ShaderExpectAssumeCase(testCtx, testParam));
1477                         break;
1478                     case OpType::Assume:
1479                         assumeGroupTest->addChild(new ShaderExpectAssumeCase(testCtx, testParam));
1480                         break;
1481                     default:
1482                         assert(false);
1483                     }
1484                 }
1485             }
1486         }
1487 
1488         stageGroupTest->addChild(expectGroupTest.release());
1489         stageGroupTest->addChild(assumeGroupTest.release());
1490 
1491         testGroup->addChild(stageGroupTest.release());
1492     }
1493 }
1494 
1495 } // namespace
1496 
createShaderExpectAssumeTests(tcu::TestContext & testCtx)1497 tcu::TestCaseGroup *createShaderExpectAssumeTests(tcu::TestContext &testCtx)
1498 {
1499     return createTestGroup(testCtx, "shader_expect_assume", addShaderExpectAssumeTests);
1500 }
1501 
1502 } // namespace shaderexecutor
1503 } // namespace vkt
1504