• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <inttypes.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 
9 #include "common/macros.h"
10 #include "common/util.h"
11 #include "common/vk/barrier.h"
12 #include "radix_sort_vk_devaddr.h"
13 #include "shaders/push.h"
14 
15 //
16 //
17 //
18 
19 #ifdef RS_VK_ENABLE_DEBUG_UTILS
20 #include "common/vk/debug_utils.h"
21 #endif
22 
23 //
24 //
25 //
26 
27 #ifdef RS_VK_ENABLE_EXTENSIONS
28 #include "radix_sort_vk_ext.h"
29 #endif
30 
31 //
32 // FIXME(allanmac): memoize some of these calculations
33 //
34 void
radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs,uint32_t count,radix_sort_vk_memory_requirements_t * mr)35 radix_sort_vk_get_memory_requirements(radix_sort_vk_t const *               rs,
36                                       uint32_t                              count,
37                                       radix_sort_vk_memory_requirements_t * mr)
38 {
39   //
40   // Keyval size
41   //
42   mr->keyval_size = rs->config.keyval_dwords * sizeof(uint32_t);
43 
44   //
45   // Subgroup and workgroup sizes
46   //
47   uint32_t const histo_sg_size    = 1 << rs->config.histogram.subgroup_size_log2;
48   uint32_t const histo_wg_size    = 1 << rs->config.histogram.workgroup_size_log2;
49   uint32_t const prefix_sg_size   = 1 << rs->config.prefix.subgroup_size_log2;
50   uint32_t const scatter_wg_size  = 1 << rs->config.scatter.workgroup_size_log2;
51   uint32_t const internal_sg_size = MAX_MACRO(uint32_t, histo_sg_size, prefix_sg_size);
52 
53   //
54   // If for some reason count is zero then initialize appropriately.
55   //
56   if (count == 0)
57     {
58       mr->keyvals_size       = 0;
59       mr->keyvals_alignment  = mr->keyval_size * histo_sg_size;
60       mr->internal_size      = 0;
61       mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
62       mr->indirect_size      = 0;
63       mr->indirect_alignment = internal_sg_size * sizeof(uint32_t);
64     }
65   else
66     {
67       //
68       // Keyvals
69       //
70 
71       // Round up to the scatter block size.
72       //
73       // Then round up to the histogram block size.
74       //
75       // Fill the difference between this new count and the original keyval
76       // count.
77       //
78       // How many scatter blocks?
79       //
80       uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
81       uint32_t const scatter_blocks    = (count + scatter_block_kvs - 1) / scatter_block_kvs;
82       uint32_t const count_ru_scatter  = scatter_blocks * scatter_block_kvs;
83 
84       //
85       // How many histogram blocks?
86       //
87       // Note that it's OK to have more max-valued digits counted by the histogram
88       // than sorted by the scatters because the sort is stable.
89       //
90       uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
91       uint32_t const histo_blocks    = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
92       uint32_t const count_ru_histo  = histo_blocks * histo_block_kvs;
93 
94       mr->keyvals_size      = mr->keyval_size * count_ru_histo;
95       mr->keyvals_alignment = mr->keyval_size * histo_sg_size;
96 
97       //
98       // Internal
99       //
100       // NOTE: Assumes .histograms are before .partitions.
101       //
102       // Last scatter workgroup skips writing to a partition.
103       //
104       // One histogram per (keyval byte + partitions)
105       //
106       uint32_t const partitions = scatter_blocks - 1;
107 
108       mr->internal_size      = (mr->keyval_size + partitions) * (RS_RADIX_SIZE * sizeof(uint32_t));
109       mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
110 
111       //
112       // Indirect
113       //
114       mr->indirect_size      = sizeof(struct rs_indirect_info);
115       mr->indirect_alignment = sizeof(struct u32vec4);
116     }
117 }
118 
119 //
120 //
121 //
122 #ifdef RS_VK_ENABLE_DEBUG_UTILS
123 
124 static void
rs_debug_utils_set(VkDevice device,struct radix_sort_vk * rs)125 rs_debug_utils_set(VkDevice device, struct radix_sort_vk * rs)
126 {
127   if (pfn_vkSetDebugUtilsObjectNameEXT != NULL)
128     {
129       VkDebugUtilsObjectNameInfoEXT duoni = {
130         .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT,
131         .pNext      = NULL,
132         .objectType = VK_OBJECT_TYPE_PIPELINE,
133       };
134 
135       duoni.objectHandle = (uint64_t)rs->pipelines.named.init;
136       duoni.pObjectName  = "radix_sort_init";
137       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
138 
139       duoni.objectHandle = (uint64_t)rs->pipelines.named.fill;
140       duoni.pObjectName  = "radix_sort_fill";
141       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
142 
143       duoni.objectHandle = (uint64_t)rs->pipelines.named.histogram;
144       duoni.pObjectName  = "radix_sort_histogram";
145       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
146 
147       duoni.objectHandle = (uint64_t)rs->pipelines.named.prefix;
148       duoni.pObjectName  = "radix_sort_prefix";
149       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
150 
151       duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].even;
152       duoni.pObjectName  = "radix_sort_scatter_0_even";
153       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
154 
155       duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].odd;
156       duoni.pObjectName  = "radix_sort_scatter_0_odd";
157       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
158 
159       if (rs->config.keyval_dwords >= 2)
160         {
161           duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].even;
162           duoni.pObjectName  = "radix_sort_scatter_1_even";
163           vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
164 
165           duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].odd;
166           duoni.pObjectName  = "radix_sort_scatter_1_odd";
167           vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
168         }
169     }
170 }
171 
172 #endif
173 
174 //
175 // How many pipelines are there?
176 //
177 static uint32_t
rs_pipeline_count(struct radix_sort_vk const * rs)178 rs_pipeline_count(struct radix_sort_vk const * rs)
179 {
180   return 1 +                            // init
181          1 +                            // fill
182          1 +                            // histogram
183          1 +                            // prefix
184          2 * rs->config.keyval_dwords;  // scatters.even/odd[keyval_dwords]
185 }
186 
187 radix_sort_vk_t *
radix_sort_vk_create(VkDevice device,VkAllocationCallbacks const * ac,VkPipelineCache pc,const uint32_t * const * spv,const uint32_t * spv_sizes,struct radix_sort_vk_target_config config)188 radix_sort_vk_create(VkDevice                           device,
189                     VkAllocationCallbacks const *      ac,
190                     VkPipelineCache                    pc,
191                     const uint32_t* const*             spv,
192                     const uint32_t*                    spv_sizes,
193                     struct radix_sort_vk_target_config config)
194 {
195   //
196   // Allocate radix_sort_vk
197   //
198   struct radix_sort_vk * const rs = calloc(1, sizeof(*rs));
199 
200   //
201   // Save the config for layer
202   //
203   rs->config = config;
204 
205   //
206   // How many pipelines?
207   //
208   uint32_t const pipeline_count = rs_pipeline_count(rs);
209 
210   //
211   // Prepare to create pipelines
212   //
213   VkPushConstantRange const pcr[] = {
214     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
215       .offset     = 0,
216       .size       = sizeof(struct rs_push_init) },
217 
218     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
219       .offset     = 0,
220       .size       = sizeof(struct rs_push_fill) },
221 
222     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
223       .offset     = 0,
224       .size       = sizeof(struct rs_push_histogram) },
225 
226     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
227       .offset     = 0,
228       .size       = sizeof(struct rs_push_prefix) },
229 
230     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
231       .offset     = 0,
232       .size       = sizeof(struct rs_push_scatter) },  // scatter_0_even
233 
234     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
235       .offset     = 0,
236       .size       = sizeof(struct rs_push_scatter) },  // scatter_0_odd
237 
238     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
239       .offset     = 0,
240       .size       = sizeof(struct rs_push_scatter) },  // scatter_1_even
241 
242     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
243       .offset     = 0,
244       .size       = sizeof(struct rs_push_scatter) },  // scatter_1_odd
245   };
246 
247   VkPipelineLayoutCreateInfo plci = {
248 
249     .sType                  = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
250     .pNext                  = NULL,
251     .flags                  = 0,
252     .setLayoutCount         = 0,
253     .pSetLayouts            = NULL,
254     .pushConstantRangeCount = 1,
255     // .pPushConstantRanges = pcr + ii;
256   };
257 
258   for (uint32_t ii = 0; ii < pipeline_count; ii++)
259     {
260       plci.pPushConstantRanges = pcr + ii;
261 
262       if (vkCreatePipelineLayout(device, &plci, NULL, rs->pipeline_layouts.handles + ii) != VK_SUCCESS)
263         goto fail_layout;
264     }
265 
266   //
267   // Create compute pipelines
268   //
269   VkShaderModuleCreateInfo smci = {
270 
271     .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
272     .pNext = NULL,
273     .flags = 0,
274     // .codeSize = ar_entries[...].size;
275     // .pCode    = ar_data + ...;
276   };
277 
278   VkShaderModule sms[ARRAY_LENGTH_MACRO(rs->pipelines.handles)] = {0};
279 
280   for (uint32_t ii = 0; ii < pipeline_count; ii++)
281     {
282       smci.codeSize = spv_sizes[ii];
283       smci.pCode    = spv[ii];
284 
285       if (vkCreateShaderModule(device, &smci, ac, sms + ii) != VK_SUCCESS)
286         goto fail_shader;
287     }
288 
289     //
290     // If necessary, set the expected subgroup size
291     //
292 #define RS_SUBGROUP_SIZE_CREATE_INFO_SET(size_)                                                    \
293   {                                                                                                \
294     .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT,       \
295     .pNext = NULL,                                                                                 \
296     .requiredSubgroupSize = size_,                                                                 \
297   }
298 
299 #undef RS_SUBGROUP_SIZE_CREATE_INFO_NAME
300 #define RS_SUBGROUP_SIZE_CREATE_INFO_NAME(name_)                                                   \
301   RS_SUBGROUP_SIZE_CREATE_INFO_SET(1 << config.name_.subgroup_size_log2)
302 
303 #define RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(name_) RS_SUBGROUP_SIZE_CREATE_INFO_SET(0)
304 
305   VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT const rsscis[] = {
306     RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(init),       // init
307     RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(fill),       // fill
308     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(histogram),  // histogram
309     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(prefix),     // prefix
310     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[0].even
311     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[0].odd
312     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[1].even
313     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[1].odd
314   };
315 
316   //
317   // Define compute pipeline create infos
318   //
319 #undef RS_COMPUTE_PIPELINE_CREATE_INFO_DECL
320 #define RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(idx_)                                                 \
321   { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,                                       \
322     .pNext = NULL,                                                                                 \
323     .flags = 0,                                                                                    \
324     .stage = { .sType               = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,         \
325                .pNext               = NULL,                                                        \
326                .flags               = 0,                                                           \
327                .stage               = VK_SHADER_STAGE_COMPUTE_BIT,                                 \
328                .module              = sms[idx_],                                                   \
329                .pName               = "main",                                                      \
330                .pSpecializationInfo = NULL },                                                      \
331                                                                                                    \
332     .layout             = rs->pipeline_layouts.handles[idx_],                                      \
333     .basePipelineHandle = VK_NULL_HANDLE,                                                          \
334     .basePipelineIndex  = 0 }
335 
336   VkComputePipelineCreateInfo cpcis[] = {
337     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(0),  // init
338     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(1),  // fill
339     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(2),  // histogram
340     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(3),  // prefix
341     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(4),  // scatter[0].even
342     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(5),  // scatter[0].odd
343     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(6),  // scatter[1].even
344     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(7),  // scatter[1].odd
345   };
346 
347   //
348   // Which of these compute pipelines require subgroup size control?
349   //
350   for (uint32_t ii = 0; ii < pipeline_count; ii++)
351     {
352       if (rsscis[ii].requiredSubgroupSize > 1)
353         {
354           cpcis[ii].stage.pNext = rsscis + ii;
355         }
356     }
357 
358   //
359   // Create the compute pipelines
360   //
361   if (vkCreateComputePipelines(device, pc, pipeline_count, cpcis, ac, rs->pipelines.handles) != VK_SUCCESS)
362     goto fail_pipeline;
363 
364   //
365   // Shader modules can be destroyed now
366   //
367   for (uint32_t ii = 0; ii < pipeline_count; ii++)
368     {
369       vkDestroyShaderModule(device, sms[ii], ac);
370     }
371 
372 #ifdef RS_VK_ENABLE_DEBUG_UTILS
373   //
374   // Tag pipelines with names
375   //
376   rs_debug_utils_set(device, rs);
377 #endif
378 
379   //
380   // Calculate "internal" buffer offsets
381   //
382   size_t const keyval_bytes = rs->config.keyval_dwords * sizeof(uint32_t);
383 
384   // the .range calculation assumes an 8-bit radix
385   rs->internal.histograms.offset = 0;
386   rs->internal.histograms.range  = keyval_bytes * (RS_RADIX_SIZE * sizeof(uint32_t));
387 
388   //
389   // NOTE(allanmac): The partitions.offset must be aligned differently if
390   // RS_RADIX_LOG2 is less than the target's subgroup size log2.  At this time,
391   // no GPU that meets this criteria.
392   //
393   rs->internal.partitions.offset = rs->internal.histograms.offset + rs->internal.histograms.range;
394 
395   return rs;
396 
397 fail_pipeline:
398   for (uint32_t ii = 0; ii < pipeline_count; ii++)
399     {
400       vkDestroyPipeline(device, rs->pipelines.handles[ii], ac);
401     }
402 fail_shader:
403   for (uint32_t ii = 0; ii < pipeline_count; ii++)
404     {
405       vkDestroyShaderModule(device, sms[ii], ac);
406     }
407 fail_layout:
408    for (uint32_t ii = 0; ii < pipeline_count; ii++)
409     {
410       vkDestroyPipelineLayout(device, rs->pipeline_layouts.handles[ii], ac);
411     }
412 
413   free(rs);
414   return NULL;
415 }
416 
417 //
418 //
419 //
420 void
radix_sort_vk_destroy(struct radix_sort_vk * rs,VkDevice d,VkAllocationCallbacks const * const ac)421 radix_sort_vk_destroy(struct radix_sort_vk * rs, VkDevice d, VkAllocationCallbacks const * const ac)
422 {
423   uint32_t const pipeline_count = rs_pipeline_count(rs);
424 
425   // destroy pipelines
426   for (uint32_t ii = 0; ii < pipeline_count; ii++)
427     {
428       vkDestroyPipeline(d, rs->pipelines.handles[ii], ac);
429     }
430 
431   // destroy pipeline layouts
432   for (uint32_t ii = 0; ii < pipeline_count; ii++)
433     {
434       vkDestroyPipelineLayout(d, rs->pipeline_layouts.handles[ii], ac);
435     }
436 
437   free(rs);
438 }
439 
440 //
441 //
442 //
443 static VkDeviceAddress
rs_get_devaddr(VkDevice device,VkDescriptorBufferInfo const * dbi)444 rs_get_devaddr(VkDevice device, VkDescriptorBufferInfo const * dbi)
445 {
446   VkBufferDeviceAddressInfo const bdai = {
447 
448     .sType  = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
449     .pNext  = NULL,
450     .buffer = dbi->buffer
451   };
452 
453   VkDeviceAddress const devaddr = vkGetBufferDeviceAddress(device, &bdai) + dbi->offset;
454 
455   return devaddr;
456 }
457 
458 //
459 //
460 //
461 #ifdef RS_VK_ENABLE_EXTENSIONS
462 
463 void
rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,VkCommandBuffer cb,VkPipelineStageFlagBits pipeline_stage)464 rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,
465                            VkCommandBuffer                       cb,
466                            VkPipelineStageFlagBits               pipeline_stage)
467 {
468   if ((ext_timestamps != NULL) &&
469       (ext_timestamps->timestamps_set < ext_timestamps->timestamp_count))
470     {
471       vkCmdWriteTimestamp(cb,
472                           pipeline_stage,
473                           ext_timestamps->timestamps,
474                           ext_timestamps->timestamps_set++);
475     }
476 }
477 
478 #endif
479 
480 //
481 //
482 //
483 
484 #ifdef RS_VK_ENABLE_EXTENSIONS
485 
486 struct radix_sort_vk_ext_base
487 {
488   void *                      ext;
489   enum radix_sort_vk_ext_type type;
490 };
491 
492 #endif
493 
494 //
495 //
496 //
497 void
radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_devaddr_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)498 radix_sort_vk_sort_devaddr(radix_sort_vk_t const *                   rs,
499                            radix_sort_vk_sort_devaddr_info_t const * info,
500                            VkDevice                                  device,
501                            VkCommandBuffer                           cb,
502                            VkDeviceAddress *                         keyvals_sorted)
503 {
504   //
505   // Anything to do?
506   //
507   if ((info->count <= 1) || (info->key_bits == 0))
508     {
509       *keyvals_sorted = info->keyvals_even.devaddr;
510 
511       return;
512     }
513 
514 #ifdef RS_VK_ENABLE_EXTENSIONS
515   //
516   // Any extensions?
517   //
518   struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
519 
520   void * ext_next = info->ext;
521 
522   while (ext_next != NULL)
523     {
524       struct radix_sort_vk_ext_base * const base = ext_next;
525 
526       switch (base->type)
527         {
528           case RS_VK_EXT_TIMESTAMPS:
529             ext_timestamps                 = ext_next;
530             ext_timestamps->timestamps_set = 0;
531             break;
532         }
533 
534       ext_next = base->ext;
535     }
536 #endif
537 
538     ////////////////////////////////////////////////////////////////////////
539     //
540     // OVERVIEW
541     //
542     //   1. Pad the keyvals in `scatter_even`.
543     //   2. Zero the `histograms` and `partitions`.
544     //      --- BARRIER ---
545     //   3. HISTOGRAM is dispatched before PREFIX.
546     //      --- BARRIER ---
547     //   4. PREFIX is dispatched before the first SCATTER.
548     //      --- BARRIER ---
549     //   5. One or more SCATTER dispatches.
550     //
551     // Note that the `partitions` buffer can be zeroed anytime before the first
552     // scatter.
553     //
554     ////////////////////////////////////////////////////////////////////////
555 
556     //
557     // Label the command buffer
558     //
559 #ifdef RS_VK_ENABLE_DEBUG_UTILS
560   if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL)
561     {
562       VkDebugUtilsLabelEXT const label = {
563         .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
564         .pNext      = NULL,
565         .pLabelName = "radix_sort_vk_sort",
566       };
567 
568       pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label);
569     }
570 #endif
571 
572   //
573   // How many passes?
574   //
575   uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
576   uint32_t const keyval_bits  = keyval_bytes * 8;
577   uint32_t const key_bits     = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
578   uint32_t const passes       = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
579 
580   *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even.devaddr;
581 
582   ////////////////////////////////////////////////////////////////////////
583   //
584   // PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
585   //
586   // Pad fractional blocks with max-valued keyvals.
587   //
588   // Zero the histograms and partitions buffer.
589   //
590   // This assumes the partitions follow the histograms.
591   //
592 #ifdef RS_VK_ENABLE_EXTENSIONS
593   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
594 #endif
595 
596   //
597   // FIXME(allanmac): Consider precomputing some of these values and hang them
598   // off `rs`.
599   //
600 
601   //
602   // How many scatter blocks?
603   //
604   uint32_t const scatter_wg_size   = 1 << rs->config.scatter.workgroup_size_log2;
605   uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
606   uint32_t const scatter_blocks    = (info->count + scatter_block_kvs - 1) / scatter_block_kvs;
607   uint32_t const count_ru_scatter  = scatter_blocks * scatter_block_kvs;
608 
609   //
610   // How many histogram blocks?
611   //
612   // Note that it's OK to have more max-valued digits counted by the histogram
613   // than sorted by the scatters because the sort is stable.
614   //
615   uint32_t const histo_wg_size   = 1 << rs->config.histogram.workgroup_size_log2;
616   uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
617   uint32_t const histo_blocks    = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
618   uint32_t const count_ru_histo  = histo_blocks * histo_block_kvs;
619 
620   //
621   // Fill with max values
622   //
623   if (count_ru_histo > info->count)
624     {
625       info->fill_buffer(cb,
626                         &info->keyvals_even,
627                         info->count * keyval_bytes,
628                         (count_ru_histo - info->count) * keyval_bytes,
629                         0xFFFFFFFF);
630     }
631 
632   //
633   // Zero histograms and invalidate partitions.
634   //
635   // Note that the partition invalidation only needs to be performed once
636   // because the even/odd scatter dispatches rely on the the previous pass to
637   // leave the partitions in an invalid state.
638   //
639   // Note that the last workgroup doesn't read/write a partition so it doesn't
640   // need to be initialized.
641   //
642   uint32_t const histo_partition_count = passes + scatter_blocks - 1;
643   uint32_t       pass_idx              = (keyval_bytes - passes);
644 
645   VkDeviceSize const fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
646 
647   info->fill_buffer(cb,
648                     &info->internal,
649                     rs->internal.histograms.offset + fill_base,
650                     histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)),
651                     0);
652 
653   ////////////////////////////////////////////////////////////////////////
654   //
655   // Pipeline: HISTOGRAM
656   //
657   // TODO(allanmac): All subgroups should try to process approximately the same
658   // number of blocks in order to minimize tail effects.  This was implemented
659   // and reverted but should be reimplemented and benchmarked later.
660   //
661 #ifdef RS_VK_ENABLE_EXTENSIONS
662   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TRANSFER_BIT);
663 #endif
664 
665   vk_barrier_transfer_w_to_compute_r(cb);
666 
667   // clang-format off
668   VkDeviceAddress const devaddr_histograms   = info->internal.devaddr + rs->internal.histograms.offset;
669   VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even.devaddr;
670   // clang-format on
671 
672   //
673   // Dispatch histogram
674   //
675   struct rs_push_histogram const push_histogram = {
676 
677     .devaddr_histograms = devaddr_histograms,
678     .devaddr_keyvals    = devaddr_keyvals_even,
679     .passes             = passes
680   };
681 
682   vkCmdPushConstants(cb,
683                      rs->pipeline_layouts.named.histogram,
684                      VK_SHADER_STAGE_COMPUTE_BIT,
685                      0,
686                      sizeof(push_histogram),
687                      &push_histogram);
688 
689   vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
690 
691   vkCmdDispatch(cb, histo_blocks, 1, 1);
692 
693   ////////////////////////////////////////////////////////////////////////
694   //
695   // Pipeline: PREFIX
696   //
697   // Launch one workgroup per pass.
698   //
699 #ifdef RS_VK_ENABLE_EXTENSIONS
700   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
701 #endif
702 
703   vk_barrier_compute_w_to_compute_r(cb);
704 
705   struct rs_push_prefix const push_prefix = {
706 
707     .devaddr_histograms = devaddr_histograms,
708   };
709 
710   vkCmdPushConstants(cb,
711                      rs->pipeline_layouts.named.prefix,
712                      VK_SHADER_STAGE_COMPUTE_BIT,
713                      0,
714                      sizeof(push_prefix),
715                      &push_prefix);
716 
717   vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
718 
719   vkCmdDispatch(cb, passes, 1, 1);
720 
721   ////////////////////////////////////////////////////////////////////////
722   //
723   // Pipeline: SCATTER
724   //
725 #ifdef RS_VK_ENABLE_EXTENSIONS
726   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
727 #endif
728 
729   vk_barrier_compute_w_to_compute_r(cb);
730 
731   // clang-format off
732   uint32_t        const histogram_offset    = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
733   VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
734   VkDeviceAddress const devaddr_partitions  = info->internal.devaddr + rs->internal.partitions.offset;
735   // clang-format on
736 
737   struct rs_push_scatter push_scatter = {
738 
739     .devaddr_keyvals_even = devaddr_keyvals_even,
740     .devaddr_keyvals_odd  = devaddr_keyvals_odd,
741     .devaddr_partitions   = devaddr_partitions,
742     .devaddr_histograms   = devaddr_histograms + histogram_offset,
743     .pass_offset          = (pass_idx & 3) * RS_RADIX_LOG2,
744   };
745 
746   {
747     uint32_t const pass_dword = pass_idx / 4;
748 
749     vkCmdPushConstants(cb,
750                        rs->pipeline_layouts.named.scatter[pass_dword].even,
751                        VK_SHADER_STAGE_COMPUTE_BIT,
752                        0,
753                        sizeof(push_scatter),
754                        &push_scatter);
755 
756     vkCmdBindPipeline(cb,
757                       VK_PIPELINE_BIND_POINT_COMPUTE,
758                       rs->pipelines.named.scatter[pass_dword].even);
759   }
760 
761   bool is_even = true;
762 
763   while (true)
764     {
765       vkCmdDispatch(cb, scatter_blocks, 1, 1);
766 
767       //
768       // Continue?
769       //
770       if (++pass_idx >= keyval_bytes)
771         break;
772 
773 #ifdef RS_VK_ENABLE_EXTENSIONS
774       rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
775 #endif
776       vk_barrier_compute_w_to_compute_r(cb);
777 
778       // clang-format off
779       is_even                         ^= true;
780       push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
781       push_scatter.pass_offset         = (pass_idx & 3) * RS_RADIX_LOG2;
782       // clang-format on
783 
784       uint32_t const pass_dword = pass_idx / 4;
785 
786       //
787       // Update push constants that changed
788       //
789       VkPipelineLayout const pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even  //
790                                           : rs->pipeline_layouts.named.scatter[pass_dword].odd;
791       vkCmdPushConstants(cb,
792                          pl,
793                          VK_SHADER_STAGE_COMPUTE_BIT,
794                          OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
795                          sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
796                          &push_scatter.devaddr_histograms);
797 
798       //
799       // Bind new pipeline
800       //
801       VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even  //
802                                    : rs->pipelines.named.scatter[pass_dword].odd;
803 
804       vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
805     }
806 
807 #ifdef RS_VK_ENABLE_EXTENSIONS
808   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
809 #endif
810 
811   //
812   // End the label
813   //
814 #ifdef RS_VK_ENABLE_DEBUG_UTILS
815   if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL)
816     {
817       pfn_vkCmdEndDebugUtilsLabelEXT(cb);
818     }
819 #endif
820 }
821 
822 //
823 //
824 //
825 void
radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_devaddr_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)826 radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const *                            rs,
827                                     radix_sort_vk_sort_indirect_devaddr_info_t const * info,
828                                     VkDevice                                           device,
829                                     VkCommandBuffer                                    cb,
830                                     VkDeviceAddress * keyvals_sorted)
831 {
832   //
833   // Anything to do?
834   //
835   if (info->key_bits == 0)
836     {
837       *keyvals_sorted = info->keyvals_even;
838       return;
839     }
840 
841 #ifdef RS_VK_ENABLE_EXTENSIONS
842   //
843   // Any extensions?
844   //
845   struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
846 
847   void * ext_next = info->ext;
848 
849   while (ext_next != NULL)
850     {
851       struct radix_sort_vk_ext_base * const base = ext_next;
852 
853       switch (base->type)
854         {
855           case RS_VK_EXT_TIMESTAMPS:
856             ext_timestamps                 = ext_next;
857             ext_timestamps->timestamps_set = 0;
858             break;
859         }
860 
861       ext_next = base->ext;
862     }
863 #endif
864 
865     ////////////////////////////////////////////////////////////////////////
866     //
867     // OVERVIEW
868     //
869     //   1. Init
870     //      --- BARRIER ---
871     //   2. Pad the keyvals in `scatter_even`.
872     //   3. Zero the `histograms` and `partitions`.
873     //      --- BARRIER ---
874     //   4. HISTOGRAM is dispatched before PREFIX.
875     //      --- BARRIER ---
876     //   5. PREFIX is dispatched before the first SCATTER.
877     //      --- BARRIER ---
878     //   6. One or more SCATTER dispatches.
879     //
880     // Note that the `partitions` buffer can be zeroed anytime before the first
881     // scatter.
882     //
883     ////////////////////////////////////////////////////////////////////////
884 
885     //
886     // Label the command buffer
887     //
888 #ifdef RS_VK_ENABLE_DEBUG_UTILS
889   if (pfn_vkCmdBeginDebugUtilsLabelEXT != NULL)
890     {
891       VkDebugUtilsLabelEXT const label = {
892         .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
893         .pNext      = NULL,
894         .pLabelName = "radix_sort_vk_sort_indirect",
895       };
896 
897       pfn_vkCmdBeginDebugUtilsLabelEXT(cb, &label);
898     }
899 #endif
900 
901   //
902   // How many passes?
903   //
904   uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
905   uint32_t const keyval_bits  = keyval_bytes * 8;
906   uint32_t const key_bits     = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
907   uint32_t const passes       = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
908   uint32_t       pass_idx     = (keyval_bytes - passes);
909 
910   *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even;
911 
912   //
913   // NOTE(allanmac): Some of these initializations appear redundant but for now
914   // we're going to assume the compiler will elide them.
915   //
916   // clang-format off
917   VkDeviceAddress const devaddr_info         = info->indirect.devaddr;
918   VkDeviceAddress const devaddr_count        = info->count;
919   VkDeviceAddress const devaddr_histograms   = info->internal + rs->internal.histograms.offset;
920   VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even;
921   // clang-format on
922 
923   //
924   // START
925   //
926 #ifdef RS_VK_ENABLE_EXTENSIONS
927   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
928 #endif
929 
930   //
931   // INIT
932   //
933   {
934     struct rs_push_init const push_init = {
935 
936       .devaddr_info  = devaddr_info,
937       .devaddr_count = devaddr_count,
938       .passes        = passes
939     };
940 
941     vkCmdPushConstants(cb,
942                        rs->pipeline_layouts.named.init,
943                        VK_SHADER_STAGE_COMPUTE_BIT,
944                        0,
945                        sizeof(push_init),
946                        &push_init);
947 
948     vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.init);
949 
950     vkCmdDispatch(cb, 1, 1, 1);
951   }
952 
953 #ifdef RS_VK_ENABLE_EXTENSIONS
954   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
955 #endif
956 
957   vk_barrier_compute_w_to_indirect_compute_r(cb);
958 
959   {
960     //
961     // PAD
962     //
963     struct rs_push_fill const push_pad = {
964 
965       .devaddr_info   = devaddr_info + offsetof(struct rs_indirect_info, pad),
966       .devaddr_dwords = devaddr_keyvals_even,
967       .dword          = 0xFFFFFFFF
968     };
969 
970     vkCmdPushConstants(cb,
971                        rs->pipeline_layouts.named.fill,
972                        VK_SHADER_STAGE_COMPUTE_BIT,
973                        0,
974                        sizeof(push_pad),
975                        &push_pad);
976 
977     vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
978 
979     info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.pad));
980   }
981 
982   //
983   // ZERO
984   //
985   {
986     VkDeviceSize const histo_offset = pass_idx * (sizeof(uint32_t) * RS_RADIX_SIZE);
987 
988     struct rs_push_fill const push_zero = {
989 
990       .devaddr_info   = devaddr_info + offsetof(struct rs_indirect_info, zero),
991       .devaddr_dwords = devaddr_histograms + histo_offset,
992       .dword          = 0
993     };
994 
995     vkCmdPushConstants(cb,
996                        rs->pipeline_layouts.named.fill,
997                        VK_SHADER_STAGE_COMPUTE_BIT,
998                        0,
999                        sizeof(push_zero),
1000                        &push_zero);
1001 
1002     vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
1003 
1004     info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.zero));
1005   }
1006 
1007 #ifdef RS_VK_ENABLE_EXTENSIONS
1008   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1009 #endif
1010 
1011   vk_barrier_compute_w_to_compute_r(cb);
1012 
1013   //
1014   // HISTOGRAM
1015   //
1016   {
1017     struct rs_push_histogram const push_histogram = {
1018 
1019       .devaddr_histograms = devaddr_histograms,
1020       .devaddr_keyvals    = devaddr_keyvals_even,
1021       .passes             = passes
1022     };
1023 
1024     vkCmdPushConstants(cb,
1025                        rs->pipeline_layouts.named.histogram,
1026                        VK_SHADER_STAGE_COMPUTE_BIT,
1027                        0,
1028                        sizeof(push_histogram),
1029                        &push_histogram);
1030 
1031     vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
1032 
1033     info->dispatch_indirect(cb,
1034                             &info->indirect,
1035                             offsetof(struct rs_indirect_info, dispatch.histogram));
1036   }
1037 
1038 #ifdef RS_VK_ENABLE_EXTENSIONS
1039   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1040 #endif
1041 
1042   vk_barrier_compute_w_to_compute_r(cb);
1043 
1044   //
1045   // PREFIX
1046   //
1047   {
1048     struct rs_push_prefix const push_prefix = {
1049       .devaddr_histograms = devaddr_histograms,
1050     };
1051 
1052     vkCmdPushConstants(cb,
1053                        rs->pipeline_layouts.named.prefix,
1054                        VK_SHADER_STAGE_COMPUTE_BIT,
1055                        0,
1056                        sizeof(push_prefix),
1057                        &push_prefix);
1058 
1059     vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
1060 
1061     vkCmdDispatch(cb, passes, 1, 1);
1062   }
1063 
1064 #ifdef RS_VK_ENABLE_EXTENSIONS
1065   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1066 #endif
1067 
1068   vk_barrier_compute_w_to_compute_r(cb);
1069 
1070   //
1071   // SCATTER
1072   //
1073   {
1074     // clang-format off
1075     uint32_t        const histogram_offset    = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
1076     VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
1077     VkDeviceAddress const devaddr_partitions  = info->internal + rs->internal.partitions.offset;
1078     // clang-format on
1079 
1080     struct rs_push_scatter push_scatter = {
1081       .devaddr_keyvals_even = devaddr_keyvals_even,
1082       .devaddr_keyvals_odd  = devaddr_keyvals_odd,
1083       .devaddr_partitions   = devaddr_partitions,
1084       .devaddr_histograms   = devaddr_histograms + histogram_offset,
1085       .pass_offset          = (pass_idx & 3) * RS_RADIX_LOG2,
1086     };
1087 
1088     {
1089       uint32_t const pass_dword = pass_idx / 4;
1090 
1091       vkCmdPushConstants(cb,
1092                          rs->pipeline_layouts.named.scatter[pass_dword].even,
1093                          VK_SHADER_STAGE_COMPUTE_BIT,
1094                          0,
1095                          sizeof(push_scatter),
1096                          &push_scatter);
1097 
1098       vkCmdBindPipeline(cb,
1099                         VK_PIPELINE_BIND_POINT_COMPUTE,
1100                         rs->pipelines.named.scatter[pass_dword].even);
1101     }
1102 
1103     bool is_even = true;
1104 
1105     while (true)
1106       {
1107         info->dispatch_indirect(cb,
1108                                 &info->indirect,
1109                                 offsetof(struct rs_indirect_info, dispatch.scatter));
1110 
1111         //
1112         // Continue?
1113         //
1114         if (++pass_idx >= keyval_bytes)
1115           break;
1116 
1117 #ifdef RS_VK_ENABLE_EXTENSIONS
1118         rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1119 #endif
1120 
1121         vk_barrier_compute_w_to_compute_r(cb);
1122 
1123         // clang-format off
1124         is_even                         ^= true;
1125         push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
1126         push_scatter.pass_offset         = (pass_idx & 3) * RS_RADIX_LOG2;
1127         // clang-format on
1128 
1129         uint32_t const pass_dword = pass_idx / 4;
1130 
1131         //
1132         // Update push constants that changed
1133         //
1134         VkPipelineLayout const pl = is_even
1135                                       ? rs->pipeline_layouts.named.scatter[pass_dword].even  //
1136                                       : rs->pipeline_layouts.named.scatter[pass_dword].odd;
1137         vkCmdPushConstants(
1138           cb,
1139           pl,
1140           VK_SHADER_STAGE_COMPUTE_BIT,
1141           OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
1142           sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
1143           &push_scatter.devaddr_histograms);
1144 
1145         //
1146         // Bind new pipeline
1147         //
1148         VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even  //
1149                                      : rs->pipelines.named.scatter[pass_dword].odd;
1150 
1151         vkCmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
1152       }
1153   }
1154 
1155 #ifdef RS_VK_ENABLE_EXTENSIONS
1156   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1157 #endif
1158 
1159   //
1160   // End the label
1161   //
1162 #ifdef RS_VK_ENABLE_DEBUG_UTILS
1163   if (pfn_vkCmdEndDebugUtilsLabelEXT != NULL)
1164     {
1165       pfn_vkCmdEndDebugUtilsLabelEXT(cb);
1166     }
1167 #endif
1168 }
1169 
1170 //
1171 //
1172 //
1173 static void
radix_sort_vk_fill_buffer(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset,VkDeviceSize size,uint32_t data)1174 radix_sort_vk_fill_buffer(VkCommandBuffer                     cb,
1175                           radix_sort_vk_buffer_info_t const * buffer_info,
1176                           VkDeviceSize                        offset,
1177                           VkDeviceSize                        size,
1178                           uint32_t                            data)
1179 {
1180   vkCmdFillBuffer(cb, buffer_info->buffer, buffer_info->offset + offset, size, data);
1181 }
1182 
1183 //
1184 //
1185 //
1186 void
radix_sort_vk_sort(radix_sort_vk_t const * rs,radix_sort_vk_sort_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1187 radix_sort_vk_sort(radix_sort_vk_t const *           rs,
1188                    radix_sort_vk_sort_info_t const * info,
1189                    VkDevice                          device,
1190                    VkCommandBuffer                   cb,
1191                    VkDescriptorBufferInfo *          keyvals_sorted)
1192 {
1193   struct radix_sort_vk_sort_devaddr_info const di = {
1194     .ext          = info->ext,
1195     .key_bits     = info->key_bits,
1196     .count        = info->count,
1197     .keyvals_even = { .buffer  = info->keyvals_even.buffer,
1198                       .offset  = info->keyvals_even.offset,
1199                       .devaddr = rs_get_devaddr(device, &info->keyvals_even) },
1200     .keyvals_odd  = rs_get_devaddr(device, &info->keyvals_odd),
1201     .internal     = { .buffer  = info->internal.buffer,
1202                       .offset  = info->internal.offset,
1203                       .devaddr = rs_get_devaddr(device, &info->internal), },
1204     .fill_buffer  = radix_sort_vk_fill_buffer,
1205   };
1206 
1207   VkDeviceAddress di_keyvals_sorted;
1208 
1209   radix_sort_vk_sort_devaddr(rs, &di, device, cb, &di_keyvals_sorted);
1210 
1211   *keyvals_sorted = (di_keyvals_sorted == di.keyvals_even.devaddr)  //
1212                       ? info->keyvals_even
1213                       : info->keyvals_odd;
1214 }
1215 
1216 //
1217 //
1218 //
1219 static void
radix_sort_vk_dispatch_indirect(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset)1220 radix_sort_vk_dispatch_indirect(VkCommandBuffer                     cb,
1221                                 radix_sort_vk_buffer_info_t const * buffer_info,
1222                                 VkDeviceSize                        offset)
1223 {
1224   vkCmdDispatchIndirect(cb, buffer_info->buffer, buffer_info->offset + offset);
1225 }
1226 
1227 //
1228 //
1229 //
1230 void
radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1231 radix_sort_vk_sort_indirect(radix_sort_vk_t const *                    rs,
1232                             radix_sort_vk_sort_indirect_info_t const * info,
1233                             VkDevice                                   device,
1234                             VkCommandBuffer                            cb,
1235                             VkDescriptorBufferInfo *                   keyvals_sorted)
1236 {
1237   struct radix_sort_vk_sort_indirect_devaddr_info const idi = {
1238     .ext               = info->ext,
1239     .key_bits          = info->key_bits,
1240     .count             = rs_get_devaddr(device, &info->count),
1241     .keyvals_even      = rs_get_devaddr(device, &info->keyvals_even),
1242     .keyvals_odd       = rs_get_devaddr(device, &info->keyvals_odd),
1243     .internal          = rs_get_devaddr(device, &info->internal),
1244     .indirect          = { .buffer  = info->indirect.buffer,
1245                            .offset  = info->indirect.offset,
1246                            .devaddr = rs_get_devaddr(device, &info->indirect) },
1247     .dispatch_indirect = radix_sort_vk_dispatch_indirect,
1248   };
1249 
1250   VkDeviceAddress idi_keyvals_sorted;
1251 
1252   radix_sort_vk_sort_indirect_devaddr(rs, &idi, device, cb, &idi_keyvals_sorted);
1253 
1254   *keyvals_sorted = (idi_keyvals_sorted == idi.keyvals_even)  //
1255                       ? info->keyvals_even
1256                       : info->keyvals_odd;
1257 }
1258 
1259 //
1260 //
1261 //
1262