• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2024 Collabora, Ltd.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "vk_shader.h"
25 
26 #include "vk_alloc.h"
27 #include "vk_command_buffer.h"
28 #include "vk_common_entrypoints.h"
29 #include "vk_descriptor_set_layout.h"
30 #include "vk_device.h"
31 #include "vk_nir.h"
32 #include "vk_physical_device.h"
33 #include "vk_pipeline.h"
34 
35 #include "util/mesa-sha1.h"
36 
37 void *
vk_shader_zalloc(struct vk_device * device,const struct vk_shader_ops * ops,gl_shader_stage stage,const VkAllocationCallbacks * alloc,size_t size)38 vk_shader_zalloc(struct vk_device *device,
39                  const struct vk_shader_ops *ops,
40                  gl_shader_stage stage,
41                  const VkAllocationCallbacks *alloc,
42                  size_t size)
43 {
44    /* For internal allocations, we need to allocate from the device scope
45     * because they might be put in pipeline caches.  Importantly, it is
46     * impossible for the client to get at this pointer and we apply this
47     * heuristic before we account for allocation fallbacks so this will only
48     * ever happen for internal shader objectx.
49     */
50    const VkSystemAllocationScope alloc_scope =
51       alloc == &device->alloc ? VK_SYSTEM_ALLOCATION_SCOPE_DEVICE
52                               : VK_SYSTEM_ALLOCATION_SCOPE_OBJECT;
53 
54    struct vk_shader *shader = vk_zalloc2(&device->alloc, alloc, size, 8,
55                                          alloc_scope);
56    if (shader == NULL)
57       return NULL;
58 
59    vk_object_base_init(device, &shader->base, VK_OBJECT_TYPE_SHADER_EXT);
60    shader->ops = ops;
61    shader->stage = stage;
62 
63    return shader;
64 }
65 
66 void
vk_shader_free(struct vk_device * device,const VkAllocationCallbacks * alloc,struct vk_shader * shader)67 vk_shader_free(struct vk_device *device,
68                const VkAllocationCallbacks *alloc,
69                struct vk_shader *shader)
70 {
71    vk_object_base_finish(&shader->base);
72    vk_free2(&device->alloc, alloc, shader);
73 }
74 
75 int
vk_shader_cmp_graphics_stages(gl_shader_stage a,gl_shader_stage b)76 vk_shader_cmp_graphics_stages(gl_shader_stage a, gl_shader_stage b)
77 {
78    static const int stage_order[MESA_SHADER_MESH + 1] = {
79       [MESA_SHADER_VERTEX] = 1,
80       [MESA_SHADER_TESS_CTRL] = 2,
81       [MESA_SHADER_TESS_EVAL] = 3,
82       [MESA_SHADER_GEOMETRY] = 4,
83       [MESA_SHADER_TASK] = 5,
84       [MESA_SHADER_MESH] = 6,
85       [MESA_SHADER_FRAGMENT] = 7,
86    };
87 
88    assert(a < ARRAY_SIZE(stage_order) && stage_order[a] > 0);
89    assert(b < ARRAY_SIZE(stage_order) && stage_order[b] > 0);
90 
91    return stage_order[a] - stage_order[b];
92 }
93 
94 struct stage_idx {
95    gl_shader_stage stage;
96    uint32_t idx;
97 };
98 
99 static int
cmp_stage_idx(const void * _a,const void * _b)100 cmp_stage_idx(const void *_a, const void *_b)
101 {
102    const struct stage_idx *a = _a, *b = _b;
103    return vk_shader_cmp_graphics_stages(a->stage, b->stage);
104 }
105 
106 static nir_shader *
vk_shader_to_nir(struct vk_device * device,const VkShaderCreateInfoEXT * info,const struct vk_pipeline_robustness_state * rs)107 vk_shader_to_nir(struct vk_device *device,
108                  const VkShaderCreateInfoEXT *info,
109                  const struct vk_pipeline_robustness_state *rs)
110 {
111    const struct vk_device_shader_ops *ops = device->shader_ops;
112 
113    const gl_shader_stage stage = vk_to_mesa_shader_stage(info->stage);
114    const nir_shader_compiler_options *nir_options =
115       ops->get_nir_options(device->physical, stage, rs);
116    struct spirv_to_nir_options spirv_options =
117       ops->get_spirv_options(device->physical, stage, rs);
118 
119    enum gl_subgroup_size subgroup_size = vk_get_subgroup_size(
120       vk_spirv_version(info->pCode, info->codeSize),
121       stage, info->pNext,
122       info->flags & VK_SHADER_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT_EXT,
123       info->flags &VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT);
124 
125    nir_shader *nir = vk_spirv_to_nir(device,
126                                      info->pCode, info->codeSize,
127                                      stage, info->pName,
128                                      subgroup_size,
129                                      info->pSpecializationInfo,
130                                      &spirv_options, nir_options,
131                                      false /* internal */, NULL);
132    if (nir == NULL)
133       return NULL;
134 
135    if (ops->preprocess_nir != NULL)
136       ops->preprocess_nir(device->physical, nir);
137 
138    return nir;
139 }
140 
141 struct set_layouts {
142    struct vk_descriptor_set_layout *set_layouts[MESA_VK_MAX_DESCRIPTOR_SETS];
143 };
144 
145 static void
vk_shader_compile_info_init(struct vk_shader_compile_info * info,struct set_layouts * set_layouts,const VkShaderCreateInfoEXT * vk_info,const struct vk_pipeline_robustness_state * rs,nir_shader * nir)146 vk_shader_compile_info_init(struct vk_shader_compile_info *info,
147                             struct set_layouts *set_layouts,
148                             const VkShaderCreateInfoEXT *vk_info,
149                             const struct vk_pipeline_robustness_state *rs,
150                             nir_shader *nir)
151 {
152    for (uint32_t sl = 0; sl < vk_info->setLayoutCount; sl++) {
153       set_layouts->set_layouts[sl] =
154          vk_descriptor_set_layout_from_handle(vk_info->pSetLayouts[sl]);
155    }
156 
157    *info = (struct vk_shader_compile_info) {
158       .stage = nir->info.stage,
159       .flags = vk_info->flags,
160       .next_stage_mask = vk_info->nextStage,
161       .nir = nir,
162       .robustness = rs,
163       .set_layout_count = vk_info->setLayoutCount,
164       .set_layouts = set_layouts->set_layouts,
165       .push_constant_range_count = vk_info->pushConstantRangeCount,
166       .push_constant_ranges = vk_info->pPushConstantRanges,
167    };
168 }
169 
170 struct vk_shader_bin_header {
171    char mesavkshaderbin[16];
172    VkDriverId driver_id;
173    uint8_t uuid[VK_UUID_SIZE];
174    uint32_t version;
175    uint64_t size;
176    uint8_t sha1[SHA1_DIGEST_LENGTH];
177    uint32_t _pad;
178 };
179 static_assert(sizeof(struct vk_shader_bin_header) == 72,
180               "This struct has no holes");
181 
182 static void
vk_shader_bin_header_init(struct vk_shader_bin_header * header,struct vk_physical_device * device)183 vk_shader_bin_header_init(struct vk_shader_bin_header *header,
184                           struct vk_physical_device *device)
185 {
186    *header = (struct vk_shader_bin_header) {
187       .mesavkshaderbin = "MesaVkShaderBin",
188       .driver_id = device->properties.driverID,
189    };
190 
191    memcpy(header->uuid, device->properties.shaderBinaryUUID, VK_UUID_SIZE);
192    header->version = device->properties.shaderBinaryVersion;
193 }
194 
195 static VkResult
vk_shader_serialize(struct vk_device * device,struct vk_shader * shader,struct blob * blob)196 vk_shader_serialize(struct vk_device *device,
197                     struct vk_shader *shader,
198                     struct blob *blob)
199 {
200    struct vk_shader_bin_header header;
201    vk_shader_bin_header_init(&header, device->physical);
202 
203    ASSERTED intptr_t header_offset = blob_reserve_bytes(blob, sizeof(header));
204    assert(header_offset == 0);
205 
206    bool success = shader->ops->serialize(device, shader, blob);
207    if (!success || blob->out_of_memory)
208       return VK_INCOMPLETE;
209 
210    /* Finalize and write the header */
211    header.size = blob->size;
212    if (blob->data != NULL) {
213       assert(sizeof(header) <= blob->size);
214 
215       struct mesa_sha1 sha1_ctx;
216       _mesa_sha1_init(&sha1_ctx);
217 
218       /* Hash the header with a zero SHA1 */
219       _mesa_sha1_update(&sha1_ctx, &header, sizeof(header));
220 
221       /* Hash the serialized data */
222       _mesa_sha1_update(&sha1_ctx, blob->data + sizeof(header),
223                         blob->size - sizeof(header));
224 
225       _mesa_sha1_final(&sha1_ctx, header.sha1);
226 
227       blob_overwrite_bytes(blob, header_offset, &header, sizeof(header));
228    }
229 
230    return VK_SUCCESS;
231 }
232 
233 static VkResult
vk_shader_deserialize(struct vk_device * device,size_t data_size,const void * data,const VkAllocationCallbacks * pAllocator,struct vk_shader ** shader_out)234 vk_shader_deserialize(struct vk_device *device,
235                       size_t data_size, const void *data,
236                       const VkAllocationCallbacks* pAllocator,
237                       struct vk_shader **shader_out)
238 {
239    const struct vk_device_shader_ops *ops = device->shader_ops;
240 
241    struct blob_reader blob;
242    blob_reader_init(&blob, data, data_size);
243 
244    struct vk_shader_bin_header header, ref_header;
245    blob_copy_bytes(&blob, &header, sizeof(header));
246    if (blob.overrun)
247       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
248 
249    vk_shader_bin_header_init(&ref_header, device->physical);
250 
251    if (memcmp(header.mesavkshaderbin, ref_header.mesavkshaderbin,
252               sizeof(header.mesavkshaderbin)))
253       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
254 
255    if (header.driver_id != ref_header.driver_id)
256       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
257 
258    if (memcmp(header.uuid, ref_header.uuid, sizeof(header.uuid)))
259       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
260 
261    /* From the Vulkan 1.3.276 spec:
262     *
263     *    "Guaranteed compatibility of shader binaries is expressed through a
264     *    combination of the shaderBinaryUUID and shaderBinaryVersion members
265     *    of the VkPhysicalDeviceShaderObjectPropertiesEXT structure queried
266     *    from a physical device. Binary shaders retrieved from a physical
267     *    device with a certain shaderBinaryUUID are guaranteed to be
268     *    compatible with all other physical devices reporting the same
269     *    shaderBinaryUUID and the same or higher shaderBinaryVersion."
270     *
271     * We handle the version check here on behalf of the driver and then pass
272     * the version into the driver's deserialize callback.
273     *
274     * If a driver doesn't want to mess with versions, they can always make the
275     * UUID a hash and always report version 0 and that will make this check
276     * effectively a no-op.
277     */
278    if (header.version > ref_header.version)
279       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
280 
281    /* Reject shader binaries that are the wrong size. */
282    if (header.size != data_size)
283       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
284 
285    assert(blob.current == (uint8_t *)data + sizeof(header));
286    blob.end = (uint8_t *)data + data_size;
287 
288    struct mesa_sha1 sha1_ctx;
289    _mesa_sha1_init(&sha1_ctx);
290 
291    /* Hash the header with a zero SHA1 */
292    struct vk_shader_bin_header sha1_header = header;
293    memset(sha1_header.sha1, 0, sizeof(sha1_header.sha1));
294    _mesa_sha1_update(&sha1_ctx, &sha1_header, sizeof(sha1_header));
295 
296    /* Hash the serialized data */
297    _mesa_sha1_update(&sha1_ctx, (uint8_t *)data + sizeof(header),
298                      data_size - sizeof(header));
299 
300    _mesa_sha1_final(&sha1_ctx, ref_header.sha1);
301    if (memcmp(header.sha1, ref_header.sha1, sizeof(header.sha1)))
302       return vk_error(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
303 
304    /* We've now verified that the header matches and that the data has the
305     * right SHA1 hash so it's safe to call into the driver.
306     */
307    return ops->deserialize(device, &blob, header.version,
308                            pAllocator, shader_out);
309 }
310 
311 VKAPI_ATTR VkResult VKAPI_CALL
vk_common_GetShaderBinaryDataEXT(VkDevice _device,VkShaderEXT _shader,size_t * pDataSize,void * pData)312 vk_common_GetShaderBinaryDataEXT(VkDevice _device,
313                                  VkShaderEXT _shader,
314                                  size_t *pDataSize,
315                                  void *pData)
316 {
317    VK_FROM_HANDLE(vk_device, device, _device);
318    VK_FROM_HANDLE(vk_shader, shader, _shader);
319    VkResult result;
320 
321    /* From the Vulkan 1.3.275 spec:
322     *
323     *    "If pData is NULL, then the size of the binary shader code of the
324     *    shader object, in bytes, is returned in pDataSize. Otherwise,
325     *    pDataSize must point to a variable set by the user to the size of the
326     *    buffer, in bytes, pointed to by pData, and on return the variable is
327     *    overwritten with the amount of data actually written to pData. If
328     *    pDataSize is less than the size of the binary shader code, nothing is
329     *    written to pData, and VK_INCOMPLETE will be returned instead of
330     *    VK_SUCCESS."
331     *
332     * This is annoying.  Unlike basically every other Vulkan data return
333     * method, we're not allowed to overwrite the client-provided memory region
334     * on VK_INCOMPLETE.  This means we either need to query the blob size
335     * up-front by serializing twice or we need to serialize into temporary
336     * memory and memcpy into the client-provided region.  We choose the first
337     * approach.
338     *
339     * In the common case, this means that vk_shader_ops::serialize will get
340     * called 3 times: Once for the client to get the size, once for us to
341     * validate the client's size, and once to actually write the data.  It's a
342     * bit heavy-weight but this shouldn't be in a hot path and this is better
343     * for memory efficiency.  Also, the vk_shader_ops::serialize should be
344     * pretty fast on a null blob.
345     */
346    struct blob blob;
347    blob_init_fixed(&blob, NULL, SIZE_MAX);
348    result = vk_shader_serialize(device, shader, &blob);
349    assert(result == VK_SUCCESS);
350 
351    if (result != VK_SUCCESS) {
352       *pDataSize = 0;
353       return result;
354    } else if (pData == NULL) {
355       *pDataSize = blob.size;
356       return VK_SUCCESS;
357    } else if (blob.size > *pDataSize) {
358       /* No data written */
359       *pDataSize = 0;
360       return VK_INCOMPLETE;
361    }
362 
363    blob_init_fixed(&blob, pData, *pDataSize);
364    result = vk_shader_serialize(device, shader, &blob);
365    assert(result == VK_SUCCESS);
366 
367    *pDataSize = blob.size;
368 
369    return result;
370 }
371 
372 /* The only place where we have "real" linking is graphics shaders and there
373  * is a limit as to how many of them can be linked together at one time.
374  */
375 #define VK_MAX_LINKED_SHADER_STAGES MESA_VK_MAX_GRAPHICS_PIPELINE_STAGES
376 
377 VKAPI_ATTR VkResult VKAPI_CALL
vk_common_CreateShadersEXT(VkDevice _device,uint32_t createInfoCount,const VkShaderCreateInfoEXT * pCreateInfos,const VkAllocationCallbacks * pAllocator,VkShaderEXT * pShaders)378 vk_common_CreateShadersEXT(VkDevice _device,
379                            uint32_t createInfoCount,
380                            const VkShaderCreateInfoEXT *pCreateInfos,
381                            const VkAllocationCallbacks *pAllocator,
382                            VkShaderEXT *pShaders)
383 {
384    VK_FROM_HANDLE(vk_device, device, _device);
385    const struct vk_device_shader_ops *ops = device->shader_ops;
386    VkResult first_fail_or_success = VK_SUCCESS;
387 
388    struct vk_pipeline_robustness_state rs = {
389       .storage_buffers = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT,
390       .uniform_buffers = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT,
391       .vertex_inputs = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT,
392       .images = VK_PIPELINE_ROBUSTNESS_IMAGE_BEHAVIOR_DISABLED_EXT,
393    };
394 
395    /* From the Vulkan 1.3.274 spec:
396     *
397     *    "When this function returns, whether or not it succeeds, it is
398     *    guaranteed that every element of pShaders will have been overwritten
399     *    by either VK_NULL_HANDLE or a valid VkShaderEXT handle."
400     *
401     * Zeroing up-front makes the error path easier.
402     */
403    memset(pShaders, 0, createInfoCount * sizeof(*pShaders));
404 
405    bool has_linked_spirv = false;
406    for (uint32_t i = 0; i < createInfoCount; i++) {
407       if (pCreateInfos[i].codeType == VK_SHADER_CODE_TYPE_SPIRV_EXT &&
408           (pCreateInfos[i].flags & VK_SHADER_CREATE_LINK_STAGE_BIT_EXT))
409          has_linked_spirv = true;
410    }
411 
412    uint32_t linked_count = 0;
413    struct stage_idx linked[VK_MAX_LINKED_SHADER_STAGES];
414 
415    for (uint32_t i = 0; i < createInfoCount; i++) {
416       const VkShaderCreateInfoEXT *vk_info = &pCreateInfos[i];
417       VkResult result = VK_SUCCESS;
418 
419       switch (vk_info->codeType) {
420       case VK_SHADER_CODE_TYPE_BINARY_EXT: {
421          /* This isn't required by Vulkan but we're allowed to fail binary
422           * import for basically any reason.  This seems like a pretty good
423           * reason.
424           */
425          if (has_linked_spirv &&
426              (vk_info->flags & VK_SHADER_CREATE_LINK_STAGE_BIT_EXT)) {
427             result = vk_errorf(device, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT,
428                                "Cannot mix linked binary and SPIR-V");
429             break;
430          }
431 
432          struct vk_shader *shader;
433          result = vk_shader_deserialize(device, vk_info->codeSize,
434                                         vk_info->pCode, pAllocator,
435                                         &shader);
436          if (result != VK_SUCCESS)
437             break;
438 
439          pShaders[i] = vk_shader_to_handle(shader);
440          break;
441       }
442 
443       case VK_SHADER_CODE_TYPE_SPIRV_EXT: {
444          if (vk_info->flags & VK_SHADER_CREATE_LINK_STAGE_BIT_EXT) {
445             /* Stash it and compile later */
446             assert(linked_count < ARRAY_SIZE(linked));
447             linked[linked_count++] = (struct stage_idx) {
448                .stage = vk_to_mesa_shader_stage(vk_info->stage),
449                .idx = i,
450             };
451          } else {
452             nir_shader *nir = vk_shader_to_nir(device, vk_info, &rs);
453             if (nir == NULL) {
454                result = vk_errorf(device, VK_ERROR_UNKNOWN,
455                                   "Failed to compile shader to NIR");
456                break;
457             }
458 
459             struct vk_shader_compile_info info;
460             struct set_layouts set_layouts;
461             vk_shader_compile_info_init(&info, &set_layouts,
462                                         vk_info, &rs, nir);
463 
464             struct vk_shader *shader;
465             result = ops->compile(device, 1, &info, NULL /* state */,
466                                   pAllocator, &shader);
467             if (result != VK_SUCCESS)
468                break;
469 
470             pShaders[i] = vk_shader_to_handle(shader);
471          }
472          break;
473       }
474 
475       default:
476          unreachable("Unknown shader code type");
477       }
478 
479       if (first_fail_or_success == VK_SUCCESS)
480          first_fail_or_success = result;
481    }
482 
483    if (linked_count > 0) {
484       struct set_layouts set_layouts[VK_MAX_LINKED_SHADER_STAGES];
485       struct vk_shader_compile_info infos[VK_MAX_LINKED_SHADER_STAGES];
486       VkResult result = VK_SUCCESS;
487 
488       /* Sort so we guarantee the driver always gets them in-order */
489       qsort(linked, linked_count, sizeof(*linked), cmp_stage_idx);
490 
491       /* Memset for easy error handling */
492       memset(infos, 0, sizeof(infos));
493 
494       for (uint32_t l = 0; l < linked_count; l++) {
495          const VkShaderCreateInfoEXT *vk_info = &pCreateInfos[linked[l].idx];
496 
497          nir_shader *nir = vk_shader_to_nir(device, vk_info, &rs);
498          if (nir == NULL) {
499             result = vk_errorf(device, VK_ERROR_UNKNOWN,
500                                "Failed to compile shader to NIR");
501             break;
502          }
503 
504          vk_shader_compile_info_init(&infos[l], &set_layouts[l],
505                                      vk_info, &rs, nir);
506       }
507 
508       if (result == VK_SUCCESS) {
509          struct vk_shader *shaders[VK_MAX_LINKED_SHADER_STAGES];
510 
511          result = ops->compile(device, linked_count, infos, NULL /* state */,
512                                pAllocator, shaders);
513          if (result == VK_SUCCESS) {
514             for (uint32_t l = 0; l < linked_count; l++)
515                pShaders[linked[l].idx] = vk_shader_to_handle(shaders[l]);
516          }
517       } else {
518          for (uint32_t l = 0; l < linked_count; l++) {
519             if (infos[l].nir != NULL)
520                ralloc_free(infos[l].nir);
521          }
522       }
523 
524       if (first_fail_or_success == VK_SUCCESS)
525          first_fail_or_success = result;
526    }
527 
528    return first_fail_or_success;
529 }
530 
531 VKAPI_ATTR void VKAPI_CALL
vk_common_DestroyShaderEXT(VkDevice _device,VkShaderEXT _shader,const VkAllocationCallbacks * pAllocator)532 vk_common_DestroyShaderEXT(VkDevice _device,
533                            VkShaderEXT _shader,
534                            const VkAllocationCallbacks *pAllocator)
535 {
536    VK_FROM_HANDLE(vk_device, device, _device);
537    VK_FROM_HANDLE(vk_shader, shader, _shader);
538 
539    if (shader == NULL)
540       return;
541 
542    vk_shader_destroy(device, shader, pAllocator);
543 }
544 
545 VKAPI_ATTR void VKAPI_CALL
vk_common_CmdBindShadersEXT(VkCommandBuffer commandBuffer,uint32_t stageCount,const VkShaderStageFlagBits * pStages,const VkShaderEXT * pShaders)546 vk_common_CmdBindShadersEXT(VkCommandBuffer commandBuffer,
547                             uint32_t stageCount,
548                             const VkShaderStageFlagBits *pStages,
549                             const VkShaderEXT *pShaders)
550 {
551    VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, commandBuffer);
552    struct vk_device *device = cmd_buffer->base.device;
553    const struct vk_device_shader_ops *ops = device->shader_ops;
554 
555    STACK_ARRAY(gl_shader_stage, stages, stageCount);
556    STACK_ARRAY(struct vk_shader *, shaders, stageCount);
557 
558    VkShaderStageFlags vk_stages = 0;
559    for (uint32_t i = 0; i < stageCount; i++) {
560       vk_stages |= pStages[i];
561       stages[i] = vk_to_mesa_shader_stage(pStages[i]);
562       shaders[i] = pShaders != NULL ? vk_shader_from_handle(pShaders[i]) : NULL;
563    }
564 
565    vk_cmd_unbind_pipelines_for_stages(cmd_buffer, vk_stages);
566    if (vk_stages & ~VK_SHADER_STAGE_COMPUTE_BIT)
567       vk_cmd_set_rp_attachments(cmd_buffer, ~0);
568 
569    ops->cmd_bind_shaders(cmd_buffer, stageCount, stages, shaders);
570 }
571