• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  *
7  */
8 
9 //
10 //
11 //
12 
13 #include <stdlib.h>
14 #include <stdbool.h>
15 #include <string.h>
16 #include <getopt.h>
17 #include <inttypes.h>
18 
19 //
20 //
21 //
22 
23 #include "networks.h"
24 #include "common/util.h"
25 #include "common/macros.h"
26 
27 //
28 //
29 //
30 
31 #undef  HSG_OP_EXPAND_X
32 #define HSG_OP_EXPAND_X(t) #t ,
33 
34 char const * const
35 hsg_op_type_string[] =
36   {
37     HSG_OP_EXPAND_ALL()
38   };
39 
40 //
41 //
42 //
43 
44 #define EXIT()                            (struct hsg_op){ HSG_OP_TYPE_EXIT                                     }
45 
46 #define END()                             (struct hsg_op){ HSG_OP_TYPE_END                                      }
47 #define BEGIN()                           (struct hsg_op){ HSG_OP_TYPE_BEGIN                                    }
48 #define ELSE()                            (struct hsg_op){ HSG_OP_TYPE_ELSE                                     }
49 
50 #define TARGET_BEGIN()                    (struct hsg_op){ HSG_OP_TYPE_TARGET_BEGIN                             }
51 #define TARGET_END()                      (struct hsg_op){ HSG_OP_TYPE_TARGET_END                               }
52 
53 #define TRANSPOSE_KERNEL_PROTO()          (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO                   }
54 #define TRANSPOSE_KERNEL_PREAMBLE()       (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE                }
55 #define TRANSPOSE_KERNEL_BODY()           (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY                    }
56 
57 #define BS_KERNEL_PROTO(i)                (struct hsg_op){ HSG_OP_TYPE_BS_KERNEL_PROTO,             { i       } }
58 #define BS_KERNEL_PREAMBLE(i)             (struct hsg_op){ HSG_OP_TYPE_BS_KERNEL_PREAMBLE,          { i       } }
59 
60 #define BC_KERNEL_PROTO(i)                (struct hsg_op){ HSG_OP_TYPE_BC_KERNEL_PROTO,             { i       } }
61 #define BC_KERNEL_PREAMBLE(i)             (struct hsg_op){ HSG_OP_TYPE_BC_KERNEL_PREAMBLE,          { i       } }
62 
63 #define FM_KERNEL_PROTO(s,r)              (struct hsg_op){ HSG_OP_TYPE_FM_KERNEL_PROTO,             { s, r    } }
64 #define FM_KERNEL_PREAMBLE(l,r)           (struct hsg_op){ HSG_OP_TYPE_FM_KERNEL_PREAMBLE,          { l, r    } }
65 
66 #define HM_KERNEL_PROTO(s)                (struct hsg_op){ HSG_OP_TYPE_HM_KERNEL_PROTO,             { s       } }
67 #define HM_KERNEL_PREAMBLE(l)             (struct hsg_op){ HSG_OP_TYPE_HM_KERNEL_PREAMBLE,          { l       } }
68 
69 #define BX_REG_GLOBAL_LOAD(n,v)           (struct hsg_op){ HSG_OP_TYPE_BX_REG_GLOBAL_LOAD,          { n, v    } }
70 #define BX_REG_GLOBAL_STORE(n)            (struct hsg_op){ HSG_OP_TYPE_BX_REG_GLOBAL_STORE,         { n       } }
71 
72 #define FM_REG_GLOBAL_LOAD_LEFT(n,i)      (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT,     { n, i    } }
73 #define FM_REG_GLOBAL_STORE_LEFT(n,i)     (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT,    { n, i    } }
74 #define FM_REG_GLOBAL_LOAD_RIGHT(n,i)     (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT,    { n, i    } }
75 #define FM_REG_GLOBAL_STORE_RIGHT(n,i)    (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT,   { n, i    } }
76 #define FM_MERGE_RIGHT_PRED(n,s)          (struct hsg_op){ HSG_OP_TYPE_FM_MERGE_RIGHT_PRED,         { n, s    } }
77 
78 #define HM_REG_GLOBAL_LOAD(n,i)           (struct hsg_op){ HSG_OP_TYPE_HM_REG_GLOBAL_LOAD,          { n, i    } }
79 #define HM_REG_GLOBAL_STORE(n,i)          (struct hsg_op){ HSG_OP_TYPE_HM_REG_GLOBAL_STORE,         { n, i    } }
80 
81 #define SLAB_FLIP(f)                      (struct hsg_op){ HSG_OP_TYPE_SLAB_FLIP,                   { f       } }
82 #define SLAB_HALF(h)                      (struct hsg_op){ HSG_OP_TYPE_SLAB_HALF,                   { h       } }
83 
84 #define CMP_FLIP(a,b,c)                   (struct hsg_op){ HSG_OP_TYPE_CMP_FLIP,                    { a, b, c } }
85 #define CMP_HALF(a,b)                     (struct hsg_op){ HSG_OP_TYPE_CMP_HALF,                    { a, b    } }
86 
87 #define CMP_XCHG(a,b,p)                   (struct hsg_op){ HSG_OP_TYPE_CMP_XCHG,                    { a, b, p } }
88 
89 #define BS_REG_SHARED_STORE_V(m,i,r)      (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_STORE_V,       { m, i, r } }
90 #define BS_REG_SHARED_LOAD_V(m,i,r)       (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_V,        { m, i, r } }
91 #define BC_REG_SHARED_LOAD_V(m,i,r)       (struct hsg_op){ HSG_OP_TYPE_BC_REG_SHARED_LOAD_V,        { m, i, r } }
92 
93 #define BX_REG_SHARED_STORE_LEFT(r,i,p)   (struct hsg_op){ HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT,    { r, i, p } }
94 #define BS_REG_SHARED_STORE_RIGHT(r,i,p)  (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT,   { r, i, p } }
95 
96 #define BS_REG_SHARED_LOAD_LEFT(r,i,p)    (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT,     { r, i, p } }
97 #define BS_REG_SHARED_LOAD_RIGHT(r,i,p)   (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT,    { r, i, p } }
98 
99 #define BC_REG_GLOBAL_LOAD_LEFT(r,i,p)    (struct hsg_op){ HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT,     { r, i, p } }
100 
101 #define REG_F_PREAMBLE(s)                 (struct hsg_op){ HSG_OP_TYPE_REG_F_PREAMBLE,              { s       } }
102 #define REG_SHARED_STORE_F(r,i,s)         (struct hsg_op){ HSG_OP_TYPE_REG_SHARED_STORE_F,          { r, i, s } }
103 #define REG_SHARED_LOAD_F(r,i,s)          (struct hsg_op){ HSG_OP_TYPE_REG_SHARED_LOAD_F,           { r, i, s } }
104 #define REG_GLOBAL_STORE_F(r,i,s)         (struct hsg_op){ HSG_OP_TYPE_REG_GLOBAL_STORE_F,          { r, i, s } }
105 
106 #define BLOCK_SYNC()                      (struct hsg_op){ HSG_OP_TYPE_BLOCK_SYNC                               }
107 
108 #define BS_FRAC_PRED(m,w)                 (struct hsg_op){ HSG_OP_TYPE_BS_FRAC_PRED,                { m, w    } }
109 
110 #define BS_MERGE_H_PREAMBLE(i)            (struct hsg_op){ HSG_OP_TYPE_BS_MERGE_H_PREAMBLE,         { i       } }
111 #define BC_MERGE_H_PREAMBLE(i)            (struct hsg_op){ HSG_OP_TYPE_BC_MERGE_H_PREAMBLE,         { i       } }
112 
113 #define BX_MERGE_H_PRED(p)                (struct hsg_op){ HSG_OP_TYPE_BX_MERGE_H_PRED,             { p       } }
114 
115 #define BS_ACTIVE_PRED(m,l)               (struct hsg_op){ HSG_OP_TYPE_BS_ACTIVE_PRED,              { m, l    } }
116 
117 //
118 // DEFAULTS
119 //
120 
121 static
122 struct hsg_config hsg_config =
123   {
124     .merge  = {
125       .flip = {
126         .warps      = 1,
127         .lo         = 1,
128         .hi         = 1
129       },
130       .half =  {
131         .warps      = 1,
132         .lo         = 1,
133         .hi         = 1
134       },
135     },
136 
137     .block  = {
138       .warps_min    = 1,          // min warps for a block that uses smem barriers
139       .warps_max    = UINT32_MAX, // max warps for the entire multiprocessor
140       .warps_mod    = 2,          // the number of warps necessary to load balance horizontal merging
141 
142       .smem_min     = 0,
143       .smem_quantum = 1,
144 
145       .smem_bs      = 49152,
146       .smem_bc      = UINT32_MAX  // implies field not set
147     },
148 
149     .warp   = {
150       .lanes        = 32,
151       .lanes_log2   = 5,
152     },
153 
154     .thread = {
155       .regs         = 24,
156       .xtra         = 0
157     },
158 
159     .type   = {
160       .words        = 2
161     }
162   };
163 
164 //
165 // ZERO HSG_MERGE STRUCT
166 //
167 
168 static
169 struct hsg_merge hsg_merge[MERGE_LEVELS_MAX_LOG2] = { 0 };
170 
171 //
172 // STATS ON INSTRUCTIONS
173 //
174 
175 static hsg_op_type hsg_op_type_counts[HSG_OP_TYPE_COUNT] = { 0 };
176 
177 //
178 //
179 //
180 
181 static
182 void
hsg_op_debug()183 hsg_op_debug()
184 {
185   uint32_t total = 0;
186 
187   for (hsg_op_type t=HSG_OP_TYPE_EXIT; t<HSG_OP_TYPE_COUNT; t++)
188     {
189       uint32_t const count = hsg_op_type_counts[t];
190 
191       total += count;
192 
193       fprintf(stderr,"%-37s : %u\n",hsg_op_type_string[t],count);
194     }
195 
196   fprintf(stderr,"%-37s : %u\n\n\n","TOTAL",total);
197 }
198 
199 //
200 //
201 //
202 
203 static
204 void
hsg_config_init_shared()205 hsg_config_init_shared()
206 {
207   //
208   // The assumption here is that a proper smem_bs value was provided
209   // that represents the maximum fraction of the multiprocessor's
210   // available shared memory that can be accessed by the initial block
211   // sorting kernel.
212   //
213   // With CUDA devices this is 48KB out of 48KB, 64KB or 96KB.
214   //
215   // Intel subslices are a little trickier and the minimum allocation
216   // is 4KB and the maximum is 64KB on pre-Skylake IGPs.  Sizes are
217   // allocated in 1KB increments.  If a maximum of two block sorters
218   // can occupy a subslice then each should be assigned 32KB of shared
219   // memory.
220   //
221   // News Flash: apparently GEN9+ IGPs can allocate 1KB of SMEM per
222   // workgroup so all the previously written logic to support this
223   // issue is being removed.
224   //
225   uint32_t const bs_keys = hsg_config.block.smem_bs / (hsg_config.type.words * sizeof(uint32_t));
226 
227   hsg_config.warp.skpw_bs = bs_keys / hsg_merge[0].warps;
228 }
229 
230 static
231 void
hsg_merge_levels_init_shared(struct hsg_merge * const merge)232 hsg_merge_levels_init_shared(struct hsg_merge * const merge)
233 {
234   {
235     //
236     // What is the max amount of shared in each possible bs block config?
237     //
238     // The provided smem_bs size will be allocated for each sorting block.
239     //
240     uint32_t const bs_threads   = merge->warps << hsg_config.warp.lanes_log2;
241     uint32_t const bs_keys      = hsg_config.block.smem_bs / (hsg_config.type.words * sizeof(uint32_t));
242     uint32_t const bs_kpt       = bs_keys / bs_threads;
243     uint32_t const bs_kpt_mod   = (bs_kpt / hsg_config.block.warps_mod) * hsg_config.block.warps_mod;
244     uint32_t const bs_rows_even = bs_kpt_mod & ~1; // must be even because flip merge only works on row pairs
245 
246     // this is a showstopper
247     if (bs_rows_even < 2)
248       {
249         fprintf(stderr,"Error: need at least 2 rows of shared memory.\n");
250         exit(-1);
251       }
252 
253     // clamp to number of registers
254     merge->rows_bs = MIN_MACRO(bs_rows_even, hsg_config.thread.regs);
255   }
256 
257   //
258   // smem key allocation rule for BC kernels is that a single block
259   // can't allocate more than smem_bs and must allocate at least
260   // smem_min in smem_quantum steps.
261   //
262   // Note that BC blocks will always be less than or equal to BS
263   // blocks.
264   //
265   {
266     //
267     // if merge->warps is not pow2 then we're going to skip creating a bc elsewhere
268     //
269     uint32_t const bc_warps_min  = MAX_MACRO(merge->warps,hsg_config.block.warps_min);
270     uint32_t const bc_threads    = bc_warps_min << hsg_config.warp.lanes_log2;
271     uint32_t const bc_block_rd   = (((hsg_config.block.smem_bc * bc_warps_min) / hsg_config.block.warps_max) /
272                                     hsg_config.block.smem_quantum) * hsg_config.block.smem_quantum;
273     uint32_t const bc_block_max  = MAX_MACRO(bc_block_rd,hsg_config.block.smem_min);
274     uint32_t const bc_block_smem = MIN_MACRO(bc_block_max,hsg_config.block.smem_bs);
275 
276     // what is the max amount of shared in each possible bc block config?
277     uint32_t const bc_keys       = bc_block_smem / (hsg_config.type.words * sizeof(uint32_t));
278     uint32_t const bc_kpt        = bc_keys / bc_threads;
279     uint32_t const bc_kpt_mod    = (bc_kpt / hsg_config.block.warps_mod) * hsg_config.block.warps_mod;
280 
281     merge->rows_bc = MIN_MACRO(bc_kpt_mod, hsg_config.thread.regs);
282     merge->skpw_bc = bc_keys / bc_warps_min;
283   }
284 }
285 
286 //
287 //
288 //
289 
290 static
291 void
hsg_merge_levels_init_1(struct hsg_merge * const merge,uint32_t const warps,uint32_t const level,uint32_t const offset)292 hsg_merge_levels_init_1(struct hsg_merge * const merge, uint32_t const warps, uint32_t const level, uint32_t const offset)
293 {
294   uint32_t const even_odd = warps & 1;
295 
296   merge->levels[level].evenodds[even_odd]++;
297   merge->levels[level].networks[even_odd] = warps;
298 
299   if (warps == 1)
300     return;
301 
302   merge->levels[level].active.b64 |= BITS_TO_MASK_AT_64(warps,offset);
303 
304   uint32_t const count = merge->levels[level].count++;
305   uint32_t const index = (1 << level) + count;
306   uint32_t const bit   = 1 << count;
307 
308   merge->levels[level].evenodd_masks[even_odd] |= bit;
309 
310   if (count > 0)
311     {
312       // offset from network to left of this network
313       uint32_t const diff   = offset - merge->offsets[index-1];
314 
315       uint32_t const diff_0 = merge->levels[level].diffs[0];
316       uint32_t const diff_1 = merge->levels[level].diffs[1];
317 
318       uint32_t diff_idx = UINT32_MAX;
319 
320       if        ((diff_0 == 0) || (diff_0 == diff)) {
321         diff_idx = 0;
322       } else if ((diff_1 == 0) || (diff_1 == diff)) {
323         diff_idx = 1;
324       } else {
325         fprintf(stderr, "*** MORE THAN TWO DIFFS ***\n");
326         exit(-1);
327       }
328 
329       merge->levels[level].diffs     [diff_idx]  = diff;
330       merge->levels[level].diff_masks[diff_idx] |= 1 << (count-1);
331     }
332 
333   merge->networks[index] = warps;
334   merge->offsets [index] = offset;
335 
336   uint32_t const l = (warps+1)/2; // lower/larger  on left
337   uint32_t const r = (warps+0)/2; // higher/smaller on right
338 
339   hsg_merge_levels_init_1(merge,l,level+1,offset);
340   hsg_merge_levels_init_1(merge,r,level+1,offset+l);
341 }
342 
343 static
344 void
hsg_merge_levels_debug(struct hsg_merge * const merge)345 hsg_merge_levels_debug(struct hsg_merge * const merge)
346 {
347   for (uint32_t level=0; level<MERGE_LEVELS_MAX_LOG2; level++)
348     {
349       uint32_t count = merge->levels[level].count;
350 
351       if (count == 0)
352         break;
353 
354       fprintf(stderr,
355               "%-4u : %016" PRIX64 " \n",
356               count,
357               merge->levels[level].active.b64);
358 
359       fprintf(stderr,
360               "%-4u : %08X (%2u)\n"
361               "%-4u : %08X (%2u)\n",
362               merge->levels[level].diffs[0],
363               merge->levels[level].diff_masks[0],
364               POPCOUNT_MACRO(merge->levels[level].diff_masks[0]),
365               merge->levels[level].diffs[1],
366               merge->levels[level].diff_masks[1],
367               POPCOUNT_MACRO(merge->levels[level].diff_masks[1]));
368 
369       fprintf(stderr,
370               "EVEN : %08X (%2u)\n"
371               "ODD  : %08X (%2u)\n",
372               merge->levels[level].evenodd_masks[0],
373               POPCOUNT_MACRO(merge->levels[level].evenodd_masks[0]),
374               merge->levels[level].evenodd_masks[1],
375               POPCOUNT_MACRO(merge->levels[level].evenodd_masks[1]));
376 
377       for (uint32_t ii=0; ii<2; ii++)
378         {
379           if (merge->levels[level].networks[ii] > 1)
380             {
381               fprintf(stderr,
382                       "%-4s : ( %2u x %2u )\n",
383                       (ii == 0) ? "EVEN" : "ODD",
384                       merge->levels[level].evenodds[ii],
385                       merge->levels[level].networks[ii]);
386             }
387         }
388 
389       uint32_t index = 1 << level;
390 
391       while (count-- > 0)
392         {
393           fprintf(stderr,
394                   "[ %2u %2u ] ",
395                   merge->offsets [index],
396                   merge->networks[index]);
397 
398           index += 1;
399         }
400 
401       fprintf(stderr,"\n\n");
402     }
403 }
404 
405 static
406 void
hsg_merge_levels_hint(struct hsg_merge * const merge,bool const autotune)407 hsg_merge_levels_hint(struct hsg_merge * const merge, bool const autotune)
408 {
409   // clamp against merge levels
410   for (uint32_t level=0; level<MERGE_LEVELS_MAX_LOG2; level++)
411     {
412       // max network
413       uint32_t const n_max = MAX_MACRO(merge->levels[level].networks[0],
414                                  merge->levels[level].networks[1]);
415 
416       if (n_max <= (merge->rows_bs + hsg_config.thread.xtra))
417         break;
418 
419       if (autotune)
420         {
421           hsg_config.thread.xtra = n_max - merge->rows_bs;
422 
423           uint32_t const r_total = hsg_config.thread.regs + hsg_config.thread.xtra;
424           uint32_t const r_limit = (hsg_config.type.words == 1) ? 120 : 58;
425 
426           if (r_total <= r_limit)
427             {
428               fprintf(stderr,"autotune: %u + %u\n",
429                       hsg_config.thread.regs,
430                       hsg_config.thread.xtra);
431               break;
432             }
433           else
434             {
435               fprintf(stderr,"skipping autotune: %u + %u > %u\n",
436                       hsg_config.thread.regs,
437                       hsg_config.thread.xtra,
438                       r_limit);
439               exit(-1);
440             }
441         }
442 
443       fprintf(stderr,"*** HINT *** Try extra registers: %u\n",
444               n_max - merge->rows_bs);
445 
446       exit(-1);
447     }
448 }
449 
450 //
451 //
452 //
453 
454 static
455 struct hsg_op *
hsg_op(struct hsg_op * ops,struct hsg_op const opcode)456 hsg_op(struct hsg_op * ops, struct hsg_op const opcode)
457 {
458   hsg_op_type_counts[opcode.type] += 1;
459 
460   *ops = opcode;
461 
462   return ops+1;
463 }
464 
465 static
466 struct hsg_op *
hsg_exit(struct hsg_op * ops)467 hsg_exit(struct hsg_op * ops)
468 {
469   return hsg_op(ops,EXIT());
470 }
471 
472 static
473 struct hsg_op *
hsg_end(struct hsg_op * ops)474 hsg_end(struct hsg_op * ops)
475 {
476   return hsg_op(ops,END());
477 }
478 
479 static
480 struct hsg_op *
hsg_begin(struct hsg_op * ops)481 hsg_begin(struct hsg_op * ops)
482 {
483   return hsg_op(ops,BEGIN());
484 }
485 
486 static
487 struct hsg_op *
hsg_else(struct hsg_op * ops)488 hsg_else(struct hsg_op * ops)
489 {
490   return hsg_op(ops,ELSE());
491 }
492 
493 static
494 struct hsg_op *
hsg_network_copy(struct hsg_op * ops,struct hsg_network const * const nets,uint32_t const idx,uint32_t const prefix)495 hsg_network_copy(struct hsg_op            *       ops,
496                  struct hsg_network const * const nets,
497                  uint32_t                   const idx,
498                  uint32_t                   const prefix)
499 {
500   uint32_t              const len = nets[idx].length;
501   struct hsg_op const * const cxa = nets[idx].network;
502 
503   for (uint32_t ii=0; ii<len; ii++)
504     {
505       struct hsg_op const * const cx = cxa + ii;
506 
507       ops = hsg_op(ops,CMP_XCHG(cx->a,cx->b,prefix));
508     }
509 
510   return ops;
511 }
512 
513 static
514 struct hsg_op *
hsg_thread_sort(struct hsg_op * ops)515 hsg_thread_sort(struct hsg_op * ops)
516 {
517   uint32_t const idx = hsg_config.thread.regs / 2 - 1;
518 
519   return hsg_network_copy(ops,hsg_networks_sorting,idx,UINT32_MAX);
520 }
521 
522 static
523 struct hsg_op *
hsg_thread_merge_prefix(struct hsg_op * ops,uint32_t const network,uint32_t const prefix)524 hsg_thread_merge_prefix(struct hsg_op * ops, uint32_t const network, uint32_t const prefix)
525 {
526   if (network <= 1)
527     return ops;
528 
529   return hsg_network_copy(ops,hsg_networks_merging,network-2,prefix);
530 }
531 
532 static
533 struct hsg_op *
hsg_thread_merge(struct hsg_op * ops,uint32_t const network)534 hsg_thread_merge(struct hsg_op * ops, uint32_t const network)
535 {
536   return hsg_thread_merge_prefix(ops,network,UINT32_MAX);
537 }
538 
539 static
540 struct hsg_op *
hsg_thread_merge_offset_prefix(struct hsg_op * ops,uint32_t const offset,uint32_t const network,uint32_t const prefix)541 hsg_thread_merge_offset_prefix(struct hsg_op * ops, uint32_t const offset, uint32_t const network, uint32_t const prefix)
542 {
543   if (network <= 1)
544     return ops;
545 
546   uint32_t                  const idx = network - 2;
547   uint32_t                  const len = hsg_networks_merging[idx].length;
548   struct hsg_op const * const cxa = hsg_networks_merging[idx].network;
549 
550   for (uint32_t ii=0; ii<len; ii++)
551     {
552       struct hsg_op const * const cx = cxa + ii;
553 
554       ops = hsg_op(ops,CMP_XCHG(offset + cx->a,offset + cx->b,prefix));
555     }
556 
557   return ops;
558 }
559 
560 static
561 struct hsg_op *
hsg_thread_merge_offset(struct hsg_op * ops,uint32_t const offset,uint32_t const network)562 hsg_thread_merge_offset(struct hsg_op * ops, uint32_t const offset, uint32_t const network)
563 {
564   return hsg_thread_merge_offset_prefix(ops,offset,network,UINT32_MAX);
565 }
566 
567 static
568 struct hsg_op *
hsg_thread_merge_left_right_prefix(struct hsg_op * ops,uint32_t const left,uint32_t const right,uint32_t const prefix)569 hsg_thread_merge_left_right_prefix(struct hsg_op * ops, uint32_t const left, uint32_t const right, uint32_t const prefix)
570 {
571   for (uint32_t l=left,r=left+1; r<=left+right; l--,r++)
572     {
573       ops = hsg_op(ops,CMP_XCHG(l,r,prefix));
574     }
575 
576   return ops;
577 }
578 
579 static
580 struct hsg_op *
hsg_thread_merge_left_right(struct hsg_op * ops,uint32_t const left,uint32_t const right)581 hsg_thread_merge_left_right(struct hsg_op * ops, uint32_t const left, uint32_t const right)
582 {
583   return hsg_thread_merge_left_right_prefix(ops,left,right,UINT32_MAX);
584 }
585 
586 static
587 struct hsg_op *
hsg_warp_half_network(struct hsg_op * ops)588 hsg_warp_half_network(struct hsg_op * ops)
589 {
590   uint32_t const n = hsg_config.thread.regs;
591 
592   for (uint32_t r=1; r<=n; r++)
593     ops = hsg_op(ops,CMP_HALF(r-1,r));
594 
595   return ops;
596 }
597 
598 static
599 struct hsg_op *
hsg_warp_half_downto(struct hsg_op * ops,uint32_t h)600 hsg_warp_half_downto(struct hsg_op * ops, uint32_t h)
601 {
602   //
603   // *** from h: downto[f/2,1)
604   // **** lane_half(h)
605   //
606   for (; h > 1; h/=2)
607     {
608       ops = hsg_begin(ops);
609 
610       ops = hsg_op(ops,SLAB_HALF(h));
611       ops = hsg_warp_half_network(ops);
612 
613       ops = hsg_end(ops);
614     }
615 
616   return ops;
617 }
618 
619 static
620 struct hsg_op *
hsg_warp_flip_network(struct hsg_op * ops)621 hsg_warp_flip_network(struct hsg_op * ops)
622 {
623   uint32_t const n = hsg_config.thread.regs;
624 
625   for (uint32_t r=1; r<=n/2; r++)
626     ops = hsg_op(ops,CMP_FLIP(r-1,r,n+1-r));
627 
628   return ops;
629 }
630 
631 static
632 struct hsg_op *
hsg_warp_flip(struct hsg_op * ops,uint32_t f)633 hsg_warp_flip(struct hsg_op * ops, uint32_t f)
634 {
635   ops = hsg_begin(ops);
636 
637   ops = hsg_op(ops,SLAB_FLIP(f));
638   ops = hsg_warp_flip_network(ops);
639 
640   ops = hsg_end(ops);
641 
642   return ops;
643 }
644 
645 static
646 struct hsg_op *
hsg_bx_warp_load(struct hsg_op * ops,const int32_t vin_or_vout)647 hsg_bx_warp_load(struct hsg_op * ops, const int32_t vin_or_vout)
648 {
649   uint32_t const n = hsg_config.thread.regs;
650 
651   for (uint32_t r=1; r<=n; r++)
652     ops = hsg_op(ops,BX_REG_GLOBAL_LOAD(r,vin_or_vout));
653 
654   return ops;
655 }
656 
657 static
658 struct hsg_op *
hsg_bx_warp_store(struct hsg_op * ops)659 hsg_bx_warp_store(struct hsg_op * ops)
660 {
661   uint32_t const n = hsg_config.thread.regs;
662 
663   for (uint32_t r=1; r<=n; r++)
664     ops = hsg_op(ops,BX_REG_GLOBAL_STORE(r));
665 
666   return ops;
667 }
668 
669 //
670 //
671 //
672 
673 static
674 struct hsg_op *
hsg_warp_transpose(struct hsg_op * ops)675 hsg_warp_transpose(struct hsg_op * ops)
676 {
677   // func proto
678   ops = hsg_op(ops,TRANSPOSE_KERNEL_PROTO());
679 
680   // begin
681   ops = hsg_begin(ops);
682 
683   // preamble
684   ops = hsg_op(ops,TRANSPOSE_KERNEL_PREAMBLE());
685 
686   // load
687   ops = hsg_bx_warp_load(ops,1); // 1 = load from vout[]
688 
689   // emit transpose blend and remap macros ...
690   ops = hsg_op(ops,TRANSPOSE_KERNEL_BODY());
691 
692   // ... done!
693   ops = hsg_end(ops);
694 
695   return ops;
696 }
697 
698 //
699 //
700 //
701 
702 static
703 struct hsg_op *
hsg_warp_half(struct hsg_op * ops,uint32_t const h)704 hsg_warp_half(struct hsg_op * ops, uint32_t const h)
705 {
706   //
707   // *** from h: downto[f/2,1)
708   // **** lane_half(h)
709   // *** thread_merge
710   //
711   ops = hsg_warp_half_downto(ops,h);
712   ops = hsg_thread_merge(ops,hsg_config.thread.regs);
713 
714   return ops;
715 }
716 
717 static
718 struct hsg_op *
hsg_warp_merge(struct hsg_op * ops)719 hsg_warp_merge(struct hsg_op * ops)
720 {
721   //
722   // * from f: upto[2,warp.lanes]
723   // ** lane_flip(f)
724   // *** from h: downto[f/2,1)
725   // **** lane_half(h)
726   // *** thread_merge
727   //
728   uint32_t const level = hsg_config.warp.lanes;
729 
730   for (uint32_t f=2; f<=level; f*=2)
731     {
732       ops = hsg_warp_flip(ops,f);
733       ops = hsg_warp_half(ops,f/2);
734     }
735 
736   return ops;
737 }
738 
739 //
740 //
741 //
742 
743 static
744 struct hsg_op *
hsg_bc_half_merge_level(struct hsg_op * ops,struct hsg_merge const * const merge,uint32_t const r_lo,uint32_t const s_count)745 hsg_bc_half_merge_level(struct hsg_op          *       ops,
746                         struct hsg_merge const * const merge,
747                         uint32_t                 const r_lo,
748                         uint32_t                 const s_count)
749 {
750   // guaranteed to be an even network
751   uint32_t const net_even = merge->levels[0].networks[0];
752 
753   // min of warps in block and remaining horizontal rows
754   uint32_t const active = MIN_MACRO(s_count, net_even);
755 
756   // conditional on blockIdx.x
757   if (active < merge->warps)
758     ops = hsg_op(ops,BX_MERGE_H_PRED(active)); // FIXME BX_MERGE
759 
760   // body begin
761   ops = hsg_begin(ops);
762 
763   // scale for min block
764   uint32_t const scale = net_even >= hsg_config.block.warps_min ? 1 : hsg_config.block.warps_min / net_even;
765 
766   // loop if more smem rows than warps
767   for (uint32_t rr=0; rr<s_count; rr+=active)
768     {
769       // body begin
770       ops = hsg_begin(ops);
771 
772       // skip down slab
773       uint32_t const gmem_base = r_lo - 1 + rr;
774 
775       // load registers horizontally -- striding across slabs
776       for (uint32_t ll=1; ll<=net_even; ll++)
777         ops = hsg_op(ops,BC_REG_GLOBAL_LOAD_LEFT(ll,gmem_base+(ll-1)*hsg_config.thread.regs,0));
778 
779       // merge all registers
780       ops = hsg_thread_merge_prefix(ops,net_even,0);
781 
782       // if we're looping then there is a base
783       uint32_t const smem_base = rr * net_even * scale;
784 
785       // store all registers
786       for (uint32_t ll=1; ll<=net_even; ll++)
787         ops = hsg_op(ops,BX_REG_SHARED_STORE_LEFT(ll,smem_base+ll-1,0));
788 
789       // body end
790       ops = hsg_end(ops);
791     }
792 
793   // body end
794   ops = hsg_end(ops);
795 
796   return ops;
797 }
798 
799 static
800 struct hsg_op *
hsg_bc_half_merge(struct hsg_op * ops,struct hsg_merge const * const merge)801 hsg_bc_half_merge(struct hsg_op * ops, struct hsg_merge const * const merge)
802 {
803   //
804   // will only be called with merge->warps >= 2
805   //
806   uint32_t const warps    = MAX_MACRO(merge->warps,hsg_config.block.warps_min);
807 
808   // guaranteed to be an even network
809   uint32_t const net_even = merge->levels[0].networks[0];
810 
811   // set up left SMEM pointer
812   ops = hsg_op(ops,BC_MERGE_H_PREAMBLE(merge->index));
813 
814   // trim to number of warps in block -- FIXME -- try make this a
815   // multiple of local processor count (Intel = 8, NVIDIA = 4)
816   uint32_t const s_max = merge->rows_bc;
817 
818   // for all the registers
819   for (uint32_t r_lo = 1; r_lo <= hsg_config.thread.regs; r_lo += s_max)
820     {
821       // compute store count
822       uint32_t const r_rem   = hsg_config.thread.regs + 1 - r_lo;
823       uint32_t const s_count = MIN_MACRO(s_max,r_rem);
824 
825       // block sync -- can skip if first
826       if (r_lo > 1)
827         ops = hsg_op(ops,BLOCK_SYNC());
828 
829       // merge loop
830       ops = hsg_bc_half_merge_level(ops,merge,r_lo,s_count);
831 
832       // block sync
833       ops = hsg_op(ops,BLOCK_SYNC());
834 
835       // load rows from shared
836       for (uint32_t c=0; c<s_count; c++)
837         ops = hsg_op(ops,BC_REG_SHARED_LOAD_V(warps,r_lo+c,c));
838     }
839 
840   return ops;
841 }
842 
843 //
844 //
845 //
846 
847 static
848 struct hsg_op *
hsg_bs_flip_merge_level(struct hsg_op * ops,struct hsg_merge const * const merge,uint32_t const level,uint32_t const s_pairs)849 hsg_bs_flip_merge_level(struct hsg_op          *       ops,
850                         struct hsg_merge const * const merge,
851                         uint32_t                 const level,
852                         uint32_t                 const s_pairs)
853 {
854   //
855   // Note there are a number of ways to flip merge these warps.  There
856   // is a magic number in the merge structure that indicates which
857   // warp to activate as well as what network size to invoke.
858   //
859   // This more complex scheme was used in the past.
860   //
861   // The newest scheme is far dumber/simpler and simply directs a warp
862   // to gather up the network associated with a row and merge them.
863   //
864   // This scheme may use more registers per thread but not all
865   // compilers are high quality.
866   //
867   // If there are more warps than smem row pairs to merge then we
868   // disable the spare warps.
869   //
870   // If there are more row pairs than warps then each warp works on
871   // an equal number of rows.
872   //
873   // Note that it takes two warps to flip merge two smem rows.
874   //
875   // FIXME -- We may want to apply the warp smem "mod" value here to
876   // attempt to balance the load>merge>store operations across the
877   // multiprocessor cores.
878   //
879   // FIXME -- the old scheme attempted to keep all the warps active
880   // but the iteration logic was more complex.  See 2016 checkins.
881   //
882 
883   // where are we in computed merge?
884   uint32_t const count  = merge->levels[level].count;
885   uint32_t const index  = 1 << level;
886 
887   uint32_t       s_rows = s_pairs * 2;
888   uint32_t       base   = 0;
889 
890   while (s_rows > 0)
891     {
892       uint32_t active = merge->warps;
893 
894       // disable warps if necessary
895       if (merge->warps > s_rows) {
896         active = s_rows;
897         ops    = hsg_op(ops,BX_MERGE_H_PRED(active));
898       }
899 
900       // body begin
901       ops = hsg_begin(ops);
902 
903       // how many equal number of rows to merge?
904       uint32_t loops = s_rows / active;
905 
906       // decrement
907       s_rows -= loops * active;
908 
909       for (uint32_t ss=0; ss<loops; ss++)
910         {
911           // load all registers
912           for (uint32_t ii=0; ii<count; ii++)
913             {
914               // body begin
915               ops = hsg_begin(ops);
916 
917               uint32_t const offset  = merge->offsets [index+ii];
918               uint32_t const network = merge->networks[index+ii];
919               uint32_t const lo      = (network + 1) / 2;
920 
921               for (uint32_t ll=1; ll<=lo; ll++)
922                 ops = hsg_op(ops,BS_REG_SHARED_LOAD_LEFT(ll,base+offset+ll-1,ii));
923 
924               for (uint32_t rr=lo+1; rr<=network; rr++)
925                 ops = hsg_op(ops,BS_REG_SHARED_LOAD_RIGHT(rr,base+offset+rr-1,ii));
926 
927               // compare left and right
928               ops = hsg_thread_merge_left_right_prefix(ops,lo,network-lo,ii);
929 
930               // right merging network
931               ops = hsg_thread_merge_offset_prefix(ops,lo,network-lo,ii);
932 
933               // left merging network
934               ops = hsg_thread_merge_prefix(ops,lo,ii);
935 
936               for (uint32_t ll=1; ll<=lo; ll++)
937                 ops = hsg_op(ops,BX_REG_SHARED_STORE_LEFT(ll,base+offset+ll-1,ii));
938 
939               for (uint32_t rr=lo+1; rr<=network; rr++)
940                 ops = hsg_op(ops,BS_REG_SHARED_STORE_RIGHT(rr,base+offset+rr-1,ii));
941 
942               // body end
943               ops = hsg_end(ops);
944             }
945 
946           base += active * merge->warps;
947         }
948 
949       // body end
950       ops = hsg_end(ops);
951     }
952 
953   return ops;
954 }
955 
956 static
957 struct hsg_op *
hsg_bs_flip_merge(struct hsg_op * ops,struct hsg_merge const * const merge)958 hsg_bs_flip_merge(struct hsg_op * ops, struct hsg_merge const * const merge)
959 {
960   // set up horizontal smem pointer
961   ops = hsg_op(ops,BS_MERGE_H_PREAMBLE(merge->index));
962 
963   // begin merge
964   uint32_t level = MERGE_LEVELS_MAX_LOG2;
965 
966   while (level-- > 0)
967     {
968       uint32_t const count = merge->levels[level].count;
969 
970       if (count == 0)
971         continue;
972 
973       uint32_t const r_mid       = hsg_config.thread.regs/2 + 1;
974       uint32_t const s_pairs_max = merge->rows_bs/2; // this is warp mod
975 
976       // for all the registers
977       for (uint32_t r_lo=1; r_lo<r_mid; r_lo+=s_pairs_max)
978         {
979           uint32_t r_hi = hsg_config.thread.regs + 1 - r_lo;
980 
981           // compute store count
982           uint32_t const s_pairs = MIN_MACRO(s_pairs_max,r_mid - r_lo);
983 
984           // store rows to shared
985           for (uint32_t c=0; c<s_pairs; c++)
986             {
987               ops = hsg_op(ops,BS_REG_SHARED_STORE_V(merge->index,r_lo+c,c*2+0));
988               ops = hsg_op(ops,BS_REG_SHARED_STORE_V(merge->index,r_hi-c,c*2+1));
989             }
990 
991           // block sync
992           ops = hsg_op(ops,BLOCK_SYNC());
993 
994           // merge loop
995           ops = hsg_bs_flip_merge_level(ops,merge,level,s_pairs);
996 
997           // block sync
998           ops = hsg_op(ops,BLOCK_SYNC());
999 
1000           // load rows from shared
1001           for (uint32_t c=0; c<s_pairs; c++)
1002             {
1003               ops = hsg_op(ops,BS_REG_SHARED_LOAD_V(merge->index,r_lo+c,c*2+0));
1004               ops = hsg_op(ops,BS_REG_SHARED_LOAD_V(merge->index,r_hi-c,c*2+1));
1005             }
1006         }
1007 
1008       // conditionally clean -- no-op if equal to number of warps/block
1009       if (merge->levels[level].active.b64 != BITS_TO_MASK_64(merge->warps))
1010         ops = hsg_op(ops,BS_ACTIVE_PRED(merge->index,level));
1011 
1012       // clean warp
1013       ops = hsg_begin(ops);
1014       ops = hsg_warp_half(ops,hsg_config.warp.lanes);
1015       ops = hsg_end(ops);
1016     }
1017 
1018   return ops;
1019 }
1020 
1021 /*
1022 
1023 //
1024 // DELETE ME WHEN READY
1025 //
1026 
1027 static
1028 struct hsg_op *
1029 hsg_bs_flip_merge_all(struct hsg_op * ops, const struct hsg_merge * const merge)
1030 {
1031   for (uint32_t merge_idx=0; merge_idx<MERGE_LEVELS_MAX_LOG2; merge_idx++)
1032     {
1033       const struct hsg_merge* const m = merge + merge_idx;
1034 
1035       if (m->warps < 2)
1036         break;
1037 
1038       ops = hsg_op(ops,BS_FRAC_PRED(merge_idx,m->warps));
1039       ops = hsg_begin(ops);
1040       ops = hsg_bs_flip_merge(ops,m);
1041       ops = hsg_end(ops);
1042     }
1043 
1044   return ops;
1045 }
1046 */
1047 
1048 //
1049 // GENERATE SORT KERNEL
1050 //
1051 
1052 static
1053 struct hsg_op *
hsg_bs_sort(struct hsg_op * ops,struct hsg_merge const * const merge)1054 hsg_bs_sort(struct hsg_op * ops, struct hsg_merge const * const merge)
1055 {
1056   // func proto
1057   ops = hsg_op(ops,BS_KERNEL_PROTO(merge->index));
1058 
1059   // begin
1060   ops = hsg_begin(ops);
1061 
1062   // shared declare
1063   ops = hsg_op(ops,BS_KERNEL_PREAMBLE(merge->index));
1064 
1065   // load
1066   ops = hsg_bx_warp_load(ops,0); // 0 = load from vin[]
1067 
1068   // thread sorting network
1069   ops = hsg_thread_sort(ops);
1070 
1071   // warp merging network
1072   ops = hsg_warp_merge(ops);
1073 
1074   // slab merging network
1075   if (merge->warps > 1)
1076     ops = hsg_bs_flip_merge(ops,merge);
1077 
1078   // store
1079   ops = hsg_bx_warp_store(ops);
1080 
1081   // end
1082   ops = hsg_end(ops);
1083 
1084   return ops;
1085 }
1086 
1087 //
1088 // GENERATE SORT KERNELS
1089 //
1090 
1091 static
1092 struct hsg_op *
hsg_bs_sort_all(struct hsg_op * ops)1093 hsg_bs_sort_all(struct hsg_op * ops)
1094 {
1095   uint32_t merge_idx = MERGE_LEVELS_MAX_LOG2;
1096 
1097   while (merge_idx-- > 0)
1098     {
1099       struct hsg_merge const * const m = hsg_merge + merge_idx;
1100 
1101       if (m->warps == 0)
1102         continue;
1103 
1104       ops = hsg_bs_sort(ops,m);
1105     }
1106 
1107   return ops;
1108 }
1109 
1110 //
1111 // GENERATE CLEAN KERNEL FOR A POWER-OF-TWO
1112 //
1113 
1114 static
1115 struct hsg_op *
hsg_bc_clean(struct hsg_op * ops,struct hsg_merge const * const merge)1116 hsg_bc_clean(struct hsg_op * ops, struct hsg_merge const * const merge)
1117 {
1118   // func proto
1119   ops = hsg_op(ops,BC_KERNEL_PROTO(merge->index));
1120 
1121   // begin
1122   ops = hsg_begin(ops);
1123 
1124   // shared declare
1125   ops = hsg_op(ops,BC_KERNEL_PREAMBLE(merge->index));
1126 
1127   // if warps == 1 then smem isn't used for merging
1128   if (merge->warps == 1)
1129     {
1130       // load slab directly
1131       ops = hsg_bx_warp_load(ops,1); // load from vout[]
1132     }
1133   else
1134     {
1135       // block merging network -- strided load of slabs
1136       ops = hsg_bc_half_merge(ops,merge);
1137     }
1138 
1139   // clean warp
1140   ops = hsg_begin(ops);
1141   ops = hsg_warp_half(ops,hsg_config.warp.lanes);
1142   ops = hsg_end(ops);
1143 
1144   // store
1145   ops = hsg_bx_warp_store(ops);
1146 
1147   // end
1148   ops = hsg_end(ops);
1149 
1150   return ops;
1151 }
1152 
1153 //
1154 // GENERATE CLEAN KERNELS
1155 //
1156 
1157 static
1158 struct hsg_op *
hsg_bc_clean_all(struct hsg_op * ops)1159 hsg_bc_clean_all(struct hsg_op * ops)
1160 {
1161   uint32_t merge_idx = MERGE_LEVELS_MAX_LOG2;
1162 
1163   while (merge_idx-- > 0)
1164     {
1165       struct hsg_merge const * const m = hsg_merge + merge_idx;
1166 
1167       if (m->warps == 0)
1168         continue;
1169 
1170       // only generate pow2 clean kernels less than or equal to max
1171       // warps in block with the assumption that we would've generated
1172       // a wider sort kernel if we could've so a wider clean kernel
1173       // isn't a feasible size
1174       if (!is_pow2_u32(m->warps))
1175         continue;
1176 
1177       ops = hsg_bc_clean(ops,m);
1178     }
1179 
1180   return ops;
1181 }
1182 
1183 //
1184 // GENERATE FLIP MERGE KERNEL
1185 //
1186 
1187 static
1188 struct hsg_op *
hsg_fm_thread_load_left(struct hsg_op * ops,uint32_t const n)1189 hsg_fm_thread_load_left(struct hsg_op * ops, uint32_t const n)
1190 {
1191   for (uint32_t r=1; r<=n; r++)
1192     ops = hsg_op(ops,FM_REG_GLOBAL_LOAD_LEFT(r,r-1));
1193 
1194   return ops;
1195 }
1196 
1197 static
1198 struct hsg_op *
hsg_fm_thread_store_left(struct hsg_op * ops,uint32_t const n)1199 hsg_fm_thread_store_left(struct hsg_op * ops, uint32_t const n)
1200 {
1201   for (uint32_t r=1; r<=n; r++)
1202     ops = hsg_op(ops,FM_REG_GLOBAL_STORE_LEFT(r,r-1));
1203 
1204   return ops;
1205 }
1206 
1207 static
1208 struct hsg_op *
hsg_fm_thread_load_right(struct hsg_op * ops,uint32_t const half_span,uint32_t const half_case)1209 hsg_fm_thread_load_right(struct hsg_op * ops, uint32_t const half_span, uint32_t const half_case)
1210 {
1211   for (uint32_t r=0; r<half_case; r++)
1212     ops = hsg_op(ops,FM_REG_GLOBAL_LOAD_RIGHT(r,half_span+1+r));
1213 
1214   return ops;
1215 }
1216 
1217 static
1218 struct hsg_op *
hsg_fm_thread_store_right(struct hsg_op * ops,uint32_t const half_span,uint32_t const half_case)1219 hsg_fm_thread_store_right(struct hsg_op * ops, uint32_t const half_span, uint32_t const half_case)
1220 {
1221   for (uint32_t r=0; r<half_case; r++)
1222     ops = hsg_op(ops,FM_REG_GLOBAL_STORE_RIGHT(r,half_span+1+r));
1223 
1224   return ops;
1225 }
1226 
1227 static
1228 struct hsg_op *
hsg_fm_merge(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const span_left,uint32_t const span_right)1229 hsg_fm_merge(struct hsg_op * ops,
1230              uint32_t const scale_log2,
1231              uint32_t const span_left,
1232              uint32_t const span_right)
1233 {
1234   // func proto
1235   ops = hsg_op(ops,FM_KERNEL_PROTO(scale_log2,msb_idx_u32(pow2_ru_u32(span_right))));
1236 
1237   // begin
1238   ops = hsg_begin(ops);
1239 
1240   // preamble for loading/storing
1241   ops = hsg_op(ops,FM_KERNEL_PREAMBLE(span_left,span_right));
1242 
1243   // load left span
1244   ops = hsg_fm_thread_load_left(ops,span_left);
1245 
1246   // load right span
1247   ops = hsg_fm_thread_load_right(ops,span_left,span_right);
1248 
1249   // compare left and right
1250   ops = hsg_thread_merge_left_right(ops,span_left,span_right);
1251 
1252   // left merging network
1253   ops = hsg_thread_merge(ops,span_left);
1254 
1255   // right merging network
1256   ops = hsg_thread_merge_offset(ops,span_left,span_right);
1257 
1258   // store
1259   ops = hsg_fm_thread_store_left(ops,span_left);
1260 
1261   // store
1262   ops = hsg_fm_thread_store_right(ops,span_left,span_right);
1263 
1264   // end
1265   ops = hsg_end(ops);
1266 
1267   return ops;
1268 }
1269 
1270 static
1271 struct hsg_op *
hsg_fm_merge_all(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const warps)1272 hsg_fm_merge_all(struct hsg_op * ops, uint32_t const scale_log2, uint32_t const warps)
1273 {
1274   uint32_t const span_left    = (warps << scale_log2) / 2;
1275   uint32_t const span_left_ru = pow2_ru_u32(span_left);
1276 
1277   for (uint32_t span_right=1; span_right<=span_left_ru; span_right*=2)
1278     ops = hsg_fm_merge(ops,scale_log2,span_left,MIN_MACRO(span_left,span_right));
1279 
1280   return ops;
1281 }
1282 
1283 //
1284 // GENERATE HALF MERGE KERNELS
1285 //
1286 
1287 static
1288 struct hsg_op *
hsg_hm_thread_load(struct hsg_op * ops,uint32_t const n)1289 hsg_hm_thread_load(struct hsg_op * ops, uint32_t const n)
1290 {
1291   for (uint32_t r=1; r<=n; r++)
1292     ops = hsg_op(ops,HM_REG_GLOBAL_LOAD(r,r-1));
1293 
1294   return ops;
1295 }
1296 
1297 static
1298 struct hsg_op *
hsg_hm_thread_store(struct hsg_op * ops,uint32_t const n)1299 hsg_hm_thread_store(struct hsg_op * ops, uint32_t const n)
1300 {
1301   for (uint32_t r=1; r<=n; r++)
1302     ops = hsg_op(ops,HM_REG_GLOBAL_STORE(r,r-1));
1303 
1304   return ops;
1305 }
1306 
1307 static
1308 struct hsg_op *
hsg_hm_merge(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const warps_pow2)1309 hsg_hm_merge(struct hsg_op * ops, uint32_t const scale_log2, uint32_t const warps_pow2)
1310 {
1311   uint32_t const span = warps_pow2 << scale_log2;
1312 
1313   // func proto
1314   ops = hsg_op(ops,HM_KERNEL_PROTO(scale_log2));
1315 
1316   // begin
1317   ops = hsg_begin(ops);
1318 
1319   // preamble for loading/storing
1320   ops = hsg_op(ops,HM_KERNEL_PREAMBLE(span/2));
1321 
1322   // load
1323   ops = hsg_hm_thread_load(ops,span);
1324 
1325   // thread merging network
1326   ops = hsg_thread_merge(ops,span);
1327 
1328   // store
1329   ops = hsg_hm_thread_store(ops,span);
1330 
1331   // end
1332   ops = hsg_end(ops);
1333 
1334   return ops;
1335 }
1336 
1337 //
1338 // GENERATE MERGE KERNELS
1339 //
1340 
1341 static
1342 struct hsg_op *
hsg_xm_merge_all(struct hsg_op * ops)1343 hsg_xm_merge_all(struct hsg_op * ops)
1344 {
1345   uint32_t const warps      = hsg_merge[0].warps;
1346   uint32_t const warps_pow2 = pow2_rd_u32(warps);
1347 
1348   //
1349   // GENERATE FLIP MERGE KERNELS
1350   //
1351   for (uint32_t scale_log2=hsg_config.merge.flip.lo; scale_log2<=hsg_config.merge.flip.hi; scale_log2++)
1352     ops = hsg_fm_merge_all(ops,scale_log2,warps);
1353 
1354   //
1355   // GENERATE HALF MERGE KERNELS
1356   //
1357   for (uint32_t scale_log2=hsg_config.merge.half.lo; scale_log2<=hsg_config.merge.half.hi; scale_log2++)
1358     ops = hsg_hm_merge(ops,scale_log2,warps_pow2);
1359 
1360   return ops;
1361 }
1362 
1363 //
1364 //
1365 //
1366 
1367 static
1368 struct hsg_op const *
hsg_op_translate_depth(hsg_target_pfn target_pfn,struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * ops,uint32_t const depth)1369 hsg_op_translate_depth(hsg_target_pfn                  target_pfn,
1370                        struct hsg_target       * const target,
1371                        struct hsg_config const * const config,
1372                        struct hsg_merge  const * const merge,
1373                        struct hsg_op     const *       ops,
1374                        uint32_t                  const depth)
1375 {
1376   while (ops->type != HSG_OP_TYPE_EXIT)
1377     {
1378       switch (ops->type)
1379         {
1380         case HSG_OP_TYPE_END:
1381           target_pfn(target,config,merge,ops,depth-1);
1382           return ops + 1;
1383 
1384         case HSG_OP_TYPE_BEGIN:
1385           target_pfn(target,config,merge,ops,depth);
1386           ops = hsg_op_translate_depth(target_pfn,target,config,merge,ops+1,depth+1);
1387           break;
1388 
1389         default:
1390           target_pfn(target,config,merge,ops++,depth);
1391         }
1392     }
1393 
1394   return ops;
1395 }
1396 
1397 static
1398 void
hsg_op_translate(hsg_target_pfn target_pfn,struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * ops)1399 hsg_op_translate(hsg_target_pfn                  target_pfn,
1400                  struct hsg_target       * const target,
1401                  struct hsg_config const * const config,
1402                  struct hsg_merge  const * const merge,
1403                  struct hsg_op     const *       ops)
1404 {
1405   hsg_op_translate_depth(target_pfn,target,config,merge,ops,0);
1406 }
1407 
1408 //
1409 //
1410 //
1411 
1412 int
main(int argc,char * argv[])1413 main(int argc, char * argv[])
1414 {
1415   //
1416   // PROCESS OPTIONS
1417   //
1418   int32_t           opt      = 0;
1419   bool              verbose  = false;
1420   bool              autotune = false;
1421   char const *      arch     = "undefined";
1422   struct hsg_target target   = { .define = NULL };
1423 
1424   while ((opt = getopt(argc,argv,"hva:g:G:s:S:w:b:B:m:M:k:r:x:t:f:F:c:C:p:P:D:z")) != EOF)
1425     {
1426       switch (opt)
1427         {
1428         case 'h':
1429           fprintf(stderr,"Help goes here...\n");
1430           return EXIT_FAILURE;
1431 
1432         case 'v':
1433           verbose = true;
1434           break;
1435 
1436         case 'a':
1437           arch = optarg;
1438           break;
1439 
1440         case 'g':
1441           hsg_config.block.smem_min = atoi(optarg);
1442           break;
1443 
1444         case 'G':
1445           hsg_config.block.smem_quantum = atoi(optarg);
1446           break;
1447 
1448         case 's':
1449           hsg_config.block.smem_bs = atoi(optarg);
1450 
1451           // set smem_bc if not already set
1452           if (hsg_config.block.smem_bc == UINT32_MAX)
1453             hsg_config.block.smem_bc = hsg_config.block.smem_bs;
1454           break;
1455 
1456         case 'S':
1457           hsg_config.block.smem_bc = atoi(optarg);
1458           break;
1459 
1460         case 'w':
1461           hsg_config.warp.lanes      = atoi(optarg);
1462           hsg_config.warp.lanes_log2 = msb_idx_u32(hsg_config.warp.lanes);
1463           break;
1464 
1465         case 'b':
1466           // maximum warps in a workgroup / cta / thread block
1467           {
1468             uint32_t const warps = atoi(optarg);
1469 
1470             // must always be even
1471             if ((warps & 1) != 0)
1472               {
1473                 fprintf(stderr,"Error: -b must be even.\n");
1474                 return EXIT_FAILURE;
1475               }
1476 
1477             hsg_merge[0].index = 0;
1478             hsg_merge[0].warps = warps;
1479 
1480             // set warps_max if not already set
1481             if (hsg_config.block.warps_max == UINT32_MAX)
1482               hsg_config.block.warps_max = pow2_ru_u32(warps);
1483           }
1484           break;
1485 
1486         case 'B':
1487           // maximum warps that can fit in a multiprocessor
1488           hsg_config.block.warps_max = atoi(optarg);
1489           break;
1490 
1491         case 'm':
1492           // blocks using smem barriers must have at least this many warps
1493           hsg_config.block.warps_min = atoi(optarg);
1494           break;
1495 
1496         case 'M':
1497           // the number of warps necessary to load balance horizontal merging
1498           hsg_config.block.warps_mod = atoi(optarg);
1499           break;
1500 
1501         case 'r':
1502           {
1503             uint32_t const regs = atoi(optarg);
1504 
1505             if ((regs & 1) != 0)
1506               {
1507                 fprintf(stderr,"Error: -r must be even.\n");
1508                 return EXIT_FAILURE;
1509               }
1510 
1511             hsg_config.thread.regs = regs;
1512           }
1513           break;
1514 
1515         case 'x':
1516           hsg_config.thread.xtra      = atoi(optarg);
1517           break;
1518 
1519         case 't':
1520           hsg_config.type.words       = atoi(optarg);
1521           break;
1522 
1523         case 'f':
1524           hsg_config.merge.flip.lo    = atoi(optarg);
1525           break;
1526 
1527         case 'F':
1528           hsg_config.merge.flip.hi    = atoi(optarg);
1529           break;
1530 
1531         case 'c':
1532           hsg_config.merge.half.lo    = atoi(optarg);
1533           break;
1534 
1535         case 'C':
1536           hsg_config.merge.half.hi    = atoi(optarg);
1537           break;
1538 
1539         case 'p':
1540           hsg_config.merge.flip.warps = atoi(optarg);
1541           break;
1542 
1543         case 'P':
1544           hsg_config.merge.half.warps = atoi(optarg);
1545           break;
1546 
1547         case 'D':
1548           target.define = optarg;
1549           break;
1550 
1551         case 'z':
1552           autotune = true;
1553           break;
1554         }
1555     }
1556 
1557   //
1558   // INIT MERGE
1559   //
1560   uint32_t const warps_ru_pow2 = pow2_ru_u32(hsg_merge[0].warps);
1561 
1562   for (uint32_t ii=1; ii<MERGE_LEVELS_MAX_LOG2; ii++)
1563     {
1564       hsg_merge[ii].index = ii;
1565       hsg_merge[ii].warps = warps_ru_pow2 >> ii;
1566     }
1567 
1568   //
1569   // WHICH ARCH TARGET?
1570   //
1571   hsg_target_pfn hsg_target_pfn;
1572 
1573   if      (strcmp(arch,"debug") == 0)
1574     hsg_target_pfn = hsg_target_debug;
1575   else if (strcmp(arch,"cuda") == 0)
1576     hsg_target_pfn = hsg_target_cuda;
1577   else if (strcmp(arch,"opencl") == 0)
1578     hsg_target_pfn = hsg_target_opencl;
1579   else if (strcmp(arch,"glsl") == 0)
1580     hsg_target_pfn = hsg_target_glsl;
1581   else {
1582     fprintf(stderr,"Invalid arch: %s\n",arch);
1583     exit(EXIT_FAILURE);
1584   }
1585 
1586   if (verbose)
1587     fprintf(stderr,"Target: %s\n",arch);
1588 
1589   //
1590   // INIT SMEM KEY ALLOCATION
1591   //
1592   hsg_config_init_shared();
1593 
1594   //
1595   // INIT MERGE MAGIC
1596   //
1597   for (uint32_t ii=0; ii<MERGE_LEVELS_MAX_LOG2; ii++)
1598     {
1599       struct hsg_merge * const merge = hsg_merge + ii;
1600 
1601       if (merge->warps == 0)
1602         break;
1603 
1604       fprintf(stderr,">>> Generating: %1u %5u %5u %3u %3u ...\n",
1605               hsg_config.type.words,
1606               hsg_config.block.smem_bs,
1607               hsg_config.block.smem_bc,
1608               hsg_config.thread.regs,
1609               merge->warps);
1610 
1611       hsg_merge_levels_init_shared(merge);
1612 
1613       hsg_merge_levels_init_1(merge,merge->warps,0,0);
1614 
1615       hsg_merge_levels_hint(merge,autotune);
1616 
1617       //
1618       // THESE ARE FOR DEBUG/INSPECTION
1619       //
1620       if (verbose)
1621         {
1622           hsg_merge_levels_debug(merge);
1623         }
1624     }
1625 
1626   if (verbose)
1627     fprintf(stderr,"\n\n");
1628 
1629   //
1630   // GENERATE THE OPCODES
1631   //
1632   uint32_t        const op_count  = 1<<17;
1633   struct hsg_op * const ops_begin = malloc(sizeof(*ops_begin) * op_count);
1634   struct hsg_op *       ops       = ops_begin;
1635 
1636   //
1637   // OPEN INITIAL FILES AND APPEND HEADER
1638   //
1639   ops = hsg_op(ops,TARGET_BEGIN());
1640 
1641   //
1642   // GENERATE SORT KERNEL
1643   //
1644   ops = hsg_bs_sort_all(ops);
1645 
1646   //
1647   // GENERATE CLEAN KERNELS
1648   //
1649   ops = hsg_bc_clean_all(ops);
1650 
1651   //
1652   // GENERATE MERGE KERNELS
1653   //
1654   ops = hsg_xm_merge_all(ops);
1655 
1656   //
1657   // GENERATE TRANSPOSE KERNEL
1658   //
1659   ops = hsg_warp_transpose(ops);
1660 
1661   //
1662   // APPEND FOOTER AND CLOSE INITIAL FILES
1663   //
1664   ops = hsg_op(ops,TARGET_END());
1665 
1666   //
1667   // ... WE'RE DONE!
1668   //
1669   ops = hsg_exit(ops);
1670 
1671   //
1672   // APPLY TARGET TRANSLATOR TO ACCUMULATED OPS
1673   //
1674   hsg_op_translate(hsg_target_pfn,&target,&hsg_config,hsg_merge,ops_begin);
1675 
1676   //
1677   // DUMP INSTRUCTION COUNTS
1678   //
1679   if (verbose)
1680     hsg_op_debug();
1681 
1682   return EXIT_SUCCESS;
1683 }
1684 
1685 //
1686 //
1687 //
1688