1 /*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can
5 * be found in the LICENSE file.
6 *
7 */
8
9 #include <stdio.h>
10 #include <stdlib.h>
11
12 //
13 //
14 //
15
16 #include "gen.h"
17 #include "transpose.h"
18
19 #include "common/util.h"
20 #include "common/macros.h"
21
22 //
23 //
24 //
25
26 struct hsg_transpose_state
27 {
28 FILE * header;
29 struct hsg_config const * config;
30 };
31
32 static
33 char
hsg_transpose_reg_prefix(uint32_t const cols_log2)34 hsg_transpose_reg_prefix(uint32_t const cols_log2)
35 {
36 return 'a' + (('r' + cols_log2 - 'a') % 26);
37 }
38
39 static
40 void
hsg_transpose_blend(uint32_t const cols_log2,uint32_t const row_ll,uint32_t const row_ur,void * blend)41 hsg_transpose_blend(uint32_t const cols_log2,
42 uint32_t const row_ll, // lower-left
43 uint32_t const row_ur, // upper-right
44 void * blend)
45 {
46 struct hsg_transpose_state * const state = blend;
47
48 // we're starting register names at '1' for now
49 fprintf(state->header,
50 " HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n",
51 hsg_transpose_reg_prefix(cols_log2-1),
52 hsg_transpose_reg_prefix(cols_log2),
53 cols_log2,row_ll+1,row_ur+1);
54 }
55
56 static
57 void
hsg_transpose_remap(uint32_t const row_from,uint32_t const row_to,void * remap)58 hsg_transpose_remap(uint32_t const row_from,
59 uint32_t const row_to,
60 void * remap)
61 {
62 struct hsg_transpose_state * const state = remap;
63
64 // we're starting register names at '1' for now
65 fprintf(state->header,
66 " HS_TRANSPOSE_REMAP( %c, %3u, %3u ) \\\n",
67 hsg_transpose_reg_prefix(state->config->warp.lanes_log2),
68 row_from+1,row_to+1);
69 }
70
71 //
72 //
73 //
74
75 static
76 void
hsg_copyright(FILE * file)77 hsg_copyright(FILE * file)
78 {
79 fprintf(file,
80 "// \n"
81 "// Copyright 2016 Google Inc. \n"
82 "// \n"
83 "// Use of this source code is governed by a BSD-style \n"
84 "// license that can be found in the LICENSE file. \n"
85 "// \n"
86 "\n");
87 }
88
89 static
90 void
hsg_macros(FILE * file)91 hsg_macros(FILE * file)
92 {
93 fprintf(file,
94 "// target-specific config \n"
95 "#include \"hs_config.h\" \n"
96 " \n"
97 "// arch/target-specific macros\n"
98 "#include \"hs_cl_macros.h\" \n"
99 " \n"
100 "// \n"
101 "// \n"
102 "// \n");
103 }
104
105 //
106 //
107 //
108
109 struct hsg_target_state
110 {
111 FILE * header;
112 FILE * source;
113 };
114
115 //
116 //
117 //
118
119 void
hsg_target_opencl(struct hsg_target * const target,struct hsg_config const * const config,struct hsg_merge const * const merge,struct hsg_op const * const ops,uint32_t const depth)120 hsg_target_opencl(struct hsg_target * const target,
121 struct hsg_config const * const config,
122 struct hsg_merge const * const merge,
123 struct hsg_op const * const ops,
124 uint32_t const depth)
125 {
126 switch (ops->type)
127 {
128 case HSG_OP_TYPE_END:
129 fprintf(target->state->source,
130 "}\n");
131 break;
132
133 case HSG_OP_TYPE_BEGIN:
134 fprintf(target->state->source,
135 "{\n");
136 break;
137
138 case HSG_OP_TYPE_ELSE:
139 fprintf(target->state->source,
140 "else\n");
141 break;
142
143 case HSG_OP_TYPE_TARGET_BEGIN:
144 {
145 // allocate state
146 target->state = malloc(sizeof(*target->state));
147
148 // allocate files
149 target->state->header = fopen("hs_config.h", "wb");
150 target->state->source = fopen("hs_kernels.cl","wb");
151
152 // initialize header
153 uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps));
154
155 hsg_copyright(target->state->header);
156
157 fprintf(target->state->header,
158 "#ifndef HS_CL_ONCE \n"
159 "#define HS_CL_ONCE \n"
160 " \n"
161 "#define HS_SLAB_THREADS_LOG2 %u \n"
162 "#define HS_SLAB_THREADS (1 << HS_SLAB_THREADS_LOG2) \n"
163 "#define HS_SLAB_WIDTH_LOG2 %u \n"
164 "#define HS_SLAB_WIDTH (1 << HS_SLAB_WIDTH_LOG2) \n"
165 "#define HS_SLAB_HEIGHT %u \n"
166 "#define HS_SLAB_KEYS (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n"
167 "#define HS_REG_LAST(c) c##%u \n"
168 "#define HS_KEY_WORDS %u \n"
169 "#define HS_VAL_WORDS 0 \n"
170 "#define HS_BS_SLABS %u \n"
171 "#define HS_BS_SLABS_LOG2_RU %u \n"
172 "#define HS_BC_SLABS_LOG2_MAX %u \n"
173 "#define HS_FM_BLOCK_HEIGHT %u \n"
174 "#define HS_FM_SCALE_MIN %u \n"
175 "#define HS_FM_SCALE_MAX %u \n"
176 "#define HS_HM_BLOCK_HEIGHT %u \n"
177 "#define HS_HM_SCALE_MIN %u \n"
178 "#define HS_HM_SCALE_MAX %u \n"
179 "#define HS_EMPTY \n"
180 " \n",
181 config->warp.lanes_log2, // FIXME - may be different on a SIMD target
182 config->warp.lanes_log2,
183 config->thread.regs,
184 config->thread.regs,
185 config->type.words,
186 merge->warps,
187 msb_idx_u32(pow2_ru_u32(merge->warps)),
188 bc_max,
189 config->merge.flip.warps,
190 config->merge.flip.lo,
191 config->merge.flip.hi,
192 config->merge.half.warps,
193 config->merge.half.lo,
194 config->merge.half.hi);
195
196 if (target->define != NULL)
197 fprintf(target->state->header,"#define %s\n\n",target->define);
198
199 fprintf(target->state->header,
200 "#define HS_SLAB_ROWS() \\\n");
201
202 for (uint32_t ii=1; ii<=config->thread.regs; ii++)
203 fprintf(target->state->header,
204 " HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1);
205
206 fprintf(target->state->header,
207 " HS_EMPTY\n"
208 " \n");
209
210 fprintf(target->state->header,
211 "#define HS_TRANSPOSE_SLAB() \\\n");
212
213 for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++)
214 fprintf(target->state->header,
215 " HS_TRANSPOSE_STAGE( %u ) \\\n",ii);
216
217 struct hsg_transpose_state state[1] =
218 {
219 { .header = target->state->header,
220 .config = config
221 }
222 };
223
224 hsg_transpose(config->warp.lanes_log2,
225 config->thread.regs,
226 hsg_transpose_blend,state,
227 hsg_transpose_remap,state);
228
229 fprintf(target->state->header,
230 " HS_EMPTY\n"
231 " \n");
232
233 hsg_copyright(target->state->source);
234
235 hsg_macros(target->state->source);
236 }
237 break;
238
239 case HSG_OP_TYPE_TARGET_END:
240 // decorate the files
241 fprintf(target->state->header,
242 "#endif \n"
243 " \n"
244 "// \n"
245 "// \n"
246 "// \n"
247 " \n");
248 fprintf(target->state->source,
249 " \n"
250 "// \n"
251 "// \n"
252 "// \n"
253 " \n");
254
255 // close files
256 fclose(target->state->header);
257 fclose(target->state->source);
258
259 // free state
260 free(target->state);
261 break;
262
263 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO:
264 {
265 fprintf(target->state->source,
266 "\nHS_TRANSPOSE_KERNEL_PROTO()\n");
267 }
268 break;
269
270 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE:
271 {
272 fprintf(target->state->source,
273 "HS_SLAB_GLOBAL_PREAMBLE();\n");
274 }
275 break;
276
277 case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY:
278 {
279 fprintf(target->state->source,
280 "HS_TRANSPOSE_SLAB()\n");
281 }
282 break;
283
284 case HSG_OP_TYPE_BS_KERNEL_PROTO:
285 {
286 struct hsg_merge const * const m = merge + ops->a;
287
288 uint32_t const bs = pow2_ru_u32(m->warps);
289 uint32_t const msb = msb_idx_u32(bs);
290
291 fprintf(target->state->source,
292 "\nHS_BS_KERNEL_PROTO(%u,%u)\n",
293 m->warps,msb);
294 }
295 break;
296
297 case HSG_OP_TYPE_BS_KERNEL_PREAMBLE:
298 {
299 struct hsg_merge const * const m = merge + ops->a;
300
301 if (m->warps > 1)
302 {
303 fprintf(target->state->source,
304 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
305 m->warps * config->warp.lanes,
306 m->rows_bs);
307 }
308
309 fprintf(target->state->source,
310 "HS_SLAB_GLOBAL_PREAMBLE();\n");
311 }
312 break;
313
314 case HSG_OP_TYPE_BC_KERNEL_PROTO:
315 {
316 struct hsg_merge const * const m = merge + ops->a;
317
318 uint32_t const msb = msb_idx_u32(m->warps);
319
320 fprintf(target->state->source,
321 "\nHS_BC_KERNEL_PROTO(%u,%u)\n",
322 m->warps,msb);
323 }
324 break;
325
326 case HSG_OP_TYPE_BC_KERNEL_PREAMBLE:
327 {
328 struct hsg_merge const * const m = merge + ops->a;
329
330 if (m->warps > 1)
331 {
332 fprintf(target->state->source,
333 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n",
334 m->warps * config->warp.lanes,
335 m->rows_bc);
336 }
337
338 fprintf(target->state->source,
339 "HS_SLAB_GLOBAL_PREAMBLE();\n");
340 }
341 break;
342
343 case HSG_OP_TYPE_FM_KERNEL_PROTO:
344 fprintf(target->state->source,
345 "\nHS_FM_KERNEL_PROTO(%u,%u)\n",
346 ops->a,ops->b);
347 break;
348
349 case HSG_OP_TYPE_FM_KERNEL_PREAMBLE:
350 fprintf(target->state->source,
351 "HS_FM_PREAMBLE(%u);\n",
352 ops->a);
353 break;
354
355 case HSG_OP_TYPE_HM_KERNEL_PROTO:
356 {
357 fprintf(target->state->source,
358 "\nHS_HM_KERNEL_PROTO(%u)\n",
359 ops->a);
360 }
361 break;
362
363 case HSG_OP_TYPE_HM_KERNEL_PREAMBLE:
364 fprintf(target->state->source,
365 "HS_HM_PREAMBLE(%u);\n",
366 ops->a);
367 break;
368
369 case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD:
370 {
371 static char const * const vstr[] = { "vin", "vout" };
372
373 fprintf(target->state->source,
374 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n",
375 ops->n,vstr[ops->v],ops->n-1);
376 }
377 break;
378
379 case HSG_OP_TYPE_BX_REG_GLOBAL_STORE:
380 fprintf(target->state->source,
381 "HS_SLAB_GLOBAL_STORE(%u,r%u);\n",
382 ops->n-1,ops->n);
383 break;
384
385 case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD:
386 fprintf(target->state->source,
387 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
388 ops->a,ops->b);
389 break;
390
391 case HSG_OP_TYPE_HM_REG_GLOBAL_STORE:
392 fprintf(target->state->source,
393 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
394 ops->b,ops->a);
395 break;
396
397 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT:
398 fprintf(target->state->source,
399 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n",
400 ops->a,ops->b);
401 break;
402
403 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT:
404 fprintf(target->state->source,
405 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n",
406 ops->b,ops->a);
407 break;
408
409 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT:
410 fprintf(target->state->source,
411 "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n",
412 ops->b,ops->a);
413 break;
414
415 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT:
416 fprintf(target->state->source,
417 "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n",
418 ops->a,ops->b);
419 break;
420
421 case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED:
422 {
423 if (ops->a <= ops->b)
424 {
425 fprintf(target->state->source,
426 "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n");
427 }
428 else if (ops->b > 1)
429 {
430 fprintf(target->state->source,
431 "else if (fm_frac == %u)\n",
432 ops->b);
433 }
434 else
435 {
436 fprintf(target->state->source,
437 "else\n");
438 }
439 }
440 break;
441
442 case HSG_OP_TYPE_SLAB_FLIP:
443 fprintf(target->state->source,
444 "HS_SLAB_FLIP_PREAMBLE(%u);\n",
445 ops->n-1);
446 break;
447
448 case HSG_OP_TYPE_SLAB_HALF:
449 fprintf(target->state->source,
450 "HS_SLAB_HALF_PREAMBLE(%u);\n",
451 ops->n / 2);
452 break;
453
454 case HSG_OP_TYPE_CMP_FLIP:
455 fprintf(target->state->source,
456 "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c);
457 break;
458
459 case HSG_OP_TYPE_CMP_HALF:
460 fprintf(target->state->source,
461 "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b);
462 break;
463
464 case HSG_OP_TYPE_CMP_XCHG:
465 if (ops->c == UINT32_MAX)
466 {
467 fprintf(target->state->source,
468 "HS_CMP_XCHG(r%-3u,r%-3u);\n",
469 ops->a,ops->b);
470 }
471 else
472 {
473 fprintf(target->state->source,
474 "HS_CMP_XCHG(r%u_%u,r%u_%u);\n",
475 ops->c,ops->a,ops->c,ops->b);
476 }
477 break;
478
479 case HSG_OP_TYPE_BS_REG_SHARED_STORE_V:
480 fprintf(target->state->source,
481 "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n",
482 merge[ops->a].warps,ops->c,ops->b);
483 break;
484
485 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V:
486 fprintf(target->state->source,
487 "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
488 ops->b,merge[ops->a].warps,ops->c);
489 break;
490
491 case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V:
492 fprintf(target->state->source,
493 "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n",
494 ops->b,ops->a,ops->c);
495 break;
496
497 case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT:
498 fprintf(target->state->source,
499 "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n",
500 ops->b * config->warp.lanes,
501 ops->c,
502 ops->a);
503 break;
504
505 case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT:
506 fprintf(target->state->source,
507 "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n",
508 ops->b * config->warp.lanes,
509 ops->c,
510 ops->a);
511 break;
512
513 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT:
514 fprintf(target->state->source,
515 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n",
516 ops->c,
517 ops->a,
518 ops->b * config->warp.lanes);
519 break;
520
521 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT:
522 fprintf(target->state->source,
523 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n",
524 ops->c,
525 ops->a,
526 ops->b * config->warp.lanes);
527 break;
528
529 case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT:
530 fprintf(target->state->source,
531 "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n",
532 ops->c,
533 ops->a,
534 ops->b);
535 break;
536
537 case HSG_OP_TYPE_BLOCK_SYNC:
538 fprintf(target->state->source,
539 "HS_BLOCK_BARRIER();\n");
540 //
541 // FIXME - Named barriers to allow coordinating warps to proceed?
542 //
543 break;
544
545 case HSG_OP_TYPE_BS_FRAC_PRED:
546 {
547 if (ops->m == 0)
548 {
549 fprintf(target->state->source,
550 "if (warp_idx < bs_full)\n");
551 }
552 else
553 {
554 fprintf(target->state->source,
555 "else if (bs_frac == %u)\n",
556 ops->w);
557 }
558 }
559 break;
560
561 case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE:
562 {
563 struct hsg_merge const * const m = merge + ops->a;
564
565 fprintf(target->state->source,
566 "HS_BS_MERGE_H_PREAMBLE(%u);\n",
567 m->warps);
568 }
569 break;
570
571 case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE:
572 {
573 struct hsg_merge const * const m = merge + ops->a;
574
575 fprintf(target->state->source,
576 "HS_BC_MERGE_H_PREAMBLE(%u);\n",
577 m->warps);
578 }
579 break;
580
581 case HSG_OP_TYPE_BX_MERGE_H_PRED:
582 fprintf(target->state->source,
583 "if (HS_SUBGROUP_ID() < %u)\n",
584 ops->a);
585 break;
586
587 case HSG_OP_TYPE_BS_ACTIVE_PRED:
588 {
589 struct hsg_merge const * const m = merge + ops->a;
590
591 if (m->warps <= 32)
592 {
593 fprintf(target->state->source,
594 "if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n",
595 m->levels[ops->b].active.b32a2[0]);
596 }
597 else
598 {
599 fprintf(target->state->source,
600 "if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n",
601 m->levels[ops->b].active.b32a2[1],
602 m->levels[ops->b].active.b32a2[0]);
603 }
604 }
605 break;
606
607 default:
608 fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]);
609 exit(EXIT_FAILURE);
610 break;
611 }
612 }
613
614 //
615 //
616 //
617