1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 *
7 */
8
9 //
10 //
11 //
12
13 #include <stdlib.h>
14 #include <stdbool.h>
15 #include <string.h>
16 #include <getopt.h>
17 #include <inttypes.h>
18
19 //
20 //
21 //
22
23 #include "networks.h"
24 #include "common/util.h"
25 #include "common/macros.h"
26
27 //
28 //
29 //
30
31 #undef HSG_OP_EXPAND_X
32 #define HSG_OP_EXPAND_X(t) #t ,
33
34 char const * const
35 hsg_op_type_string[] =
36 {
37 HSG_OP_EXPAND_ALL()
38 };
39
40 //
41 //
42 //
43
44 #define EXIT() (struct hsg_op){ HSG_OP_TYPE_EXIT }
45
46 #define END() (struct hsg_op){ HSG_OP_TYPE_END }
47 #define BEGIN() (struct hsg_op){ HSG_OP_TYPE_BEGIN }
48 #define ELSE() (struct hsg_op){ HSG_OP_TYPE_ELSE }
49
50 #define TARGET_BEGIN() (struct hsg_op){ HSG_OP_TYPE_TARGET_BEGIN }
51 #define TARGET_END() (struct hsg_op){ HSG_OP_TYPE_TARGET_END }
52
53 #define TRANSPOSE_KERNEL_PROTO() (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO }
54 #define TRANSPOSE_KERNEL_PREAMBLE() (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE }
55 #define TRANSPOSE_KERNEL_BODY() (struct hsg_op){ HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY }
56
57 #define BS_KERNEL_PROTO(i) (struct hsg_op){ HSG_OP_TYPE_BS_KERNEL_PROTO, { i } }
58 #define BS_KERNEL_PREAMBLE(i) (struct hsg_op){ HSG_OP_TYPE_BS_KERNEL_PREAMBLE, { i } }
59
60 #define BC_KERNEL_PROTO(i) (struct hsg_op){ HSG_OP_TYPE_BC_KERNEL_PROTO, { i } }
61 #define BC_KERNEL_PREAMBLE(i) (struct hsg_op){ HSG_OP_TYPE_BC_KERNEL_PREAMBLE, { i } }
62
63 #define FM_KERNEL_PROTO(s,r) (struct hsg_op){ HSG_OP_TYPE_FM_KERNEL_PROTO, { s, r } }
64 #define FM_KERNEL_PREAMBLE(l,r) (struct hsg_op){ HSG_OP_TYPE_FM_KERNEL_PREAMBLE, { l, r } }
65
66 #define HM_KERNEL_PROTO(s) (struct hsg_op){ HSG_OP_TYPE_HM_KERNEL_PROTO, { s } }
67 #define HM_KERNEL_PREAMBLE(l) (struct hsg_op){ HSG_OP_TYPE_HM_KERNEL_PREAMBLE, { l } }
68
69 #define BX_REG_GLOBAL_LOAD(n,v) (struct hsg_op){ HSG_OP_TYPE_BX_REG_GLOBAL_LOAD, { n, v } }
70 #define BX_REG_GLOBAL_STORE(n) (struct hsg_op){ HSG_OP_TYPE_BX_REG_GLOBAL_STORE, { n } }
71
72 #define FM_REG_GLOBAL_LOAD_LEFT(n,i) (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT, { n, i } }
73 #define FM_REG_GLOBAL_STORE_LEFT(n,i) (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT, { n, i } }
74 #define FM_REG_GLOBAL_LOAD_RIGHT(n,i) (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT, { n, i } }
75 #define FM_REG_GLOBAL_STORE_RIGHT(n,i) (struct hsg_op){ HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT, { n, i } }
76 #define FM_MERGE_RIGHT_PRED(n,s) (struct hsg_op){ HSG_OP_TYPE_FM_MERGE_RIGHT_PRED, { n, s } }
77
78 #define HM_REG_GLOBAL_LOAD(n,i) (struct hsg_op){ HSG_OP_TYPE_HM_REG_GLOBAL_LOAD, { n, i } }
79 #define HM_REG_GLOBAL_STORE(n,i) (struct hsg_op){ HSG_OP_TYPE_HM_REG_GLOBAL_STORE, { n, i } }
80
81 #define SLAB_FLIP(f) (struct hsg_op){ HSG_OP_TYPE_SLAB_FLIP, { f } }
82 #define SLAB_HALF(h) (struct hsg_op){ HSG_OP_TYPE_SLAB_HALF, { h } }
83
84 #define CMP_FLIP(a,b,c) (struct hsg_op){ HSG_OP_TYPE_CMP_FLIP, { a, b, c } }
85 #define CMP_HALF(a,b) (struct hsg_op){ HSG_OP_TYPE_CMP_HALF, { a, b } }
86
87 #define CMP_XCHG(a,b,p) (struct hsg_op){ HSG_OP_TYPE_CMP_XCHG, { a, b, p } }
88
89 #define BS_REG_SHARED_STORE_V(m,i,r) (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_STORE_V, { m, i, r } }
90 #define BS_REG_SHARED_LOAD_V(m,i,r) (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_V, { m, i, r } }
91 #define BC_REG_SHARED_LOAD_V(m,i,r) (struct hsg_op){ HSG_OP_TYPE_BC_REG_SHARED_LOAD_V, { m, i, r } }
92
93 #define BX_REG_SHARED_STORE_LEFT(r,i,p) (struct hsg_op){ HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT, { r, i, p } }
94 #define BS_REG_SHARED_STORE_RIGHT(r,i,p) (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT, { r, i, p } }
95
96 #define BS_REG_SHARED_LOAD_LEFT(r,i,p) (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT, { r, i, p } }
97 #define BS_REG_SHARED_LOAD_RIGHT(r,i,p) (struct hsg_op){ HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT, { r, i, p } }
98
99 #define BC_REG_GLOBAL_LOAD_LEFT(r,i,p) (struct hsg_op){ HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT, { r, i, p } }
100
101 #define REG_F_PREAMBLE(s) (struct hsg_op){ HSG_OP_TYPE_REG_F_PREAMBLE, { s } }
102 #define REG_SHARED_STORE_F(r,i,s) (struct hsg_op){ HSG_OP_TYPE_REG_SHARED_STORE_F, { r, i, s } }
103 #define REG_SHARED_LOAD_F(r,i,s) (struct hsg_op){ HSG_OP_TYPE_REG_SHARED_LOAD_F, { r, i, s } }
104 #define REG_GLOBAL_STORE_F(r,i,s) (struct hsg_op){ HSG_OP_TYPE_REG_GLOBAL_STORE_F, { r, i, s } }
105
106 #define BLOCK_SYNC() (struct hsg_op){ HSG_OP_TYPE_BLOCK_SYNC }
107
108 #define BS_FRAC_PRED(m,w) (struct hsg_op){ HSG_OP_TYPE_BS_FRAC_PRED, { m, w } }
109
110 #define BS_MERGE_H_PREAMBLE(i) (struct hsg_op){ HSG_OP_TYPE_BS_MERGE_H_PREAMBLE, { i } }
111 #define BC_MERGE_H_PREAMBLE(i) (struct hsg_op){ HSG_OP_TYPE_BC_MERGE_H_PREAMBLE, { i } }
112
113 #define BX_MERGE_H_PRED(p) (struct hsg_op){ HSG_OP_TYPE_BX_MERGE_H_PRED, { p } }
114
115 #define BS_ACTIVE_PRED(m,l) (struct hsg_op){ HSG_OP_TYPE_BS_ACTIVE_PRED, { m, l } }
116
117 //
118 // DEFAULTS
119 //
120
121 static
122 struct hsg_config hsg_config =
123 {
124 .merge = {
125 .flip = {
126 .warps = 1,
127 .lo = 1,
128 .hi = 1
129 },
130 .half = {
131 .warps = 1,
132 .lo = 1,
133 .hi = 1
134 },
135 },
136
137 .block = {
138 .warps_min = 1, // min warps for a block that uses smem barriers
139 .warps_max = UINT32_MAX, // max warps for the entire multiprocessor
140 .warps_mod = 2, // the number of warps necessary to load balance horizontal merging
141
142 .smem_min = 0,
143 .smem_quantum = 1,
144
145 .smem_bs = 49152,
146 .smem_bc = UINT32_MAX // implies field not set
147 },
148
149 .warp = {
150 .lanes = 32,
151 .lanes_log2 = 5,
152 },
153
154 .thread = {
155 .regs = 24,
156 .xtra = 0
157 },
158
159 .type = {
160 .words = 2
161 }
162 };
163
164 //
165 // ZERO HSG_MERGE STRUCT
166 //
167
168 static
169 struct hsg_merge hsg_merge[MERGE_LEVELS_MAX_LOG2] = { 0 };
170
171 //
172 // STATS ON INSTRUCTIONS
173 //
174
175 static hsg_op_type hsg_op_type_counts[HSG_OP_TYPE_COUNT] = { 0 };
176
177 //
178 //
179 //
180
181 static
182 void
hsg_op_debug()183 hsg_op_debug()
184 {
185 uint32_t total = 0;
186
187 for (hsg_op_type t=HSG_OP_TYPE_EXIT; t<HSG_OP_TYPE_COUNT; t++)
188 {
189 uint32_t const count = hsg_op_type_counts[t];
190
191 total += count;
192
193 fprintf(stderr,"%-37s : %u\n",hsg_op_type_string[t],count);
194 }
195
196 fprintf(stderr,"%-37s : %u\n\n\n","TOTAL",total);
197 }
198
199 //
200 //
201 //
202
203 static
204 void
hsg_config_init_shared()205 hsg_config_init_shared()
206 {
207 //
208 // The assumption here is that a proper smem_bs value was provided
209 // that represents the maximum fraction of the multiprocessor's
210 // available shared memory that can be accessed by the initial block
211 // sorting kernel.
212 //
213 // With CUDA devices this is 48KB out of 48KB, 64KB or 96KB.
214 //
215 // Intel subslices are a little trickier and the minimum allocation
216 // is 4KB and the maximum is 64KB on pre-Skylake IGPs. Sizes are
217 // allocated in 1KB increments. If a maximum of two block sorters
218 // can occupy a subslice then each should be assigned 32KB of shared
219 // memory.
220 //
221 // News Flash: apparently GEN9+ IGPs can allocate 1KB of SMEM per
222 // workgroup so all the previously written logic to support this
223 // issue is being removed.
224 //
225 uint32_t const bs_keys = hsg_config.block.smem_bs / (hsg_config.type.words * sizeof(uint32_t));
226
227 hsg_config.warp.skpw_bs = bs_keys / hsg_merge[0].warps;
228 }
229
230 static
231 void
hsg_merge_levels_init_shared(struct hsg_merge * const merge)232 hsg_merge_levels_init_shared(struct hsg_merge * const merge)
233 {
234 {
235 //
236 // What is the max amount of shared in each possible bs block config?
237 //
238 // The provided smem_bs size will be allocated for each sorting block.
239 //
240 uint32_t const bs_threads = merge->warps << hsg_config.warp.lanes_log2;
241 uint32_t const bs_keys = hsg_config.block.smem_bs / (hsg_config.type.words * sizeof(uint32_t));
242 uint32_t const bs_kpt = bs_keys / bs_threads;
243 uint32_t const bs_kpt_mod = (bs_kpt / hsg_config.block.warps_mod) * hsg_config.block.warps_mod;
244 uint32_t const bs_rows_even = bs_kpt_mod & ~1; // must be even because flip merge only works on row pairs
245
246 // this is a showstopper
247 if (bs_rows_even < 2)
248 {
249 fprintf(stderr,"Error: need at least 2 rows of shared memory.\n");
250 exit(-1);
251 }
252
253 // clamp to number of registers
254 merge->rows_bs = MIN_MACRO(bs_rows_even, hsg_config.thread.regs);
255 }
256
257 //
258 // smem key allocation rule for BC kernels is that a single block
259 // can't allocate more than smem_bs and must allocate at least
260 // smem_min in smem_quantum steps.
261 //
262 // Note that BC blocks will always be less than or equal to BS
263 // blocks.
264 //
265 {
266 //
267 // if merge->warps is not pow2 then we're going to skip creating a bc elsewhere
268 //
269 uint32_t const bc_warps_min = MAX_MACRO(merge->warps,hsg_config.block.warps_min);
270 uint32_t const bc_threads = bc_warps_min << hsg_config.warp.lanes_log2;
271 uint32_t const bc_block_rd = (((hsg_config.block.smem_bc * bc_warps_min) / hsg_config.block.warps_max) /
272 hsg_config.block.smem_quantum) * hsg_config.block.smem_quantum;
273 uint32_t const bc_block_max = MAX_MACRO(bc_block_rd,hsg_config.block.smem_min);
274 uint32_t const bc_block_smem = MIN_MACRO(bc_block_max,hsg_config.block.smem_bs);
275
276 // what is the max amount of shared in each possible bc block config?
277 uint32_t const bc_keys = bc_block_smem / (hsg_config.type.words * sizeof(uint32_t));
278 uint32_t const bc_kpt = bc_keys / bc_threads;
279 uint32_t const bc_kpt_mod = (bc_kpt / hsg_config.block.warps_mod) * hsg_config.block.warps_mod;
280
281 merge->rows_bc = MIN_MACRO(bc_kpt_mod, hsg_config.thread.regs);
282 merge->skpw_bc = bc_keys / bc_warps_min;
283 }
284 }
285
286 //
287 //
288 //
289
290 static
291 void
hsg_merge_levels_init_1(struct hsg_merge * const merge,uint32_t const warps,uint32_t const level,uint32_t const offset)292 hsg_merge_levels_init_1(struct hsg_merge * const merge, uint32_t const warps, uint32_t const level, uint32_t const offset)
293 {
294 uint32_t const even_odd = warps & 1;
295
296 merge->levels[level].evenodds[even_odd]++;
297 merge->levels[level].networks[even_odd] = warps;
298
299 if (warps == 1)
300 return;
301
302 merge->levels[level].active.b64 |= BITS_TO_MASK_AT_64(warps,offset);
303
304 uint32_t const count = merge->levels[level].count++;
305 uint32_t const index = (1 << level) + count;
306 uint32_t const bit = 1 << count;
307
308 merge->levels[level].evenodd_masks[even_odd] |= bit;
309
310 if (count > 0)
311 {
312 // offset from network to left of this network
313 uint32_t const diff = offset - merge->offsets[index-1];
314
315 uint32_t const diff_0 = merge->levels[level].diffs[0];
316 uint32_t const diff_1 = merge->levels[level].diffs[1];
317
318 uint32_t diff_idx = UINT32_MAX;
319
320 if ((diff_0 == 0) || (diff_0 == diff)) {
321 diff_idx = 0;
322 } else if ((diff_1 == 0) || (diff_1 == diff)) {
323 diff_idx = 1;
324 } else {
325 fprintf(stderr, "*** MORE THAN TWO DIFFS ***\n");
326 exit(-1);
327 }
328
329 merge->levels[level].diffs [diff_idx] = diff;
330 merge->levels[level].diff_masks[diff_idx] |= 1 << (count-1);
331 }
332
333 merge->networks[index] = warps;
334 merge->offsets [index] = offset;
335
336 uint32_t const l = (warps+1)/2; // lower/larger on left
337 uint32_t const r = (warps+0)/2; // higher/smaller on right
338
339 hsg_merge_levels_init_1(merge,l,level+1,offset);
340 hsg_merge_levels_init_1(merge,r,level+1,offset+l);
341 }
342
343 static
344 void
hsg_merge_levels_debug(struct hsg_merge * const merge)345 hsg_merge_levels_debug(struct hsg_merge * const merge)
346 {
347 for (uint32_t level=0; level<MERGE_LEVELS_MAX_LOG2; level++)
348 {
349 uint32_t count = merge->levels[level].count;
350
351 if (count == 0)
352 break;
353
354 fprintf(stderr,
355 "%-4u : %016" PRIX64 " \n",
356 count,
357 merge->levels[level].active.b64);
358
359 fprintf(stderr,
360 "%-4u : %08X (%2u)\n"
361 "%-4u : %08X (%2u)\n",
362 merge->levels[level].diffs[0],
363 merge->levels[level].diff_masks[0],
364 POPCOUNT_MACRO(merge->levels[level].diff_masks[0]),
365 merge->levels[level].diffs[1],
366 merge->levels[level].diff_masks[1],
367 POPCOUNT_MACRO(merge->levels[level].diff_masks[1]));
368
369 fprintf(stderr,
370 "EVEN : %08X (%2u)\n"
371 "ODD : %08X (%2u)\n",
372 merge->levels[level].evenodd_masks[0],
373 POPCOUNT_MACRO(merge->levels[level].evenodd_masks[0]),
374 merge->levels[level].evenodd_masks[1],
375 POPCOUNT_MACRO(merge->levels[level].evenodd_masks[1]));
376
377 for (uint32_t ii=0; ii<2; ii++)
378 {
379 if (merge->levels[level].networks[ii] > 1)
380 {
381 fprintf(stderr,
382 "%-4s : ( %2u x %2u )\n",
383 (ii == 0) ? "EVEN" : "ODD",
384 merge->levels[level].evenodds[ii],
385 merge->levels[level].networks[ii]);
386 }
387 }
388
389 uint32_t index = 1 << level;
390
391 while (count-- > 0)
392 {
393 fprintf(stderr,
394 "[ %2u %2u ] ",
395 merge->offsets [index],
396 merge->networks[index]);
397
398 index += 1;
399 }
400
401 fprintf(stderr,"\n\n");
402 }
403 }
404
405 static
406 void
hsg_merge_levels_hint(struct hsg_merge * const merge,bool const autotune)407 hsg_merge_levels_hint(struct hsg_merge * const merge, bool const autotune)
408 {
409 // clamp against merge levels
410 for (uint32_t level=0; level<MERGE_LEVELS_MAX_LOG2; level++)
411 {
412 // max network
413 uint32_t const n_max = MAX_MACRO(merge->levels[level].networks[0],
414 merge->levels[level].networks[1]);
415
416 if (n_max <= (merge->rows_bs + hsg_config.thread.xtra))
417 break;
418
419 if (autotune)
420 {
421 hsg_config.thread.xtra = n_max - merge->rows_bs;
422
423 uint32_t const r_total = hsg_config.thread.regs + hsg_config.thread.xtra;
424 uint32_t const r_limit = (hsg_config.type.words == 1) ? 120 : 58;
425
426 if (r_total <= r_limit)
427 {
428 fprintf(stderr,"autotune: %u + %u\n",
429 hsg_config.thread.regs,
430 hsg_config.thread.xtra);
431 break;
432 }
433 else
434 {
435 fprintf(stderr,"skipping autotune: %u + %u > %u\n",
436 hsg_config.thread.regs,
437 hsg_config.thread.xtra,
438 r_limit);
439 exit(-1);
440 }
441 }
442
443 fprintf(stderr,"*** HINT *** Try extra registers: %u\n",
444 n_max - merge->rows_bs);
445
446 exit(-1);
447 }
448 }
449
450 //
451 //
452 //
453
454 static
455 struct hsg_op *
hsg_op(struct hsg_op * ops,struct hsg_op const opcode)456 hsg_op(struct hsg_op * ops, struct hsg_op const opcode)
457 {
458 hsg_op_type_counts[opcode.type] += 1;
459
460 *ops = opcode;
461
462 return ops+1;
463 }
464
465 static
466 struct hsg_op *
hsg_exit(struct hsg_op * ops)467 hsg_exit(struct hsg_op * ops)
468 {
469 return hsg_op(ops,EXIT());
470 }
471
472 static
473 struct hsg_op *
hsg_end(struct hsg_op * ops)474 hsg_end(struct hsg_op * ops)
475 {
476 return hsg_op(ops,END());
477 }
478
479 static
480 struct hsg_op *
hsg_begin(struct hsg_op * ops)481 hsg_begin(struct hsg_op * ops)
482 {
483 return hsg_op(ops,BEGIN());
484 }
485
486 static
487 struct hsg_op *
hsg_else(struct hsg_op * ops)488 hsg_else(struct hsg_op * ops)
489 {
490 return hsg_op(ops,ELSE());
491 }
492
493 static
494 struct hsg_op *
hsg_network_copy(struct hsg_op * ops,struct hsg_network const * const nets,uint32_t const idx,uint32_t const prefix)495 hsg_network_copy(struct hsg_op * ops,
496 struct hsg_network const * const nets,
497 uint32_t const idx,
498 uint32_t const prefix)
499 {
500 uint32_t const len = nets[idx].length;
501 struct hsg_op const * const cxa = nets[idx].network;
502
503 for (uint32_t ii=0; ii<len; ii++)
504 {
505 struct hsg_op const * const cx = cxa + ii;
506
507 ops = hsg_op(ops,CMP_XCHG(cx->a,cx->b,prefix));
508 }
509
510 return ops;
511 }
512
513 static
514 struct hsg_op *
hsg_thread_sort(struct hsg_op * ops)515 hsg_thread_sort(struct hsg_op * ops)
516 {
517 uint32_t const idx = hsg_config.thread.regs / 2 - 1;
518
519 return hsg_network_copy(ops,hsg_networks_sorting,idx,UINT32_MAX);
520 }
521
522 static
523 struct hsg_op *
hsg_thread_merge_prefix(struct hsg_op * ops,uint32_t const network,uint32_t const prefix)524 hsg_thread_merge_prefix(struct hsg_op * ops, uint32_t const network, uint32_t const prefix)
525 {
526 if (network <= 1)
527 return ops;
528
529 return hsg_network_copy(ops,hsg_networks_merging,network-2,prefix);
530 }
531
532 static
533 struct hsg_op *
hsg_thread_merge(struct hsg_op * ops,uint32_t const network)534 hsg_thread_merge(struct hsg_op * ops, uint32_t const network)
535 {
536 return hsg_thread_merge_prefix(ops,network,UINT32_MAX);
537 }
538
539 static
540 struct hsg_op *
hsg_thread_merge_offset_prefix(struct hsg_op * ops,uint32_t const offset,uint32_t const network,uint32_t const prefix)541 hsg_thread_merge_offset_prefix(struct hsg_op * ops, uint32_t const offset, uint32_t const network, uint32_t const prefix)
542 {
543 if (network <= 1)
544 return ops;
545
546 uint32_t const idx = network - 2;
547 uint32_t const len = hsg_networks_merging[idx].length;
548 struct hsg_op const * const cxa = hsg_networks_merging[idx].network;
549
550 for (uint32_t ii=0; ii<len; ii++)
551 {
552 struct hsg_op const * const cx = cxa + ii;
553
554 ops = hsg_op(ops,CMP_XCHG(offset + cx->a,offset + cx->b,prefix));
555 }
556
557 return ops;
558 }
559
560 static
561 struct hsg_op *
hsg_thread_merge_offset(struct hsg_op * ops,uint32_t const offset,uint32_t const network)562 hsg_thread_merge_offset(struct hsg_op * ops, uint32_t const offset, uint32_t const network)
563 {
564 return hsg_thread_merge_offset_prefix(ops,offset,network,UINT32_MAX);
565 }
566
567 static
568 struct hsg_op *
hsg_thread_merge_left_right_prefix(struct hsg_op * ops,uint32_t const left,uint32_t const right,uint32_t const prefix)569 hsg_thread_merge_left_right_prefix(struct hsg_op * ops, uint32_t const left, uint32_t const right, uint32_t const prefix)
570 {
571 for (uint32_t l=left,r=left+1; r<=left+right; l--,r++)
572 {
573 ops = hsg_op(ops,CMP_XCHG(l,r,prefix));
574 }
575
576 return ops;
577 }
578
579 static
580 struct hsg_op *
hsg_thread_merge_left_right(struct hsg_op * ops,uint32_t const left,uint32_t const right)581 hsg_thread_merge_left_right(struct hsg_op * ops, uint32_t const left, uint32_t const right)
582 {
583 return hsg_thread_merge_left_right_prefix(ops,left,right,UINT32_MAX);
584 }
585
586 static
587 struct hsg_op *
hsg_warp_half_network(struct hsg_op * ops)588 hsg_warp_half_network(struct hsg_op * ops)
589 {
590 uint32_t const n = hsg_config.thread.regs;
591
592 for (uint32_t r=1; r<=n; r++)
593 ops = hsg_op(ops,CMP_HALF(r-1,r));
594
595 return ops;
596 }
597
598 static
599 struct hsg_op *
hsg_warp_half_downto(struct hsg_op * ops,uint32_t h)600 hsg_warp_half_downto(struct hsg_op * ops, uint32_t h)
601 {
602 //
603 // *** from h: downto[f/2,1)
604 // **** lane_half(h)
605 //
606 for (; h > 1; h/=2)
607 {
608 ops = hsg_begin(ops);
609
610 ops = hsg_op(ops,SLAB_HALF(h));
611 ops = hsg_warp_half_network(ops);
612
613 ops = hsg_end(ops);
614 }
615
616 return ops;
617 }
618
619 static
620 struct hsg_op *
hsg_warp_flip_network(struct hsg_op * ops)621 hsg_warp_flip_network(struct hsg_op * ops)
622 {
623 uint32_t const n = hsg_config.thread.regs;
624
625 for (uint32_t r=1; r<=n/2; r++)
626 ops = hsg_op(ops,CMP_FLIP(r-1,r,n+1-r));
627
628 return ops;
629 }
630
631 static
632 struct hsg_op *
hsg_warp_flip(struct hsg_op * ops,uint32_t f)633 hsg_warp_flip(struct hsg_op * ops, uint32_t f)
634 {
635 ops = hsg_begin(ops);
636
637 ops = hsg_op(ops,SLAB_FLIP(f));
638 ops = hsg_warp_flip_network(ops);
639
640 ops = hsg_end(ops);
641
642 return ops;
643 }
644
645 static
646 struct hsg_op *
hsg_bx_warp_load(struct hsg_op * ops,const int32_t vin_or_vout)647 hsg_bx_warp_load(struct hsg_op * ops, const int32_t vin_or_vout)
648 {
649 uint32_t const n = hsg_config.thread.regs;
650
651 for (uint32_t r=1; r<=n; r++)
652 ops = hsg_op(ops,BX_REG_GLOBAL_LOAD(r,vin_or_vout));
653
654 return ops;
655 }
656
657 static
658 struct hsg_op *
hsg_bx_warp_store(struct hsg_op * ops)659 hsg_bx_warp_store(struct hsg_op * ops)
660 {
661 uint32_t const n = hsg_config.thread.regs;
662
663 for (uint32_t r=1; r<=n; r++)
664 ops = hsg_op(ops,BX_REG_GLOBAL_STORE(r));
665
666 return ops;
667 }
668
669 //
670 //
671 //
672
673 static
674 struct hsg_op *
hsg_warp_transpose(struct hsg_op * ops)675 hsg_warp_transpose(struct hsg_op * ops)
676 {
677 // func proto
678 ops = hsg_op(ops,TRANSPOSE_KERNEL_PROTO());
679
680 // begin
681 ops = hsg_begin(ops);
682
683 // preamble
684 ops = hsg_op(ops,TRANSPOSE_KERNEL_PREAMBLE());
685
686 // load
687 ops = hsg_bx_warp_load(ops,1); // 1 = load from vout[]
688
689 // emit transpose blend and remap macros ...
690 ops = hsg_op(ops,TRANSPOSE_KERNEL_BODY());
691
692 // ... done!
693 ops = hsg_end(ops);
694
695 return ops;
696 }
697
698 //
699 //
700 //
701
702 static
703 struct hsg_op *
hsg_warp_half(struct hsg_op * ops,uint32_t const h)704 hsg_warp_half(struct hsg_op * ops, uint32_t const h)
705 {
706 //
707 // *** from h: downto[f/2,1)
708 // **** lane_half(h)
709 // *** thread_merge
710 //
711 ops = hsg_warp_half_downto(ops,h);
712 ops = hsg_thread_merge(ops,hsg_config.thread.regs);
713
714 return ops;
715 }
716
717 static
718 struct hsg_op *
hsg_warp_merge(struct hsg_op * ops)719 hsg_warp_merge(struct hsg_op * ops)
720 {
721 //
722 // * from f: upto[2,warp.lanes]
723 // ** lane_flip(f)
724 // *** from h: downto[f/2,1)
725 // **** lane_half(h)
726 // *** thread_merge
727 //
728 uint32_t const level = hsg_config.warp.lanes;
729
730 for (uint32_t f=2; f<=level; f*=2)
731 {
732 ops = hsg_warp_flip(ops,f);
733 ops = hsg_warp_half(ops,f/2);
734 }
735
736 return ops;
737 }
738
739 //
740 //
741 //
742
743 static
744 struct hsg_op *
hsg_bc_half_merge_level(struct hsg_op * ops,struct hsg_merge const * const merge,uint32_t const r_lo,uint32_t const s_count)745 hsg_bc_half_merge_level(struct hsg_op * ops,
746 struct hsg_merge const * const merge,
747 uint32_t const r_lo,
748 uint32_t const s_count)
749 {
750 // guaranteed to be an even network
751 uint32_t const net_even = merge->levels[0].networks[0];
752
753 // min of warps in block and remaining horizontal rows
754 uint32_t const active = MIN_MACRO(s_count, net_even);
755
756 // conditional on blockIdx.x
757 if (active < merge->warps)
758 ops = hsg_op(ops,BX_MERGE_H_PRED(active)); // FIXME BX_MERGE
759
760 // body begin
761 ops = hsg_begin(ops);
762
763 // scale for min block
764 uint32_t const scale = net_even >= hsg_config.block.warps_min ? 1 : hsg_config.block.warps_min / net_even;
765
766 // loop if more smem rows than warps
767 for (uint32_t rr=0; rr<s_count; rr+=active)
768 {
769 // body begin
770 ops = hsg_begin(ops);
771
772 // skip down slab
773 uint32_t const gmem_base = r_lo - 1 + rr;
774
775 // load registers horizontally -- striding across slabs
776 for (uint32_t ll=1; ll<=net_even; ll++)
777 ops = hsg_op(ops,BC_REG_GLOBAL_LOAD_LEFT(ll,gmem_base+(ll-1)*hsg_config.thread.regs,0));
778
779 // merge all registers
780 ops = hsg_thread_merge_prefix(ops,net_even,0);
781
782 // if we're looping then there is a base
783 uint32_t const smem_base = rr * net_even * scale;
784
785 // store all registers
786 for (uint32_t ll=1; ll<=net_even; ll++)
787 ops = hsg_op(ops,BX_REG_SHARED_STORE_LEFT(ll,smem_base+ll-1,0));
788
789 // body end
790 ops = hsg_end(ops);
791 }
792
793 // body end
794 ops = hsg_end(ops);
795
796 return ops;
797 }
798
799 static
800 struct hsg_op *
hsg_bc_half_merge(struct hsg_op * ops,struct hsg_merge const * const merge)801 hsg_bc_half_merge(struct hsg_op * ops, struct hsg_merge const * const merge)
802 {
803 //
804 // will only be called with merge->warps >= 2
805 //
806 uint32_t const warps = MAX_MACRO(merge->warps,hsg_config.block.warps_min);
807
808 // guaranteed to be an even network
809 uint32_t const net_even = merge->levels[0].networks[0];
810
811 // set up left SMEM pointer
812 ops = hsg_op(ops,BC_MERGE_H_PREAMBLE(merge->index));
813
814 // trim to number of warps in block -- FIXME -- try make this a
815 // multiple of local processor count (Intel = 8, NVIDIA = 4)
816 uint32_t const s_max = merge->rows_bc;
817
818 // for all the registers
819 for (uint32_t r_lo = 1; r_lo <= hsg_config.thread.regs; r_lo += s_max)
820 {
821 // compute store count
822 uint32_t const r_rem = hsg_config.thread.regs + 1 - r_lo;
823 uint32_t const s_count = MIN_MACRO(s_max,r_rem);
824
825 // block sync -- can skip if first
826 if (r_lo > 1)
827 ops = hsg_op(ops,BLOCK_SYNC());
828
829 // merge loop
830 ops = hsg_bc_half_merge_level(ops,merge,r_lo,s_count);
831
832 // block sync
833 ops = hsg_op(ops,BLOCK_SYNC());
834
835 // load rows from shared
836 for (uint32_t c=0; c<s_count; c++)
837 ops = hsg_op(ops,BC_REG_SHARED_LOAD_V(warps,r_lo+c,c));
838 }
839
840 return ops;
841 }
842
843 //
844 //
845 //
846
847 static
848 struct hsg_op *
hsg_bs_flip_merge_level(struct hsg_op * ops,struct hsg_merge const * const merge,uint32_t const level,uint32_t const s_pairs)849 hsg_bs_flip_merge_level(struct hsg_op * ops,
850 struct hsg_merge const * const merge,
851 uint32_t const level,
852 uint32_t const s_pairs)
853 {
854 //
855 // Note there are a number of ways to flip merge these warps. There
856 // is a magic number in the merge structure that indicates which
857 // warp to activate as well as what network size to invoke.
858 //
859 // This more complex scheme was used in the past.
860 //
861 // The newest scheme is far dumber/simpler and simply directs a warp
862 // to gather up the network associated with a row and merge them.
863 //
864 // This scheme may use more registers per thread but not all
865 // compilers are high quality.
866 //
867 // If there are more warps than smem row pairs to merge then we
868 // disable the spare warps.
869 //
870 // If there are more row pairs than warps then each warp works on
871 // an equal number of rows.
872 //
873 // Note that it takes two warps to flip merge two smem rows.
874 //
875 // FIXME -- We may want to apply the warp smem "mod" value here to
876 // attempt to balance the load>merge>store operations across the
877 // multiprocessor cores.
878 //
879 // FIXME -- the old scheme attempted to keep all the warps active
880 // but the iteration logic was more complex. See 2016 checkins.
881 //
882
883 // where are we in computed merge?
884 uint32_t const count = merge->levels[level].count;
885 uint32_t const index = 1 << level;
886
887 uint32_t s_rows = s_pairs * 2;
888 uint32_t base = 0;
889
890 while (s_rows > 0)
891 {
892 uint32_t active = merge->warps;
893
894 // disable warps if necessary
895 if (merge->warps > s_rows) {
896 active = s_rows;
897 ops = hsg_op(ops,BX_MERGE_H_PRED(active));
898 }
899
900 // body begin
901 ops = hsg_begin(ops);
902
903 // how many equal number of rows to merge?
904 uint32_t loops = s_rows / active;
905
906 // decrement
907 s_rows -= loops * active;
908
909 for (uint32_t ss=0; ss<loops; ss++)
910 {
911 // load all registers
912 for (uint32_t ii=0; ii<count; ii++)
913 {
914 // body begin
915 ops = hsg_begin(ops);
916
917 uint32_t const offset = merge->offsets [index+ii];
918 uint32_t const network = merge->networks[index+ii];
919 uint32_t const lo = (network + 1) / 2;
920
921 for (uint32_t ll=1; ll<=lo; ll++)
922 ops = hsg_op(ops,BS_REG_SHARED_LOAD_LEFT(ll,base+offset+ll-1,ii));
923
924 for (uint32_t rr=lo+1; rr<=network; rr++)
925 ops = hsg_op(ops,BS_REG_SHARED_LOAD_RIGHT(rr,base+offset+rr-1,ii));
926
927 // compare left and right
928 ops = hsg_thread_merge_left_right_prefix(ops,lo,network-lo,ii);
929
930 // right merging network
931 ops = hsg_thread_merge_offset_prefix(ops,lo,network-lo,ii);
932
933 // left merging network
934 ops = hsg_thread_merge_prefix(ops,lo,ii);
935
936 for (uint32_t ll=1; ll<=lo; ll++)
937 ops = hsg_op(ops,BX_REG_SHARED_STORE_LEFT(ll,base+offset+ll-1,ii));
938
939 for (uint32_t rr=lo+1; rr<=network; rr++)
940 ops = hsg_op(ops,BS_REG_SHARED_STORE_RIGHT(rr,base+offset+rr-1,ii));
941
942 // body end
943 ops = hsg_end(ops);
944 }
945
946 base += active * merge->warps;
947 }
948
949 // body end
950 ops = hsg_end(ops);
951 }
952
953 return ops;
954 }
955
956 static
957 struct hsg_op *
hsg_bs_flip_merge(struct hsg_op * ops,struct hsg_merge const * const merge)958 hsg_bs_flip_merge(struct hsg_op * ops, struct hsg_merge const * const merge)
959 {
960 // set up horizontal smem pointer
961 ops = hsg_op(ops,BS_MERGE_H_PREAMBLE(merge->index));
962
963 // begin merge
964 uint32_t level = MERGE_LEVELS_MAX_LOG2;
965
966 while (level-- > 0)
967 {
968 uint32_t const count = merge->levels[level].count;
969
970 if (count == 0)
971 continue;
972
973 uint32_t const r_mid = hsg_config.thread.regs/2 + 1;
974 uint32_t const s_pairs_max = merge->rows_bs/2; // this is warp mod
975
976 // for all the registers
977 for (uint32_t r_lo=1; r_lo<r_mid; r_lo+=s_pairs_max)
978 {
979 uint32_t r_hi = hsg_config.thread.regs + 1 - r_lo;
980
981 // compute store count
982 uint32_t const s_pairs = MIN_MACRO(s_pairs_max,r_mid - r_lo);
983
984 // store rows to shared
985 for (uint32_t c=0; c<s_pairs; c++)
986 {
987 ops = hsg_op(ops,BS_REG_SHARED_STORE_V(merge->index,r_lo+c,c*2+0));
988 ops = hsg_op(ops,BS_REG_SHARED_STORE_V(merge->index,r_hi-c,c*2+1));
989 }
990
991 // block sync
992 ops = hsg_op(ops,BLOCK_SYNC());
993
994 // merge loop
995 ops = hsg_bs_flip_merge_level(ops,merge,level,s_pairs);
996
997 // block sync
998 ops = hsg_op(ops,BLOCK_SYNC());
999
1000 // load rows from shared
1001 for (uint32_t c=0; c<s_pairs; c++)
1002 {
1003 ops = hsg_op(ops,BS_REG_SHARED_LOAD_V(merge->index,r_lo+c,c*2+0));
1004 ops = hsg_op(ops,BS_REG_SHARED_LOAD_V(merge->index,r_hi-c,c*2+1));
1005 }
1006 }
1007
1008 // conditionally clean -- no-op if equal to number of warps/block
1009 if (merge->levels[level].active.b64 != BITS_TO_MASK_64(merge->warps))
1010 ops = hsg_op(ops,BS_ACTIVE_PRED(merge->index,level));
1011
1012 // clean warp
1013 ops = hsg_begin(ops);
1014 ops = hsg_warp_half(ops,hsg_config.warp.lanes);
1015 ops = hsg_end(ops);
1016 }
1017
1018 return ops;
1019 }
1020
1021 /*
1022
1023 //
1024 // DELETE ME WHEN READY
1025 //
1026
1027 static
1028 struct hsg_op *
1029 hsg_bs_flip_merge_all(struct hsg_op * ops, const struct hsg_merge * const merge)
1030 {
1031 for (uint32_t merge_idx=0; merge_idx<MERGE_LEVELS_MAX_LOG2; merge_idx++)
1032 {
1033 const struct hsg_merge* const m = merge + merge_idx;
1034
1035 if (m->warps < 2)
1036 break;
1037
1038 ops = hsg_op(ops,BS_FRAC_PRED(merge_idx,m->warps));
1039 ops = hsg_begin(ops);
1040 ops = hsg_bs_flip_merge(ops,m);
1041 ops = hsg_end(ops);
1042 }
1043
1044 return ops;
1045 }
1046 */
1047
1048 //
1049 // GENERATE SORT KERNEL
1050 //
1051
1052 static
1053 struct hsg_op *
hsg_bs_sort(struct hsg_op * ops,struct hsg_merge const * const merge)1054 hsg_bs_sort(struct hsg_op * ops, struct hsg_merge const * const merge)
1055 {
1056 // func proto
1057 ops = hsg_op(ops,BS_KERNEL_PROTO(merge->index));
1058
1059 // begin
1060 ops = hsg_begin(ops);
1061
1062 // shared declare
1063 ops = hsg_op(ops,BS_KERNEL_PREAMBLE(merge->index));
1064
1065 // load
1066 ops = hsg_bx_warp_load(ops,0); // 0 = load from vin[]
1067
1068 // thread sorting network
1069 ops = hsg_thread_sort(ops);
1070
1071 // warp merging network
1072 ops = hsg_warp_merge(ops);
1073
1074 // slab merging network
1075 if (merge->warps > 1)
1076 ops = hsg_bs_flip_merge(ops,merge);
1077
1078 // store
1079 ops = hsg_bx_warp_store(ops);
1080
1081 // end
1082 ops = hsg_end(ops);
1083
1084 return ops;
1085 }
1086
1087 //
1088 // GENERATE SORT KERNELS
1089 //
1090
1091 static
1092 struct hsg_op *
hsg_bs_sort_all(struct hsg_op * ops)1093 hsg_bs_sort_all(struct hsg_op * ops)
1094 {
1095 uint32_t merge_idx = MERGE_LEVELS_MAX_LOG2;
1096
1097 while (merge_idx-- > 0)
1098 {
1099 struct hsg_merge const * const m = hsg_merge + merge_idx;
1100
1101 if (m->warps == 0)
1102 continue;
1103
1104 ops = hsg_bs_sort(ops,m);
1105 }
1106
1107 return ops;
1108 }
1109
1110 //
1111 // GENERATE CLEAN KERNEL FOR A POWER-OF-TWO
1112 //
1113
1114 static
1115 struct hsg_op *
hsg_bc_clean(struct hsg_op * ops,struct hsg_merge const * const merge)1116 hsg_bc_clean(struct hsg_op * ops, struct hsg_merge const * const merge)
1117 {
1118 // func proto
1119 ops = hsg_op(ops,BC_KERNEL_PROTO(merge->index));
1120
1121 // begin
1122 ops = hsg_begin(ops);
1123
1124 // shared declare
1125 ops = hsg_op(ops,BC_KERNEL_PREAMBLE(merge->index));
1126
1127 // if warps == 1 then smem isn't used for merging
1128 if (merge->warps == 1)
1129 {
1130 // load slab directly
1131 ops = hsg_bx_warp_load(ops,1); // load from vout[]
1132 }
1133 else
1134 {
1135 // block merging network -- strided load of slabs
1136 ops = hsg_bc_half_merge(ops,merge);
1137 }
1138
1139 // clean warp
1140 ops = hsg_begin(ops);
1141 ops = hsg_warp_half(ops,hsg_config.warp.lanes);
1142 ops = hsg_end(ops);
1143
1144 // store
1145 ops = hsg_bx_warp_store(ops);
1146
1147 // end
1148 ops = hsg_end(ops);
1149
1150 return ops;
1151 }
1152
1153 //
1154 // GENERATE CLEAN KERNELS
1155 //
1156
1157 static
1158 struct hsg_op *
hsg_bc_clean_all(struct hsg_op * ops)1159 hsg_bc_clean_all(struct hsg_op * ops)
1160 {
1161 uint32_t merge_idx = MERGE_LEVELS_MAX_LOG2;
1162
1163 while (merge_idx-- > 0)
1164 {
1165 struct hsg_merge const * const m = hsg_merge + merge_idx;
1166
1167 if (m->warps == 0)
1168 continue;
1169
1170 // only generate pow2 clean kernels less than or equal to max
1171 // warps in block with the assumption that we would've generated
1172 // a wider sort kernel if we could've so a wider clean kernel
1173 // isn't a feasible size
1174 if (!is_pow2_u32(m->warps))
1175 continue;
1176
1177 ops = hsg_bc_clean(ops,m);
1178 }
1179
1180 return ops;
1181 }
1182
1183 //
1184 // GENERATE FLIP MERGE KERNEL
1185 //
1186
1187 static
1188 struct hsg_op *
hsg_fm_thread_load_left(struct hsg_op * ops,uint32_t const n)1189 hsg_fm_thread_load_left(struct hsg_op * ops, uint32_t const n)
1190 {
1191 for (uint32_t r=1; r<=n; r++)
1192 ops = hsg_op(ops,FM_REG_GLOBAL_LOAD_LEFT(r,r-1));
1193
1194 return ops;
1195 }
1196
1197 static
1198 struct hsg_op *
hsg_fm_thread_store_left(struct hsg_op * ops,uint32_t const n)1199 hsg_fm_thread_store_left(struct hsg_op * ops, uint32_t const n)
1200 {
1201 for (uint32_t r=1; r<=n; r++)
1202 ops = hsg_op(ops,FM_REG_GLOBAL_STORE_LEFT(r,r-1));
1203
1204 return ops;
1205 }
1206
1207 static
1208 struct hsg_op *
hsg_fm_thread_load_right(struct hsg_op * ops,uint32_t const half_span,uint32_t const half_case)1209 hsg_fm_thread_load_right(struct hsg_op * ops, uint32_t const half_span, uint32_t const half_case)
1210 {
1211 for (uint32_t r=0; r<half_case; r++)
1212 ops = hsg_op(ops,FM_REG_GLOBAL_LOAD_RIGHT(r,half_span+1+r));
1213
1214 return ops;
1215 }
1216
1217 static
1218 struct hsg_op *
hsg_fm_thread_store_right(struct hsg_op * ops,uint32_t const half_span,uint32_t const half_case)1219 hsg_fm_thread_store_right(struct hsg_op * ops, uint32_t const half_span, uint32_t const half_case)
1220 {
1221 for (uint32_t r=0; r<half_case; r++)
1222 ops = hsg_op(ops,FM_REG_GLOBAL_STORE_RIGHT(r,half_span+1+r));
1223
1224 return ops;
1225 }
1226
1227 static
1228 struct hsg_op *
hsg_fm_merge(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const span_left,uint32_t const span_right)1229 hsg_fm_merge(struct hsg_op * ops,
1230 uint32_t const scale_log2,
1231 uint32_t const span_left,
1232 uint32_t const span_right)
1233 {
1234 // func proto
1235 ops = hsg_op(ops,FM_KERNEL_PROTO(scale_log2,msb_idx_u32(pow2_ru_u32(span_right))));
1236
1237 // begin
1238 ops = hsg_begin(ops);
1239
1240 // preamble for loading/storing
1241 ops = hsg_op(ops,FM_KERNEL_PREAMBLE(span_left,span_right));
1242
1243 // load left span
1244 ops = hsg_fm_thread_load_left(ops,span_left);
1245
1246 // load right span
1247 ops = hsg_fm_thread_load_right(ops,span_left,span_right);
1248
1249 // compare left and right
1250 ops = hsg_thread_merge_left_right(ops,span_left,span_right);
1251
1252 // left merging network
1253 ops = hsg_thread_merge(ops,span_left);
1254
1255 // right merging network
1256 ops = hsg_thread_merge_offset(ops,span_left,span_right);
1257
1258 // store
1259 ops = hsg_fm_thread_store_left(ops,span_left);
1260
1261 // store
1262 ops = hsg_fm_thread_store_right(ops,span_left,span_right);
1263
1264 // end
1265 ops = hsg_end(ops);
1266
1267 return ops;
1268 }
1269
1270 static
1271 struct hsg_op *
hsg_fm_merge_all(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const warps)1272 hsg_fm_merge_all(struct hsg_op * ops, uint32_t const scale_log2, uint32_t const warps)
1273 {
1274 uint32_t const span_left = (warps << scale_log2) / 2;
1275 uint32_t const span_left_ru = pow2_ru_u32(span_left);
1276
1277 for (uint32_t span_right=1; span_right<=span_left_ru; span_right*=2)
1278 ops = hsg_fm_merge(ops,scale_log2,span_left,MIN_MACRO(span_left,span_right));
1279
1280 return ops;
1281 }
1282
1283 //
1284 // GENERATE HALF MERGE KERNELS
1285 //
1286
1287 static
1288 struct hsg_op *
hsg_hm_thread_load(struct hsg_op * ops,uint32_t const n)1289 hsg_hm_thread_load(struct hsg_op * ops, uint32_t const n)
1290 {
1291 for (uint32_t r=1; r<=n; r++)
1292 ops = hsg_op(ops,HM_REG_GLOBAL_LOAD(r,r-1));
1293
1294 return ops;
1295 }
1296
1297 static
1298 struct hsg_op *
hsg_hm_thread_store(struct hsg_op * ops,uint32_t const n)1299 hsg_hm_thread_store(struct hsg_op * ops, uint32_t const n)
1300 {
1301 for (uint32_t r=1; r<=n; r++)
1302 ops = hsg_op(ops,HM_REG_GLOBAL_STORE(r,r-1));
1303
1304 return ops;
1305 }
1306
1307 static
1308 struct hsg_op *
hsg_hm_merge(struct hsg_op * ops,uint32_t const scale_log2,uint32_t const warps_pow2)1309 hsg_hm_merge(struct hsg_op * ops, uint32_t const scale_log2, uint32_t const warps_pow2)
1310 {
1311 uint32_t const span = warps_pow2 << scale_log2;
1312
1313 // func proto
1314 ops = hsg_op(ops,HM_KERNEL_PROTO(scale_log2));
1315
1316 // begin
1317 ops = hsg_begin(ops);
1318
1319 // preamble for loading/storing
1320 ops = hsg_op(ops,HM_KERNEL_PREAMBLE(span/2));
1321
1322 // load
1323 ops = hsg_hm_thread_load(ops,span);
1324
1325 // thread merging network
1326 ops = hsg_thread_merge(ops,span);
1327
1328 // store
1329 ops = hsg_hm_thread_store(ops,span);
1330
1331 // end
1332 ops = hsg_end(ops);
1333
1334 return ops;
1335 }
1336
1337 //
1338 // GENERATE MERGE KERNELS
1339 //
1340
1341 static
1342 struct hsg_op *
hsg_xm_merge_all(struct hsg_op * ops)1343 hsg_xm_merge_all(struct hsg_op * ops)
1344 {
1345 uint32_t const warps = hsg_merge[0].warps;
1346 uint32_t const warps_pow2 = pow2_rd_u32(warps);
1347
1348 //
1349 // GENERATE FLIP MERGE KERNELS
1350 //
1351 for (uint32_t scale_log2=hsg_config.merge.flip.lo; scale_log2<=hsg_config.merge.flip.hi; scale_log2++)
1352 ops = hsg_fm_merge_all(ops,scale_log2,warps);
1353
1354 //
1355 // GENERATE HALF MERGE KERNELS
1356 //
1357 for (uint32_t scale_log2=hsg_config.merge.half.lo; scale_log2<=hsg_config.merge.half.hi; scale_log2++)
1358 ops = hsg_hm_merge(ops,scale_log2,warps_pow2);
1359
1360 return ops;
1361 }
1362
1363 //
1364 //
1365 //
1366
1367 static
1368 struct hsg_op const *
hsg_op_translate_depth(hsg_target_pfn target_pfn,struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * ops,uint32_t const depth)1369 hsg_op_translate_depth(hsg_target_pfn target_pfn,
1370 struct hsg_target * const target,
1371 struct hsg_config const * const config,
1372 struct hsg_merge const * const merge,
1373 struct hsg_op const * ops,
1374 uint32_t const depth)
1375 {
1376 while (ops->type != HSG_OP_TYPE_EXIT)
1377 {
1378 switch (ops->type)
1379 {
1380 case HSG_OP_TYPE_END:
1381 target_pfn(target,config,merge,ops,depth-1);
1382 return ops + 1;
1383
1384 case HSG_OP_TYPE_BEGIN:
1385 target_pfn(target,config,merge,ops,depth);
1386 ops = hsg_op_translate_depth(target_pfn,target,config,merge,ops+1,depth+1);
1387 break;
1388
1389 default:
1390 target_pfn(target,config,merge,ops++,depth);
1391 }
1392 }
1393
1394 return ops;
1395 }
1396
1397 static
1398 void
hsg_op_translate(hsg_target_pfn target_pfn,struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * ops)1399 hsg_op_translate(hsg_target_pfn target_pfn,
1400 struct hsg_target * const target,
1401 struct hsg_config const * const config,
1402 struct hsg_merge const * const merge,
1403 struct hsg_op const * ops)
1404 {
1405 hsg_op_translate_depth(target_pfn,target,config,merge,ops,0);
1406 }
1407
1408 //
1409 //
1410 //
1411
1412 int
main(int argc,char * argv[])1413 main(int argc, char * argv[])
1414 {
1415 //
1416 // PROCESS OPTIONS
1417 //
1418 int32_t opt = 0;
1419 bool verbose = false;
1420 bool autotune = false;
1421 char const * arch = "undefined";
1422 struct hsg_target target = { .define = NULL };
1423
1424 while ((opt = getopt(argc,argv,"hva:g:G:s:S:w:b:B:m:M:k:r:x:t:f:F:c:C:p:P:D:z")) != EOF)
1425 {
1426 switch (opt)
1427 {
1428 case 'h':
1429 fprintf(stderr,"Help goes here...\n");
1430 return EXIT_FAILURE;
1431
1432 case 'v':
1433 verbose = true;
1434 break;
1435
1436 case 'a':
1437 arch = optarg;
1438 break;
1439
1440 case 'g':
1441 hsg_config.block.smem_min = atoi(optarg);
1442 break;
1443
1444 case 'G':
1445 hsg_config.block.smem_quantum = atoi(optarg);
1446 break;
1447
1448 case 's':
1449 hsg_config.block.smem_bs = atoi(optarg);
1450
1451 // set smem_bc if not already set
1452 if (hsg_config.block.smem_bc == UINT32_MAX)
1453 hsg_config.block.smem_bc = hsg_config.block.smem_bs;
1454 break;
1455
1456 case 'S':
1457 hsg_config.block.smem_bc = atoi(optarg);
1458 break;
1459
1460 case 'w':
1461 hsg_config.warp.lanes = atoi(optarg);
1462 hsg_config.warp.lanes_log2 = msb_idx_u32(hsg_config.warp.lanes);
1463 break;
1464
1465 case 'b':
1466 // maximum warps in a workgroup / cta / thread block
1467 {
1468 uint32_t const warps = atoi(optarg);
1469
1470 // must always be even
1471 if ((warps & 1) != 0)
1472 {
1473 fprintf(stderr,"Error: -b must be even.\n");
1474 return EXIT_FAILURE;
1475 }
1476
1477 hsg_merge[0].index = 0;
1478 hsg_merge[0].warps = warps;
1479
1480 // set warps_max if not already set
1481 if (hsg_config.block.warps_max == UINT32_MAX)
1482 hsg_config.block.warps_max = pow2_ru_u32(warps);
1483 }
1484 break;
1485
1486 case 'B':
1487 // maximum warps that can fit in a multiprocessor
1488 hsg_config.block.warps_max = atoi(optarg);
1489 break;
1490
1491 case 'm':
1492 // blocks using smem barriers must have at least this many warps
1493 hsg_config.block.warps_min = atoi(optarg);
1494 break;
1495
1496 case 'M':
1497 // the number of warps necessary to load balance horizontal merging
1498 hsg_config.block.warps_mod = atoi(optarg);
1499 break;
1500
1501 case 'r':
1502 {
1503 uint32_t const regs = atoi(optarg);
1504
1505 if ((regs & 1) != 0)
1506 {
1507 fprintf(stderr,"Error: -r must be even.\n");
1508 return EXIT_FAILURE;
1509 }
1510
1511 hsg_config.thread.regs = regs;
1512 }
1513 break;
1514
1515 case 'x':
1516 hsg_config.thread.xtra = atoi(optarg);
1517 break;
1518
1519 case 't':
1520 hsg_config.type.words = atoi(optarg);
1521 break;
1522
1523 case 'f':
1524 hsg_config.merge.flip.lo = atoi(optarg);
1525 break;
1526
1527 case 'F':
1528 hsg_config.merge.flip.hi = atoi(optarg);
1529 break;
1530
1531 case 'c':
1532 hsg_config.merge.half.lo = atoi(optarg);
1533 break;
1534
1535 case 'C':
1536 hsg_config.merge.half.hi = atoi(optarg);
1537 break;
1538
1539 case 'p':
1540 hsg_config.merge.flip.warps = atoi(optarg);
1541 break;
1542
1543 case 'P':
1544 hsg_config.merge.half.warps = atoi(optarg);
1545 break;
1546
1547 case 'D':
1548 target.define = optarg;
1549 break;
1550
1551 case 'z':
1552 autotune = true;
1553 break;
1554 }
1555 }
1556
1557 //
1558 // INIT MERGE
1559 //
1560 uint32_t const warps_ru_pow2 = pow2_ru_u32(hsg_merge[0].warps);
1561
1562 for (uint32_t ii=1; ii<MERGE_LEVELS_MAX_LOG2; ii++)
1563 {
1564 hsg_merge[ii].index = ii;
1565 hsg_merge[ii].warps = warps_ru_pow2 >> ii;
1566 }
1567
1568 //
1569 // WHICH ARCH TARGET?
1570 //
1571 hsg_target_pfn hsg_target_pfn;
1572
1573 if (strcmp(arch,"debug") == 0)
1574 hsg_target_pfn = hsg_target_debug;
1575 else if (strcmp(arch,"cuda") == 0)
1576 hsg_target_pfn = hsg_target_cuda;
1577 else if (strcmp(arch,"opencl") == 0)
1578 hsg_target_pfn = hsg_target_opencl;
1579 else if (strcmp(arch,"glsl") == 0)
1580 hsg_target_pfn = hsg_target_glsl;
1581 else {
1582 fprintf(stderr,"Invalid arch: %s\n",arch);
1583 exit(EXIT_FAILURE);
1584 }
1585
1586 if (verbose)
1587 fprintf(stderr,"Target: %s\n",arch);
1588
1589 //
1590 // INIT SMEM KEY ALLOCATION
1591 //
1592 hsg_config_init_shared();
1593
1594 //
1595 // INIT MERGE MAGIC
1596 //
1597 for (uint32_t ii=0; ii<MERGE_LEVELS_MAX_LOG2; ii++)
1598 {
1599 struct hsg_merge * const merge = hsg_merge + ii;
1600
1601 if (merge->warps == 0)
1602 break;
1603
1604 fprintf(stderr,">>> Generating: %1u %5u %5u %3u %3u ...\n",
1605 hsg_config.type.words,
1606 hsg_config.block.smem_bs,
1607 hsg_config.block.smem_bc,
1608 hsg_config.thread.regs,
1609 merge->warps);
1610
1611 hsg_merge_levels_init_shared(merge);
1612
1613 hsg_merge_levels_init_1(merge,merge->warps,0,0);
1614
1615 hsg_merge_levels_hint(merge,autotune);
1616
1617 //
1618 // THESE ARE FOR DEBUG/INSPECTION
1619 //
1620 if (verbose)
1621 {
1622 hsg_merge_levels_debug(merge);
1623 }
1624 }
1625
1626 if (verbose)
1627 fprintf(stderr,"\n\n");
1628
1629 //
1630 // GENERATE THE OPCODES
1631 //
1632 uint32_t const op_count = 1<<17;
1633 struct hsg_op * const ops_begin = malloc(sizeof(*ops_begin) * op_count);
1634 struct hsg_op * ops = ops_begin;
1635
1636 //
1637 // OPEN INITIAL FILES AND APPEND HEADER
1638 //
1639 ops = hsg_op(ops,TARGET_BEGIN());
1640
1641 //
1642 // GENERATE SORT KERNEL
1643 //
1644 ops = hsg_bs_sort_all(ops);
1645
1646 //
1647 // GENERATE CLEAN KERNELS
1648 //
1649 ops = hsg_bc_clean_all(ops);
1650
1651 //
1652 // GENERATE MERGE KERNELS
1653 //
1654 ops = hsg_xm_merge_all(ops);
1655
1656 //
1657 // GENERATE TRANSPOSE KERNEL
1658 //
1659 ops = hsg_warp_transpose(ops);
1660
1661 //
1662 // APPEND FOOTER AND CLOSE INITIAL FILES
1663 //
1664 ops = hsg_op(ops,TARGET_END());
1665
1666 //
1667 // ... WE'RE DONE!
1668 //
1669 ops = hsg_exit(ops);
1670
1671 //
1672 // APPLY TARGET TRANSLATOR TO ACCUMULATED OPS
1673 //
1674 hsg_op_translate(hsg_target_pfn,&target,&hsg_config,hsg_merge,ops_begin);
1675
1676 //
1677 // DUMP INSTRUCTION COUNTS
1678 //
1679 if (verbose)
1680 hsg_op_debug();
1681
1682 return EXIT_SUCCESS;
1683 }
1684
1685 //
1686 //
1687 //
1688