• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can
5  * be found in the LICENSE file.
6  *
7  */
8 
9 #include <stdlib.h>
10 #include <string.h>
11 #include <inttypes.h>
12 
13 #include "common/util.h"
14 #include "common/macros.h"
15 #include "common/vk/assert_vk.h"
16 
17 #include "hs_vk.h"
18 #include "hs_vk_target.h"
19 
20 //
21 // We want concurrent kernel execution to occur in a few places.
22 //
23 // The summary is:
24 //
25 //   1) If necessary, some max valued keys are written to the end of
26 //      the vin/vout buffers.
27 //
28 //   2) Blocks of slabs of keys are sorted.
29 //
30 //   3) If necesary, the blocks of slabs are merged until complete.
31 //
32 //   4) If requested, the slabs will be converted from slab ordering
33 //      to linear ordering.
34 //
35 // Below is the general "happens-before" relationship between HotSort
36 // compute kernels.
37 //
38 // Note the diagram assumes vin and vout are different buffers.  If
39 // they're not, then the first merge doesn't include the pad_vout
40 // event in the wait list.
41 //
42 //                    +----------+            +---------+
43 //                    | pad_vout |            | pad_vin |
44 //                    +----+-----+            +----+----+
45 //                         |                       |
46 //                         |                WAITFOR(pad_vin)
47 //                         |                       |
48 //                         |                 +-----v-----+
49 //                         |                 |           |
50 //                         |            +----v----+ +----v----+
51 //                         |            | bs_full | | bs_frac |
52 //                         |            +----+----+ +----+----+
53 //                         |                 |           |
54 //                         |                 +-----v-----+
55 //                         |                       |
56 //                         |  +------NO------JUST ONE BLOCK?
57 //                         | /                     |
58 //                         |/                     YES
59 //                         +                       |
60 //                         |                       v
61 //                         |         END_WITH_EVENTS(bs_full,bs_frac)
62 //                         |
63 //                         |
64 //        WAITFOR(pad_vout,bs_full,bs_frac) >>> first iteration of loop <<<
65 //                         |
66 //                         |
67 //                         +-----------<------------+
68 //                         |                        |
69 //                   +-----v-----+                  |
70 //                   |           |                  |
71 //              +----v----+ +----v----+             |
72 //              | fm_full | | fm_frac |             |
73 //              +----+----+ +----+----+             |
74 //                   |           |                  ^
75 //                   +-----v-----+                  |
76 //                         |                        |
77 //              WAITFOR(fm_full,fm_frac)            |
78 //                         |                        |
79 //                         v                        |
80 //                      +--v--+                WAITFOR(bc)
81 //                      | hm  |                     |
82 //                      +-----+                     |
83 //                         |                        |
84 //                    WAITFOR(hm)                   |
85 //                         |                        ^
86 //                      +--v--+                     |
87 //                      | bc  |                     |
88 //                      +-----+                     |
89 //                         |                        |
90 //                         v                        |
91 //                  MERGING COMPLETE?-------NO------+
92 //                         |
93 //                        YES
94 //                         |
95 //                         v
96 //                END_WITH_EVENTS(bc)
97 //
98 
99 struct hs_vk
100 {
101   VkAllocationCallbacks const * allocator;
102   VkDevice                      device;
103 
104   struct {
105     struct {
106       VkDescriptorSetLayout     vout_vin;
107     } layout;
108   } desc_set;
109 
110   struct {
111     struct {
112       VkPipelineLayout          vout_vin;
113     } layout;
114   } pipeline;
115 
116   struct hs_vk_target_config    config;
117 
118   uint32_t                      key_val_size;
119   uint32_t                      slab_keys;
120   uint32_t                      bs_slabs_log2_ru;
121   uint32_t                      bc_slabs_log2_max;
122 
123   struct {
124     uint32_t                    count;
125     VkPipeline                * bs;
126     VkPipeline                * bc;
127     VkPipeline                * fm[3];
128     VkPipeline                * hm[3];
129     VkPipeline                * transpose;
130     VkPipeline                  all[];
131   } pipelines;
132 };
133 
134 //
135 //
136 //
137 
138 struct hs_state
139 {
140   VkCommandBuffer      cb;
141 
142   // If sorting in-place, then vout == vin
143   VkBuffer             vout;
144   VkBuffer             vin;
145 
146   // bx_ru is number of rounded up warps in vin
147   uint32_t             bx_ru;
148 };
149 
150 //
151 //
152 //
153 
154 static
155 void
hs_barrier_compute_w_to_compute_r(struct hs_state * const state)156 hs_barrier_compute_w_to_compute_r(struct hs_state * const state)
157 {
158   static VkMemoryBarrier const shader_w_to_r = {
159     .sType         = VK_STRUCTURE_TYPE_MEMORY_BARRIER,
160     .pNext         = NULL,
161     .srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT,
162     .dstAccessMask = VK_ACCESS_SHADER_READ_BIT
163   };
164 
165   vkCmdPipelineBarrier(state->cb,
166                        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
167                        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
168                        0,
169                        1,
170                        &shader_w_to_r,
171                        0,
172                        NULL,
173                        0,
174                        NULL);
175 }
176 
177 //
178 //
179 //
180 
181 static
182 void
hs_barrier_to_compute_r(struct hs_state * const state,VkPipelineStageFlags const src_stage,VkAccessFlagBits const src_access)183 hs_barrier_to_compute_r(struct hs_state    * const state,
184                         VkPipelineStageFlags const src_stage,
185                         VkAccessFlagBits     const src_access)
186 {
187   if (src_stage == 0)
188     return;
189 
190   VkMemoryBarrier const compute_r = {
191     .sType         = VK_STRUCTURE_TYPE_MEMORY_BARRIER,
192     .pNext         = NULL,
193     .srcAccessMask = src_access,
194     .dstAccessMask = VK_ACCESS_SHADER_READ_BIT
195   };
196 
197   vkCmdPipelineBarrier(state->cb,
198                        src_stage,
199                        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
200                        0,
201                        1,
202                        &compute_r,
203                        0,
204                        NULL,
205                        0,
206                        NULL);
207 }
208 
209 //
210 //
211 //
212 
213 static
214 void
hs_barrier_to_transfer_fill(struct hs_state * const state,VkPipelineStageFlags const src_stage,VkAccessFlagBits const src_access)215 hs_barrier_to_transfer_fill(struct hs_state    * const state,
216                             VkPipelineStageFlags const src_stage,
217                             VkAccessFlagBits     const src_access)
218 {
219   if (src_stage == 0)
220     return;
221 
222   VkMemoryBarrier const fill_w = {
223     .sType         = VK_STRUCTURE_TYPE_MEMORY_BARRIER,
224     .pNext         = NULL,
225     .srcAccessMask = src_access,
226     .dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT
227   };
228 
229   vkCmdPipelineBarrier(state->cb,
230                        src_stage,
231                        VK_PIPELINE_STAGE_TRANSFER_BIT,
232                        0,
233                        1,
234                        &fill_w,
235                        0,
236                        NULL,
237                        0,
238                        NULL);
239 }
240 
241 //
242 //
243 //
244 
245 static
246 void
hs_transpose(struct hs_vk const * const hs,struct hs_state * const state)247 hs_transpose(struct hs_vk const * const hs,
248              struct hs_state    * const state)
249 {
250   hs_barrier_compute_w_to_compute_r(state);
251 
252   vkCmdBindPipeline(state->cb,
253                     VK_PIPELINE_BIND_POINT_COMPUTE,
254                     hs->pipelines.transpose[0]);
255 
256   vkCmdDispatch(state->cb,state->bx_ru,1,1);
257 }
258 
259 //
260 //
261 //
262 
263 static
264 void
hs_bc(struct hs_vk const * const hs,struct hs_state * const state,uint32_t const down_slabs,uint32_t const clean_slabs_log2)265 hs_bc(struct hs_vk const * const hs,
266       struct hs_state    * const state,
267       uint32_t             const down_slabs,
268       uint32_t             const clean_slabs_log2)
269 {
270   hs_barrier_compute_w_to_compute_r(state);
271 
272   // block clean the minimal number of down_slabs_log2 spans
273   uint32_t const frac_ru = (1u << clean_slabs_log2) - 1;
274   uint32_t const full_bc = (down_slabs + frac_ru) >> clean_slabs_log2;
275 
276   vkCmdBindPipeline(state->cb,
277                     VK_PIPELINE_BIND_POINT_COMPUTE,
278                     hs->pipelines.bc[clean_slabs_log2]);
279 
280   vkCmdDispatch(state->cb,full_bc,1,1);
281 }
282 
283 //
284 //
285 //
286 
287 static
288 uint32_t
hs_hm(struct hs_vk const * const hs,struct hs_state * const state,uint32_t const down_slabs,uint32_t const clean_slabs_log2)289 hs_hm(struct hs_vk const * const hs,
290       struct hs_state    * const state,
291       uint32_t             const down_slabs,
292       uint32_t             const clean_slabs_log2)
293 {
294   hs_barrier_compute_w_to_compute_r(state);
295 
296   // how many scaled half-merge spans are there?
297   uint32_t const frac_ru    = (1 << clean_slabs_log2) - 1;
298   uint32_t const spans      = (down_slabs + frac_ru) >> clean_slabs_log2;
299 
300   // for now, just clamp to the max
301   uint32_t const log2_rem   = clean_slabs_log2 - hs->bc_slabs_log2_max;
302   uint32_t const scale_log2 = MIN_MACRO(hs->config.merge.hm.scale_max,log2_rem);
303   uint32_t const log2_out   = log2_rem - scale_log2;
304 
305   // size the grid
306   uint32_t const slab_span  = hs->config.slab.height << log2_out;
307 
308   vkCmdBindPipeline(state->cb,
309                     VK_PIPELINE_BIND_POINT_COMPUTE,
310                     hs->pipelines.hm[scale_log2][0]);
311 
312   vkCmdDispatch(state->cb,slab_span,spans,1);
313 
314   return log2_out;
315 }
316 
317 //
318 // FIXME -- some of this logic can be skipped if BS is a power-of-two
319 //
320 
321 static
322 uint32_t
hs_fm(struct hs_vk const * const hs,struct hs_state * const state,uint32_t * const down_slabs,uint32_t const up_scale_log2)323 hs_fm(struct hs_vk const * const hs,
324       struct hs_state    * const state,
325       uint32_t           * const down_slabs,
326       uint32_t             const up_scale_log2)
327 {
328   //
329   // FIXME OPTIMIZATION: in previous HotSort launchers it's sometimes
330   // a performance win to bias toward launching the smaller flip merge
331   // kernel in order to get more warps in flight (increased
332   // occupancy).  This is useful when merging small numbers of slabs.
333   //
334   // Note that HS_FM_SCALE_MIN will always be 0 or 1.
335   //
336   // So, for now, just clamp to the max until there is a reason to
337   // restore the fancier and probably low-impact approach.
338   //
339   uint32_t const scale_log2 = MIN_MACRO(hs->config.merge.fm.scale_max,up_scale_log2);
340   uint32_t const clean_log2 = up_scale_log2 - scale_log2;
341 
342   // number of slabs in a full-sized scaled flip-merge span
343   uint32_t const full_span_slabs = hs->config.block.slabs << up_scale_log2;
344 
345   // how many full-sized scaled flip-merge spans are there?
346   uint32_t full_fm = state->bx_ru / full_span_slabs;
347   uint32_t frac_fm = 0;
348 
349   // initialize down_slabs
350   *down_slabs = full_fm * full_span_slabs;
351 
352   // how many half-size scaled + fractional scaled spans are there?
353   uint32_t const span_rem        = state->bx_ru - *down_slabs;
354   uint32_t const half_span_slabs = full_span_slabs >> 1;
355 
356   // if we have over a half-span then fractionally merge it
357   if (span_rem > half_span_slabs)
358     {
359       // the remaining slabs will be cleaned
360       *down_slabs += span_rem;
361 
362       uint32_t const frac_rem      = span_rem - half_span_slabs;
363       uint32_t const frac_rem_pow2 = pow2_ru_u32(frac_rem);
364 
365       if (frac_rem_pow2 >= half_span_slabs)
366         {
367           // bump it up to a full span
368           full_fm += 1;
369         }
370       else
371         {
372           // otherwise, add fractional
373           frac_fm  = MAX_MACRO(1,frac_rem_pow2 >> clean_log2);
374         }
375     }
376 
377   //
378   // Size the grid
379   //
380   // The simplifying choices below limit the maximum keys that can be
381   // sorted with this grid scheme to around ~2B.
382   //
383   //   .x : slab height << clean_log2  -- this is the slab span
384   //   .y : [1...65535]                -- this is the slab index
385   //   .z : ( this could also be used to further expand .y )
386   //
387   // Note that OpenCL declares a grid in terms of global threads and
388   // not grids and blocks
389   //
390 
391   //
392   // size the grid
393   //
394   uint32_t const slab_span = hs->config.slab.height << clean_log2;
395 
396   if (full_fm > 0)
397     {
398       uint32_t const full_idx = hs->bs_slabs_log2_ru - 1 + scale_log2;
399 
400       vkCmdBindPipeline(state->cb,
401                         VK_PIPELINE_BIND_POINT_COMPUTE,
402                         hs->pipelines.fm[scale_log2][full_idx]);
403 
404       vkCmdDispatch(state->cb,slab_span,full_fm,1);
405     }
406 
407   if (frac_fm > 0)
408     {
409       vkCmdBindPipeline(state->cb,
410                         VK_PIPELINE_BIND_POINT_COMPUTE,
411                         hs->pipelines.fm[scale_log2][msb_idx_u32(frac_fm)]);
412 
413       vkCmdDispatchBase(state->cb,
414                         0,full_fm,0,
415                         slab_span,1,1);
416     }
417 
418   return clean_log2;
419 }
420 
421 //
422 //
423 //
424 
425 static
426 void
hs_bs(struct hs_vk const * const hs,struct hs_state * const state,uint32_t const count_padded_in)427 hs_bs(struct hs_vk const * const hs,
428       struct hs_state    * const state,
429       uint32_t             const count_padded_in)
430 {
431   uint32_t const slabs_in = count_padded_in / hs->slab_keys;
432   uint32_t const full_bs  = slabs_in / hs->config.block.slabs;
433   uint32_t const frac_bs  = slabs_in - full_bs * hs->config.block.slabs;
434 
435   if (full_bs > 0)
436     {
437       vkCmdBindPipeline(state->cb,
438                         VK_PIPELINE_BIND_POINT_COMPUTE,
439                         hs->pipelines.bs[hs->bs_slabs_log2_ru]);
440 
441       vkCmdDispatch(state->cb,full_bs,1,1);
442     }
443 
444   if (frac_bs > 0)
445     {
446       uint32_t const frac_idx          = msb_idx_u32(frac_bs);
447       uint32_t const full_to_frac_log2 = hs->bs_slabs_log2_ru - frac_idx;
448 
449       vkCmdBindPipeline(state->cb,
450                         VK_PIPELINE_BIND_POINT_COMPUTE,
451                         hs->pipelines.bs[msb_idx_u32(frac_bs)]);
452 
453       vkCmdDispatchBase(state->cb,
454                         full_bs<<full_to_frac_log2,0,0,
455                         1,1,1);
456     }
457 }
458 
459 //
460 //
461 //
462 
463 static
464 void
hs_keyset_pre_fm(struct hs_vk const * const hs,struct hs_state * const state,uint32_t const count_lo,uint32_t const count_hi)465 hs_keyset_pre_fm(struct hs_vk const * const hs,
466                  struct hs_state    * const state,
467                  uint32_t             const count_lo,
468                  uint32_t             const count_hi)
469 {
470   uint32_t const vout_span = count_hi - count_lo;
471 
472   vkCmdFillBuffer(state->cb,
473                   state->vout,
474                   count_lo  * hs->key_val_size,
475                   vout_span * hs->key_val_size,
476                   UINT32_MAX);
477 }
478 
479 //
480 //
481 //
482 
483 static
484 void
hs_keyset_pre_bs(struct hs_vk const * const hs,struct hs_state * const state,uint32_t const count,uint32_t const count_hi)485 hs_keyset_pre_bs(struct hs_vk const * const hs,
486                  struct hs_state    * const state,
487                  uint32_t             const count,
488                  uint32_t             const count_hi)
489 {
490   uint32_t const vin_span = count_hi - count;
491 
492   vkCmdFillBuffer(state->cb,
493                   state->vin,
494                   count    * hs->key_val_size,
495                   vin_span * hs->key_val_size,
496                   UINT32_MAX);
497 }
498 
499 //
500 //
501 //
502 
503 void
hs_vk_ds_bind(struct hs_vk const * const hs,VkDescriptorSet hs_ds,VkCommandBuffer cb,VkBuffer vin,VkBuffer vout)504 hs_vk_ds_bind(struct hs_vk const * const hs,
505               VkDescriptorSet            hs_ds,
506               VkCommandBuffer            cb,
507               VkBuffer                   vin,
508               VkBuffer                   vout)
509 {
510   //
511   // initialize the HotSort descriptor set
512   //
513   VkDescriptorBufferInfo const dbi[] = {
514     {
515       .buffer = vout == VK_NULL_HANDLE ? vin : vout,
516       .offset = 0,
517       .range  = VK_WHOLE_SIZE
518     },
519     {
520       .buffer = vin,
521       .offset = 0,
522       .range  = VK_WHOLE_SIZE
523     }
524   };
525 
526   VkWriteDescriptorSet const wds[] = {
527     {
528       .sType            = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
529       .pNext            = NULL,
530       .dstSet           = hs_ds,
531       .dstBinding       = 0,
532       .dstArrayElement  = 0,
533       .descriptorCount  = 2,
534       .descriptorType   = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
535       .pImageInfo       = NULL,
536       .pBufferInfo      = dbi,
537       .pTexelBufferView = NULL
538     }
539   };
540 
541   vkUpdateDescriptorSets(hs->device,
542                          ARRAY_LENGTH_MACRO(wds),
543                          wds,
544                          0,
545                          NULL);
546 
547   //
548   // All HotSort kernels can use the same descriptor set:
549   //
550   //   {
551   //     HS_KEY_TYPE vout[];
552   //     HS_KEY_TYPE vin[];
553   //   }
554   //
555   // Note that only the bs() kernels read from vin().
556   //
557   vkCmdBindDescriptorSets(cb,
558                           VK_PIPELINE_BIND_POINT_COMPUTE,
559                           hs->pipeline.layout.vout_vin,
560                           0,
561                           1,
562                           &hs_ds,
563                           0,
564                           NULL);
565 }
566 
567 //
568 //
569 //
570 
571 void
hs_vk_sort(struct hs_vk const * const hs,VkCommandBuffer cb,VkBuffer vin,VkPipelineStageFlags const vin_src_stage,VkAccessFlagBits const vin_src_access,VkBuffer vout,VkPipelineStageFlags const vout_src_stage,VkAccessFlagBits const vout_src_access,uint32_t const count,uint32_t const count_padded_in,uint32_t const count_padded_out,bool const linearize)572 hs_vk_sort(struct hs_vk const * const hs,
573            VkCommandBuffer            cb,
574            VkBuffer                   vin,
575            VkPipelineStageFlags const vin_src_stage,
576            VkAccessFlagBits     const vin_src_access,
577            VkBuffer                   vout,
578            VkPipelineStageFlags const vout_src_stage,
579            VkAccessFlagBits     const vout_src_access,
580            uint32_t             const count,
581            uint32_t             const count_padded_in,
582            uint32_t             const count_padded_out,
583            bool                 const linearize)
584 {
585   // is this sort in place?
586   bool const is_in_place = (vout == VK_NULL_HANDLE);
587 
588   //
589   // create some common state
590   //
591   struct hs_state state = {
592     .cb    = cb,
593     .vin   = vin,
594     .vout  = is_in_place ? vin : vout,
595     .bx_ru = (count + hs->slab_keys - 1) / hs->slab_keys
596   };
597 
598   // initialize vin
599   uint32_t const count_hi          = is_in_place ? count_padded_out : count_padded_in;
600   bool     const is_pre_sort_reqd  = count_hi > count;
601   bool     const is_pre_merge_reqd = !is_in_place && (count_padded_out > count_padded_in);
602 
603   //
604   // pre-sort  keyset needs to happen before bs()
605   // pre-merge keyset needs to happen before fm()
606   //
607 
608   VkPipelineStageFlags bs_src_stage  = 0;
609   VkAccessFlagBits     bs_src_access = 0;
610 
611   // initialize any trailing keys in vin before sorting
612   if (is_pre_sort_reqd)
613     {
614       hs_barrier_to_transfer_fill(&state,vin_src_stage,vin_src_access);
615 
616       hs_keyset_pre_bs(hs,&state,count,count_hi);
617 
618       bs_src_stage  |= VK_PIPELINE_STAGE_TRANSFER_BIT;
619       bs_src_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
620     }
621   else
622     {
623       bs_src_stage  = vin_src_stage;
624       bs_src_access = vin_src_access;
625     }
626 
627   hs_barrier_to_compute_r(&state,bs_src_stage,bs_src_access);
628 
629   // sort blocks of slabs... after hs_keyset_pre_sort()
630   hs_bs(hs,&state,count_padded_in);
631 
632   VkPipelineStageFlags fm_src_stage  = VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
633   VkAccessFlagBits     fm_src_access = VK_ACCESS_SHADER_READ_BIT;
634 
635   // initialize any trailing keys in vout before merging
636   if (is_pre_merge_reqd)
637     {
638       hs_barrier_to_transfer_fill(&state,vout_src_stage,vout_src_access);
639 
640       hs_keyset_pre_fm(hs,&state,count_padded_in,count_padded_out);
641 
642       fm_src_stage  |= VK_PIPELINE_STAGE_TRANSFER_BIT;
643       fm_src_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
644     }
645   else
646     {
647       fm_src_stage  |= vout_src_stage;
648       fm_src_access |= vout_src_access;
649     }
650 
651   //
652   // if this was a single bs block then there is no merging
653   //
654   if (state.bx_ru > hs->config.block.slabs)
655     {
656       hs_barrier_to_compute_r(&state,fm_src_stage,fm_src_access);
657 
658       //
659       // otherwise, merge sorted spans of slabs until done
660       //
661       int32_t up_scale_log2 = 1;
662 
663       while (true)
664         {
665           uint32_t down_slabs;
666 
667           // flip merge slabs -- return span of slabs that must be cleaned
668           uint32_t clean_slabs_log2 = hs_fm(hs,&state,
669                                             &down_slabs,
670                                             up_scale_log2);
671 
672           // if span is gt largest slab block cleaner then half merge
673           while (clean_slabs_log2 > hs->bc_slabs_log2_max)
674             {
675               clean_slabs_log2 = hs_hm(hs,&state,
676                                        down_slabs,
677                                        clean_slabs_log2);
678             }
679 
680           // launch clean slab grid -- is it the final launch?
681           hs_bc(hs,&state,down_slabs,clean_slabs_log2);
682 
683           // was this the final block clean?
684           if (((uint32_t)hs->config.block.slabs << up_scale_log2) >= state.bx_ru)
685             break;
686 
687           // otherwise, merge twice as many slabs
688           up_scale_log2 += 1;
689 
690           // drop a barrier
691           hs_barrier_compute_w_to_compute_r(&state);
692         }
693     }
694 
695   // slabs or linear?
696   if (linearize)
697     hs_transpose(hs,&state);
698 }
699 
700 //
701 //
702 //
703 
704 #ifdef HS_VK_VERBOSE_STATISTICS_AMD
705 
706 #include <stdio.h>
707 
708 static
709 void
hs_vk_verbose_statistics_amd(VkDevice device,struct hs_vk const * const hs)710 hs_vk_verbose_statistics_amd(VkDevice device, struct hs_vk const * const hs)
711 {
712   PFN_vkGetShaderInfoAMD vkGetShaderInfoAMD =
713     (PFN_vkGetShaderInfoAMD)
714     vkGetDeviceProcAddr(device,"vkGetShaderInfoAMD");
715 
716   if (vkGetShaderInfoAMD == NULL)
717     return;
718 
719   fprintf(stdout,
720           "                                   PHY   PHY  AVAIL AVAIL\n"
721           "VGPRs SGPRs LDS_MAX LDS/WG  SPILL VGPRs SGPRs VGPRs SGPRs  WORKGROUP_SIZE\n");
722 
723   for (uint32_t ii=0; ii<hs->pipelines.count; ii++)
724     {
725       VkShaderStatisticsInfoAMD ssi_amd;
726       size_t                    ssi_amd_size = sizeof(ssi_amd);
727 
728       if (vkGetShaderInfoAMD(hs->device,
729                              hs->pipelines.all[ii],
730                              VK_SHADER_STAGE_COMPUTE_BIT,
731                              VK_SHADER_INFO_TYPE_STATISTICS_AMD,
732                              &ssi_amd_size,
733                              &ssi_amd) == VK_SUCCESS)
734         {
735           fprintf(stdout,
736                   "%5" PRIu32 " "
737                   "%5" PRIu32 "   "
738                   "%5" PRIu32 " "
739 
740                   "%6zu "
741                   "%6zu "
742 
743                   "%5" PRIu32 " "
744                   "%5" PRIu32 " "
745                   "%5" PRIu32 " "
746                   "%5" PRIu32 "  "
747 
748                   "( %6" PRIu32 ", " "%6" PRIu32 ", " "%6" PRIu32 " )\n",
749                   ssi_amd.resourceUsage.numUsedVgprs,
750                   ssi_amd.resourceUsage.numUsedSgprs,
751                   ssi_amd.resourceUsage.ldsSizePerLocalWorkGroup,
752                   ssi_amd.resourceUsage.ldsUsageSizeInBytes,    // size_t
753                   ssi_amd.resourceUsage.scratchMemUsageInBytes, // size_t
754                   ssi_amd.numPhysicalVgprs,
755                   ssi_amd.numPhysicalSgprs,
756                   ssi_amd.numAvailableVgprs,
757                   ssi_amd.numAvailableSgprs,
758                   ssi_amd.computeWorkGroupSize[0],
759                   ssi_amd.computeWorkGroupSize[1],
760                   ssi_amd.computeWorkGroupSize[2]);
761         }
762     }
763 }
764 
765 #endif
766 
767 //
768 //
769 //
770 
771 #ifdef HS_VK_VERBOSE_DISASSEMBLY_AMD
772 
773 #include <stdio.h>
774 
775 static
776 void
hs_vk_verbose_disassembly_amd(VkDevice device,struct hs_vk const * const hs)777 hs_vk_verbose_disassembly_amd(VkDevice device, struct hs_vk const * const hs)
778 {
779   PFN_vkGetShaderInfoAMD vkGetShaderInfoAMD =
780     (PFN_vkGetShaderInfoAMD)
781     vkGetDeviceProcAddr(device,"vkGetShaderInfoAMD");
782 
783   if (vkGetShaderInfoAMD == NULL)
784     return;
785 
786   for (uint32_t ii=0; ii<hs->pipelines.count; ii++)
787     {
788       size_t disassembly_amd_size;
789 
790       if (vkGetShaderInfoAMD(hs->device,
791                              hs->pipelines.all[ii],
792                              VK_SHADER_STAGE_COMPUTE_BIT,
793                              VK_SHADER_INFO_TYPE_DISASSEMBLY_AMD,
794                              &disassembly_amd_size,
795                              NULL) == VK_SUCCESS)
796         {
797           void * disassembly_amd = malloc(disassembly_amd_size);
798 
799           if (vkGetShaderInfoAMD(hs->device,
800                                  hs->pipelines.all[ii],
801                                  VK_SHADER_STAGE_COMPUTE_BIT,
802                                  VK_SHADER_INFO_TYPE_DISASSEMBLY_AMD,
803                                  &disassembly_amd_size,
804                                  disassembly_amd) == VK_SUCCESS)
805             {
806               fprintf(stdout,"%s",(char*)disassembly_amd);
807             }
808 
809           free(disassembly_amd);
810         }
811     }
812 }
813 
814 #endif
815 
816 //
817 //
818 //
819 
820 struct hs_vk *
hs_vk_create(struct hs_vk_target const * const target,VkDevice device,VkAllocationCallbacks const * allocator,VkPipelineCache pipeline_cache)821 hs_vk_create(struct hs_vk_target   const * const target,
822              VkDevice                            device,
823              VkAllocationCallbacks const *       allocator,
824              VkPipelineCache                     pipeline_cache)
825 {
826   //
827   // we reference these values a lot
828   //
829   uint32_t const bs_slabs_log2_ru  = msb_idx_u32(pow2_ru_u32(target->config.block.slabs));
830   uint32_t const bc_slabs_log2_max = msb_idx_u32(pow2_rd_u32(target->config.block.slabs));
831 
832   //
833   // how many kernels will be created?
834   //
835   uint32_t const count_bs    = bs_slabs_log2_ru + 1;
836   uint32_t const count_bc    = bc_slabs_log2_max + 1;
837   uint32_t       count_fm[3] = { 0 };
838   uint32_t       count_hm[3] = { 0 };
839 
840   // guaranteed to be in range [0,2]
841   for (uint32_t scale = target->config.merge.fm.scale_min;
842        scale <= target->config.merge.fm.scale_max;
843        scale++)
844     {
845       uint32_t fm_left = (target->config.block.slabs / 2) << scale;
846 
847       count_fm[scale] = msb_idx_u32(pow2_ru_u32(fm_left)) + 1;
848     }
849 
850   // guaranteed to be in range [0,2]
851   for (uint32_t scale = target->config.merge.hm.scale_min;
852        scale <= target->config.merge.hm.scale_max;
853        scale++)
854     {
855       count_hm[scale] = 1;
856     }
857 
858   uint32_t const count_bc_fm_hm_transpose =
859     + count_bc
860     + count_fm[0] + count_fm[1] + count_fm[2]
861     + count_hm[0] + count_hm[1] + count_hm[2] +
862     1; // transpose
863 
864   uint32_t const count_all = count_bs + count_bc_fm_hm_transpose;
865 
866   //
867   // allocate hs_vk
868   //
869   struct hs_vk * hs;
870 
871   if (allocator == NULL)
872     {
873       hs = malloc(sizeof(*hs) + sizeof(VkPipeline*) * count_all);
874     }
875   else
876     {
877       hs = allocator->pfnAllocation(NULL,
878                                     sizeof(*hs) + sizeof(VkPipeline*) * count_all,
879                                     0,
880                                     VK_SYSTEM_ALLOCATION_SCOPE_INSTANCE);
881     }
882 
883   // save device & allocator
884   hs->device    = device;
885   hs->allocator = allocator;
886 
887   //
888   // create one descriptor set layout
889   //
890   static VkDescriptorSetLayoutBinding const dslb_vout_vin[] = {
891     {
892       .binding            = 0, // vout
893       .descriptorType     = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
894       .descriptorCount    = 1,
895       .stageFlags         = VK_SHADER_STAGE_COMPUTE_BIT,
896       .pImmutableSamplers = NULL
897     },
898     {
899       .binding            = 1, // vin
900       .descriptorType     = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
901       .descriptorCount    = 1,
902       .stageFlags         = VK_SHADER_STAGE_COMPUTE_BIT,
903       .pImmutableSamplers = NULL
904     }
905   };
906 
907   static VkDescriptorSetLayoutCreateInfo const dscli = {
908     .sType        = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
909     .pNext        = NULL,
910     .flags        = 0,
911     .bindingCount = 2, // 0:vout[], 1:vin[]
912     .pBindings    = dslb_vout_vin
913   };
914 
915   vk(CreateDescriptorSetLayout(device,
916                                &dscli,
917                                allocator,
918                                &hs->desc_set.layout.vout_vin));
919 
920   //
921   // create one pipeline layout
922   //
923   VkPipelineLayoutCreateInfo plci = {
924     .sType                  = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
925     .pNext                  = NULL,
926     .flags                  = 0,
927     .setLayoutCount         = 1,
928     .pSetLayouts            = &hs->desc_set.layout.vout_vin,
929     .pushConstantRangeCount = 0,
930     .pPushConstantRanges    = NULL
931   };
932 
933   vk(CreatePipelineLayout(device,
934                           &plci,
935                           allocator,
936                           &hs->pipeline.layout.vout_vin));
937 
938   //
939   // copy the config from the target -- we need these values later
940   //
941   memcpy(&hs->config,&target->config,sizeof(hs->config));
942 
943   // save some frequently used calculated values
944   hs->key_val_size      = (target->config.words.key + target->config.words.val) * 4;
945   hs->slab_keys         = target->config.slab.height << target->config.slab.width_log2;
946   hs->bs_slabs_log2_ru  = bs_slabs_log2_ru;
947   hs->bc_slabs_log2_max = bc_slabs_log2_max;
948 
949   // save kernel count
950   hs->pipelines.count   = count_all;
951 
952   //
953   // create all the compute pipelines by reusing this info
954   //
955   VkComputePipelineCreateInfo cpci = {
956     .sType                 = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
957     .pNext                 = NULL,
958     .flags                 = VK_PIPELINE_CREATE_DISPATCH_BASE, // | VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT,
959     .stage = {
960       .sType               = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
961       .pNext               = NULL,
962       .flags               = 0,
963       .stage               = VK_SHADER_STAGE_COMPUTE_BIT,
964       .module              = VK_NULL_HANDLE,
965       .pName               = "main",
966       .pSpecializationInfo = NULL
967     },
968     .layout                = hs->pipeline.layout.vout_vin,
969     .basePipelineHandle    = VK_NULL_HANDLE,
970     .basePipelineIndex     = 0
971   };
972 
973   //
974   // Create a shader module, use it to create a pipeline... and
975   // dispose of the shader module.
976   //
977   // The BS     compute shaders have the same layout
978   // The non-BS compute shaders have the same layout
979   //
980   VkShaderModuleCreateInfo smci = {
981     .sType    = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
982     .pNext    = NULL,
983     .flags    = 0,
984     .codeSize = 0,
985     .pCode    = (uint32_t const *)target->modules // FIXME -- unfortunate typecast
986   };
987 
988   //
989   // bs kernels have layout: (vout,vin)
990   // remaining  have layout: (vout)
991   //
992   for (uint32_t ii=0; ii<count_all; ii++)
993     {
994       // convert bytes to words
995       uint32_t const * const module = smci.pCode + smci.codeSize / sizeof(*module);
996 
997       smci.codeSize = NTOHL_MACRO(module[0]);
998       smci.pCode    = module + 1;
999 
1000       vk(CreateShaderModule(device,
1001                             &smci,
1002                             allocator,
1003                             &cpci.stage.module));
1004 
1005       vk(CreateComputePipelines(device,
1006                                 pipeline_cache,
1007                                 1,
1008                                 &cpci,
1009                                 allocator,
1010                                 hs->pipelines.all+ii));
1011 
1012       vkDestroyShaderModule(device,
1013                             cpci.stage.module,
1014                             allocator);
1015     }
1016 
1017   //
1018   // initialize pointers to pipeline handles
1019   //
1020   VkPipeline * pipeline_next = hs->pipelines.all;
1021 
1022   // BS
1023   hs->pipelines.bs        = pipeline_next;
1024   pipeline_next          += count_bs;
1025 
1026   // BC
1027   hs->pipelines.bc        = pipeline_next;
1028   pipeline_next          += count_bc;
1029 
1030   // FM[0]
1031   hs->pipelines.fm[0]     = count_fm[0] ? pipeline_next : NULL;
1032   pipeline_next          += count_fm[0];
1033 
1034   // FM[1]
1035   hs->pipelines.fm[1]     = count_fm[1] ? pipeline_next : NULL;
1036   pipeline_next          += count_fm[1];
1037 
1038   // FM[2]
1039   hs->pipelines.fm[2]     = count_fm[2] ? pipeline_next : NULL;
1040   pipeline_next          += count_fm[2];
1041 
1042   // HM[0]
1043   hs->pipelines.hm[0]     = count_hm[0] ? pipeline_next : NULL;
1044   pipeline_next          += count_hm[0];
1045 
1046   // HM[1]
1047   hs->pipelines.hm[1]     = count_hm[1] ? pipeline_next : NULL;
1048   pipeline_next          += count_hm[1];
1049 
1050   // HM[2]
1051   hs->pipelines.hm[2]     = count_hm[2] ? pipeline_next : NULL;
1052   pipeline_next          += count_hm[2];
1053 
1054   // TRANSPOSE
1055   hs->pipelines.transpose = pipeline_next;
1056   pipeline_next          += 1;
1057 
1058   //
1059   // optionally dump pipeline stats
1060   //
1061 #ifdef HS_VK_VERBOSE_STATISTICS_AMD
1062   hs_vk_verbose_statistics_amd(device,hs);
1063 #endif
1064 #ifdef HS_VK_VERBOSE_DISASSEMBLY_AMD
1065   hs_vk_verbose_disassembly_amd(device,hs);
1066 #endif
1067 
1068   //
1069   //
1070   //
1071 
1072   return hs;
1073 }
1074 
1075 //
1076 //
1077 //
1078 
1079 void
hs_vk_release(struct hs_vk * const hs)1080 hs_vk_release(struct hs_vk * const hs)
1081 {
1082   vkDestroyDescriptorSetLayout(hs->device,
1083                                hs->desc_set.layout.vout_vin,
1084                                hs->allocator);
1085 
1086   vkDestroyPipelineLayout(hs->device,
1087                           hs->pipeline.layout.vout_vin,
1088                           hs->allocator);
1089 
1090   for (uint32_t ii=0; ii<hs->pipelines.count; ii++)
1091     {
1092       vkDestroyPipeline(hs->device,
1093                         hs->pipelines.all[ii],
1094                         hs->allocator);
1095     }
1096 
1097   if (hs->allocator == NULL)
1098     {
1099       free(hs);
1100     }
1101   else
1102     {
1103       hs->allocator->pfnFree(NULL,hs);
1104     }
1105 }
1106 
1107 //
1108 // Allocate a per-thread descriptor set for the vin and vout
1109 // VkBuffers.  Note that HotSort uses only one descriptor set.
1110 //
1111 
1112 VkDescriptorSet
hs_vk_ds_alloc(struct hs_vk const * const hs,VkDescriptorPool desc_pool)1113 hs_vk_ds_alloc(struct hs_vk const * const hs, VkDescriptorPool desc_pool)
1114 {
1115   VkDescriptorSetAllocateInfo const ds_alloc_info = {
1116     .sType              = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,
1117     .pNext              = NULL,
1118     .descriptorPool     = desc_pool,
1119     .descriptorSetCount = 1,
1120     .pSetLayouts        = &hs->desc_set.layout.vout_vin
1121   };
1122 
1123   VkDescriptorSet hs_ds;
1124 
1125   vk(AllocateDescriptorSets(hs->device,
1126                             &ds_alloc_info,
1127                             &hs_ds));
1128 
1129   return hs_ds;
1130 }
1131 
1132 //
1133 //
1134 //
1135 
1136 void
hs_vk_pad(struct hs_vk const * const hs,uint32_t const count,uint32_t * const count_padded_in,uint32_t * const count_padded_out)1137 hs_vk_pad(struct hs_vk const * const hs,
1138           uint32_t             const count,
1139           uint32_t           * const count_padded_in,
1140           uint32_t           * const count_padded_out)
1141 {
1142   //
1143   // round up the count to slabs
1144   //
1145   uint32_t const slabs_ru        = (count + hs->slab_keys - 1) / hs->slab_keys;
1146   uint32_t const blocks          = slabs_ru / hs->config.block.slabs;
1147   uint32_t const block_slabs     = blocks * hs->config.block.slabs;
1148   uint32_t const slabs_ru_rem    = slabs_ru - block_slabs;
1149   uint32_t const slabs_ru_rem_ru = MIN_MACRO(pow2_ru_u32(slabs_ru_rem),hs->config.block.slabs);
1150 
1151   *count_padded_in  = (block_slabs + slabs_ru_rem_ru) * hs->slab_keys;
1152   *count_padded_out = *count_padded_in;
1153 
1154   //
1155   // will merging be required?
1156   //
1157   if (slabs_ru > hs->config.block.slabs)
1158     {
1159       // more than one block
1160       uint32_t const blocks_lo       = pow2_rd_u32(blocks);
1161       uint32_t const block_slabs_lo  = blocks_lo * hs->config.block.slabs;
1162       uint32_t const block_slabs_rem = slabs_ru - block_slabs_lo;
1163 
1164       if (block_slabs_rem > 0)
1165         {
1166           uint32_t const block_slabs_rem_ru     = pow2_ru_u32(block_slabs_rem);
1167 
1168           uint32_t const block_slabs_hi         = MAX_MACRO(block_slabs_rem_ru,
1169                                                             blocks_lo << (1 - hs->config.merge.fm.scale_min));
1170 
1171           uint32_t const block_slabs_padded_out = MIN_MACRO(block_slabs_lo+block_slabs_hi,
1172                                                             block_slabs_lo*2); // clamp non-pow2 blocks
1173 
1174           *count_padded_out = block_slabs_padded_out * hs->slab_keys;
1175         }
1176     }
1177 }
1178 
1179 //
1180 //
1181 //
1182