• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can
5  * be found in the LICENSE file.
6  *
7  */
8 
9 #include <stdio.h>
10 #include <stdlib.h>
11 
12 //
13 //
14 //
15 
16 #include "gen.h"
17 #include "transpose.h"
18 
19 #include "common/util.h"
20 #include "common/macros.h"
21 
22 //
23 //
24 //
25 
26 struct hsg_transpose_state
27 {
28   FILE                    * header;
29   struct hsg_config const * config;
30 };
31 
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36   return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38 
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42                     uint32_t const row_ll, // lower-left
43                     uint32_t const row_ur, // upper-right
44                     void *         blend)
45 {
46   struct hsg_transpose_state * const state = blend;
47 
48   // we're starting register names at '1' for now
49   fprintf(state->header,
50           "  HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51           hsg_transpose_reg_prefix(cols_log2-1),
52           hsg_transpose_reg_prefix(cols_log2),
53           cols_log2,row_ll+1,row_ur+1);
54 }
55 
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59                     uint32_t const row_to,
60                     void *         remap)
61 {
62   struct hsg_transpose_state * const state = remap;
63 
64   // we're starting register names at '1' for now
65   fprintf(state->header,
66           "  HS_TRANSPOSE_REMAP( %c, %3u, %3u )        \\\n",
67           hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68           row_from+1,row_to+1);
69 }
70 
71 //
72 //
73 //
74 
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79   fprintf(file,
80           "//                                                              \n"
81           "// Copyright 2016 Google Inc.                                   \n"
82           "//                                                              \n"
83           "// Use of this source code is governed by a BSD-style           \n"
84           "// license that can be found in the LICENSE file.               \n"
85           "//                                                              \n"
86           "\n");
87 }
88 
89 static
90 void
hsg_macros(FILE * file)91 hsg_macros(FILE * file)
92 {
93   fprintf(file,
94           "// target-specific config      \n"
95           "#include \"hs_config.h\"       \n"
96           "                               \n"
97           "// GLSL preamble               \n"
98           "#include \"hs_glsl_preamble.h\"\n"
99           "                               \n"
100           "// arch/target-specific macros \n"
101           "#include \"hs_glsl_macros.h\"  \n"
102           "                               \n"
103           "//                             \n"
104           "//                             \n"
105           "//                             \n"
106           "\n");
107 }
108 
109 //
110 //
111 //
112 
113 struct hsg_target_state
114 {
115   FILE * header;
116   FILE * modules;
117   FILE * source;
118 };
119 
120 //
121 //
122 //
123 
124 void
hsg_target_glsl(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)125 hsg_target_glsl(struct hsg_target       * const target,
126                 struct hsg_config const * const config,
127                 struct hsg_merge  const * const merge,
128                 struct hsg_op     const * const ops,
129                 uint32_t                  const depth)
130 {
131   switch (ops->type)
132     {
133     case HSG_OP_TYPE_END:
134       fprintf(target->state->source,
135               "}\n");
136 
137       if (depth == 0) {
138         fclose(target->state->source);
139         target->state->source = NULL;
140       }
141       break;
142 
143     case HSG_OP_TYPE_BEGIN:
144       fprintf(target->state->source,
145               "{\n");
146       break;
147 
148     case HSG_OP_TYPE_ELSE:
149       fprintf(target->state->source,
150               "else\n");
151       break;
152 
153     case HSG_OP_TYPE_TARGET_BEGIN:
154       {
155         // allocate state
156         target->state = malloc(sizeof(*target->state));
157 
158         // allocate files
159         target->state->header  = fopen("hs_config.h", "wb");
160         target->state->modules = fopen("hs_modules.h","wb");
161 
162         hsg_copyright(target->state->header);
163         hsg_copyright(target->state->modules);
164 
165         // initialize header
166         uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
167 
168         fprintf(target->state->header,
169                 "#ifndef HS_GLSL_ONCE                                            \n"
170                 "#define HS_GLSL_ONCE                                            \n"
171                 "                                                                \n"
172                 "#define HS_SLAB_THREADS_LOG2    %u                              \n"
173                 "#define HS_SLAB_THREADS         (1 << HS_SLAB_THREADS_LOG2)     \n"
174                 "#define HS_SLAB_WIDTH_LOG2      %u                              \n"
175                 "#define HS_SLAB_WIDTH           (1 << HS_SLAB_WIDTH_LOG2)       \n"
176                 "#define HS_SLAB_HEIGHT          %u                              \n"
177                 "#define HS_SLAB_KEYS            (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
178                 "#define HS_REG_LAST(c)          c##%u                           \n"
179                 "#define HS_KEY_WORDS            %u                              \n"
180                 "#define HS_VAL_WORDS            0                               \n"
181                 "#define HS_BS_SLABS             %u                              \n"
182                 "#define HS_BS_SLABS_LOG2_RU     %u                              \n"
183                 "#define HS_BC_SLABS_LOG2_MAX    %u                              \n"
184                 "#define HS_FM_BLOCK_HEIGHT      %u                              \n"
185                 "#define HS_FM_SCALE_MIN         %u                              \n"
186                 "#define HS_FM_SCALE_MAX         %u                              \n"
187                 "#define HS_HM_BLOCK_HEIGHT      %u                              \n"
188                 "#define HS_HM_SCALE_MIN         %u                              \n"
189                 "#define HS_HM_SCALE_MAX         %u                              \n"
190                 "#define HS_EMPTY                                                \n"
191                 "                                                                \n",
192                 config->warp.lanes_log2, // FIXME -- this matters for SIMD
193                 config->warp.lanes_log2,
194                 config->thread.regs,
195                 config->thread.regs,
196                 config->type.words,
197                 merge->warps,
198                 msb_idx_u32(pow2_ru_u32(merge->warps)),
199                 bc_max,
200                 config->merge.flip.warps,
201                 config->merge.flip.lo,
202                 config->merge.flip.hi,
203                 config->merge.half.warps,
204                 config->merge.half.lo,
205                 config->merge.half.hi);
206 
207         if (target->define != NULL)
208           fprintf(target->state->header,"#define %s\n\n",target->define);
209 
210         fprintf(target->state->header,
211                 "#define HS_SLAB_ROWS()    \\\n");
212 
213         for (uint32_t ii=1; ii<=config->thread.regs; ii++)
214           fprintf(target->state->header,
215                   "  HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
216 
217         fprintf(target->state->header,
218                 "  HS_EMPTY\n"
219                 "          \n");
220 
221         fprintf(target->state->header,
222                 "#define HS_TRANSPOSE_SLAB()                \\\n");
223 
224         for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
225           fprintf(target->state->header,
226                   "  HS_TRANSPOSE_STAGE( %u )                  \\\n",ii);
227 
228         struct hsg_transpose_state state[1] =
229           {
230            { .header = target->state->header,
231              .config = config
232            }
233           };
234 
235         hsg_transpose(config->warp.lanes_log2,
236                       config->thread.regs,
237                       hsg_transpose_blend,state,
238                       hsg_transpose_remap,state);
239 
240         fprintf(target->state->header,
241                 "  HS_EMPTY\n"
242                 "          \n");
243       }
244       break;
245 
246     case HSG_OP_TYPE_TARGET_END:
247       // decorate the files
248       fprintf(target->state->header,
249               "#endif \n"
250               "       \n"
251               "//     \n"
252               "//     \n"
253               "//     \n"
254               "       \n");
255 
256       // close files
257       fclose(target->state->header);
258       fclose(target->state->modules);
259 
260       // free state
261       free(target->state);
262       break;
263 
264     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
265       {
266         fprintf(target->state->modules,
267                 "#include \"hs_transpose.len.xxd\"\n,\n"
268                 "#include \"hs_transpose.spv.xxd\"\n,\n");
269 
270         target->state->source = fopen("hs_transpose.comp","w+");
271 
272         hsg_copyright(target->state->source);
273 
274         hsg_macros(target->state->source);
275 
276         fprintf(target->state->source,
277                 "HS_TRANSPOSE_KERNEL_PROTO()\n");
278       }
279       break;
280 
281     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
282       {
283         fprintf(target->state->source,
284                 "HS_SUBGROUP_PREAMBLE();\n");
285 
286         fprintf(target->state->source,
287                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
288       }
289       break;
290 
291     case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
292       {
293         fprintf(target->state->source,
294                 "HS_TRANSPOSE_SLAB()\n");
295       }
296       break;
297 
298     case HSG_OP_TYPE_BS_KERNEL_PROTO:
299       {
300         struct hsg_merge const * const m = merge + ops->a;
301 
302         uint32_t const bs  = pow2_ru_u32(m->warps);
303         uint32_t const msb = msb_idx_u32(bs);
304 
305         fprintf(target->state->modules,
306                 "#include \"hs_bs_%u.len.xxd\"\n,\n"
307                 "#include \"hs_bs_%u.spv.xxd\"\n,\n",
308                 msb,
309                 msb);
310 
311         char filename[] = { "hs_bs_XX.comp" };
312         sprintf(filename,"hs_bs_%u.comp",msb);
313 
314         target->state->source = fopen(filename,"w+");
315 
316         hsg_copyright(target->state->source);
317 
318         hsg_macros(target->state->source);
319 
320         if (m->warps > 1)
321           {
322             fprintf(target->state->source,
323                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
324                     m->warps * config->warp.lanes,
325                     m->rows_bs);
326           }
327 
328         fprintf(target->state->source,
329                 "HS_BS_KERNEL_PROTO(%u,%u)\n",
330                 m->warps,msb);
331       }
332       break;
333 
334     case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
335       {
336         fprintf(target->state->source,
337                 "HS_SUBGROUP_PREAMBLE();\n");
338 
339         fprintf(target->state->source,
340                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
341       }
342       break;
343 
344     case HSG_OP_TYPE_BC_KERNEL_PROTO:
345       {
346         struct hsg_merge const * const m = merge + ops->a;
347 
348         uint32_t const msb = msb_idx_u32(m->warps);
349 
350         fprintf(target->state->modules,
351                 "#include \"hs_bc_%u.len.xxd\"\n,\n"
352                 "#include \"hs_bc_%u.spv.xxd\"\n,\n",
353                 msb,
354                 msb);
355 
356         char filename[] = { "hs_bc_XX.comp" };
357         sprintf(filename,"hs_bc_%u.comp",msb);
358 
359         target->state->source = fopen(filename,"w+");
360 
361         hsg_copyright(target->state->source);
362 
363         hsg_macros(target->state->source);
364 
365         if (m->warps > 1)
366           {
367             fprintf(target->state->source,
368                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
369                     m->warps * config->warp.lanes,
370                     m->rows_bc);
371           }
372 
373         fprintf(target->state->source,
374                 "HS_BC_KERNEL_PROTO(%u,%u)\n",
375                 m->warps,msb);
376       }
377       break;
378 
379     case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
380       {
381         fprintf(target->state->source,
382                 "HS_SUBGROUP_PREAMBLE()\n");
383 
384         fprintf(target->state->source,
385                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
386       }
387       break;
388 
389     case HSG_OP_TYPE_FM_KERNEL_PROTO:
390       {
391         fprintf(target->state->modules,
392                 "#include \"hs_fm_%u_%u.len.xxd\"\n,\n"
393                 "#include \"hs_fm_%u_%u.spv.xxd\"\n,\n",
394                 ops->a,ops->b,
395                 ops->a,ops->b);
396 
397         char filename[] = { "hs_fm_X_XX.comp" };
398         sprintf(filename,"hs_fm_%u_%u.comp",ops->a,ops->b);
399 
400         target->state->source = fopen(filename,"w+");
401 
402         hsg_copyright(target->state->source);
403 
404         hsg_macros(target->state->source);
405 
406         fprintf(target->state->source,
407                 "HS_FM_KERNEL_PROTO(%u,%u)\n",
408                 ops->a,ops->b);
409       }
410       break;
411 
412     case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
413       {
414         fprintf(target->state->source,
415                 "HS_SUBGROUP_PREAMBLE()\n");
416 
417         fprintf(target->state->source,
418                 "HS_FM_PREAMBLE(%u);\n",
419                 ops->a);
420       }
421       break;
422 
423     case HSG_OP_TYPE_HM_KERNEL_PROTO:
424       {
425         fprintf(target->state->modules,
426                 "#include \"hs_hm_%u.len.xxd\"\n,\n"
427                 "#include \"hs_hm_%u.spv.xxd\"\n,\n",
428                 ops->a,
429                 ops->a);
430 
431         char filename[] = { "hs_hm_X.comp" };
432         sprintf(filename,"hs_hm_%u.comp",ops->a);
433 
434         target->state->source = fopen(filename,"w+");
435 
436         hsg_copyright(target->state->source);
437 
438         hsg_macros(target->state->source);
439 
440         fprintf(target->state->source,
441                 "HS_HM_KERNEL_PROTO(%u)\n",
442                 ops->a);
443       }
444       break;
445 
446     case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
447       {
448         fprintf(target->state->source,
449                 "HS_SUBGROUP_PREAMBLE()\n");
450 
451         fprintf(target->state->source,
452                 "HS_HM_PREAMBLE(%u);\n",
453                 ops->a);
454       }
455       break;
456 
457     case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
458       {
459         static char const * const vstr[] = { "vin", "vout" };
460 
461         fprintf(target->state->source,
462                 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
463                 ops->n,vstr[ops->v],ops->n-1);
464       }
465       break;
466 
467     case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
468       fprintf(target->state->source,
469               "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
470               ops->n-1,ops->n);
471       break;
472 
473     case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
474       fprintf(target->state->source,
475               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
476               ops->a,ops->b);
477       break;
478 
479     case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
480       fprintf(target->state->source,
481               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
482               ops->b,ops->a);
483       break;
484 
485     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
486       fprintf(target->state->source,
487               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
488               ops->a,ops->b);
489       break;
490 
491     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
492       fprintf(target->state->source,
493               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
494               ops->b,ops->a);
495       break;
496 
497     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
498       fprintf(target->state->source,
499               "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
500               ops->b,ops->a);
501       break;
502 
503     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
504       fprintf(target->state->source,
505               "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
506               ops->a,ops->b);
507       break;
508 
509     case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
510       {
511         if (ops->a <= ops->b)
512           {
513             fprintf(target->state->source,
514                     "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
515           }
516         else if (ops->b > 1)
517           {
518             fprintf(target->state->source,
519                     "else if (fm_frac == %u)\n",
520                     ops->b);
521           }
522         else
523           {
524 	    fprintf(target->state->source,
525 		    "else\n");
526           }
527       }
528       break;
529 
530     case HSG_OP_TYPE_SLAB_FLIP:
531       fprintf(target->state->source,
532               "HS_SLAB_FLIP_PREAMBLE(%u);\n",
533               ops->n-1);
534       break;
535 
536     case HSG_OP_TYPE_SLAB_HALF:
537       fprintf(target->state->source,
538               "HS_SLAB_HALF_PREAMBLE(%u);\n",
539               ops->n / 2);
540       break;
541 
542     case HSG_OP_TYPE_CMP_FLIP:
543       fprintf(target->state->source,
544               "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
545       break;
546 
547     case HSG_OP_TYPE_CMP_HALF:
548       fprintf(target->state->source,
549               "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
550       break;
551 
552     case HSG_OP_TYPE_CMP_XCHG:
553       if (ops->c == UINT32_MAX)
554         {
555           fprintf(target->state->source,
556                   "HS_CMP_XCHG(r%-3u,r%-3u);\n",
557                   ops->a,ops->b);
558         }
559       else
560         {
561           fprintf(target->state->source,
562                   "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
563                   ops->c,ops->a,ops->c,ops->b);
564         }
565       break;
566 
567     case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
568       fprintf(target->state->source,
569               "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
570               merge[ops->a].warps,ops->c,ops->b);
571       break;
572 
573     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
574       fprintf(target->state->source,
575               "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
576               ops->b,merge[ops->a].warps,ops->c);
577       break;
578 
579     case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
580       fprintf(target->state->source,
581               "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
582               ops->b,ops->a,ops->c);
583       break;
584 
585     case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
586       fprintf(target->state->source,
587               "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
588               ops->b * config->warp.lanes,
589               ops->c,
590               ops->a);
591       break;
592 
593     case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
594       fprintf(target->state->source,
595               "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
596               ops->b * config->warp.lanes,
597               ops->c,
598               ops->a);
599       break;
600 
601     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
602       fprintf(target->state->source,
603               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
604               ops->c,
605               ops->a,
606               ops->b * config->warp.lanes);
607       break;
608 
609     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
610       fprintf(target->state->source,
611               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
612               ops->c,
613               ops->a,
614               ops->b * config->warp.lanes);
615       break;
616 
617     case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
618       fprintf(target->state->source,
619               "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
620               ops->c,
621               ops->a,
622               ops->b);
623       break;
624 
625     case HSG_OP_TYPE_BLOCK_SYNC:
626       fprintf(target->state->source,
627               "HS_BLOCK_BARRIER();\n");
628       //
629       // FIXME - Named barriers to allow coordinating warps to proceed?
630       //
631       break;
632 
633     case HSG_OP_TYPE_BS_FRAC_PRED:
634       {
635         if (ops->m == 0)
636           {
637             fprintf(target->state->source,
638                     "if (warp_idx < bs_full)\n");
639           }
640         else
641           {
642             fprintf(target->state->source,
643                     "else if (bs_frac == %u)\n",
644                     ops->w);
645           }
646       }
647       break;
648 
649     case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
650       {
651         struct hsg_merge const * const m = merge + ops->a;
652 
653         fprintf(target->state->source,
654                 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
655                 m->warps);
656       }
657       break;
658 
659     case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
660       {
661         struct hsg_merge const * const m = merge + ops->a;
662 
663         fprintf(target->state->source,
664                 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
665                 m->warps);
666       }
667       break;
668 
669     case HSG_OP_TYPE_BX_MERGE_H_PRED:
670       fprintf(target->state->source,
671               "if (HS_SUBGROUP_ID() < %u)\n",
672               ops->a);
673       break;
674 
675     case HSG_OP_TYPE_BS_ACTIVE_PRED:
676       {
677         struct hsg_merge const * const m = merge + ops->a;
678 
679         if (m->warps <= 32)
680           {
681             fprintf(target->state->source,
682                     "if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n",
683                     m->levels[ops->b].active.b32a2[0]);
684           }
685         else
686           {
687             fprintf(target->state->source,
688                     "if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n",
689                     m->levels[ops->b].active.b32a2[1],
690                     m->levels[ops->b].active.b32a2[0]);
691           }
692       }
693       break;
694 
695     default:
696       fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
697       exit(EXIT_FAILURE);
698       break;
699     }
700 }
701 
702 //
703 //
704 //
705