• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can
5  * be found in the LICENSE file.
6  *
7  */
8 
9 #include <stdio.h>
10 #include <stdlib.h>
11 
12 //
13 //
14 //
15 
16 #include "gen.h"
17 #include "transpose.h"
18 
19 #include "common/util.h"
20 #include "common/macros.h"
21 
22 //
23 //
24 //
25 
26 struct hsg_transpose_state
27 {
28   FILE                    * header;
29   struct hsg_config const * config;
30 };
31 
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36   return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38 
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42                     uint32_t const row_ll, // lower-left
43                     uint32_t const row_ur, // upper-right
44                     void *         blend)
45 {
46   struct hsg_transpose_state * const state = blend;
47 
48   // we're starting register names at '1' for now
49   fprintf(state->header,
50           "  HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51           hsg_transpose_reg_prefix(cols_log2-1),
52           hsg_transpose_reg_prefix(cols_log2),
53           cols_log2,row_ll+1,row_ur+1);
54 }
55 
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59                     uint32_t const row_to,
60                     void *         remap)
61 {
62   struct hsg_transpose_state * const state = remap;
63 
64   // we're starting register names at '1' for now
65   fprintf(state->header,
66           "  HS_TRANSPOSE_REMAP( %c, %3u, %3u )        \\\n",
67           hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68           row_from+1,row_to+1);
69 }
70 
71 //
72 //
73 //
74 
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79   fprintf(file,
80           "//                                                    \n"
81           "// Copyright 2016 Google Inc.                         \n"
82           "//                                                    \n"
83           "// Use of this source code is governed by a BSD-style \n"
84           "// license that can be found in the LICENSE file.     \n"
85           "//                                                    \n"
86           "\n");
87 }
88 
89 //
90 //
91 //
92 
93 struct hsg_target_state
94 {
95   FILE * header;
96   FILE * source;
97 };
98 
99 //
100 //
101 //
102 
103 void
hsg_target_cuda(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)104 hsg_target_cuda(struct hsg_target       * const target,
105                 struct hsg_config const * const config,
106                 struct hsg_merge  const * const merge,
107                 struct hsg_op     const * const ops,
108                 uint32_t                  const depth)
109 {
110   switch (ops->type)
111     {
112     case HSG_OP_TYPE_END:
113       fprintf(target->state->source,
114               "}\n");
115       break;
116 
117     case HSG_OP_TYPE_BEGIN:
118       fprintf(target->state->source,
119               "{\n");
120       break;
121 
122     case HSG_OP_TYPE_ELSE:
123       fprintf(target->state->source,
124               "else\n");
125       break;
126 
127     case HSG_OP_TYPE_TARGET_BEGIN:
128       {
129         // allocate state
130         target->state = malloc(sizeof(*target->state));
131 
132         //
133         // Note that we're generating file names with different
134         // suffixes despite storing them in different directories
135         // because NVCC on Visual Studio appears to overwrite .cu
136         // files with the same name but different paths.
137         //
138         // I would prefer to the original layout:
139         //
140         //   path/to/<type>/hs_cuda.[cu|config|h]
141         //
142 
143         // allocate files
144         target->state->header = fopen("hs_cuda_config.h","wb");
145         target->state->source = fopen((config->type.words == 1) ?
146                                       "hs_cuda_u32.cu" : "hs_cuda_u64.cu",
147                                       "wb");
148 
149         // initialize header
150         uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
151 
152         hsg_copyright(target->state->header);
153 
154         fprintf(target->state->header,
155                 "#ifndef HS_CUDA_CONFIG_ONCE                                     \n"
156                 "#define HS_CUDA_CONFIG_ONCE                                     \n"
157                 "                                                                \n"
158                 "#define HS_SLAB_THREADS_LOG2    %u                              \n"
159                 "#define HS_SLAB_THREADS         (1 << HS_SLAB_THREADS_LOG2)     \n"
160                 "#define HS_SLAB_WIDTH_LOG2      %u                              \n"
161                 "#define HS_SLAB_WIDTH           (1 << HS_SLAB_WIDTH_LOG2)       \n"
162                 "#define HS_SLAB_HEIGHT          %u                              \n"
163                 "#define HS_SLAB_KEYS            (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
164                 "#define HS_REG_LAST(c)          c##%u                           \n"
165                 "#define HS_KEY_TYPE_PRETTY      %s                              \n"
166                 "#define HS_KEY_WORDS            %u                              \n"
167                 "#define HS_VAL_WORDS            0                               \n"
168                 "#define HS_BS_SLABS             %u                              \n"
169                 "#define HS_BS_SLABS_LOG2_RU     %u                              \n"
170                 "#define HS_BC_SLABS_LOG2_MAX    %u                              \n"
171                 "#define HS_FM_BLOCK_HEIGHT      %u                              \n"
172                 "#define HS_FM_SCALE_MIN         %u                              \n"
173                 "#define HS_FM_SCALE_MAX         %u                              \n"
174                 "#define HS_HM_BLOCK_HEIGHT      %u                              \n"
175                 "#define HS_HM_SCALE_MIN         %u                              \n"
176                 "#define HS_HM_SCALE_MAX         %u                              \n"
177                 "#define HS_EMPTY                                                \n"
178                 "                                                                \n",
179                 config->warp.lanes_log2,
180                 config->warp.lanes_log2,
181                 config->thread.regs,
182                 config->thread.regs,
183                 (config->type.words == 1) ? "u32"      : "u64",
184                 config->type.words,
185                 merge->warps,
186                 msb_idx_u32(pow2_ru_u32(merge->warps)),
187                 bc_max,
188                 config->merge.flip.warps,
189                 config->merge.flip.lo,
190                 config->merge.flip.hi,
191                 config->merge.half.warps,
192                 config->merge.half.lo,
193                 config->merge.half.hi);
194 
195         if (target->define != NULL)
196           fprintf(target->state->header,"#define %s\n\n",target->define);
197 
198         fprintf(target->state->header,
199                 "#define HS_SLAB_ROWS()    \\\n");
200 
201         for (uint32_t ii=1; ii<=config->thread.regs; ii++)
202           fprintf(target->state->header,
203                   "  HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
204 
205         fprintf(target->state->header,
206                 "  HS_EMPTY\n"
207                 "          \n");
208 
209         fprintf(target->state->header,
210                 "#define HS_TRANSPOSE_SLAB()                \\\n");
211 
212         for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
213           fprintf(target->state->header,
214                   "  HS_TRANSPOSE_STAGE( %u )                  \\\n",ii);
215 
216         struct hsg_transpose_state state[1] =
217           {
218            { .header = target->state->header,
219              .config = config
220            }
221           };
222 
223         hsg_transpose(config->warp.lanes_log2,
224                       config->thread.regs,
225                       hsg_transpose_blend,state,
226                       hsg_transpose_remap,state);
227 
228         fprintf(target->state->header,
229                 "  HS_EMPTY\n"
230                 "          \n");
231 
232         hsg_copyright(target->state->source);
233 
234         fprintf(target->state->source,
235                 "#ifdef __cplusplus               \n"
236                 "extern \"C\" {                   \n"
237                 "#endif                           \n"
238                 "                                 \n"
239                 "#include \"hs_cuda.h\"           \n"
240                 "                                 \n"
241                 "#ifdef __cplusplus               \n"
242                 "}                                \n"
243                 "#endif                           \n"
244                 "                                 \n"
245                 "#include \"hs_cuda_config.h\"    \n"
246                 "                                 \n"
247                 "#include \"../hs_cuda_macros.h\" \n"
248                 "                                 \n"
249                 "//                               \n"
250                 "//                               \n"
251                 "//                               \n");
252       }
253       break;
254 
255     case HSG_OP_TYPE_TARGET_END:
256       // decorate the files
257       fprintf(target->state->header,
258               "#endif \n"
259               "       \n"
260               "//     \n"
261               "//     \n"
262               "//     \n"
263               "       \n");
264       fprintf(target->state->source,
265               "       \n"
266               "//     \n"
267               "//     \n"
268               "//     \n"
269               "       \n"
270               "#include \"../../hs_cuda.inl\" \n"
271               "       \n"
272               "//     \n"
273               "//     \n"
274               "//     \n"
275               "       \n");
276 
277       // close files
278       fclose(target->state->header);
279       fclose(target->state->source);
280 
281       // free state
282       free(target->state);
283       break;
284 
285     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
286       {
287         fprintf(target->state->source,
288                 "\nHS_TRANSPOSE_KERNEL_PROTO()\n");
289       }
290       break;
291 
292     case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
293       {
294         fprintf(target->state->source,
295                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
296       }
297       break;
298 
299     case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
300       {
301         fprintf(target->state->source,
302                 "HS_TRANSPOSE_SLAB();\n");
303       }
304       break;
305 
306     case HSG_OP_TYPE_BS_KERNEL_PROTO:
307       {
308         struct hsg_merge const * const m = merge + ops->a;
309 
310         uint32_t const bs  = pow2_ru_u32(m->warps);
311         uint32_t const msb = msb_idx_u32(bs);
312 
313         if (ops->a == 0)
314           {
315             fprintf(target->state->source,
316                     "\nHS_BS_KERNEL_PROTO(%u,%u)\n",
317                     m->warps,msb);
318           }
319         else
320           {
321             fprintf(target->state->source,
322                     "\nHS_OFFSET_BS_KERNEL_PROTO(%u,%u)\n",
323                     m->warps,msb);
324           }
325       }
326       break;
327 
328     case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
329       {
330         struct hsg_merge const * const m = merge + ops->a;
331 
332         if (m->warps > 1)
333           {
334             fprintf(target->state->source,
335                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
336                     m->warps * config->warp.lanes,
337                     m->rows_bs);
338           }
339 
340         if (ops->a == 0)
341           {
342             fprintf(target->state->source,
343                     "HS_SLAB_GLOBAL_PREAMBLE();\n");
344           }
345         else
346           {
347             fprintf(target->state->source,
348                     "HS_OFFSET_SLAB_GLOBAL_PREAMBLE();\n");
349           }
350       }
351       break;
352 
353     case HSG_OP_TYPE_BC_KERNEL_PROTO:
354       {
355         struct hsg_merge const * const m = merge + ops->a;
356 
357         uint32_t const msb = msb_idx_u32(m->warps);
358 
359         fprintf(target->state->source,
360                 "\nHS_BC_KERNEL_PROTO(%u,%u)\n",
361                 m->warps,msb);
362       }
363       break;
364 
365     case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
366       {
367         struct hsg_merge const * const m = merge + ops->a;
368 
369         if (m->warps > 1)
370           {
371             fprintf(target->state->source,
372                     "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
373                     m->warps * config->warp.lanes,
374                     m->rows_bc);
375           }
376 
377         fprintf(target->state->source,
378                 "HS_SLAB_GLOBAL_PREAMBLE();\n");
379       }
380       break;
381 
382     case HSG_OP_TYPE_FM_KERNEL_PROTO:
383       {
384         uint32_t const span_left  = (merge[0].warps << ops->a) / 2;
385         uint32_t const span_right = 1 << ops->b;
386 
387         // uint32_t const msb = msb_idx_u32(pow2_ru_u32(merge[0].warps));
388         // if ((ops->a + ops->b - 1) == msb)
389 
390         if (span_right == span_left)
391           {
392             fprintf(target->state->source,
393                     "\nHS_FM_KERNEL_PROTO(%u,%u)\n",
394                     ops->a,ops->b);
395           }
396         else
397           {
398             fprintf(target->state->source,
399                     "\nHS_OFFSET_FM_KERNEL_PROTO(%u,%u)\n",
400                     ops->a,ops->b);
401           }
402       }
403       break;
404 
405     case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
406       {
407         uint32_t const msb = msb_idx_u32(pow2_ru_u32(merge[0].warps));
408 
409         if (ops->a == ops->b) // equal left and right spans
410           {
411             fprintf(target->state->source,
412                     "HS_FM_PREAMBLE(%u);\n",
413                     ops->a);
414           }
415         else // right span is lesser pow2
416           {
417             fprintf(target->state->source,
418                     "HS_OFFSET_FM_PREAMBLE(%u);\n",
419                     ops->a);
420           }
421       }
422       break;
423 
424     case HSG_OP_TYPE_HM_KERNEL_PROTO:
425       {
426         fprintf(target->state->source,
427                 "\nHS_HM_KERNEL_PROTO(%u)\n",
428                 ops->a);
429       }
430       break;
431 
432     case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
433       fprintf(target->state->source,
434               "HS_HM_PREAMBLE(%u);\n",
435               ops->a);
436       break;
437 
438     case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
439       {
440         static char const * const vstr[] = { "vin", "vout" };
441 
442         fprintf(target->state->source,
443                 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
444                 ops->n,vstr[ops->v],ops->n-1);
445       }
446       break;
447 
448     case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
449       fprintf(target->state->source,
450               "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
451               ops->n-1,ops->n);
452       break;
453 
454     case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
455       fprintf(target->state->source,
456               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
457               ops->a,ops->b);
458       break;
459 
460     case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
461       fprintf(target->state->source,
462               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
463               ops->b,ops->a);
464       break;
465 
466     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
467       fprintf(target->state->source,
468               "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
469               ops->a,ops->b);
470       break;
471 
472     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
473       fprintf(target->state->source,
474               "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
475               ops->b,ops->a);
476       break;
477 
478     case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
479       fprintf(target->state->source,
480               "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
481               ops->b,ops->a);
482       break;
483 
484     case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
485       fprintf(target->state->source,
486               "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
487               ops->a,ops->b);
488       break;
489 
490     case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
491       {
492         if (ops->a <= ops->b)
493           {
494             fprintf(target->state->source,
495                     "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
496           }
497         else if (ops->b > 1)
498           {
499             fprintf(target->state->source,
500                     "else if (fm_frac == %u)\n",
501                     ops->b);
502           }
503         else
504           {
505 	    fprintf(target->state->source,
506 		    "else\n");
507           }
508       }
509       break;
510 
511     case HSG_OP_TYPE_SLAB_FLIP:
512       fprintf(target->state->source,
513               "HS_SLAB_FLIP_PREAMBLE(%u);\n",
514               ops->n-1);
515       break;
516 
517     case HSG_OP_TYPE_SLAB_HALF:
518       fprintf(target->state->source,
519               "HS_SLAB_HALF_PREAMBLE(%u);\n",
520               ops->n / 2);
521       break;
522 
523     case HSG_OP_TYPE_CMP_FLIP:
524       fprintf(target->state->source,
525               "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
526       break;
527 
528     case HSG_OP_TYPE_CMP_HALF:
529       fprintf(target->state->source,
530               "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
531       break;
532 
533     case HSG_OP_TYPE_CMP_XCHG:
534       if (ops->c == UINT32_MAX)
535         {
536           fprintf(target->state->source,
537                   "HS_CMP_XCHG(r%-3u,r%-3u);\n",
538                   ops->a,ops->b);
539         }
540       else
541         {
542           fprintf(target->state->source,
543                   "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
544                   ops->c,ops->a,ops->c,ops->b);
545         }
546       break;
547 
548     case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
549       fprintf(target->state->source,
550               "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
551               merge[ops->a].warps,ops->c,ops->b);
552       break;
553 
554     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
555       fprintf(target->state->source,
556               "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
557               ops->b,merge[ops->a].warps,ops->c);
558       break;
559 
560     case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
561       fprintf(target->state->source,
562               "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
563               ops->b,ops->a,ops->c);
564       break;
565 
566     case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
567       fprintf(target->state->source,
568               "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
569               ops->b * config->warp.lanes,
570               ops->c,
571               ops->a);
572       break;
573 
574     case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
575       fprintf(target->state->source,
576               "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
577               ops->b * config->warp.lanes,
578               ops->c,
579               ops->a);
580       break;
581 
582     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
583       fprintf(target->state->source,
584               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
585               ops->c,
586               ops->a,
587               ops->b * config->warp.lanes);
588       break;
589 
590     case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
591       fprintf(target->state->source,
592               "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
593               ops->c,
594               ops->a,
595               ops->b * config->warp.lanes);
596       break;
597 
598     case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
599       fprintf(target->state->source,
600               "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
601               ops->c,
602               ops->a,
603               ops->b);
604       break;
605 
606     case HSG_OP_TYPE_BLOCK_SYNC:
607       fprintf(target->state->source,
608               "HS_BLOCK_BARRIER();\n");
609       //
610       // FIXME - Named barriers to allow coordinating warps to proceed?
611       //
612       break;
613 
614     case HSG_OP_TYPE_BS_FRAC_PRED:
615       {
616         if (ops->m == 0)
617           {
618             fprintf(target->state->source,
619                     "if (warp_idx < bs_full)\n");
620           }
621         else
622           {
623             fprintf(target->state->source,
624                     "else if (bs_frac == %u)\n",
625                     ops->w);
626           }
627       }
628       break;
629 
630     case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
631       {
632         struct hsg_merge const * const m = merge + ops->a;
633 
634         fprintf(target->state->source,
635                 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
636                 m->warps);
637       }
638       break;
639 
640     case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
641       {
642         struct hsg_merge const * const m = merge + ops->a;
643 
644         fprintf(target->state->source,
645                 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
646                 m->warps);
647       }
648       break;
649 
650     case HSG_OP_TYPE_BX_MERGE_H_PRED:
651       fprintf(target->state->source,
652               "if (HS_WARP_ID_X() < %u)\n",
653               ops->a);
654       break;
655 
656     case HSG_OP_TYPE_BS_ACTIVE_PRED:
657       {
658         struct hsg_merge const * const m = merge + ops->a;
659 
660         if (m->warps <= 32)
661           {
662             fprintf(target->state->source,
663                     "if (((1u << HS_WARP_ID_X()) & 0x%08X) != 0)\n",
664                     m->levels[ops->b].active.b32a2[0]);
665           }
666         else
667           {
668             fprintf(target->state->source,
669                     "if (((1UL << HS_WARP_ID_X()) & 0x%08X%08XL) != 0L)\n",
670                     m->levels[ops->b].active.b32a2[1],
671                     m->levels[ops->b].active.b32a2[0]);
672           }
673       }
674       break;
675 
676     default:
677       fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
678       exit(EXIT_FAILURE);
679       break;
680     }
681 }
682 
683 //
684 //
685 //
686