1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2023 LunarG, Inc.
6 * Copyright (c) 2023 Nintendo
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Wrapper that can construct monolithic pipeline or use
23 VK_EXT_shader_object for compute pipeline construction.
24 *//*--------------------------------------------------------------------*/
25
26 #include "vkComputePipelineConstructionUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29
30 namespace vk
31 {
32
checkShaderObjectRequirements(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,ComputePipelineConstructionType computePipelineConstructionType)33 void checkShaderObjectRequirements (const InstanceInterface& vki,
34 VkPhysicalDevice physicalDevice,
35 ComputePipelineConstructionType computePipelineConstructionType)
36 {
37 if (computePipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
38 return;
39
40 const auto& supportedExtensions = enumerateCachedDeviceExtensionProperties(vki, physicalDevice);
41 if (!isExtensionStructSupported(supportedExtensions, RequiredExtension("VK_EXT_shader_object")))
42 TCU_THROW(NotSupportedError, "VK_EXT_shader_object not supported");
43 }
44
45 struct ComputePipelineWrapper::InternalData
46 {
47 const DeviceInterface& vk;
48 VkDevice device;
49 const ComputePipelineConstructionType pipelineConstructionType;
50
51 // initialize with most common values
InternalDatavk::ComputePipelineWrapper::InternalData52 InternalData(const DeviceInterface& vkd, VkDevice vkDevice, const ComputePipelineConstructionType constructionType)
53 : vk (vkd)
54 , device (vkDevice)
55 , pipelineConstructionType (constructionType)
56 {
57 }
58 };
59
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType)60 ComputePipelineWrapper::ComputePipelineWrapper (const DeviceInterface& vk,
61 VkDevice device,
62 const ComputePipelineConstructionType pipelineConstructionType)
63 : m_internalData (new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
64 , m_programBinary (DE_NULL)
65 , m_specializationInfo {}
66 , m_pipelineCreateFlags ((VkPipelineCreateFlags)0u)
67 , m_pipelineCreatePNext (DE_NULL)
68 , m_subgroupSize (0)
69 {
70
71 }
72
ComputePipelineWrapper(const DeviceInterface & vk,VkDevice device,const ComputePipelineConstructionType pipelineConstructionType,const ProgramBinary & programBinary)73 ComputePipelineWrapper::ComputePipelineWrapper (const DeviceInterface& vk,
74 VkDevice device,
75 const ComputePipelineConstructionType pipelineConstructionType,
76 const ProgramBinary& programBinary)
77 : m_internalData (new ComputePipelineWrapper::InternalData(vk, device, pipelineConstructionType))
78 , m_programBinary (&programBinary)
79 , m_specializationInfo {}
80 , m_pipelineCreateFlags ((VkPipelineCreateFlags)0u)
81 , m_pipelineCreatePNext (DE_NULL)
82 , m_subgroupSize (0)
83 {
84 }
85
ComputePipelineWrapper(const ComputePipelineWrapper & rhs)86 ComputePipelineWrapper::ComputePipelineWrapper (const ComputePipelineWrapper& rhs) noexcept
87 : m_internalData (rhs.m_internalData)
88 , m_programBinary (rhs.m_programBinary)
89 , m_descriptorSetLayouts (rhs.m_descriptorSetLayouts)
90 , m_specializationInfo (rhs.m_specializationInfo)
91 , m_pipelineCreateFlags (rhs.m_pipelineCreateFlags)
92 , m_pipelineCreatePNext (rhs.m_pipelineCreatePNext)
93 , m_subgroupSize (rhs.m_subgroupSize)
94 {
95 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
96 #ifndef CTS_USES_VULKANSC
97 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
98 #endif
99 }
100
ComputePipelineWrapper(ComputePipelineWrapper && rhs)101 ComputePipelineWrapper::ComputePipelineWrapper (ComputePipelineWrapper&& rhs) noexcept
102 : m_internalData (rhs.m_internalData)
103 , m_programBinary (rhs.m_programBinary)
104 , m_descriptorSetLayouts (rhs.m_descriptorSetLayouts)
105 , m_specializationInfo (rhs.m_specializationInfo)
106 , m_pipelineCreateFlags (rhs.m_pipelineCreateFlags)
107 , m_pipelineCreatePNext (rhs.m_pipelineCreatePNext)
108 , m_subgroupSize (rhs.m_subgroupSize)
109 {
110 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
111 #ifndef CTS_USES_VULKANSC
112 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
113 #endif
114 }
115
operator =(const ComputePipelineWrapper & rhs)116 ComputePipelineWrapper& ComputePipelineWrapper::operator= (const ComputePipelineWrapper& rhs) noexcept
117 {
118 m_internalData = rhs.m_internalData;
119 m_programBinary = rhs.m_programBinary;
120 m_descriptorSetLayouts = rhs.m_descriptorSetLayouts;
121 m_specializationInfo = rhs.m_specializationInfo;
122 m_pipelineCreateFlags = rhs.m_pipelineCreateFlags;
123 m_pipelineCreatePNext = rhs.m_pipelineCreatePNext;
124 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
125 #ifndef CTS_USES_VULKANSC
126 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
127 #endif
128 m_subgroupSize = rhs.m_subgroupSize;
129 return *this;
130 }
131
operator =(ComputePipelineWrapper && rhs)132 ComputePipelineWrapper& ComputePipelineWrapper::operator= (ComputePipelineWrapper&& rhs) noexcept
133 {
134 m_internalData = std::move(rhs.m_internalData);
135 m_programBinary = rhs.m_programBinary;
136 m_descriptorSetLayouts = std::move(rhs.m_descriptorSetLayouts);
137 m_specializationInfo = rhs.m_specializationInfo;
138 m_pipelineCreateFlags = rhs.m_pipelineCreateFlags;
139 m_pipelineCreatePNext = rhs.m_pipelineCreatePNext;
140 DE_ASSERT(rhs.m_pipeline.get() == DE_NULL);
141 #ifndef CTS_USES_VULKANSC
142 DE_ASSERT(rhs.m_shader.get() == DE_NULL);
143 #endif
144 m_subgroupSize = rhs.m_subgroupSize;
145 return *this;
146 }
147
setDescriptorSetLayout(VkDescriptorSetLayout descriptorSetLayout)148 void ComputePipelineWrapper::setDescriptorSetLayout (VkDescriptorSetLayout descriptorSetLayout)
149 {
150 m_descriptorSetLayouts = { descriptorSetLayout };
151 }
152
setDescriptorSetLayouts(deUint32 setLayoutCount,const VkDescriptorSetLayout * descriptorSetLayouts)153 void ComputePipelineWrapper::setDescriptorSetLayouts (deUint32 setLayoutCount, const VkDescriptorSetLayout* descriptorSetLayouts)
154 {
155 m_descriptorSetLayouts.assign(descriptorSetLayouts, descriptorSetLayouts + setLayoutCount);
156 }
157
setSpecializationInfo(VkSpecializationInfo specializationInfo)158 void ComputePipelineWrapper::setSpecializationInfo (VkSpecializationInfo specializationInfo)
159 {
160 m_specializationInfo = specializationInfo;
161 }
162
setPipelineCreateFlags(VkPipelineCreateFlags pipelineCreateFlags)163 void ComputePipelineWrapper::setPipelineCreateFlags (VkPipelineCreateFlags pipelineCreateFlags)
164 {
165 m_pipelineCreateFlags = pipelineCreateFlags;
166 }
167
setPipelineCreatePNext(void * pipelineCreatePNext)168 void ComputePipelineWrapper::setPipelineCreatePNext (void* pipelineCreatePNext)
169 {
170 m_pipelineCreatePNext = pipelineCreatePNext;
171 }
172
setSubgroupSize(uint32_t subgroupSize)173 void ComputePipelineWrapper::setSubgroupSize (uint32_t subgroupSize)
174 {
175 m_subgroupSize = subgroupSize;
176 }
buildPipeline(void)177 void ComputePipelineWrapper::buildPipeline (void)
178 {
179 const auto& vk = m_internalData->vk;
180 const auto& device = m_internalData->device;
181
182 VkSpecializationInfo* specializationInfo = m_specializationInfo.mapEntryCount > 0 ? &m_specializationInfo : DE_NULL;
183 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
184 {
185 DE_ASSERT(m_pipeline.get() == DE_NULL);
186 const Unique<VkShaderModule> shaderModule (createShaderModule(vk, device, *m_programBinary));
187 buildPipelineLayout();
188 m_pipeline = vk::makeComputePipeline(vk, device, *m_pipelineLayout, m_pipelineCreateFlags, m_pipelineCreatePNext, *shaderModule, 0u, specializationInfo, 0, m_subgroupSize);
189 }
190 else
191 {
192 #ifndef CTS_USES_VULKANSC
193 DE_ASSERT(m_shader.get() == DE_NULL);
194 buildPipelineLayout();
195 vk::VkShaderCreateInfoEXT createInfo =
196 {
197 vk::VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT, // VkStructureType sType;
198 DE_NULL, // const void* pNext;
199 0u, // VkShaderCreateFlagsEXT flags;
200 vk::VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlagBits stage;
201 0u, // VkShaderStageFlags nextStage;
202 vk::VK_SHADER_CODE_TYPE_SPIRV_EXT, // VkShaderCodeTypeEXT codeType;
203 m_programBinary->getSize(), // size_t codeSize;
204 m_programBinary->getBinary(), // const void* pCode;
205 "main", // const char* pName;
206 (deUint32)m_descriptorSetLayouts.size(), // uint32_t setLayoutCount;
207 m_descriptorSetLayouts.data(), // VkDescriptorSetLayout* pSetLayouts;
208 0u, // uint32_t pushConstantRangeCount;
209 DE_NULL, // const VkPushConstantRange* pPushConstantRanges;
210 specializationInfo, // const VkSpecializationInfo* pSpecializationInfo;
211 };
212
213 m_shader = createShader(vk, device, createInfo);
214
215 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_SHADER_OBJECT_BINARY)
216 {
217 size_t dataSize;
218 vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, DE_NULL);
219 std::vector<deUint8> data(dataSize);
220 vk.getShaderBinaryDataEXT(device, *m_shader, &dataSize, data.data());
221
222 createInfo.codeType = vk::VK_SHADER_CODE_TYPE_BINARY_EXT;
223 createInfo.codeSize = dataSize;
224 createInfo.pCode = data.data();
225
226 m_shader = createShader(vk, device, createInfo);
227 }
228 #endif
229 }
230 }
231
bind(VkCommandBuffer commandBuffer)232 void ComputePipelineWrapper::bind (VkCommandBuffer commandBuffer)
233 {
234 if (m_internalData->pipelineConstructionType == COMPUTE_PIPELINE_CONSTRUCTION_TYPE_PIPELINE)
235 {
236 m_internalData->vk.cmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, m_pipeline.get());
237 }
238 else
239 {
240 #ifndef CTS_USES_VULKANSC
241 const vk::VkShaderStageFlagBits stage = vk::VK_SHADER_STAGE_COMPUTE_BIT;
242 m_internalData->vk.cmdBindShadersEXT(commandBuffer, 1, &stage, &*m_shader);
243 #endif
244 }
245 }
246
buildPipelineLayout(void)247 void ComputePipelineWrapper::buildPipelineLayout (void)
248 {
249 m_pipelineLayout = makePipelineLayout(m_internalData->vk, m_internalData->device, m_descriptorSetLayouts);
250 }
251
getPipelineLayout(void)252 VkPipelineLayout ComputePipelineWrapper::getPipelineLayout (void)
253 {
254 return *m_pipelineLayout;
255 }
256
257 } // vk
258