1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can
5 * be found in the LICENSE file.
6 *
7 */
8
9 #include <stdio.h>
10 #include <stdlib.h>
11
12 //
13 //
14 //
15
16 #include "gen.h"
17 #include "transpose.h"
18
19 #include "common/util.h"
20 #include "common/macros.h"
21
22 //
23 //
24 //
25
26 struct hsg_transpose_state
27 {
28 FILE * header;
29 struct hsg_config const * config;
30 };
31
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36 return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42 uint32_t const row_ll, // lower-left
43 uint32_t const row_ur, // upper-right
44 void * blend)
45 {
46 struct hsg_transpose_state * const state = blend;
47
48 // we're starting register names at '1' for now
49 fprintf(state->header,
50 " HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51 hsg_transpose_reg_prefix(cols_log2-1),
52 hsg_transpose_reg_prefix(cols_log2),
53 cols_log2,row_ll+1,row_ur+1);
54 }
55
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59 uint32_t const row_to,
60 void * remap)
61 {
62 struct hsg_transpose_state * const state = remap;
63
64 // we're starting register names at '1' for now
65 fprintf(state->header,
66 " HS_TRANSPOSE_REMAP( %c, %3u, %3u ) \\\n",
67 hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68 row_from+1,row_to+1);
69 }
70
71 //
72 //
73 //
74
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79 fprintf(file,
80 "// \n"
81 "// Copyright 2016 Google Inc. \n"
82 "// \n"
83 "// Use of this source code is governed by a BSD-style \n"
84 "// license that can be found in the LICENSE file. \n"
85 "// \n"
86 "\n");
87 }
88
89 static
90 void
hsg_macros(FILE * file)91 hsg_macros(FILE * file)
92 {
93 fprintf(file,
94 "// target-specific config \n"
95 "#include \"hs_config.h\" \n"
96 " \n"
97 "// GLSL preamble \n"
98 "#include \"hs_glsl_preamble.h\"\n"
99 " \n"
100 "// arch/target-specific macros \n"
101 "#include \"hs_glsl_macros.h\" \n"
102 " \n"
103 "// \n"
104 "// \n"
105 "// \n"
106 "\n");
107 }
108
109 //
110 //
111 //
112
113 struct hsg_target_state
114 {
115 FILE * header;
116 FILE * modules;
117 FILE * source;
118 };
119
120 //
121 //
122 //
123
124 void
hsg_target_glsl(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)125 hsg_target_glsl(struct hsg_target * const target,
126 struct hsg_config const * const config,
127 struct hsg_merge const * const merge,
128 struct hsg_op const * const ops,
129 uint32_t const depth)
130 {
131 switch (ops->type)
132 {
133 case HSG_OP_TYPE_END:
134 fprintf(target->state->source,
135 "}\n");
136
137 if (depth == 0) {
138 fclose(target->state->source);
139 target->state->source = NULL;
140 }
141 break;
142
143 case HSG_OP_TYPE_BEGIN:
144 fprintf(target->state->source,
145 "{\n");
146 break;
147
148 case HSG_OP_TYPE_ELSE:
149 fprintf(target->state->source,
150 "else\n");
151 break;
152
153 case HSG_OP_TYPE_TARGET_BEGIN:
154 {
155 // allocate state
156 target->state = malloc(sizeof(*target->state));
157
158 // allocate files
159 target->state->header = fopen("hs_config.h", "wb");
160 target->state->modules = fopen("hs_modules.h","wb");
161
162 hsg_copyright(target->state->header);
163 hsg_copyright(target->state->modules);
164
165 // initialize header
166 uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
167
168 fprintf(target->state->header,
169 "#ifndef HS_GLSL_ONCE \n"
170 "#define HS_GLSL_ONCE \n"
171 " \n"
172 "#define HS_SLAB_THREADS_LOG2 %u \n"
173 "#define HS_SLAB_THREADS (1 << HS_SLAB_THREADS_LOG2) \n"
174 "#define HS_SLAB_WIDTH_LOG2 %u \n"
175 "#define HS_SLAB_WIDTH (1 << HS_SLAB_WIDTH_LOG2) \n"
176 "#define HS_SLAB_HEIGHT %u \n"
177 "#define HS_SLAB_KEYS (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
178 "#define HS_REG_LAST(c) c##%u \n"
179 "#define HS_KEY_WORDS %u \n"
180 "#define HS_VAL_WORDS 0 \n"
181 "#define HS_BS_SLABS %u \n"
182 "#define HS_BS_SLABS_LOG2_RU %u \n"
183 "#define HS_BC_SLABS_LOG2_MAX %u \n"
184 "#define HS_FM_BLOCK_HEIGHT %u \n"
185 "#define HS_FM_SCALE_MIN %u \n"
186 "#define HS_FM_SCALE_MAX %u \n"
187 "#define HS_HM_BLOCK_HEIGHT %u \n"
188 "#define HS_HM_SCALE_MIN %u \n"
189 "#define HS_HM_SCALE_MAX %u \n"
190 "#define HS_EMPTY \n"
191 " \n",
192 config->warp.lanes_log2, // FIXME -- this matters for SIMD
193 config->warp.lanes_log2,
194 config->thread.regs,
195 config->thread.regs,
196 config->type.words,
197 merge->warps,
198 msb_idx_u32(pow2_ru_u32(merge->warps)),
199 bc_max,
200 config->merge.flip.warps,
201 config->merge.flip.lo,
202 config->merge.flip.hi,
203 config->merge.half.warps,
204 config->merge.half.lo,
205 config->merge.half.hi);
206
207 if (target->define != NULL)
208 fprintf(target->state->header,"#define %s\n\n",target->define);
209
210 fprintf(target->state->header,
211 "#define HS_SLAB_ROWS() \\\n");
212
213 for (uint32_t ii=1; ii<=config->thread.regs; ii++)
214 fprintf(target->state->header,
215 " HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
216
217 fprintf(target->state->header,
218 " HS_EMPTY\n"
219 " \n");
220
221 fprintf(target->state->header,
222 "#define HS_TRANSPOSE_SLAB() \\\n");
223
224 for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
225 fprintf(target->state->header,
226 " HS_TRANSPOSE_STAGE( %u ) \\\n",ii);
227
228 struct hsg_transpose_state state[1] =
229 {
230 { .header = target->state->header,
231 .config = config
232 }
233 };
234
235 hsg_transpose(config->warp.lanes_log2,
236 config->thread.regs,
237 hsg_transpose_blend,state,
238 hsg_transpose_remap,state);
239
240 fprintf(target->state->header,
241 " HS_EMPTY\n"
242 " \n");
243 }
244 break;
245
246 case HSG_OP_TYPE_TARGET_END:
247 // decorate the files
248 fprintf(target->state->header,
249 "#endif \n"
250 " \n"
251 "// \n"
252 "// \n"
253 "// \n"
254 " \n");
255
256 // close files
257 fclose(target->state->header);
258 fclose(target->state->modules);
259
260 // free state
261 free(target->state);
262 break;
263
264 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
265 {
266 fprintf(target->state->modules,
267 "#include \"hs_transpose.len.xxd\"\n,\n"
268 "#include \"hs_transpose.spv.xxd\"\n,\n");
269
270 target->state->source = fopen("hs_transpose.comp","w+");
271
272 hsg_copyright(target->state->source);
273
274 hsg_macros(target->state->source);
275
276 fprintf(target->state->source,
277 "HS_TRANSPOSE_KERNEL_PROTO()\n");
278 }
279 break;
280
281 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
282 {
283 fprintf(target->state->source,
284 "HS_SUBGROUP_PREAMBLE();\n");
285
286 fprintf(target->state->source,
287 "HS_SLAB_GLOBAL_PREAMBLE();\n");
288 }
289 break;
290
291 case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
292 {
293 fprintf(target->state->source,
294 "HS_TRANSPOSE_SLAB()\n");
295 }
296 break;
297
298 case HSG_OP_TYPE_BS_KERNEL_PROTO:
299 {
300 struct hsg_merge const * const m = merge + ops->a;
301
302 uint32_t const bs = pow2_ru_u32(m->warps);
303 uint32_t const msb = msb_idx_u32(bs);
304
305 fprintf(target->state->modules,
306 "#include \"hs_bs_%u.len.xxd\"\n,\n"
307 "#include \"hs_bs_%u.spv.xxd\"\n,\n",
308 msb,
309 msb);
310
311 char filename[] = { "hs_bs_XX.comp" };
312 sprintf(filename,"hs_bs_%u.comp",msb);
313
314 target->state->source = fopen(filename,"w+");
315
316 hsg_copyright(target->state->source);
317
318 hsg_macros(target->state->source);
319
320 if (m->warps > 1)
321 {
322 fprintf(target->state->source,
323 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
324 m->warps * config->warp.lanes,
325 m->rows_bs);
326 }
327
328 fprintf(target->state->source,
329 "HS_BS_KERNEL_PROTO(%u,%u)\n",
330 m->warps,msb);
331 }
332 break;
333
334 case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
335 {
336 fprintf(target->state->source,
337 "HS_SUBGROUP_PREAMBLE();\n");
338
339 fprintf(target->state->source,
340 "HS_SLAB_GLOBAL_PREAMBLE();\n");
341 }
342 break;
343
344 case HSG_OP_TYPE_BC_KERNEL_PROTO:
345 {
346 struct hsg_merge const * const m = merge + ops->a;
347
348 uint32_t const msb = msb_idx_u32(m->warps);
349
350 fprintf(target->state->modules,
351 "#include \"hs_bc_%u.len.xxd\"\n,\n"
352 "#include \"hs_bc_%u.spv.xxd\"\n,\n",
353 msb,
354 msb);
355
356 char filename[] = { "hs_bc_XX.comp" };
357 sprintf(filename,"hs_bc_%u.comp",msb);
358
359 target->state->source = fopen(filename,"w+");
360
361 hsg_copyright(target->state->source);
362
363 hsg_macros(target->state->source);
364
365 if (m->warps > 1)
366 {
367 fprintf(target->state->source,
368 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
369 m->warps * config->warp.lanes,
370 m->rows_bc);
371 }
372
373 fprintf(target->state->source,
374 "HS_BC_KERNEL_PROTO(%u,%u)\n",
375 m->warps,msb);
376 }
377 break;
378
379 case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
380 {
381 fprintf(target->state->source,
382 "HS_SUBGROUP_PREAMBLE()\n");
383
384 fprintf(target->state->source,
385 "HS_SLAB_GLOBAL_PREAMBLE();\n");
386 }
387 break;
388
389 case HSG_OP_TYPE_FM_KERNEL_PROTO:
390 {
391 fprintf(target->state->modules,
392 "#include \"hs_fm_%u_%u.len.xxd\"\n,\n"
393 "#include \"hs_fm_%u_%u.spv.xxd\"\n,\n",
394 ops->a,ops->b,
395 ops->a,ops->b);
396
397 char filename[] = { "hs_fm_X_XX.comp" };
398 sprintf(filename,"hs_fm_%u_%u.comp",ops->a,ops->b);
399
400 target->state->source = fopen(filename,"w+");
401
402 hsg_copyright(target->state->source);
403
404 hsg_macros(target->state->source);
405
406 fprintf(target->state->source,
407 "HS_FM_KERNEL_PROTO(%u,%u)\n",
408 ops->a,ops->b);
409 }
410 break;
411
412 case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
413 {
414 fprintf(target->state->source,
415 "HS_SUBGROUP_PREAMBLE()\n");
416
417 fprintf(target->state->source,
418 "HS_FM_PREAMBLE(%u);\n",
419 ops->a);
420 }
421 break;
422
423 case HSG_OP_TYPE_HM_KERNEL_PROTO:
424 {
425 fprintf(target->state->modules,
426 "#include \"hs_hm_%u.len.xxd\"\n,\n"
427 "#include \"hs_hm_%u.spv.xxd\"\n,\n",
428 ops->a,
429 ops->a);
430
431 char filename[] = { "hs_hm_X.comp" };
432 sprintf(filename,"hs_hm_%u.comp",ops->a);
433
434 target->state->source = fopen(filename,"w+");
435
436 hsg_copyright(target->state->source);
437
438 hsg_macros(target->state->source);
439
440 fprintf(target->state->source,
441 "HS_HM_KERNEL_PROTO(%u)\n",
442 ops->a);
443 }
444 break;
445
446 case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
447 {
448 fprintf(target->state->source,
449 "HS_SUBGROUP_PREAMBLE()\n");
450
451 fprintf(target->state->source,
452 "HS_HM_PREAMBLE(%u);\n",
453 ops->a);
454 }
455 break;
456
457 case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
458 {
459 static char const * const vstr[] = { "vin", "vout" };
460
461 fprintf(target->state->source,
462 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
463 ops->n,vstr[ops->v],ops->n-1);
464 }
465 break;
466
467 case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
468 fprintf(target->state->source,
469 "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
470 ops->n-1,ops->n);
471 break;
472
473 case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
474 fprintf(target->state->source,
475 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
476 ops->a,ops->b);
477 break;
478
479 case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
480 fprintf(target->state->source,
481 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
482 ops->b,ops->a);
483 break;
484
485 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
486 fprintf(target->state->source,
487 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
488 ops->a,ops->b);
489 break;
490
491 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
492 fprintf(target->state->source,
493 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
494 ops->b,ops->a);
495 break;
496
497 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
498 fprintf(target->state->source,
499 "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
500 ops->b,ops->a);
501 break;
502
503 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
504 fprintf(target->state->source,
505 "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
506 ops->a,ops->b);
507 break;
508
509 case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
510 {
511 if (ops->a <= ops->b)
512 {
513 fprintf(target->state->source,
514 "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
515 }
516 else if (ops->b > 1)
517 {
518 fprintf(target->state->source,
519 "else if (fm_frac == %u)\n",
520 ops->b);
521 }
522 else
523 {
524 fprintf(target->state->source,
525 "else\n");
526 }
527 }
528 break;
529
530 case HSG_OP_TYPE_SLAB_FLIP:
531 fprintf(target->state->source,
532 "HS_SLAB_FLIP_PREAMBLE(%u);\n",
533 ops->n-1);
534 break;
535
536 case HSG_OP_TYPE_SLAB_HALF:
537 fprintf(target->state->source,
538 "HS_SLAB_HALF_PREAMBLE(%u);\n",
539 ops->n / 2);
540 break;
541
542 case HSG_OP_TYPE_CMP_FLIP:
543 fprintf(target->state->source,
544 "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
545 break;
546
547 case HSG_OP_TYPE_CMP_HALF:
548 fprintf(target->state->source,
549 "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
550 break;
551
552 case HSG_OP_TYPE_CMP_XCHG:
553 if (ops->c == UINT32_MAX)
554 {
555 fprintf(target->state->source,
556 "HS_CMP_XCHG(r%-3u,r%-3u);\n",
557 ops->a,ops->b);
558 }
559 else
560 {
561 fprintf(target->state->source,
562 "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
563 ops->c,ops->a,ops->c,ops->b);
564 }
565 break;
566
567 case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
568 fprintf(target->state->source,
569 "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
570 merge[ops->a].warps,ops->c,ops->b);
571 break;
572
573 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
574 fprintf(target->state->source,
575 "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
576 ops->b,merge[ops->a].warps,ops->c);
577 break;
578
579 case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
580 fprintf(target->state->source,
581 "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
582 ops->b,ops->a,ops->c);
583 break;
584
585 case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
586 fprintf(target->state->source,
587 "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
588 ops->b * config->warp.lanes,
589 ops->c,
590 ops->a);
591 break;
592
593 case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
594 fprintf(target->state->source,
595 "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
596 ops->b * config->warp.lanes,
597 ops->c,
598 ops->a);
599 break;
600
601 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
602 fprintf(target->state->source,
603 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
604 ops->c,
605 ops->a,
606 ops->b * config->warp.lanes);
607 break;
608
609 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
610 fprintf(target->state->source,
611 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
612 ops->c,
613 ops->a,
614 ops->b * config->warp.lanes);
615 break;
616
617 case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
618 fprintf(target->state->source,
619 "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
620 ops->c,
621 ops->a,
622 ops->b);
623 break;
624
625 case HSG_OP_TYPE_BLOCK_SYNC:
626 fprintf(target->state->source,
627 "HS_BLOCK_BARRIER();\n");
628 //
629 // FIXME - Named barriers to allow coordinating warps to proceed?
630 //
631 break;
632
633 case HSG_OP_TYPE_BS_FRAC_PRED:
634 {
635 if (ops->m == 0)
636 {
637 fprintf(target->state->source,
638 "if (warp_idx < bs_full)\n");
639 }
640 else
641 {
642 fprintf(target->state->source,
643 "else if (bs_frac == %u)\n",
644 ops->w);
645 }
646 }
647 break;
648
649 case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
650 {
651 struct hsg_merge const * const m = merge + ops->a;
652
653 fprintf(target->state->source,
654 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
655 m->warps);
656 }
657 break;
658
659 case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
660 {
661 struct hsg_merge const * const m = merge + ops->a;
662
663 fprintf(target->state->source,
664 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
665 m->warps);
666 }
667 break;
668
669 case HSG_OP_TYPE_BX_MERGE_H_PRED:
670 fprintf(target->state->source,
671 "if (HS_SUBGROUP_ID() < %u)\n",
672 ops->a);
673 break;
674
675 case HSG_OP_TYPE_BS_ACTIVE_PRED:
676 {
677 struct hsg_merge const * const m = merge + ops->a;
678
679 if (m->warps <= 32)
680 {
681 fprintf(target->state->source,
682 "if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n",
683 m->levels[ops->b].active.b32a2[0]);
684 }
685 else
686 {
687 fprintf(target->state->source,
688 "if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n",
689 m->levels[ops->b].active.b32a2[1],
690 m->levels[ops->b].active.b32a2[0]);
691 }
692 }
693 break;
694
695 default:
696 fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
697 exit(EXIT_FAILURE);
698 break;
699 }
700 }
701
702 //
703 //
704 //
705