1 // Copyright 2021 The Fuchsia Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <inttypes.h>
6 #include <stdio.h>
7 #include <stdlib.h>
8
9 #include "common/macros.h"
10 #include "common/util.h"
11 #include "common/vk/barrier.h"
12 #include "radix_sort_vk_devaddr.h"
13 #include "shaders/push.h"
14 #include "shaders/config.h"
15
16 #include "vk_command_buffer.h"
17 #include "vk_device.h"
18
19 //
20 //
21 //
22
23 #ifdef RS_VK_ENABLE_DEBUG_UTILS
24 #include "common/vk/debug_utils.h"
25 #endif
26
27 //
28 //
29 //
30
31 #ifdef RS_VK_ENABLE_EXTENSIONS
32 #include "radix_sort_vk_ext.h"
33 #endif
34
35 //
36 // FIXME(allanmac): memoize some of these calculations
37 //
38 void
radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs,uint32_t count,radix_sort_vk_memory_requirements_t * mr)39 radix_sort_vk_get_memory_requirements(radix_sort_vk_t const * rs,
40 uint32_t count,
41 radix_sort_vk_memory_requirements_t * mr)
42 {
43 //
44 // Keyval size
45 //
46 mr->keyval_size = rs->config.keyval_dwords * sizeof(uint32_t);
47
48 //
49 // Subgroup and workgroup sizes
50 //
51 uint32_t const histo_sg_size = 1 << rs->config.histogram.subgroup_size_log2;
52 uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
53 uint32_t const prefix_sg_size = 1 << rs->config.prefix.subgroup_size_log2;
54 uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
55 uint32_t const internal_sg_size = MAX_MACRO(uint32_t, histo_sg_size, prefix_sg_size);
56
57 //
58 // If for some reason count is zero then initialize appropriately.
59 //
60 if (count == 0)
61 {
62 mr->keyvals_size = 0;
63 mr->keyvals_alignment = mr->keyval_size * histo_sg_size;
64 mr->internal_size = 0;
65 mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
66 mr->indirect_size = 0;
67 mr->indirect_alignment = internal_sg_size * sizeof(uint32_t);
68 }
69 else
70 {
71 //
72 // Keyvals
73 //
74
75 // Round up to the scatter block size.
76 //
77 // Then round up to the histogram block size.
78 //
79 // Fill the difference between this new count and the original keyval
80 // count.
81 //
82 // How many scatter blocks?
83 //
84 uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
85 uint32_t const scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs;
86 uint32_t const count_ru_scatter = scatter_blocks * scatter_block_kvs;
87
88 //
89 // How many histogram blocks?
90 //
91 // Note that it's OK to have more max-valued digits counted by the histogram
92 // than sorted by the scatters because the sort is stable.
93 //
94 uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
95 uint32_t const histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
96 uint32_t const count_ru_histo = histo_blocks * histo_block_kvs;
97
98 mr->keyvals_size = mr->keyval_size * count_ru_histo;
99 mr->keyvals_alignment = mr->keyval_size * histo_sg_size;
100
101 //
102 // Internal
103 //
104 // NOTE: Assumes .histograms are before .partitions.
105 //
106 // Last scatter workgroup skips writing to a partition.
107 // Each RS_RADIX_LOG2 (8) bit pass has a zero-initialized histogram. This
108 // is one RS_RADIX_SIZE histogram per keyval byte.
109 //
110 // The last scatter workgroup skips writing to a partition so it doesn't
111 // need to be allocated.
112 //
113 // If the device doesn't support "sequential dispatch" of workgroups, then
114 // we need a zero-initialized dword counter per radix pass in the keyval
115 // to atomically acquire a virtual workgroup id. On sequentially
116 // dispatched devices, this is simply `gl_WorkGroupID.x`.
117 //
118 // The "internal" memory map looks like this:
119 //
120 // +---------------------------------+ <-- 0
121 // | histograms[keyval_size] |
122 // +---------------------------------+ <-- keyval_size * histo_size
123 // | partitions[scatter_blocks_ru-1] |
124 // +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size
125 // | workgroup_ids[keyval_size] |
126 // +---------------------------------+ <-- (keyval_size + scatter_blocks_ru - 1) * histo_size + workgroup_ids_size
127 //
128 // The `.workgroup_ids[]` are located after the last partition.
129 //
130 VkDeviceSize const histo_size = RS_RADIX_SIZE * sizeof(uint32_t);
131
132 mr->internal_size = (mr->keyval_size + scatter_blocks - 1) * histo_size;
133 mr->internal_alignment = internal_sg_size * sizeof(uint32_t);
134
135 //
136 // Support for nonsequential dispatch can be disabled.
137 //
138 VkDeviceSize const workgroup_ids_size = mr->keyval_size * sizeof(uint32_t);
139
140 mr->internal_size += workgroup_ids_size;
141
142 //
143 // Indirect
144 //
145 mr->indirect_size = sizeof(struct rs_indirect_info);
146 mr->indirect_alignment = sizeof(struct u32vec4);
147 }
148 }
149
150 //
151 //
152 //
153 #ifdef RS_VK_ENABLE_DEBUG_UTILS
154
155 static void
rs_debug_utils_set(VkDevice device,struct radix_sort_vk * rs)156 rs_debug_utils_set(VkDevice device, struct radix_sort_vk * rs)
157 {
158 if (pfn_vkSetDebugUtilsObjectNameEXT != NULL)
159 {
160 VkDebugUtilsObjectNameInfoEXT duoni = {
161 .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT,
162 .pNext = NULL,
163 .objectType = VK_OBJECT_TYPE_PIPELINE,
164 };
165
166 duoni.objectHandle = (uint64_t)rs->pipelines.named.init;
167 duoni.pObjectName = "radix_sort_init";
168 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
169
170 duoni.objectHandle = (uint64_t)rs->pipelines.named.fill;
171 duoni.pObjectName = "radix_sort_fill";
172 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
173
174 duoni.objectHandle = (uint64_t)rs->pipelines.named.histogram;
175 duoni.pObjectName = "radix_sort_histogram";
176 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
177
178 duoni.objectHandle = (uint64_t)rs->pipelines.named.prefix;
179 duoni.pObjectName = "radix_sort_prefix";
180 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
181
182 duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].even;
183 duoni.pObjectName = "radix_sort_scatter_0_even";
184 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
185
186 duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[0].odd;
187 duoni.pObjectName = "radix_sort_scatter_0_odd";
188 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
189
190 if (rs->config.keyval_dwords >= 2)
191 {
192 duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].even;
193 duoni.pObjectName = "radix_sort_scatter_1_even";
194 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
195
196 duoni.objectHandle = (uint64_t)rs->pipelines.named.scatter[1].odd;
197 duoni.pObjectName = "radix_sort_scatter_1_odd";
198 vk_ok(pfn_vkSetDebugUtilsObjectNameEXT(device, &duoni));
199 }
200 }
201 }
202
203 #endif
204
205 //
206 // How many pipelines are there?
207 //
208 static uint32_t
rs_pipeline_count(struct radix_sort_vk const * rs)209 rs_pipeline_count(struct radix_sort_vk const * rs)
210 {
211 return 1 + // init
212 1 + // fill
213 1 + // histogram
214 1 + // prefix
215 2 * rs->config.keyval_dwords; // scatters.even/odd[keyval_dwords]
216 }
217
218 radix_sort_vk_t *
radix_sort_vk_create(VkDevice _device,VkAllocationCallbacks const * ac,VkPipelineCache pc,const uint32_t * const * spv,const uint32_t * spv_sizes,struct radix_sort_vk_target_config config)219 radix_sort_vk_create(VkDevice _device,
220 VkAllocationCallbacks const * ac,
221 VkPipelineCache pc,
222 const uint32_t* const* spv,
223 const uint32_t* spv_sizes,
224 struct radix_sort_vk_target_config config)
225 {
226 VK_FROM_HANDLE(vk_device, device, _device);
227
228 const struct vk_device_dispatch_table *disp = &device->dispatch_table;
229
230 //
231 // Allocate radix_sort_vk
232 //
233 struct radix_sort_vk * const rs = calloc(1, sizeof(*rs));
234
235 //
236 // Save the config for layer
237 //
238 rs->config = config;
239
240 //
241 // How many pipelines?
242 //
243 uint32_t const pipeline_count = rs_pipeline_count(rs);
244
245 //
246 // Prepare to create pipelines
247 //
248 VkPushConstantRange const pcr[] = {
249 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
250 .offset = 0,
251 .size = sizeof(struct rs_push_init) },
252
253 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
254 .offset = 0,
255 .size = sizeof(struct rs_push_fill) },
256
257 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
258 .offset = 0,
259 .size = sizeof(struct rs_push_histogram) },
260
261 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
262 .offset = 0,
263 .size = sizeof(struct rs_push_prefix) },
264
265 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
266 .offset = 0,
267 .size = sizeof(struct rs_push_scatter) }, // scatter_0_even
268
269 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
270 .offset = 0,
271 .size = sizeof(struct rs_push_scatter) }, // scatter_0_odd
272
273 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
274 .offset = 0,
275 .size = sizeof(struct rs_push_scatter) }, // scatter_1_even
276
277 { .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, //
278 .offset = 0,
279 .size = sizeof(struct rs_push_scatter) }, // scatter_1_odd
280 };
281
282 uint32_t spec_constants[] = {
283 [RS_FILL_WORKGROUP_SIZE] = 1u << config.fill.workgroup_size_log2,
284 [RS_FILL_BLOCK_ROWS] = config.fill.block_rows,
285 [RS_HISTOGRAM_WORKGROUP_SIZE] = 1u << config.histogram.workgroup_size_log2,
286 [RS_HISTOGRAM_SUBGROUP_SIZE_LOG2] = config.histogram.subgroup_size_log2,
287 [RS_HISTOGRAM_BLOCK_ROWS] = config.histogram.block_rows,
288 [RS_PREFIX_WORKGROUP_SIZE] = 1u << config.prefix.workgroup_size_log2,
289 [RS_PREFIX_SUBGROUP_SIZE_LOG2] = config.prefix.subgroup_size_log2,
290 [RS_SCATTER_WORKGROUP_SIZE] = 1u << config.scatter.workgroup_size_log2,
291 [RS_SCATTER_SUBGROUP_SIZE_LOG2] = config.scatter.subgroup_size_log2,
292 [RS_SCATTER_BLOCK_ROWS] = config.scatter.block_rows,
293 [RS_SCATTER_NONSEQUENTIAL_DISPATCH] = config.nonsequential_dispatch,
294 };
295
296 VkSpecializationMapEntry spec_map[ARRAY_LENGTH_MACRO(spec_constants)];
297
298 for (uint32_t ii = 0; ii < ARRAY_LENGTH_MACRO(spec_constants); ii++)
299 {
300 spec_map[ii] = (VkSpecializationMapEntry) {
301 .constantID = ii,
302 .offset = sizeof(uint32_t) * ii,
303 .size = sizeof(uint32_t),
304 };
305 }
306
307 VkSpecializationInfo spec_info = {
308 .mapEntryCount = ARRAY_LENGTH_MACRO(spec_map),
309 .pMapEntries = spec_map,
310 .dataSize = sizeof(spec_constants),
311 .pData = spec_constants,
312 };
313
314 VkPipelineLayoutCreateInfo plci = {
315
316 .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
317 .pNext = NULL,
318 .flags = 0,
319 .setLayoutCount = 0,
320 .pSetLayouts = NULL,
321 .pushConstantRangeCount = 1,
322 // .pPushConstantRanges = pcr + ii;
323 };
324
325 for (uint32_t ii = 0; ii < pipeline_count; ii++)
326 {
327 plci.pPushConstantRanges = pcr + ii;
328
329 if (disp->CreatePipelineLayout(_device, &plci, NULL, rs->pipeline_layouts.handles + ii) != VK_SUCCESS)
330 goto fail_layout;
331 }
332
333 //
334 // Create compute pipelines
335 //
336 VkShaderModuleCreateInfo smci = {
337
338 .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
339 .pNext = NULL,
340 .flags = 0,
341 // .codeSize = ar_entries[...].size;
342 // .pCode = ar_data + ...;
343 };
344
345 VkShaderModule sms[ARRAY_LENGTH_MACRO(rs->pipelines.handles)] = {0};
346
347 for (uint32_t ii = 0; ii < pipeline_count; ii++)
348 {
349 smci.codeSize = spv_sizes[ii];
350 smci.pCode = spv[ii];
351
352 if (disp->CreateShaderModule(_device, &smci, ac, sms + ii) != VK_SUCCESS)
353 goto fail_shader;
354 }
355
356 //
357 // If necessary, set the expected subgroup size
358 //
359 #define RS_SUBGROUP_SIZE_CREATE_INFO_SET(size_) \
360 { \
361 .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT, \
362 .pNext = NULL, \
363 .requiredSubgroupSize = size_, \
364 }
365
366 #undef RS_SUBGROUP_SIZE_CREATE_INFO_NAME
367 #define RS_SUBGROUP_SIZE_CREATE_INFO_NAME(name_) \
368 RS_SUBGROUP_SIZE_CREATE_INFO_SET(1 << config.name_.subgroup_size_log2)
369
370 #define RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(name_) RS_SUBGROUP_SIZE_CREATE_INFO_SET(0)
371
372 VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT const rsscis[] = {
373 RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(init), // init
374 RS_SUBGROUP_SIZE_CREATE_INFO_ZERO(fill), // fill
375 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(histogram), // histogram
376 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(prefix), // prefix
377 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].even
378 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[0].odd
379 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].even
380 RS_SUBGROUP_SIZE_CREATE_INFO_NAME(scatter), // scatter[1].odd
381 };
382
383 //
384 // Define compute pipeline create infos
385 //
386 #undef RS_COMPUTE_PIPELINE_CREATE_INFO_DECL
387 #define RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(idx_) \
388 { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, \
389 .pNext = NULL, \
390 .flags = 0, \
391 .stage = { .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, \
392 .pNext = NULL, \
393 .flags = VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT, \
394 .stage = VK_SHADER_STAGE_COMPUTE_BIT, \
395 .module = sms[idx_], \
396 .pName = "main", \
397 .pSpecializationInfo = &spec_info }, \
398 \
399 .layout = rs->pipeline_layouts.handles[idx_], \
400 .basePipelineHandle = VK_NULL_HANDLE, \
401 .basePipelineIndex = 0 }
402
403 VkComputePipelineCreateInfo cpcis[] = {
404 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(0), // init
405 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(1), // fill
406 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(2), // histogram
407 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(3), // prefix
408 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(4), // scatter[0].even
409 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(5), // scatter[0].odd
410 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(6), // scatter[1].even
411 RS_COMPUTE_PIPELINE_CREATE_INFO_DECL(7), // scatter[1].odd
412 };
413
414 //
415 // Which of these compute pipelines require subgroup size control?
416 //
417 for (uint32_t ii = 0; ii < pipeline_count; ii++)
418 {
419 if (rsscis[ii].requiredSubgroupSize > 1)
420 {
421 cpcis[ii].stage.pNext = rsscis + ii;
422 }
423 }
424
425 //
426 // Create the compute pipelines
427 //
428 if (disp->CreateComputePipelines(_device, pc, pipeline_count, cpcis, ac, rs->pipelines.handles) != VK_SUCCESS)
429 goto fail_pipeline;
430
431 //
432 // Shader modules can be destroyed now
433 //
434 for (uint32_t ii = 0; ii < pipeline_count; ii++)
435 {
436 disp->DestroyShaderModule(_device, sms[ii], ac);
437 }
438
439 #ifdef RS_VK_ENABLE_DEBUG_UTILS
440 //
441 // Tag pipelines with names
442 //
443 rs_debug_utils_set(device, rs);
444 #endif
445
446 //
447 // Calculate "internal" buffer offsets
448 //
449 size_t const keyval_bytes = rs->config.keyval_dwords * sizeof(uint32_t);
450
451 // the .range calculation assumes an 8-bit radix
452 rs->internal.histograms.offset = 0;
453 rs->internal.histograms.range = keyval_bytes * (RS_RADIX_SIZE * sizeof(uint32_t));
454
455 //
456 // NOTE(allanmac): The partitions.offset must be aligned differently if
457 // RS_RADIX_LOG2 is less than the target's subgroup size log2. At this time,
458 // no GPU that meets this criteria.
459 //
460 rs->internal.partitions.offset = rs->internal.histograms.offset + rs->internal.histograms.range;
461
462 return rs;
463
464 fail_pipeline:
465 for (uint32_t ii = 0; ii < pipeline_count; ii++)
466 {
467 disp->DestroyPipeline(_device, rs->pipelines.handles[ii], ac);
468 }
469 fail_shader:
470 for (uint32_t ii = 0; ii < pipeline_count; ii++)
471 {
472 disp->DestroyShaderModule(_device, sms[ii], ac);
473 }
474 fail_layout:
475 for (uint32_t ii = 0; ii < pipeline_count; ii++)
476 {
477 disp->DestroyPipelineLayout(_device, rs->pipeline_layouts.handles[ii], ac);
478 }
479
480 free(rs);
481 return NULL;
482 }
483
484 //
485 //
486 //
487 void
radix_sort_vk_destroy(struct radix_sort_vk * rs,VkDevice d,VkAllocationCallbacks const * const ac)488 radix_sort_vk_destroy(struct radix_sort_vk * rs, VkDevice d, VkAllocationCallbacks const * const ac)
489 {
490 VK_FROM_HANDLE(vk_device, device, d);
491
492 const struct vk_device_dispatch_table *disp = &device->dispatch_table;
493
494 uint32_t const pipeline_count = rs_pipeline_count(rs);
495
496 // destroy pipelines
497 for (uint32_t ii = 0; ii < pipeline_count; ii++)
498 {
499 disp->DestroyPipeline(d, rs->pipelines.handles[ii], ac);
500 }
501
502 // destroy pipeline layouts
503 for (uint32_t ii = 0; ii < pipeline_count; ii++)
504 {
505 disp->DestroyPipelineLayout(d, rs->pipeline_layouts.handles[ii], ac);
506 }
507
508 free(rs);
509 }
510
511 //
512 //
513 //
514 static VkDeviceAddress
rs_get_devaddr(VkDevice _device,VkDescriptorBufferInfo const * dbi)515 rs_get_devaddr(VkDevice _device, VkDescriptorBufferInfo const * dbi)
516 {
517 VK_FROM_HANDLE(vk_device, device, _device);
518
519 const struct vk_device_dispatch_table *disp = &device->dispatch_table;
520
521 VkBufferDeviceAddressInfo const bdai = {
522
523 .sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
524 .pNext = NULL,
525 .buffer = dbi->buffer
526 };
527
528 VkDeviceAddress const devaddr = disp->GetBufferDeviceAddress(_device, &bdai) + dbi->offset;
529
530 return devaddr;
531 }
532
533 //
534 //
535 //
536 #ifdef RS_VK_ENABLE_EXTENSIONS
537
538 void
rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,VkCommandBuffer cb,VkPipelineStageFlagBits pipeline_stage)539 rs_ext_cmd_write_timestamp(struct radix_sort_vk_ext_timestamps * ext_timestamps,
540 VkCommandBuffer cb,
541 VkPipelineStageFlagBits pipeline_stage)
542 {
543 VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
544 const struct vk_device_dispatch_table *disp =
545 &cmd_buffer->base.device->dispatch_table;
546
547 if ((ext_timestamps != NULL) &&
548 (ext_timestamps->timestamps_set < ext_timestamps->timestamp_count))
549 {
550 disp->CmdWriteTimestamp(cb,
551 pipeline_stage,
552 ext_timestamps->timestamps,
553 ext_timestamps->timestamps_set++);
554 }
555 }
556
557 #endif
558
559 //
560 //
561 //
562
563 #ifdef RS_VK_ENABLE_EXTENSIONS
564
565 struct radix_sort_vk_ext_base
566 {
567 void * ext;
568 enum radix_sort_vk_ext_type type;
569 };
570
571 #endif
572
573 //
574 //
575 //
576 void
radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_devaddr_info_t const * info,VkDevice _device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)577 radix_sort_vk_sort_devaddr(radix_sort_vk_t const * rs,
578 radix_sort_vk_sort_devaddr_info_t const * info,
579 VkDevice _device,
580 VkCommandBuffer cb,
581 VkDeviceAddress * keyvals_sorted)
582 {
583 VK_FROM_HANDLE(vk_device, device, _device);
584
585 const struct vk_device_dispatch_table *disp = &device->dispatch_table;
586
587 //
588 // Anything to do?
589 //
590 if ((info->count <= 1) || (info->key_bits == 0))
591 {
592 *keyvals_sorted = info->keyvals_even.devaddr;
593
594 return;
595 }
596
597 #ifdef RS_VK_ENABLE_EXTENSIONS
598 //
599 // Any extensions?
600 //
601 struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
602
603 void * ext_next = info->ext;
604
605 while (ext_next != NULL)
606 {
607 struct radix_sort_vk_ext_base * const base = ext_next;
608
609 switch (base->type)
610 {
611 case RS_VK_EXT_TIMESTAMPS:
612 ext_timestamps = ext_next;
613 ext_timestamps->timestamps_set = 0;
614 break;
615 }
616
617 ext_next = base->ext;
618 }
619 #endif
620
621 ////////////////////////////////////////////////////////////////////////
622 //
623 // OVERVIEW
624 //
625 // 1. Pad the keyvals in `scatter_even`.
626 // 2. Zero the `histograms` and `partitions`.
627 // --- BARRIER ---
628 // 3. HISTOGRAM is dispatched before PREFIX.
629 // --- BARRIER ---
630 // 4. PREFIX is dispatched before the first SCATTER.
631 // --- BARRIER ---
632 // 5. One or more SCATTER dispatches.
633 //
634 // Note that the `partitions` buffer can be zeroed anytime before the first
635 // scatter.
636 //
637 ////////////////////////////////////////////////////////////////////////
638
639 //
640 // Label the command buffer
641 //
642 #ifdef RS_VK_ENABLE_DEBUG_UTILS
643 VkDebugUtilsLabelEXT const label = {
644 .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
645 .pNext = NULL,
646 .pLabelName = "radix_sort_vk_sort",
647 };
648
649 disp->CmdBeginDebugUtilsLabelEXT(cb, &label);
650 #endif
651
652 //
653 // How many passes?
654 //
655 uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
656 uint32_t const keyval_bits = keyval_bytes * 8;
657 uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
658 uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
659
660 *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even.devaddr;
661
662 ////////////////////////////////////////////////////////////////////////
663 //
664 // PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
665 //
666 // Pad fractional blocks with max-valued keyvals.
667 //
668 // Zero the histograms and partitions buffer.
669 //
670 // This assumes the partitions follow the histograms.
671 //
672 #ifdef RS_VK_ENABLE_EXTENSIONS
673 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
674 #endif
675
676 //
677 // FIXME(allanmac): Consider precomputing some of these values and hang them
678 // off `rs`.
679 //
680
681 //
682 // How many scatter blocks?
683 //
684 uint32_t const scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
685 uint32_t const scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
686 uint32_t const scatter_blocks = (info->count + scatter_block_kvs - 1) / scatter_block_kvs;
687 uint32_t const count_ru_scatter = scatter_blocks * scatter_block_kvs;
688
689 //
690 // How many histogram blocks?
691 //
692 // Note that it's OK to have more max-valued digits counted by the histogram
693 // than sorted by the scatters because the sort is stable.
694 //
695 uint32_t const histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
696 uint32_t const histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
697 uint32_t const histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
698 uint32_t const count_ru_histo = histo_blocks * histo_block_kvs;
699
700 //
701 // Fill with max values
702 //
703 if (count_ru_histo > info->count)
704 {
705 info->fill_buffer(cb,
706 &info->keyvals_even,
707 info->count * keyval_bytes,
708 (count_ru_histo - info->count) * keyval_bytes,
709 0xFFFFFFFF);
710 }
711
712 //
713 // Zero histograms and invalidate partitions.
714 //
715 // Note that the partition invalidation only needs to be performed once
716 // because the even/odd scatter dispatches rely on the the previous pass to
717 // leave the partitions in an invalid state.
718 //
719 // Note that the last workgroup doesn't read/write a partition so it doesn't
720 // need to be initialized.
721 //
722 uint32_t const histo_partition_count = passes + scatter_blocks - 1;
723 uint32_t pass_idx = (keyval_bytes - passes);
724
725 VkDeviceSize const fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
726
727 info->fill_buffer(cb,
728 &info->internal,
729 rs->internal.histograms.offset + fill_base,
730 histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)),
731 0);
732
733 ////////////////////////////////////////////////////////////////////////
734 //
735 // Pipeline: HISTOGRAM
736 //
737 // TODO(allanmac): All subgroups should try to process approximately the same
738 // number of blocks in order to minimize tail effects. This was implemented
739 // and reverted but should be reimplemented and benchmarked later.
740 //
741 #ifdef RS_VK_ENABLE_EXTENSIONS
742 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TRANSFER_BIT);
743 #endif
744
745 vk_barrier_transfer_w_to_compute_r(cb);
746
747 // clang-format off
748 VkDeviceAddress const devaddr_histograms = info->internal.devaddr + rs->internal.histograms.offset;
749 VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even.devaddr;
750 // clang-format on
751
752 //
753 // Dispatch histogram
754 //
755 struct rs_push_histogram const push_histogram = {
756
757 .devaddr_histograms = devaddr_histograms,
758 .devaddr_keyvals = devaddr_keyvals_even,
759 .passes = passes
760 };
761
762 disp->CmdPushConstants(cb,
763 rs->pipeline_layouts.named.histogram,
764 VK_SHADER_STAGE_COMPUTE_BIT,
765 0,
766 sizeof(push_histogram),
767 &push_histogram);
768
769 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
770
771 disp->CmdDispatch(cb, histo_blocks, 1, 1);
772
773 ////////////////////////////////////////////////////////////////////////
774 //
775 // Pipeline: PREFIX
776 //
777 // Launch one workgroup per pass.
778 //
779 #ifdef RS_VK_ENABLE_EXTENSIONS
780 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
781 #endif
782
783 vk_barrier_compute_w_to_compute_r(cb);
784
785 struct rs_push_prefix const push_prefix = {
786
787 .devaddr_histograms = devaddr_histograms,
788 };
789
790 disp->CmdPushConstants(cb,
791 rs->pipeline_layouts.named.prefix,
792 VK_SHADER_STAGE_COMPUTE_BIT,
793 0,
794 sizeof(push_prefix),
795 &push_prefix);
796
797 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
798
799 disp->CmdDispatch(cb, passes, 1, 1);
800
801 ////////////////////////////////////////////////////////////////////////
802 //
803 // Pipeline: SCATTER
804 //
805 #ifdef RS_VK_ENABLE_EXTENSIONS
806 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
807 #endif
808
809 vk_barrier_compute_w_to_compute_r(cb);
810
811 // clang-format off
812 uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
813 VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
814 VkDeviceAddress const devaddr_partitions = info->internal.devaddr + rs->internal.partitions.offset;
815 // clang-format on
816
817 struct rs_push_scatter push_scatter = {
818
819 .devaddr_keyvals_even = devaddr_keyvals_even,
820 .devaddr_keyvals_odd = devaddr_keyvals_odd,
821 .devaddr_partitions = devaddr_partitions,
822 .devaddr_histograms = devaddr_histograms + histogram_offset,
823 .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
824 };
825
826 {
827 uint32_t const pass_dword = pass_idx / 4;
828
829 disp->CmdPushConstants(cb,
830 rs->pipeline_layouts.named.scatter[pass_dword].even,
831 VK_SHADER_STAGE_COMPUTE_BIT,
832 0,
833 sizeof(push_scatter),
834 &push_scatter);
835
836 disp->CmdBindPipeline(cb,
837 VK_PIPELINE_BIND_POINT_COMPUTE,
838 rs->pipelines.named.scatter[pass_dword].even);
839 }
840
841 bool is_even = true;
842
843 while (true)
844 {
845 disp->CmdDispatch(cb, scatter_blocks, 1, 1);
846
847 //
848 // Continue?
849 //
850 if (++pass_idx >= keyval_bytes)
851 break;
852
853 #ifdef RS_VK_ENABLE_EXTENSIONS
854 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
855 #endif
856 vk_barrier_compute_w_to_compute_r(cb);
857
858 // clang-format off
859 is_even ^= true;
860 push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
861 push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
862 // clang-format on
863
864 uint32_t const pass_dword = pass_idx / 4;
865
866 //
867 // Update push constants that changed
868 //
869 VkPipelineLayout const pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even //
870 : rs->pipeline_layouts.named.scatter[pass_dword].odd;
871 disp->CmdPushConstants(cb,
872 pl,
873 VK_SHADER_STAGE_COMPUTE_BIT,
874 OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
875 sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
876 &push_scatter.devaddr_histograms);
877
878 //
879 // Bind new pipeline
880 //
881 VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even //
882 : rs->pipelines.named.scatter[pass_dword].odd;
883
884 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
885 }
886
887 #ifdef RS_VK_ENABLE_EXTENSIONS
888 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
889 #endif
890
891 //
892 // End the label
893 //
894 #ifdef RS_VK_ENABLE_DEBUG_UTILS
895 disp->CmdEndDebugUtilsLabelEXT(cb);
896 #endif
897 }
898
899 //
900 //
901 //
902 void
radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_devaddr_info_t const * info,VkDevice _device,VkCommandBuffer cb,VkDeviceAddress * keyvals_sorted)903 radix_sort_vk_sort_indirect_devaddr(radix_sort_vk_t const * rs,
904 radix_sort_vk_sort_indirect_devaddr_info_t const * info,
905 VkDevice _device,
906 VkCommandBuffer cb,
907 VkDeviceAddress * keyvals_sorted)
908 {
909 VK_FROM_HANDLE(vk_device, device, _device);
910
911 const struct vk_device_dispatch_table *disp = &device->dispatch_table;
912
913 //
914 // Anything to do?
915 //
916 if (info->key_bits == 0)
917 {
918 *keyvals_sorted = info->keyvals_even;
919 return;
920 }
921
922 #ifdef RS_VK_ENABLE_EXTENSIONS
923 //
924 // Any extensions?
925 //
926 struct radix_sort_vk_ext_timestamps * ext_timestamps = NULL;
927
928 void * ext_next = info->ext;
929
930 while (ext_next != NULL)
931 {
932 struct radix_sort_vk_ext_base * const base = ext_next;
933
934 switch (base->type)
935 {
936 case RS_VK_EXT_TIMESTAMPS:
937 ext_timestamps = ext_next;
938 ext_timestamps->timestamps_set = 0;
939 break;
940 }
941
942 ext_next = base->ext;
943 }
944 #endif
945
946 ////////////////////////////////////////////////////////////////////////
947 //
948 // OVERVIEW
949 //
950 // 1. Init
951 // --- BARRIER ---
952 // 2. Pad the keyvals in `scatter_even`.
953 // 3. Zero the `histograms` and `partitions`.
954 // --- BARRIER ---
955 // 4. HISTOGRAM is dispatched before PREFIX.
956 // --- BARRIER ---
957 // 5. PREFIX is dispatched before the first SCATTER.
958 // --- BARRIER ---
959 // 6. One or more SCATTER dispatches.
960 //
961 // Note that the `partitions` buffer can be zeroed anytime before the first
962 // scatter.
963 //
964 ////////////////////////////////////////////////////////////////////////
965
966 //
967 // Label the command buffer
968 //
969 #ifdef RS_VK_ENABLE_DEBUG_UTILS
970 VkDebugUtilsLabelEXT const label = {
971 .sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
972 .pNext = NULL,
973 .pLabelName = "radix_sort_vk_sort_indirect",
974 };
975
976 disp->CmdBeginDebugUtilsLabelEXT(cb, &label);
977 #endif
978
979 //
980 // How many passes?
981 //
982 uint32_t const keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
983 uint32_t const keyval_bits = keyval_bytes * 8;
984 uint32_t const key_bits = MIN_MACRO(uint32_t, info->key_bits, keyval_bits);
985 uint32_t const passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
986 uint32_t pass_idx = (keyval_bytes - passes);
987
988 *keyvals_sorted = ((passes & 1) != 0) ? info->keyvals_odd : info->keyvals_even;
989
990 //
991 // NOTE(allanmac): Some of these initializations appear redundant but for now
992 // we're going to assume the compiler will elide them.
993 //
994 // clang-format off
995 VkDeviceAddress const devaddr_info = info->indirect.devaddr;
996 VkDeviceAddress const devaddr_count = info->count;
997 VkDeviceAddress const devaddr_histograms = info->internal + rs->internal.histograms.offset;
998 VkDeviceAddress const devaddr_keyvals_even = info->keyvals_even;
999 // clang-format on
1000
1001 //
1002 // START
1003 //
1004 #ifdef RS_VK_ENABLE_EXTENSIONS
1005 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT);
1006 #endif
1007
1008 //
1009 // INIT
1010 //
1011 {
1012 struct rs_push_init const push_init = {
1013
1014 .devaddr_info = devaddr_info,
1015 .devaddr_count = devaddr_count,
1016 .passes = passes
1017 };
1018
1019 disp->CmdPushConstants(cb,
1020 rs->pipeline_layouts.named.init,
1021 VK_SHADER_STAGE_COMPUTE_BIT,
1022 0,
1023 sizeof(push_init),
1024 &push_init);
1025
1026 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.init);
1027
1028 disp->CmdDispatch(cb, 1, 1, 1);
1029 }
1030
1031 #ifdef RS_VK_ENABLE_EXTENSIONS
1032 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1033 #endif
1034
1035 vk_barrier_compute_w_to_indirect_compute_r(cb);
1036
1037 {
1038 //
1039 // PAD
1040 //
1041 struct rs_push_fill const push_pad = {
1042
1043 .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, pad),
1044 .devaddr_dwords = devaddr_keyvals_even,
1045 .dword = 0xFFFFFFFF
1046 };
1047
1048 disp->CmdPushConstants(cb,
1049 rs->pipeline_layouts.named.fill,
1050 VK_SHADER_STAGE_COMPUTE_BIT,
1051 0,
1052 sizeof(push_pad),
1053 &push_pad);
1054
1055 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
1056
1057 info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.pad));
1058 }
1059
1060 //
1061 // ZERO
1062 //
1063 {
1064 VkDeviceSize const histo_offset = pass_idx * (sizeof(uint32_t) * RS_RADIX_SIZE);
1065
1066 struct rs_push_fill const push_zero = {
1067
1068 .devaddr_info = devaddr_info + offsetof(struct rs_indirect_info, zero),
1069 .devaddr_dwords = devaddr_histograms + histo_offset,
1070 .dword = 0
1071 };
1072
1073 disp->CmdPushConstants(cb,
1074 rs->pipeline_layouts.named.fill,
1075 VK_SHADER_STAGE_COMPUTE_BIT,
1076 0,
1077 sizeof(push_zero),
1078 &push_zero);
1079
1080 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.fill);
1081
1082 info->dispatch_indirect(cb, &info->indirect, offsetof(struct rs_indirect_info, dispatch.zero));
1083 }
1084
1085 #ifdef RS_VK_ENABLE_EXTENSIONS
1086 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1087 #endif
1088
1089 vk_barrier_compute_w_to_compute_r(cb);
1090
1091 //
1092 // HISTOGRAM
1093 //
1094 {
1095 struct rs_push_histogram const push_histogram = {
1096
1097 .devaddr_histograms = devaddr_histograms,
1098 .devaddr_keyvals = devaddr_keyvals_even,
1099 .passes = passes
1100 };
1101
1102 disp->CmdPushConstants(cb,
1103 rs->pipeline_layouts.named.histogram,
1104 VK_SHADER_STAGE_COMPUTE_BIT,
1105 0,
1106 sizeof(push_histogram),
1107 &push_histogram);
1108
1109 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
1110
1111 info->dispatch_indirect(cb,
1112 &info->indirect,
1113 offsetof(struct rs_indirect_info, dispatch.histogram));
1114 }
1115
1116 #ifdef RS_VK_ENABLE_EXTENSIONS
1117 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1118 #endif
1119
1120 vk_barrier_compute_w_to_compute_r(cb);
1121
1122 //
1123 // PREFIX
1124 //
1125 {
1126 struct rs_push_prefix const push_prefix = {
1127 .devaddr_histograms = devaddr_histograms,
1128 };
1129
1130 disp->CmdPushConstants(cb,
1131 rs->pipeline_layouts.named.prefix,
1132 VK_SHADER_STAGE_COMPUTE_BIT,
1133 0,
1134 sizeof(push_prefix),
1135 &push_prefix);
1136
1137 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
1138
1139 disp->CmdDispatch(cb, passes, 1, 1);
1140 }
1141
1142 #ifdef RS_VK_ENABLE_EXTENSIONS
1143 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1144 #endif
1145
1146 vk_barrier_compute_w_to_compute_r(cb);
1147
1148 //
1149 // SCATTER
1150 //
1151 {
1152 // clang-format off
1153 uint32_t const histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
1154 VkDeviceAddress const devaddr_keyvals_odd = info->keyvals_odd;
1155 VkDeviceAddress const devaddr_partitions = info->internal + rs->internal.partitions.offset;
1156 // clang-format on
1157
1158 struct rs_push_scatter push_scatter = {
1159 .devaddr_keyvals_even = devaddr_keyvals_even,
1160 .devaddr_keyvals_odd = devaddr_keyvals_odd,
1161 .devaddr_partitions = devaddr_partitions,
1162 .devaddr_histograms = devaddr_histograms + histogram_offset,
1163 .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
1164 };
1165
1166 {
1167 uint32_t const pass_dword = pass_idx / 4;
1168
1169 disp->CmdPushConstants(cb,
1170 rs->pipeline_layouts.named.scatter[pass_dword].even,
1171 VK_SHADER_STAGE_COMPUTE_BIT,
1172 0,
1173 sizeof(push_scatter),
1174 &push_scatter);
1175
1176 disp->CmdBindPipeline(cb,
1177 VK_PIPELINE_BIND_POINT_COMPUTE,
1178 rs->pipelines.named.scatter[pass_dword].even);
1179 }
1180
1181 bool is_even = true;
1182
1183 while (true)
1184 {
1185 info->dispatch_indirect(cb,
1186 &info->indirect,
1187 offsetof(struct rs_indirect_info, dispatch.scatter));
1188
1189 //
1190 // Continue?
1191 //
1192 if (++pass_idx >= keyval_bytes)
1193 break;
1194
1195 #ifdef RS_VK_ENABLE_EXTENSIONS
1196 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1197 #endif
1198
1199 vk_barrier_compute_w_to_compute_r(cb);
1200
1201 // clang-format off
1202 is_even ^= true;
1203 push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
1204 push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
1205 // clang-format on
1206
1207 uint32_t const pass_dword = pass_idx / 4;
1208
1209 //
1210 // Update push constants that changed
1211 //
1212 VkPipelineLayout const pl = is_even
1213 ? rs->pipeline_layouts.named.scatter[pass_dword].even //
1214 : rs->pipeline_layouts.named.scatter[pass_dword].odd;
1215 disp->CmdPushConstants(
1216 cb,
1217 pl,
1218 VK_SHADER_STAGE_COMPUTE_BIT,
1219 OFFSETOF_MACRO(struct rs_push_scatter, devaddr_histograms),
1220 sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
1221 &push_scatter.devaddr_histograms);
1222
1223 //
1224 // Bind new pipeline
1225 //
1226 VkPipeline const p = is_even ? rs->pipelines.named.scatter[pass_dword].even //
1227 : rs->pipelines.named.scatter[pass_dword].odd;
1228
1229 disp->CmdBindPipeline(cb, VK_PIPELINE_BIND_POINT_COMPUTE, p);
1230 }
1231 }
1232
1233 #ifdef RS_VK_ENABLE_EXTENSIONS
1234 rs_ext_cmd_write_timestamp(ext_timestamps, cb, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT);
1235 #endif
1236
1237 //
1238 // End the label
1239 //
1240 #ifdef RS_VK_ENABLE_DEBUG_UTILS
1241 disp->CmdEndDebugUtilsLabelEXT(cb);
1242 #endif
1243 }
1244
1245 //
1246 //
1247 //
1248 static void
radix_sort_vk_fill_buffer(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset,VkDeviceSize size,uint32_t data)1249 radix_sort_vk_fill_buffer(VkCommandBuffer cb,
1250 radix_sort_vk_buffer_info_t const * buffer_info,
1251 VkDeviceSize offset,
1252 VkDeviceSize size,
1253 uint32_t data)
1254 {
1255 VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
1256 const struct vk_device_dispatch_table *disp =
1257 &cmd_buffer->base.device->dispatch_table;
1258
1259 disp->CmdFillBuffer(cb, buffer_info->buffer, buffer_info->offset + offset, size, data);
1260 }
1261
1262 //
1263 //
1264 //
1265 void
radix_sort_vk_sort(radix_sort_vk_t const * rs,radix_sort_vk_sort_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1266 radix_sort_vk_sort(radix_sort_vk_t const * rs,
1267 radix_sort_vk_sort_info_t const * info,
1268 VkDevice device,
1269 VkCommandBuffer cb,
1270 VkDescriptorBufferInfo * keyvals_sorted)
1271 {
1272 struct radix_sort_vk_sort_devaddr_info const di = {
1273 .ext = info->ext,
1274 .key_bits = info->key_bits,
1275 .count = info->count,
1276 .keyvals_even = { .buffer = info->keyvals_even.buffer,
1277 .offset = info->keyvals_even.offset,
1278 .devaddr = rs_get_devaddr(device, &info->keyvals_even) },
1279 .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd),
1280 .internal = { .buffer = info->internal.buffer,
1281 .offset = info->internal.offset,
1282 .devaddr = rs_get_devaddr(device, &info->internal), },
1283 .fill_buffer = radix_sort_vk_fill_buffer,
1284 };
1285
1286 VkDeviceAddress di_keyvals_sorted;
1287
1288 radix_sort_vk_sort_devaddr(rs, &di, device, cb, &di_keyvals_sorted);
1289
1290 *keyvals_sorted = (di_keyvals_sorted == di.keyvals_even.devaddr) //
1291 ? info->keyvals_even
1292 : info->keyvals_odd;
1293 }
1294
1295 //
1296 //
1297 //
1298 static void
radix_sort_vk_dispatch_indirect(VkCommandBuffer cb,radix_sort_vk_buffer_info_t const * buffer_info,VkDeviceSize offset)1299 radix_sort_vk_dispatch_indirect(VkCommandBuffer cb,
1300 radix_sort_vk_buffer_info_t const * buffer_info,
1301 VkDeviceSize offset)
1302 {
1303 VK_FROM_HANDLE(vk_command_buffer, cmd_buffer, cb);
1304 const struct vk_device_dispatch_table *disp =
1305 &cmd_buffer->base.device->dispatch_table;
1306
1307 disp->CmdDispatchIndirect(cb, buffer_info->buffer, buffer_info->offset + offset);
1308 }
1309
1310 //
1311 //
1312 //
1313 void
radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs,radix_sort_vk_sort_indirect_info_t const * info,VkDevice device,VkCommandBuffer cb,VkDescriptorBufferInfo * keyvals_sorted)1314 radix_sort_vk_sort_indirect(radix_sort_vk_t const * rs,
1315 radix_sort_vk_sort_indirect_info_t const * info,
1316 VkDevice device,
1317 VkCommandBuffer cb,
1318 VkDescriptorBufferInfo * keyvals_sorted)
1319 {
1320 struct radix_sort_vk_sort_indirect_devaddr_info const idi = {
1321 .ext = info->ext,
1322 .key_bits = info->key_bits,
1323 .count = rs_get_devaddr(device, &info->count),
1324 .keyvals_even = rs_get_devaddr(device, &info->keyvals_even),
1325 .keyvals_odd = rs_get_devaddr(device, &info->keyvals_odd),
1326 .internal = rs_get_devaddr(device, &info->internal),
1327 .indirect = { .buffer = info->indirect.buffer,
1328 .offset = info->indirect.offset,
1329 .devaddr = rs_get_devaddr(device, &info->indirect) },
1330 .dispatch_indirect = radix_sort_vk_dispatch_indirect,
1331 };
1332
1333 VkDeviceAddress idi_keyvals_sorted;
1334
1335 radix_sort_vk_sort_indirect_devaddr(rs, &idi, device, cb, &idi_keyvals_sorted);
1336
1337 *keyvals_sorted = (idi_keyvals_sorted == idi.keyvals_even) //
1338 ? info->keyvals_even
1339 : info->keyvals_odd;
1340 }
1341
1342 //
1343 //
1344 //
1345