• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <inttypes.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8 
9 #include "common/macros.h"
10 #include "common/util.h"
11 #include "common/vk/barrier.h"
12 #include "radix_sort_vk_devaddr.h"
13 #include "shaders/push.h"
14 #include "shaders/config.h"
15 
16 #include "vk_command_buffer.h"
17 #include "vk_device.h"
18 
19 //
20 //
21 //
22 
23 #ifdef RS_VK_ENABLE_DEBUG_UTILS
24 #include "common/vk/debug_utils.h"
25 #endif
26 
27 //
28 //
29 //
30 
31 #ifdef RS_VK_ENABLE_EXTENSIONS
32 #include "radix_sort_vk_ext.h"
33 #endif
34 
35 //
36 // FIXME(allanmac): memoize some of these calculations
37 //
38 void
radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs,uint32_t count,radix_sort_vk_memory_requirements_t * mr)39 radix_sort_vk_get_memory_requirements(radix_sort_vk_t const *               rs,
40                                       uint32_t                              count,
41                                       radix_sort_vk_memory_requirements_t * mr)
42 {
43   //
44   // Keyval size
45   //
46   mr->keyval_size = rs->config.keyval_dwords * sizeof(uint32_t);
47 
48   //
49   // Subgroup and workgroup sizes
50   //
51   uint32_t const histo_sg_size    = 1 << rs->config.histogram.subgroup_size_log2;
52   uint32_t const histo_wg_size    = 1 << rs->config.histogram.workgroup_size_log2;
53   uint32_t const prefix_sg_size   = 1 << rs->config.prefix.subgroup_size_log2;
54   uint32_t const scatter_wg_size  = 1 << rs->config.scatter.workgroup_size_log2;
55   uint32_t const internal_sg_size = MAX_MACRO(uint32_t, histo_sg_size, prefix_sg_size);
56 
57   //
58   // If for some reason count is zero then initialize appropriately.
59   //
60   if (count == 0)
61     {
62       mr->keyvals_size       = 0;
63       mr->keyvals_alignment  = mr->keyval_size * histo_sg_size;
64       mr->internal_size      = 0;
65       mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
66       mr->indirect_size      = 0;
67       mr->indirect_alignment = internal_sg_size * sizeof(uint32_t);
68     }
69   else
70     {
71       //
72       // Keyvals
73       //
74 
75       // Round up to the scatter block size.
76       //
77       // Then round up to the histogram block size.
78       //
79       // Fill the difference between this new count and the original keyval
80       // count.
81       //
82       // How many scatter blocks?
83       //
84       uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
85       uint32_t const scatter_blocks    = (count + scatter_block_kvs - 1) / scatter_block_kvs;
86       uint32_t const count_ru_scatter  = scatter_blocks * scatter_block_kvs;
87 
88       //
89       // How many histogram blocks?
90       //
91       // Note that it's OK to have more max-valued digits counted by the histogram
92       // than sorted by the scatters because the sort is stable.
93       //
94       uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
95       uint32_t const histo_blocks    = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
96       uint32_t const count_ru_histo  = histo_blocks * histo_block_kvs;
97 
98       mr->keyvals_size      = mr->keyval_size * count_ru_histo;
99       mr->keyvals_alignment = mr->keyval_size * histo_sg_size;
100 
101       //
102       // Internal
103       //
104       // NOTE: Assumes .histograms are before .partitions.
105       //
106       // Last scatter workgroup skips writing to a partition.
107       // Each RS_RADIX_LOG2 (8) bit pass has a zero-initialized histogram. This
108       // is one RS_RADIX_SIZE histogram per keyval byte.
109       //
110       // The last scatter workgroup skips writing to a partition so it doesn't
111       // need to be allocated.
112       //
113       // If the device doesn't support "sequential dispatch" of workgroups, then
114       // we need a zero-initialized dword counter per radix pass in the keyval
115       // to atomically acquire a virtual workgroup id.  On sequentially
116       // dispatched devices, this is simply `gl_WorkGroupID.x`.
117       //
118       // The "internal" memory map looks like this:
119       //
120       //   +---------------------------------+ <-- 0
121       //   | histograms[keyval_size]         |
122       //   +---------------------------------+ <-- keyval_size                           * histo_size
123       //   | partitions[scatter_blocks_ru-1] |
124       //   +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size
125       //   | workgroup_ids[keyval_size]      |
126       //   +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size + workgroup_ids_size
127       //
128       // The `.workgroup_ids[]` are located after the last partition.
129       //
130       VkDeviceSize const histo_size = RS_RADIX_SIZE * sizeof(uint32_t);
131 
132       mr->internal_size      = (mr->keyval_size + scatter_blocks - 1) * histo_size;
133       mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
134 
135       //
136       // Support for nonsequential dispatch can be disabled.
137       //
138       VkDeviceSize const workgroup_ids_size = mr->keyval_size * sizeof(uint32_t);
139 
140       mr->internal_size += workgroup_ids_size;
141 
142       //
143       // Indirect
144       //
145       mr->indirect_size      = sizeof(struct rs_indirect_info);
146       mr->indirect_alignment = sizeof(struct u32vec4);
147     }
148 }
149 
150 //
151 //
152 //
153 #ifdef RS_VK_ENABLE_DEBUG_UTILS
154 
155 static void
rs_debug_utils_set(VkDevice device,struct radix_sort_vk * rs)156 rs_debug_utils_set(VkDevice device, struct radix_sort_vk * rs)
157 {
158   if (pfn_vkSetDebugUtilsObjectNameEXT != NULL)
159     {
160       VkDebugUtilsObjectNameInfoEXT duoni = {
161         .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT,
162         .pNext      = NULL,
163         .objectType = VK_OBJECT_TYPE_PIPELINE,
164       };
165 
166       duoni.objectHandle = (uint64_t)rs->pipelines.named.init;
167       duoni.pObjectName  = "radix_sort_init";
168       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
169 
170       duoni.objectHandle = (uint64_t)rs->pipelines.named.fill;
171       duoni.pObjectName  = "radix_sort_fill";
172       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
173 
174       duoni.objectHandle = (uint64_t)rs->pipelines.named.histogram;
175       duoni.pObjectName  = "radix_sort_histogram";
176       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
177 
178       duoni.objectHandle = (uint64_t)rs->pipelines.named.prefix;
179       duoni.pObjectName  = "radix_sort_prefix";
180       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
181 
182       duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].even;
183       duoni.pObjectName  = "radix_sort_scatter_0_even";
184       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
185 
186       duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].odd;
187       duoni.pObjectName  = "radix_sort_scatter_0_odd";
188       vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
189 
190       if (rs->config.keyval_dwords >= 2)
191         {
192           duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].even;
193           duoni.pObjectName  = "radix_sort_scatter_1_even";
194           vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
195 
196           duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].odd;
197           duoni.pObjectName  = "radix_sort_scatter_1_odd";
198           vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
199         }
200     }
201 }
202 
203 #endif
204 
205 //
206 // How many pipelines are there?
207 //
208 static uint32_t
rs_pipeline_count(struct radix_sort_vk const * rs)209 rs_pipeline_count(struct radix_sort_vk const * rs)
210 {
211   return 1 +                            // init
212          1 +                            // fill
213          1 +                            // histogram
214          1 +                            // prefix
215          2 * rs->config.keyval_dwords;  // scatters.even/odd[keyval_dwords]
216 }
217 
218 radix_sort_vk_t *
radix_sort_vk_create(VkDevice _device,VkAllocationCallbacks const * ac,VkPipelineCache pc,const uint32_t * const * spv,const uint32_t * spv_sizes,struct radix_sort_vk_target_config config)219 radix_sort_vk_create(VkDevice                           _device,
220                     VkAllocationCallbacks const *      ac,
221                     VkPipelineCache                    pc,
222                     const uint32_t* const*             spv,
223                     const uint32_t*                    spv_sizes,
224                     struct radix_sort_vk_target_config config)
225 {
226   VK_FROM_HANDLE(vk_device, device, _device);
227 
228   const struct vk_device_dispatch_table *disp = &device->dispatch_table;
229 
230   //
231   // Allocate radix_sort_vk
232   //
233   struct radix_sort_vk * const rs = calloc(1, sizeof(*rs));
234 
235   //
236   // Save the config for layer
237   //
238   rs->config = config;
239 
240   //
241   // How many pipelines?
242   //
243   uint32_t const pipeline_count = rs_pipeline_count(rs);
244 
245   //
246   // Prepare to create pipelines
247   //
248   VkPushConstantRange const pcr[] = {
249     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
250       .offset     = 0,
251       .size       = sizeof(struct rs_push_init) },
252 
253     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
254       .offset     = 0,
255       .size       = sizeof(struct rs_push_fill) },
256 
257     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
258       .offset     = 0,
259       .size       = sizeof(struct rs_push_histogram) },
260 
261     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
262       .offset     = 0,
263       .size       = sizeof(struct rs_push_prefix) },
264 
265     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
266       .offset     = 0,
267       .size       = sizeof(struct rs_push_scatter) },  // scatter_0_even
268 
269     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
270       .offset     = 0,
271       .size       = sizeof(struct rs_push_scatter) },  // scatter_0_odd
272 
273     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
274       .offset     = 0,
275       .size       = sizeof(struct rs_push_scatter) },  // scatter_1_even
276 
277     { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT,  //
278       .offset     = 0,
279       .size       = sizeof(struct rs_push_scatter) },  // scatter_1_odd
280   };
281 
282   uint32_t spec_constants[] = {
283     [RS_FILL_WORKGROUP_SIZE] = 1u << config.fill.workgroup_size_log2,
284     [RS_FILL_BLOCK_ROWS] = config.fill.block_rows,
285     [RS_HISTOGRAM_WORKGROUP_SIZE] = 1u << config.histogram.workgroup_size_log2,
286     [RS_HISTOGRAM_SUBGROUP_SIZE_LOG2] = config.histogram.subgroup_size_log2,
287     [RS_HISTOGRAM_BLOCK_ROWS] = config.histogram.block_rows,
288     [RS_PREFIX_WORKGROUP_SIZE] = 1u << config.prefix.workgroup_size_log2,
289     [RS_PREFIX_SUBGROUP_SIZE_LOG2] = config.prefix.subgroup_size_log2,
290     [RS_SCATTER_WORKGROUP_SIZE] = 1u << config.scatter.workgroup_size_log2,
291     [RS_SCATTER_SUBGROUP_SIZE_LOG2] = config.scatter.subgroup_size_log2,
292     [RS_SCATTER_BLOCK_ROWS] = config.scatter.block_rows,
293     [RS_SCATTER_NONSEQUENTIAL_DISPATCH] = config.nonsequential_dispatch,
294   };
295 
296   VkSpecializationMapEntry spec_map[ARRAY_LENGTH_MACRO(spec_constants)];
297 
298   for (uint32_t ii = 0; ii < ARRAY_LENGTH_MACRO(spec_constants); ii++)
299     {
300       spec_map[ii] = (VkSpecializationMapEntry) {
301         .constantID = ii,
302         .offset = sizeof(uint32_t) * ii,
303         .size = sizeof(uint32_t),
304       };
305     }
306 
307   VkSpecializationInfo spec_info = {
308     .mapEntryCount = ARRAY_LENGTH_MACRO(spec_map),
309     .pMapEntries = spec_map,
310     .dataSize = sizeof(spec_constants),
311     .pData = spec_constants,
312   };
313 
314   VkPipelineLayoutCreateInfo plci = {
315 
316     .sType                  = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
317     .pNext                  = NULL,
318     .flags                  = 0,
319     .setLayoutCount         = 0,
320     .pSetLayouts            = NULL,
321     .pushConstantRangeCount = 1,
322     // .pPushConstantRanges = pcr + ii;
323   };
324 
325   for (uint32_t ii = 0; ii < pipeline_count; ii++)
326     {
327       plci.pPushConstantRanges = pcr + ii;
328 
329       if (disp->CreatePipelineLayout(_device, &plci, NULL, rs->pipeline_layouts.handles + ii) != VK_SUCCESS)
330         goto fail_layout;
331     }
332 
333   //
334   // Create compute pipelines
335   //
336   VkShaderModuleCreateInfo smci = {
337 
338     .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
339     .pNext = NULL,
340     .flags = 0,
341     // .codeSize = ar_entries[...].size;
342     // .pCode    = ar_data + ...;
343   };
344 
345   VkShaderModule sms[ARRAY_LENGTH_MACRO(rs->pipelines.handles)] = {0};
346 
347   for (uint32_t ii = 0; ii < pipeline_count; ii++)
348     {
349       smci.codeSize = spv_sizes[ii];
350       smci.pCode    = spv[ii];
351 
352       if (disp->CreateShaderModule(_device, &smci, ac, sms + ii) != VK_SUCCESS)
353         goto fail_shader;
354     }
355 
356     //
357     // If necessary, set the expected subgroup size
358     //
359 #define RS_SUBGROUP_SIZE_CREATE_INFO_SET(size_)                                                    \
360   {                                                                                                \
361     .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT,       \
362     .pNext = NULL,                                                                                 \
363     .requiredSubgroupSize = size_,                                                                 \
364   }
365 
366 #undef RS_SUBGROUP_SIZE_CREATE_INFO_NAME
367 #define RS_SUBGROUP_SIZE_CREATE_INFO_NAME(name_)                                                   \
368   RS_SUBGROUP_SIZE_CREATE_INFO_SET(1 << config.name_.subgroup_size_log2)
369 
370 #define RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(name_) RS_SUBGROUP_SIZE_CREATE_INFO_SET(0)
371 
372   VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT const rsscis[] = {
373     RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(init),       // init
374     RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(fill),       // fill
375     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(histogram),  // histogram
376     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(prefix),     // prefix
377     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[0].even
378     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[0].odd
379     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[1].even
380     RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter),    // scatter[1].odd
381   };
382 
383   //
384   // Define compute pipeline create infos
385   //
386 #undef RS_COMPUTE_PIPELINE_CREATE_INFO_DECL
387 #define RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(idx_)                                                 \
388   { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,                                       \
389     .pNext = NULL,                                                                                 \
390     .flags = 0,                                                                                    \
391     .stage = { .sType               = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,         \
392                .pNext               = NULL,                                                        \
393                .flags               = VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT, \
394                .stage               = VK_SHADER_STAGE_COMPUTE_BIT,                                 \
395                .module              = sms[idx_],                                                   \
396                .pName               = "main",                                                      \
397                .pSpecializationInfo = &spec_info },                                                \
398                                                                                                    \
399     .layout             = rs->pipeline_layouts.handles[idx_],                                      \
400     .basePipelineHandle = VK_NULL_HANDLE,                                                          \
401     .basePipelineIndex  = 0 }
402 
403   VkComputePipelineCreateInfo cpcis[] = {
404     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(0),  // init
405     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(1),  // fill
406     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(2),  // histogram
407     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(3),  // prefix
408     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(4),  // scatter[0].even
409     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(5),  // scatter[0].odd
410     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(6),  // scatter[1].even
411     RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(7),  // scatter[1].odd
412   };
413 
414   //
415   // Which of these compute pipelines require subgroup size control?
416   //
417   for (uint32_t ii = 0; ii < pipeline_count; ii++)
418     {
419       if (rsscis[ii].requiredSubgroupSize > 1)
420         {
421           cpcis[ii].stage.pNext = rsscis + ii;
422         }
423     }
424 
425   //
426   // Create the compute pipelines
427   //
428   if (disp->CreateComputePipelines(_device, pc, pipeline_count, cpcis, ac, rs->pipelines.handles) != VK_SUCCESS)
429     goto fail_pipeline;
430 
431   //
432   // Shader modules can be destroyed now
433   //
434   for (uint32_t ii = 0; ii < pipeline_count; ii++)
435     {
436       disp->DestroyShaderModule(_device, sms[ii], ac);
437     }
438 
439 #ifdef RS_VK_ENABLE_DEBUG_UTILS
440   //
441   // Tag pipelines with names
442   //
443   rs_debug_utils_set(device, rs);
444 #endif
445 
446   //
447   // Calculate "internal" buffer offsets
448   //
449   size_t const keyval_bytes = rs->config.keyval_dwords * sizeof(uint32_t);
450 
451   // the .range calculation assumes an 8-bit radix
452   rs->internal.histograms.offset = 0;
453   rs->internal.histograms.range  = keyval_bytes * (RS_RADIX_SIZE * sizeof(uint32_t));
454 
455   //
456   // NOTE(allanmac): The partitions.offset must be aligned differently if
457   // RS_RADIX_LOG2 is less than the target's subgroup size log2.  At this time,
458   // no GPU that meets this criteria.
459   //
460   rs->internal.partitions.offset = rs->internal.histograms.offset + rs->internal.histograms.range;
461 
462   return rs;
463 
464 fail_pipeline:
465   for (uint32_t ii = 0; ii < pipeline_count; ii++)
466     {
467       disp->DestroyPipeline(_device, rs->pipelines.handles[ii], ac);
468     }
469 fail_shader:
470   for (uint32_t ii = 0; ii < pipeline_count; ii++)
471     {
472       disp->DestroyShaderModule(_device, sms[ii], ac);
473     }
474 fail_layout:
475    for (uint32_t ii = 0; ii < pipeline_count; ii++)
476     {
477       disp->DestroyPipelineLayout(_device, rs->pipeline_layouts.handles[ii], ac);
478     }
479 
480   free(rs);
481   return NULL;
482 }
483 
484 //
485 //
486 //
487 void
radix_sort_vk_destroy(struct radix_sort_vk * rs,VkDevice d,VkAllocationCallbacks const * const ac)488 radix_sort_vk_destroy(struct radix_sort_vk * rs, VkDevice d, VkAllocationCallbacks const * const ac)
489 {
490   VK_FROM_HANDLE(vk_device, device, d);
491 
492   const struct vk_device_dispatch_table *disp = &device->dispatch_table;
493 
494   uint32_t const pipeline_count = rs_pipeline_count(rs);
495 
496   // destroy pipelines
497   for (uint32_t ii = 0; ii < pipeline_count; ii++)
498     {
499       disp->DestroyPipeline(d, rs->pipelines.handles[ii], ac);
500     }
501 
502   // destroy pipeline layouts
503   for (uint32_t ii = 0; ii < pipeline_count; ii++)
504     {
505       disp->DestroyPipelineLayout(d, rs->pipeline_layouts.handles[ii], ac);
506     }
507 
508   free(rs);
509 }
510 
511 //
512 //
513 //
514 static VkDeviceAddress
rs_get_devaddr(VkDevice _device,VkDescriptorBufferInfo const * dbi)515 rs_get_devaddr(VkDevice _device, VkDescriptorBufferInfo const * dbi)
516 {
517   VK_FROM_HANDLE(vk_device, device, _device);
518 
519   const struct vk_device_dispatch_table *disp = &device->dispatch_table;
520 
521   VkBufferDeviceAddressInfo const bdai = {
522 
523     .sType  = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
524     .pNext  = NULL,
525     .buffer = dbi->buffer
526   };
527 
528   VkDeviceAddress const devaddr = disp->GetBufferDeviceAddress(_device, &bdai) + dbi->offset;
529 
530   return devaddr;
531 }
532 
533 //
534 //
535 //
536 #ifdef RS_VK_ENABLE_EXTENSIONS
537 
538 void
rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,VkCommandBuffer cb,VkPipelineStageFlagBits pipeline_stage)539 rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,
540                            VkCommandBuffer                       cb,
541                            VkPipelineStageFlagBits               pipeline_stage)
542 {
543   VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
544   const struct vk_device_dispatch_table *disp =
545     &cmd_buffer->base.device->dispatch_table;
546 
547   if ((ext_timestamps != NULL) &&
548       (ext_timestamps->timestamps_set < ext_timestamps->timestamp_count))
549     {
550       disp->CmdWriteTimestamp(cb,
551                               pipeline_stage,
552                               ext_timestamps->timestamps,
553                               ext_timestamps->timestamps_set++);
554     }
555 }
556 
557 #endif
558 
559 //
560 //
561 //
562 
563 #ifdef RS_VK_ENABLE_EXTENSIONS
564 
565 struct radix_sort_vk_ext_base
566 {
567   void *                      ext;
568   enum radix_sort_vk_ext_type type;
569 };
570 
571 #endif
572 
573 //
574 //
575 //
576 void
radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_devaddr_info_t const * info,VkDevice _device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)577 radix_sort_vk_sort_devaddr(radix_sort_vk_t const *                   rs,
578                            radix_sort_vk_sort_devaddr_info_t const * info,
579                            VkDevice                                  _device,
580                            VkCommandBuffer                           cb,
581                            VkDeviceAddress *                         keyvals_sorted)
582 {
583   VK_FROM_HANDLE(vk_device, device, _device);
584 
585   const struct vk_device_dispatch_table *disp = &device->dispatch_table;
586 
587   //
588   // Anything to do?
589   //
590   if ((info->count <= 1) || (info->key_bits == 0))
591     {
592       *keyvals_sorted = info->keyvals_even.devaddr;
593 
594       return;
595     }
596 
597 #ifdef RS_VK_ENABLE_EXTENSIONS
598   //
599   // Any extensions?
600   //
601   struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
602 
603   void * ext_next = info->ext;
604 
605   while (ext_next != NULL)
606     {
607       struct radix_sort_vk_ext_base * const base = ext_next;
608 
609       switch (base->type)
610         {
611           case RS_VK_EXT_TIMESTAMPS:
612             ext_timestamps                 = ext_next;
613             ext_timestamps->timestamps_set = 0;
614             break;
615         }
616 
617       ext_next = base->ext;
618     }
619 #endif
620 
621     ////////////////////////////////////////////////////////////////////////
622     //
623     // OVERVIEW
624     //
625     //   1. Pad the keyvals in `scatter_even`.
626     //   2. Zero the `histograms` and `partitions`.
627     //      --- BARRIER ---
628     //   3. HISTOGRAM is dispatched before PREFIX.
629     //      --- BARRIER ---
630     //   4. PREFIX is dispatched before the first SCATTER.
631     //      --- BARRIER ---
632     //   5. One or more SCATTER dispatches.
633     //
634     // Note that the `partitions` buffer can be zeroed anytime before the first
635     // scatter.
636     //
637     ////////////////////////////////////////////////////////////////////////
638 
639     //
640     // Label the command buffer
641     //
642 #ifdef RS_VK_ENABLE_DEBUG_UTILS
643    VkDebugUtilsLabelEXT const label = {
644      .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
645      .pNext      = NULL,
646      .pLabelName = "radix_sort_vk_sort",
647    };
648 
649    disp->CmdBeginDebugUtilsLabelEXT(cb, &label);
650 #endif
651 
652   //
653   // How many passes?
654   //
655   uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
656   uint32_t const keyval_bits  = keyval_bytes * 8;
657   uint32_t const key_bits     = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
658   uint32_t const passes       = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
659 
660   *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even.devaddr;
661 
662   ////////////////////////////////////////////////////////////////////////
663   //
664   // PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
665   //
666   // Pad fractional blocks with max-valued keyvals.
667   //
668   // Zero the histograms and partitions buffer.
669   //
670   // This assumes the partitions follow the histograms.
671   //
672 #ifdef RS_VK_ENABLE_EXTENSIONS
673   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
674 #endif
675 
676   //
677   // FIXME(allanmac): Consider precomputing some of these values and hang them
678   // off `rs`.
679   //
680 
681   //
682   // How many scatter blocks?
683   //
684   uint32_t const scatter_wg_size   = 1 << rs->config.scatter.workgroup_size_log2;
685   uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
686   uint32_t const scatter_blocks    = (info->count + scatter_block_kvs - 1) / scatter_block_kvs;
687   uint32_t const count_ru_scatter  = scatter_blocks * scatter_block_kvs;
688 
689   //
690   // How many histogram blocks?
691   //
692   // Note that it's OK to have more max-valued digits counted by the histogram
693   // than sorted by the scatters because the sort is stable.
694   //
695   uint32_t const histo_wg_size   = 1 << rs->config.histogram.workgroup_size_log2;
696   uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
697   uint32_t const histo_blocks    = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
698   uint32_t const count_ru_histo  = histo_blocks * histo_block_kvs;
699 
700   //
701   // Fill with max values
702   //
703   if (count_ru_histo > info->count)
704     {
705       info->fill_buffer(cb,
706                         &info->keyvals_even,
707                         info->count * keyval_bytes,
708                         (count_ru_histo - info->count) * keyval_bytes,
709                         0xFFFFFFFF);
710     }
711 
712   //
713   // Zero histograms and invalidate partitions.
714   //
715   // Note that the partition invalidation only needs to be performed once
716   // because the even/odd scatter dispatches rely on the the previous pass to
717   // leave the partitions in an invalid state.
718   //
719   // Note that the last workgroup doesn't read/write a partition so it doesn't
720   // need to be initialized.
721   //
722   uint32_t const histo_partition_count = passes + scatter_blocks - 1;
723   uint32_t       pass_idx              = (keyval_bytes - passes);
724 
725   VkDeviceSize const fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
726 
727   info->fill_buffer(cb,
728                     &info->internal,
729                     rs->internal.histograms.offset + fill_base,
730                     histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)),
731                     0);
732 
733   ////////////////////////////////////////////////////////////////////////
734   //
735   // Pipeline: HISTOGRAM
736   //
737   // TODO(allanmac): All subgroups should try to process approximately the same
738   // number of blocks in order to minimize tail effects.  This was implemented
739   // and reverted but should be reimplemented and benchmarked later.
740   //
741 #ifdef RS_VK_ENABLE_EXTENSIONS
742   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TRANSFER_BIT);
743 #endif
744 
745   vk_barrier_transfer_w_to_compute_r(cb);
746 
747   // clang-format off
748   VkDeviceAddress const devaddr_histograms   = info->internal.devaddr + rs->internal.histograms.offset;
749   VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even.devaddr;
750   // clang-format on
751 
752   //
753   // Dispatch histogram
754   //
755   struct rs_push_histogram const push_histogram = {
756 
757     .devaddr_histograms = devaddr_histograms,
758     .devaddr_keyvals    = devaddr_keyvals_even,
759     .passes             = passes
760   };
761 
762   disp->CmdPushConstants(cb,
763                      rs->pipeline_layouts.named.histogram,
764                      VK_SHADER_STAGE_COMPUTE_BIT,
765                      0,
766                      sizeof(push_histogram),
767                      &push_histogram);
768 
769   disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
770 
771   disp->CmdDispatch(cb, histo_blocks, 1, 1);
772 
773   ////////////////////////////////////////////////////////////////////////
774   //
775   // Pipeline: PREFIX
776   //
777   // Launch one workgroup per pass.
778   //
779 #ifdef RS_VK_ENABLE_EXTENSIONS
780   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
781 #endif
782 
783   vk_barrier_compute_w_to_compute_r(cb);
784 
785   struct rs_push_prefix const push_prefix = {
786 
787     .devaddr_histograms = devaddr_histograms,
788   };
789 
790   disp->CmdPushConstants(cb,
791                      rs->pipeline_layouts.named.prefix,
792                      VK_SHADER_STAGE_COMPUTE_BIT,
793                      0,
794                      sizeof(push_prefix),
795                      &push_prefix);
796 
797   disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
798 
799   disp->CmdDispatch(cb, passes, 1, 1);
800 
801   ////////////////////////////////////////////////////////////////////////
802   //
803   // Pipeline: SCATTER
804   //
805 #ifdef RS_VK_ENABLE_EXTENSIONS
806   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
807 #endif
808 
809   vk_barrier_compute_w_to_compute_r(cb);
810 
811   // clang-format off
812   uint32_t        const histogram_offset    = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
813   VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
814   VkDeviceAddress const devaddr_partitions  = info->internal.devaddr + rs->internal.partitions.offset;
815   // clang-format on
816 
817   struct rs_push_scatter push_scatter = {
818 
819     .devaddr_keyvals_even = devaddr_keyvals_even,
820     .devaddr_keyvals_odd  = devaddr_keyvals_odd,
821     .devaddr_partitions   = devaddr_partitions,
822     .devaddr_histograms   = devaddr_histograms + histogram_offset,
823     .pass_offset          = (pass_idx & 3) * RS_RADIX_LOG2,
824   };
825 
826   {
827     uint32_t const pass_dword = pass_idx / 4;
828 
829     disp->CmdPushConstants(cb,
830                        rs->pipeline_layouts.named.scatter[pass_dword].even,
831                        VK_SHADER_STAGE_COMPUTE_BIT,
832                        0,
833                        sizeof(push_scatter),
834                        &push_scatter);
835 
836     disp->CmdBindPipeline(cb,
837                       VK_PIPELINE_BIND_POINT_COMPUTE,
838                       rs->pipelines.named.scatter[pass_dword].even);
839   }
840 
841   bool is_even = true;
842 
843   while (true)
844     {
845       disp->CmdDispatch(cb, scatter_blocks, 1, 1);
846 
847       //
848       // Continue?
849       //
850       if (++pass_idx >= keyval_bytes)
851         break;
852 
853 #ifdef RS_VK_ENABLE_EXTENSIONS
854       rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
855 #endif
856       vk_barrier_compute_w_to_compute_r(cb);
857 
858       // clang-format off
859       is_even                         ^= true;
860       push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
861       push_scatter.pass_offset         = (pass_idx & 3) * RS_RADIX_LOG2;
862       // clang-format on
863 
864       uint32_t const pass_dword = pass_idx / 4;
865 
866       //
867       // Update push constants that changed
868       //
869       VkPipelineLayout const pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even  //
870                                           : rs->pipeline_layouts.named.scatter[pass_dword].odd;
871       disp->CmdPushConstants(cb,
872                          pl,
873                          VK_SHADER_STAGE_COMPUTE_BIT,
874                          OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
875                          sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
876                          &push_scatter.devaddr_histograms);
877 
878       //
879       // Bind new pipeline
880       //
881       VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even  //
882                                    : rs->pipelines.named.scatter[pass_dword].odd;
883 
884       disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
885     }
886 
887 #ifdef RS_VK_ENABLE_EXTENSIONS
888   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
889 #endif
890 
891   //
892   // End the label
893   //
894 #ifdef RS_VK_ENABLE_DEBUG_UTILS
895   disp->CmdEndDebugUtilsLabelEXT(cb);
896 #endif
897 }
898 
899 //
900 //
901 //
902 void
radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_devaddr_info_t const * info,VkDevice _device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)903 radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const *                            rs,
904                                     radix_sort_vk_sort_indirect_devaddr_info_t const * info,
905                                     VkDevice                                           _device,
906                                     VkCommandBuffer                                    cb,
907                                     VkDeviceAddress * keyvals_sorted)
908 {
909   VK_FROM_HANDLE(vk_device, device, _device);
910 
911   const struct vk_device_dispatch_table *disp = &device->dispatch_table;
912 
913   //
914   // Anything to do?
915   //
916   if (info->key_bits == 0)
917     {
918       *keyvals_sorted = info->keyvals_even;
919       return;
920     }
921 
922 #ifdef RS_VK_ENABLE_EXTENSIONS
923   //
924   // Any extensions?
925   //
926   struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
927 
928   void * ext_next = info->ext;
929 
930   while (ext_next != NULL)
931     {
932       struct radix_sort_vk_ext_base * const base = ext_next;
933 
934       switch (base->type)
935         {
936           case RS_VK_EXT_TIMESTAMPS:
937             ext_timestamps                 = ext_next;
938             ext_timestamps->timestamps_set = 0;
939             break;
940         }
941 
942       ext_next = base->ext;
943     }
944 #endif
945 
946     ////////////////////////////////////////////////////////////////////////
947     //
948     // OVERVIEW
949     //
950     //   1. Init
951     //      --- BARRIER ---
952     //   2. Pad the keyvals in `scatter_even`.
953     //   3. Zero the `histograms` and `partitions`.
954     //      --- BARRIER ---
955     //   4. HISTOGRAM is dispatched before PREFIX.
956     //      --- BARRIER ---
957     //   5. PREFIX is dispatched before the first SCATTER.
958     //      --- BARRIER ---
959     //   6. One or more SCATTER dispatches.
960     //
961     // Note that the `partitions` buffer can be zeroed anytime before the first
962     // scatter.
963     //
964     ////////////////////////////////////////////////////////////////////////
965 
966     //
967     // Label the command buffer
968     //
969 #ifdef RS_VK_ENABLE_DEBUG_UTILS
970   VkDebugUtilsLabelEXT const label = {
971     .sType      = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
972     .pNext      = NULL,
973     .pLabelName = "radix_sort_vk_sort_indirect",
974   };
975 
976   disp->CmdBeginDebugUtilsLabelEXT(cb, &label);
977 #endif
978 
979   //
980   // How many passes?
981   //
982   uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
983   uint32_t const keyval_bits  = keyval_bytes * 8;
984   uint32_t const key_bits     = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
985   uint32_t const passes       = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
986   uint32_t       pass_idx     = (keyval_bytes - passes);
987 
988   *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even;
989 
990   //
991   // NOTE(allanmac): Some of these initializations appear redundant but for now
992   // we're going to assume the compiler will elide them.
993   //
994   // clang-format off
995   VkDeviceAddress const devaddr_info         = info->indirect.devaddr;
996   VkDeviceAddress const devaddr_count        = info->count;
997   VkDeviceAddress const devaddr_histograms   = info->internal + rs->internal.histograms.offset;
998   VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even;
999   // clang-format on
1000 
1001   //
1002   // START
1003   //
1004 #ifdef RS_VK_ENABLE_EXTENSIONS
1005   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
1006 #endif
1007 
1008   //
1009   // INIT
1010   //
1011   {
1012     struct rs_push_init const push_init = {
1013 
1014       .devaddr_info  = devaddr_info,
1015       .devaddr_count = devaddr_count,
1016       .passes        = passes
1017     };
1018 
1019     disp->CmdPushConstants(cb,
1020                        rs->pipeline_layouts.named.init,
1021                        VK_SHADER_STAGE_COMPUTE_BIT,
1022                        0,
1023                        sizeof(push_init),
1024                        &push_init);
1025 
1026     disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.init);
1027 
1028     disp->CmdDispatch(cb, 1, 1, 1);
1029   }
1030 
1031 #ifdef RS_VK_ENABLE_EXTENSIONS
1032   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1033 #endif
1034 
1035   vk_barrier_compute_w_to_indirect_compute_r(cb);
1036 
1037   {
1038     //
1039     // PAD
1040     //
1041     struct rs_push_fill const push_pad = {
1042 
1043       .devaddr_info   = devaddr_info + offsetof(struct rs_indirect_info, pad),
1044       .devaddr_dwords = devaddr_keyvals_even,
1045       .dword          = 0xFFFFFFFF
1046     };
1047 
1048     disp->CmdPushConstants(cb,
1049                        rs->pipeline_layouts.named.fill,
1050                        VK_SHADER_STAGE_COMPUTE_BIT,
1051                        0,
1052                        sizeof(push_pad),
1053                        &push_pad);
1054 
1055     disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
1056 
1057     info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.pad));
1058   }
1059 
1060   //
1061   // ZERO
1062   //
1063   {
1064     VkDeviceSize const histo_offset = pass_idx * (sizeof(uint32_t) * RS_RADIX_SIZE);
1065 
1066     struct rs_push_fill const push_zero = {
1067 
1068       .devaddr_info   = devaddr_info + offsetof(struct rs_indirect_info, zero),
1069       .devaddr_dwords = devaddr_histograms + histo_offset,
1070       .dword          = 0
1071     };
1072 
1073     disp->CmdPushConstants(cb,
1074                        rs->pipeline_layouts.named.fill,
1075                        VK_SHADER_STAGE_COMPUTE_BIT,
1076                        0,
1077                        sizeof(push_zero),
1078                        &push_zero);
1079 
1080     disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
1081 
1082     info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.zero));
1083   }
1084 
1085 #ifdef RS_VK_ENABLE_EXTENSIONS
1086   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1087 #endif
1088 
1089   vk_barrier_compute_w_to_compute_r(cb);
1090 
1091   //
1092   // HISTOGRAM
1093   //
1094   {
1095     struct rs_push_histogram const push_histogram = {
1096 
1097       .devaddr_histograms = devaddr_histograms,
1098       .devaddr_keyvals    = devaddr_keyvals_even,
1099       .passes             = passes
1100     };
1101 
1102     disp->CmdPushConstants(cb,
1103                        rs->pipeline_layouts.named.histogram,
1104                        VK_SHADER_STAGE_COMPUTE_BIT,
1105                        0,
1106                        sizeof(push_histogram),
1107                        &push_histogram);
1108 
1109     disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
1110 
1111     info->dispatch_indirect(cb,
1112                             &info->indirect,
1113                             offsetof(struct rs_indirect_info, dispatch.histogram));
1114   }
1115 
1116 #ifdef RS_VK_ENABLE_EXTENSIONS
1117   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1118 #endif
1119 
1120   vk_barrier_compute_w_to_compute_r(cb);
1121 
1122   //
1123   // PREFIX
1124   //
1125   {
1126     struct rs_push_prefix const push_prefix = {
1127       .devaddr_histograms = devaddr_histograms,
1128     };
1129 
1130     disp->CmdPushConstants(cb,
1131                        rs->pipeline_layouts.named.prefix,
1132                        VK_SHADER_STAGE_COMPUTE_BIT,
1133                        0,
1134                        sizeof(push_prefix),
1135                        &push_prefix);
1136 
1137     disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
1138 
1139     disp->CmdDispatch(cb, passes, 1, 1);
1140   }
1141 
1142 #ifdef RS_VK_ENABLE_EXTENSIONS
1143   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1144 #endif
1145 
1146   vk_barrier_compute_w_to_compute_r(cb);
1147 
1148   //
1149   // SCATTER
1150   //
1151   {
1152     // clang-format off
1153     uint32_t        const histogram_offset    = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
1154     VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
1155     VkDeviceAddress const devaddr_partitions  = info->internal + rs->internal.partitions.offset;
1156     // clang-format on
1157 
1158     struct rs_push_scatter push_scatter = {
1159       .devaddr_keyvals_even = devaddr_keyvals_even,
1160       .devaddr_keyvals_odd  = devaddr_keyvals_odd,
1161       .devaddr_partitions   = devaddr_partitions,
1162       .devaddr_histograms   = devaddr_histograms + histogram_offset,
1163       .pass_offset          = (pass_idx & 3) * RS_RADIX_LOG2,
1164     };
1165 
1166     {
1167       uint32_t const pass_dword = pass_idx / 4;
1168 
1169       disp->CmdPushConstants(cb,
1170                          rs->pipeline_layouts.named.scatter[pass_dword].even,
1171                          VK_SHADER_STAGE_COMPUTE_BIT,
1172                          0,
1173                          sizeof(push_scatter),
1174                          &push_scatter);
1175 
1176       disp->CmdBindPipeline(cb,
1177                         VK_PIPELINE_BIND_POINT_COMPUTE,
1178                         rs->pipelines.named.scatter[pass_dword].even);
1179     }
1180 
1181     bool is_even = true;
1182 
1183     while (true)
1184       {
1185         info->dispatch_indirect(cb,
1186                                 &info->indirect,
1187                                 offsetof(struct rs_indirect_info, dispatch.scatter));
1188 
1189         //
1190         // Continue?
1191         //
1192         if (++pass_idx >= keyval_bytes)
1193           break;
1194 
1195 #ifdef RS_VK_ENABLE_EXTENSIONS
1196         rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1197 #endif
1198 
1199         vk_barrier_compute_w_to_compute_r(cb);
1200 
1201         // clang-format off
1202         is_even                         ^= true;
1203         push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
1204         push_scatter.pass_offset         = (pass_idx & 3) * RS_RADIX_LOG2;
1205         // clang-format on
1206 
1207         uint32_t const pass_dword = pass_idx / 4;
1208 
1209         //
1210         // Update push constants that changed
1211         //
1212         VkPipelineLayout const pl = is_even
1213                                       ? rs->pipeline_layouts.named.scatter[pass_dword].even  //
1214                                       : rs->pipeline_layouts.named.scatter[pass_dword].odd;
1215         disp->CmdPushConstants(
1216           cb,
1217           pl,
1218           VK_SHADER_STAGE_COMPUTE_BIT,
1219           OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
1220           sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
1221           &push_scatter.devaddr_histograms);
1222 
1223         //
1224         // Bind new pipeline
1225         //
1226         VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even  //
1227                                      : rs->pipelines.named.scatter[pass_dword].odd;
1228 
1229         disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
1230       }
1231   }
1232 
1233 #ifdef RS_VK_ENABLE_EXTENSIONS
1234   rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1235 #endif
1236 
1237   //
1238   // End the label
1239   //
1240 #ifdef RS_VK_ENABLE_DEBUG_UTILS
1241   disp->CmdEndDebugUtilsLabelEXT(cb);
1242 #endif
1243 }
1244 
1245 //
1246 //
1247 //
1248 static void
radix_sort_vk_fill_buffer(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset,VkDeviceSize size,uint32_t data)1249 radix_sort_vk_fill_buffer(VkCommandBuffer                     cb,
1250                           radix_sort_vk_buffer_info_t const * buffer_info,
1251                           VkDeviceSize                        offset,
1252                           VkDeviceSize                        size,
1253                           uint32_t                            data)
1254 {
1255   VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
1256   const struct vk_device_dispatch_table *disp =
1257     &cmd_buffer->base.device->dispatch_table;
1258 
1259   disp->CmdFillBuffer(cb, buffer_info->buffer, buffer_info->offset + offset, size, data);
1260 }
1261 
1262 //
1263 //
1264 //
1265 void
radix_sort_vk_sort(radix_sort_vk_t const * rs,radix_sort_vk_sort_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1266 radix_sort_vk_sort(radix_sort_vk_t const *           rs,
1267                    radix_sort_vk_sort_info_t const * info,
1268                    VkDevice                          device,
1269                    VkCommandBuffer                   cb,
1270                    VkDescriptorBufferInfo *          keyvals_sorted)
1271 {
1272   struct radix_sort_vk_sort_devaddr_info const di = {
1273     .ext          = info->ext,
1274     .key_bits     = info->key_bits,
1275     .count        = info->count,
1276     .keyvals_even = { .buffer  = info->keyvals_even.buffer,
1277                       .offset  = info->keyvals_even.offset,
1278                       .devaddr = rs_get_devaddr(device, &info->keyvals_even) },
1279     .keyvals_odd  = rs_get_devaddr(device, &info->keyvals_odd),
1280     .internal     = { .buffer  = info->internal.buffer,
1281                       .offset  = info->internal.offset,
1282                       .devaddr = rs_get_devaddr(device, &info->internal), },
1283     .fill_buffer  = radix_sort_vk_fill_buffer,
1284   };
1285 
1286   VkDeviceAddress di_keyvals_sorted;
1287 
1288   radix_sort_vk_sort_devaddr(rs, &di, device, cb, &di_keyvals_sorted);
1289 
1290   *keyvals_sorted = (di_keyvals_sorted == di.keyvals_even.devaddr)  //
1291                       ? info->keyvals_even
1292                       : info->keyvals_odd;
1293 }
1294 
1295 //
1296 //
1297 //
1298 static void
radix_sort_vk_dispatch_indirect(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset)1299 radix_sort_vk_dispatch_indirect(VkCommandBuffer                     cb,
1300                                 radix_sort_vk_buffer_info_t const * buffer_info,
1301                                 VkDeviceSize                        offset)
1302 {
1303   VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
1304   const struct vk_device_dispatch_table *disp =
1305     &cmd_buffer->base.device->dispatch_table;
1306 
1307   disp->CmdDispatchIndirect(cb, buffer_info->buffer, buffer_info->offset + offset);
1308 }
1309 
1310 //
1311 //
1312 //
1313 void
radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1314 radix_sort_vk_sort_indirect(radix_sort_vk_t const *                    rs,
1315                             radix_sort_vk_sort_indirect_info_t const * info,
1316                             VkDevice                                   device,
1317                             VkCommandBuffer                            cb,
1318                             VkDescriptorBufferInfo *                   keyvals_sorted)
1319 {
1320   struct radix_sort_vk_sort_indirect_devaddr_info const idi = {
1321     .ext               = info->ext,
1322     .key_bits          = info->key_bits,
1323     .count             = rs_get_devaddr(device, &info->count),
1324     .keyvals_even      = rs_get_devaddr(device, &info->keyvals_even),
1325     .keyvals_odd       = rs_get_devaddr(device, &info->keyvals_odd),
1326     .internal          = rs_get_devaddr(device, &info->internal),
1327     .indirect          = { .buffer  = info->indirect.buffer,
1328                            .offset  = info->indirect.offset,
1329                            .devaddr = rs_get_devaddr(device, &info->indirect) },
1330     .dispatch_indirect = radix_sort_vk_dispatch_indirect,
1331   };
1332 
1333   VkDeviceAddress idi_keyvals_sorted;
1334 
1335   radix_sort_vk_sort_indirect_devaddr(rs, &idi, device, cb, &idi_keyvals_sorted);
1336 
1337   *keyvals_sorted = (idi_keyvals_sorted == idi.keyvals_even)  //
1338                       ? info->keyvals_even
1339                       : info->keyvals_odd;
1340 }
1341 
1342 //
1343 //
1344 //
1345