1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can
5 * be found in the LICENSE file.
6 *
7 */
8
9 #include <stdio.h>
10 #include <stdlib.h>
11
12 //
13 //
14 //
15
16 #include "gen.h"
17 #include "transpose.h"
18
19 #include "common/util.h"
20 #include "common/macros.h"
21
22 //
23 //
24 //
25
26 struct hsg_transpose_state
27 {
28 FILE * header;
29 struct hsg_config const * config;
30 };
31
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36 return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42 uint32_t const row_ll, // lower-left
43 uint32_t const row_ur, // upper-right
44 void * blend)
45 {
46 struct hsg_transpose_state * const state = blend;
47
48 // we're starting register names at '1' for now
49 fprintf(state->header,
50 " HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51 hsg_transpose_reg_prefix(cols_log2-1),
52 hsg_transpose_reg_prefix(cols_log2),
53 cols_log2,row_ll+1,row_ur+1);
54 }
55
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59 uint32_t const row_to,
60 void * remap)
61 {
62 struct hsg_transpose_state * const state = remap;
63
64 // we're starting register names at '1' for now
65 fprintf(state->header,
66 " HS_TRANSPOSE_REMAP( %c, %3u, %3u ) \\\n",
67 hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68 row_from+1,row_to+1);
69 }
70
71 //
72 //
73 //
74
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79 fprintf(file,
80 "// \n"
81 "// Copyright 2016 Google Inc. \n"
82 "// \n"
83 "// Use of this source code is governed by a BSD-style \n"
84 "// license that can be found in the LICENSE file. \n"
85 "// \n"
86 "\n");
87 }
88
89 //
90 //
91 //
92
93 struct hsg_target_state
94 {
95 FILE * header;
96 FILE * source;
97 };
98
99 //
100 //
101 //
102
103 void
hsg_target_cuda(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)104 hsg_target_cuda(struct hsg_target * const target,
105 struct hsg_config const * const config,
106 struct hsg_merge const * const merge,
107 struct hsg_op const * const ops,
108 uint32_t const depth)
109 {
110 switch (ops->type)
111 {
112 case HSG_OP_TYPE_END:
113 fprintf(target->state->source,
114 "}\n");
115 break;
116
117 case HSG_OP_TYPE_BEGIN:
118 fprintf(target->state->source,
119 "{\n");
120 break;
121
122 case HSG_OP_TYPE_ELSE:
123 fprintf(target->state->source,
124 "else\n");
125 break;
126
127 case HSG_OP_TYPE_TARGET_BEGIN:
128 {
129 // allocate state
130 target->state = malloc(sizeof(*target->state));
131
132 //
133 // Note that we're generating file names with different
134 // suffixes despite storing them in different directories
135 // because NVCC on Visual Studio appears to overwrite .cu
136 // files with the same name but different paths.
137 //
138 // I would prefer to the original layout:
139 //
140 // path/to/<type>/hs_cuda.[cu|config|h]
141 //
142
143 // allocate files
144 target->state->header = fopen("hs_cuda_config.h","wb");
145 target->state->source = fopen((config->type.words == 1) ?
146 "hs_cuda_u32.cu" : "hs_cuda_u64.cu",
147 "wb");
148
149 // initialize header
150 uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
151
152 hsg_copyright(target->state->header);
153
154 fprintf(target->state->header,
155 "#ifndef HS_CUDA_CONFIG_ONCE \n"
156 "#define HS_CUDA_CONFIG_ONCE \n"
157 " \n"
158 "#define HS_SLAB_THREADS_LOG2 %u \n"
159 "#define HS_SLAB_THREADS (1 << HS_SLAB_THREADS_LOG2) \n"
160 "#define HS_SLAB_WIDTH_LOG2 %u \n"
161 "#define HS_SLAB_WIDTH (1 << HS_SLAB_WIDTH_LOG2) \n"
162 "#define HS_SLAB_HEIGHT %u \n"
163 "#define HS_SLAB_KEYS (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
164 "#define HS_REG_LAST(c) c##%u \n"
165 "#define HS_KEY_TYPE_PRETTY %s \n"
166 "#define HS_KEY_WORDS %u \n"
167 "#define HS_VAL_WORDS 0 \n"
168 "#define HS_BS_SLABS %u \n"
169 "#define HS_BS_SLABS_LOG2_RU %u \n"
170 "#define HS_BC_SLABS_LOG2_MAX %u \n"
171 "#define HS_FM_BLOCK_HEIGHT %u \n"
172 "#define HS_FM_SCALE_MIN %u \n"
173 "#define HS_FM_SCALE_MAX %u \n"
174 "#define HS_HM_BLOCK_HEIGHT %u \n"
175 "#define HS_HM_SCALE_MIN %u \n"
176 "#define HS_HM_SCALE_MAX %u \n"
177 "#define HS_EMPTY \n"
178 " \n",
179 config->warp.lanes_log2,
180 config->warp.lanes_log2,
181 config->thread.regs,
182 config->thread.regs,
183 (config->type.words == 1) ? "u32" : "u64",
184 config->type.words,
185 merge->warps,
186 msb_idx_u32(pow2_ru_u32(merge->warps)),
187 bc_max,
188 config->merge.flip.warps,
189 config->merge.flip.lo,
190 config->merge.flip.hi,
191 config->merge.half.warps,
192 config->merge.half.lo,
193 config->merge.half.hi);
194
195 if (target->define != NULL)
196 fprintf(target->state->header,"#define %s\n\n",target->define);
197
198 fprintf(target->state->header,
199 "#define HS_SLAB_ROWS() \\\n");
200
201 for (uint32_t ii=1; ii<=config->thread.regs; ii++)
202 fprintf(target->state->header,
203 " HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
204
205 fprintf(target->state->header,
206 " HS_EMPTY\n"
207 " \n");
208
209 fprintf(target->state->header,
210 "#define HS_TRANSPOSE_SLAB() \\\n");
211
212 for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
213 fprintf(target->state->header,
214 " HS_TRANSPOSE_STAGE( %u ) \\\n",ii);
215
216 struct hsg_transpose_state state[1] =
217 {
218 { .header = target->state->header,
219 .config = config
220 }
221 };
222
223 hsg_transpose(config->warp.lanes_log2,
224 config->thread.regs,
225 hsg_transpose_blend,state,
226 hsg_transpose_remap,state);
227
228 fprintf(target->state->header,
229 " HS_EMPTY\n"
230 " \n");
231
232 hsg_copyright(target->state->source);
233
234 fprintf(target->state->source,
235 "#ifdef __cplusplus \n"
236 "extern \"C\" { \n"
237 "#endif \n"
238 " \n"
239 "#include \"hs_cuda.h\" \n"
240 " \n"
241 "#ifdef __cplusplus \n"
242 "} \n"
243 "#endif \n"
244 " \n"
245 "#include \"hs_cuda_config.h\" \n"
246 " \n"
247 "#include \"../hs_cuda_macros.h\" \n"
248 " \n"
249 "// \n"
250 "// \n"
251 "// \n");
252 }
253 break;
254
255 case HSG_OP_TYPE_TARGET_END:
256 // decorate the files
257 fprintf(target->state->header,
258 "#endif \n"
259 " \n"
260 "// \n"
261 "// \n"
262 "// \n"
263 " \n");
264 fprintf(target->state->source,
265 " \n"
266 "// \n"
267 "// \n"
268 "// \n"
269 " \n"
270 "#include \"../../hs_cuda.inl\" \n"
271 " \n"
272 "// \n"
273 "// \n"
274 "// \n"
275 " \n");
276
277 // close files
278 fclose(target->state->header);
279 fclose(target->state->source);
280
281 // free state
282 free(target->state);
283 break;
284
285 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
286 {
287 fprintf(target->state->source,
288 "\nHS_TRANSPOSE_KERNEL_PROTO()\n");
289 }
290 break;
291
292 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
293 {
294 fprintf(target->state->source,
295 "HS_SLAB_GLOBAL_PREAMBLE();\n");
296 }
297 break;
298
299 case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
300 {
301 fprintf(target->state->source,
302 "HS_TRANSPOSE_SLAB();\n");
303 }
304 break;
305
306 case HSG_OP_TYPE_BS_KERNEL_PROTO:
307 {
308 struct hsg_merge const * const m = merge + ops->a;
309
310 uint32_t const bs = pow2_ru_u32(m->warps);
311 uint32_t const msb = msb_idx_u32(bs);
312
313 if (ops->a == 0)
314 {
315 fprintf(target->state->source,
316 "\nHS_BS_KERNEL_PROTO(%u,%u)\n",
317 m->warps,msb);
318 }
319 else
320 {
321 fprintf(target->state->source,
322 "\nHS_OFFSET_BS_KERNEL_PROTO(%u,%u)\n",
323 m->warps,msb);
324 }
325 }
326 break;
327
328 case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
329 {
330 struct hsg_merge const * const m = merge + ops->a;
331
332 if (m->warps > 1)
333 {
334 fprintf(target->state->source,
335 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
336 m->warps * config->warp.lanes,
337 m->rows_bs);
338 }
339
340 if (ops->a == 0)
341 {
342 fprintf(target->state->source,
343 "HS_SLAB_GLOBAL_PREAMBLE();\n");
344 }
345 else
346 {
347 fprintf(target->state->source,
348 "HS_OFFSET_SLAB_GLOBAL_PREAMBLE();\n");
349 }
350 }
351 break;
352
353 case HSG_OP_TYPE_BC_KERNEL_PROTO:
354 {
355 struct hsg_merge const * const m = merge + ops->a;
356
357 uint32_t const msb = msb_idx_u32(m->warps);
358
359 fprintf(target->state->source,
360 "\nHS_BC_KERNEL_PROTO(%u,%u)\n",
361 m->warps,msb);
362 }
363 break;
364
365 case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
366 {
367 struct hsg_merge const * const m = merge + ops->a;
368
369 if (m->warps > 1)
370 {
371 fprintf(target->state->source,
372 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
373 m->warps * config->warp.lanes,
374 m->rows_bc);
375 }
376
377 fprintf(target->state->source,
378 "HS_SLAB_GLOBAL_PREAMBLE();\n");
379 }
380 break;
381
382 case HSG_OP_TYPE_FM_KERNEL_PROTO:
383 {
384 uint32_t const span_left = (merge[0].warps << ops->a) / 2;
385 uint32_t const span_right = 1 << ops->b;
386
387 // uint32_t const msb = msb_idx_u32(pow2_ru_u32(merge[0].warps));
388 // if ((ops->a + ops->b - 1) == msb)
389
390 if (span_right == span_left)
391 {
392 fprintf(target->state->source,
393 "\nHS_FM_KERNEL_PROTO(%u,%u)\n",
394 ops->a,ops->b);
395 }
396 else
397 {
398 fprintf(target->state->source,
399 "\nHS_OFFSET_FM_KERNEL_PROTO(%u,%u)\n",
400 ops->a,ops->b);
401 }
402 }
403 break;
404
405 case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
406 {
407 uint32_t const msb = msb_idx_u32(pow2_ru_u32(merge[0].warps));
408
409 if (ops->a == ops->b) // equal left and right spans
410 {
411 fprintf(target->state->source,
412 "HS_FM_PREAMBLE(%u);\n",
413 ops->a);
414 }
415 else // right span is lesser pow2
416 {
417 fprintf(target->state->source,
418 "HS_OFFSET_FM_PREAMBLE(%u);\n",
419 ops->a);
420 }
421 }
422 break;
423
424 case HSG_OP_TYPE_HM_KERNEL_PROTO:
425 {
426 fprintf(target->state->source,
427 "\nHS_HM_KERNEL_PROTO(%u)\n",
428 ops->a);
429 }
430 break;
431
432 case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
433 fprintf(target->state->source,
434 "HS_HM_PREAMBLE(%u);\n",
435 ops->a);
436 break;
437
438 case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
439 {
440 static char const * const vstr[] = { "vin", "vout" };
441
442 fprintf(target->state->source,
443 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
444 ops->n,vstr[ops->v],ops->n-1);
445 }
446 break;
447
448 case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
449 fprintf(target->state->source,
450 "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
451 ops->n-1,ops->n);
452 break;
453
454 case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
455 fprintf(target->state->source,
456 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
457 ops->a,ops->b);
458 break;
459
460 case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
461 fprintf(target->state->source,
462 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
463 ops->b,ops->a);
464 break;
465
466 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
467 fprintf(target->state->source,
468 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
469 ops->a,ops->b);
470 break;
471
472 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
473 fprintf(target->state->source,
474 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
475 ops->b,ops->a);
476 break;
477
478 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
479 fprintf(target->state->source,
480 "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
481 ops->b,ops->a);
482 break;
483
484 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
485 fprintf(target->state->source,
486 "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
487 ops->a,ops->b);
488 break;
489
490 case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
491 {
492 if (ops->a <= ops->b)
493 {
494 fprintf(target->state->source,
495 "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
496 }
497 else if (ops->b > 1)
498 {
499 fprintf(target->state->source,
500 "else if (fm_frac == %u)\n",
501 ops->b);
502 }
503 else
504 {
505 fprintf(target->state->source,
506 "else\n");
507 }
508 }
509 break;
510
511 case HSG_OP_TYPE_SLAB_FLIP:
512 fprintf(target->state->source,
513 "HS_SLAB_FLIP_PREAMBLE(%u);\n",
514 ops->n-1);
515 break;
516
517 case HSG_OP_TYPE_SLAB_HALF:
518 fprintf(target->state->source,
519 "HS_SLAB_HALF_PREAMBLE(%u);\n",
520 ops->n / 2);
521 break;
522
523 case HSG_OP_TYPE_CMP_FLIP:
524 fprintf(target->state->source,
525 "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
526 break;
527
528 case HSG_OP_TYPE_CMP_HALF:
529 fprintf(target->state->source,
530 "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
531 break;
532
533 case HSG_OP_TYPE_CMP_XCHG:
534 if (ops->c == UINT32_MAX)
535 {
536 fprintf(target->state->source,
537 "HS_CMP_XCHG(r%-3u,r%-3u);\n",
538 ops->a,ops->b);
539 }
540 else
541 {
542 fprintf(target->state->source,
543 "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
544 ops->c,ops->a,ops->c,ops->b);
545 }
546 break;
547
548 case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
549 fprintf(target->state->source,
550 "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
551 merge[ops->a].warps,ops->c,ops->b);
552 break;
553
554 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
555 fprintf(target->state->source,
556 "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
557 ops->b,merge[ops->a].warps,ops->c);
558 break;
559
560 case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
561 fprintf(target->state->source,
562 "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
563 ops->b,ops->a,ops->c);
564 break;
565
566 case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
567 fprintf(target->state->source,
568 "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
569 ops->b * config->warp.lanes,
570 ops->c,
571 ops->a);
572 break;
573
574 case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
575 fprintf(target->state->source,
576 "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
577 ops->b * config->warp.lanes,
578 ops->c,
579 ops->a);
580 break;
581
582 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
583 fprintf(target->state->source,
584 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
585 ops->c,
586 ops->a,
587 ops->b * config->warp.lanes);
588 break;
589
590 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
591 fprintf(target->state->source,
592 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
593 ops->c,
594 ops->a,
595 ops->b * config->warp.lanes);
596 break;
597
598 case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
599 fprintf(target->state->source,
600 "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
601 ops->c,
602 ops->a,
603 ops->b);
604 break;
605
606 case HSG_OP_TYPE_BLOCK_SYNC:
607 fprintf(target->state->source,
608 "HS_BLOCK_BARRIER();\n");
609 //
610 // FIXME - Named barriers to allow coordinating warps to proceed?
611 //
612 break;
613
614 case HSG_OP_TYPE_BS_FRAC_PRED:
615 {
616 if (ops->m == 0)
617 {
618 fprintf(target->state->source,
619 "if (warp_idx < bs_full)\n");
620 }
621 else
622 {
623 fprintf(target->state->source,
624 "else if (bs_frac == %u)\n",
625 ops->w);
626 }
627 }
628 break;
629
630 case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
631 {
632 struct hsg_merge const * const m = merge + ops->a;
633
634 fprintf(target->state->source,
635 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
636 m->warps);
637 }
638 break;
639
640 case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
641 {
642 struct hsg_merge const * const m = merge + ops->a;
643
644 fprintf(target->state->source,
645 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
646 m->warps);
647 }
648 break;
649
650 case HSG_OP_TYPE_BX_MERGE_H_PRED:
651 fprintf(target->state->source,
652 "if (HS_WARP_ID_X() < %u)\n",
653 ops->a);
654 break;
655
656 case HSG_OP_TYPE_BS_ACTIVE_PRED:
657 {
658 struct hsg_merge const * const m = merge + ops->a;
659
660 if (m->warps <= 32)
661 {
662 fprintf(target->state->source,
663 "if (((1u << HS_WARP_ID_X()) & 0x%08X) != 0)\n",
664 m->levels[ops->b].active.b32a2[0]);
665 }
666 else
667 {
668 fprintf(target->state->source,
669 "if (((1UL << HS_WARP_ID_X()) & 0x%08X%08XL) != 0L)\n",
670 m->levels[ops->b].active.b32a2[1],
671 m->levels[ops->b].active.b32a2[0]);
672 }
673 }
674 break;
675
676 default:
677 fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
678 exit(EXIT_FAILURE);
679 break;
680 }
681 }
682
683 //
684 //
685 //
686