1 //
2 // Copyright (c) 2021 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "subgroup_common_templates.h"
19 #include "harness/typeWrappers.h"
20 #include <bitset>
21
22 namespace {
23 // Test for ballot functions
24 template <typename Ty> struct BALLOT
25 {
gen__anon3aaeb11a0111::BALLOT26 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
27 {
28 // no work here
29 int gws = test_params.global_workgroup_size;
30 int lws = test_params.local_workgroup_size;
31 int sbs = test_params.subgroup_size;
32 int non_uniform_size = gws % lws;
33 log_info(" sub_group_ballot...\n");
34 if (non_uniform_size)
35 {
36 log_info(" non uniform work group size mode ON\n");
37 }
38 }
39
chk__anon3aaeb11a0111::BALLOT40 static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
41 const WorkGroupParams &test_params)
42 {
43 int wi_id, wg_id, sb_id;
44 int gws = test_params.global_workgroup_size;
45 int lws = test_params.local_workgroup_size;
46 int sbs = test_params.subgroup_size;
47 int sb_number = (lws + sbs - 1) / sbs;
48 int current_sbs = 0;
49 cl_uint expected_result, device_result;
50 int non_uniform_size = gws % lws;
51 int wg_number = gws / lws;
52 wg_number = non_uniform_size ? wg_number + 1 : wg_number;
53 int last_subgroup_size = 0;
54
55 for (wg_id = 0; wg_id < wg_number; ++wg_id)
56 { // for each work_group
57 if (non_uniform_size && wg_id == wg_number - 1)
58 {
59 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
60 last_subgroup_size);
61 }
62
63 for (wi_id = 0; wi_id < lws; ++wi_id)
64 { // inside the work_group
65 // read device outputs for work_group
66 my[wi_id] = y[wi_id];
67 }
68
69 for (sb_id = 0; sb_id < sb_number; ++sb_id)
70 { // for each subgroup
71 int wg_offset = sb_id * sbs;
72 if (last_subgroup_size && sb_id == sb_number - 1)
73 {
74 current_sbs = last_subgroup_size;
75 }
76 else
77 {
78 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
79 }
80 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
81 {
82 device_result = my[wg_offset + wi_id];
83 expected_result = 1;
84 if (!compare(device_result, expected_result))
85 {
86 log_error(
87 "ERROR: sub_group_ballot mismatch for local id "
88 "%d in sub group %d in group %d obtained {%d}, "
89 "expected {%d} \n",
90 wi_id, sb_id, wg_id, device_result,
91 expected_result);
92 return TEST_FAIL;
93 }
94 }
95 }
96 y += lws;
97 m += 4 * lws;
98 }
99 log_info(" sub_group_ballot... passed\n");
100 return TEST_PASS;
101 }
102 };
103
104 // Test for bit extract ballot functions
105 template <typename Ty, BallotOp operation> struct BALLOT_BIT_EXTRACT
106 {
gen__anon3aaeb11a0111::BALLOT_BIT_EXTRACT107 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
108 {
109 int wi_id, sb_id, wg_id, l;
110 int gws = test_params.global_workgroup_size;
111 int lws = test_params.local_workgroup_size;
112 int sbs = test_params.subgroup_size;
113 int sb_number = (lws + sbs - 1) / sbs;
114 int wg_number = gws / lws;
115 int limit_sbs = sbs > 100 ? 100 : sbs;
116 int non_uniform_size = gws % lws;
117 log_info(" sub_group_%s(%s)...\n", operation_names(operation),
118 TypeManager<Ty>::name());
119
120 if (non_uniform_size)
121 {
122 log_info(" non uniform work group size mode ON\n");
123 }
124
125 for (wg_id = 0; wg_id < wg_number; ++wg_id)
126 { // for each work_group
127 for (sb_id = 0; sb_id < sb_number; ++sb_id)
128 { // for each subgroup
129 int wg_offset = sb_id * sbs;
130 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
131 // rand index to bit extract
132 int index_for_odd = (int)(genrand_int32(gMTdata) & 0x7fffffff)
133 % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
134 int index_for_even = (int)(genrand_int32(gMTdata) & 0x7fffffff)
135 % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
136 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
137 {
138 // index of the third element int the vector.
139 int midx = 4 * wg_offset + 4 * wi_id + 2;
140 // storing information about index to bit extract
141 m[midx] = (cl_int)index_for_odd;
142 m[++midx] = (cl_int)index_for_even;
143 }
144 set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
145 }
146
147 // Now map into work group using map from device
148 for (wi_id = 0; wi_id < lws; ++wi_id)
149 {
150 x[wi_id] = t[wi_id];
151 }
152
153 x += lws;
154 m += 4 * lws;
155 }
156 }
157
chk__anon3aaeb11a0111::BALLOT_BIT_EXTRACT158 static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
159 const WorkGroupParams &test_params)
160 {
161 int wi_id, wg_id, l, sb_id;
162 int gws = test_params.global_workgroup_size;
163 int lws = test_params.local_workgroup_size;
164 int sbs = test_params.subgroup_size;
165 int sb_number = (lws + sbs - 1) / sbs;
166 int wg_number = gws / lws;
167 cl_uint4 expected_result, device_result;
168 int last_subgroup_size = 0;
169 int current_sbs = 0;
170 int non_uniform_size = gws % lws;
171
172 for (wg_id = 0; wg_id < wg_number; ++wg_id)
173 { // for each work_group
174 if (non_uniform_size && wg_id == wg_number - 1)
175 {
176 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
177 last_subgroup_size);
178 }
179 // Map to array indexed to array indexed by local ID and sub group
180 for (wi_id = 0; wi_id < lws; ++wi_id)
181 { // inside the work_group
182 // read host inputs for work_group
183 mx[wi_id] = x[wi_id];
184 // read device outputs for work_group
185 my[wi_id] = y[wi_id];
186 }
187
188 for (sb_id = 0; sb_id < sb_number; ++sb_id)
189 { // for each subgroup
190 int wg_offset = sb_id * sbs;
191 if (last_subgroup_size && sb_id == sb_number - 1)
192 {
193 current_sbs = last_subgroup_size;
194 }
195 else
196 {
197 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
198 }
199 // take index of array where info which work_item will
200 // be broadcast its value is stored
201 int midx = 4 * wg_offset + 2;
202 // take subgroup local id of this work_item
203 int index_for_odd = (int)m[midx];
204 int index_for_even = (int)m[++midx];
205
206 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
207 { // for each subgroup
208 int bit_value = 0;
209 // from which value of bitfield bit
210 // verification will be done
211 int take_shift =
212 (wi_id & 1) ? index_for_odd % 32 : index_for_even % 32;
213 int bit_mask = 1 << take_shift;
214
215 if (wi_id < 32)
216 (mx[wg_offset + wi_id].s0 & bit_mask) > 0
217 ? bit_value = 1
218 : bit_value = 0;
219 if (wi_id >= 32 && wi_id < 64)
220 (mx[wg_offset + wi_id].s1 & bit_mask) > 0
221 ? bit_value = 1
222 : bit_value = 0;
223 if (wi_id >= 64 && wi_id < 96)
224 (mx[wg_offset + wi_id].s2 & bit_mask) > 0
225 ? bit_value = 1
226 : bit_value = 0;
227 if (wi_id >= 96 && wi_id < 128)
228 (mx[wg_offset + wi_id].s3 & bit_mask) > 0
229 ? bit_value = 1
230 : bit_value = 0;
231
232 if (wi_id & 1)
233 {
234 bit_value ? expected_result = { 1, 0, 0, 1 }
235 : expected_result = { 0, 0, 0, 1 };
236 }
237 else
238 {
239 bit_value ? expected_result = { 1, 0, 0, 2 }
240 : expected_result = { 0, 0, 0, 2 };
241 }
242
243 device_result = my[wg_offset + wi_id];
244 if (!compare(device_result, expected_result))
245 {
246 log_error(
247 "ERROR: sub_group_%s mismatch for local id %d in "
248 "sub group %d in group %d obtained {%d, %d, %d, "
249 "%d}, expected {%d, %d, %d, %d}\n",
250 operation_names(operation), wi_id, sb_id, wg_id,
251 device_result.s0, device_result.s1,
252 device_result.s2, device_result.s3,
253 expected_result.s0, expected_result.s1,
254 expected_result.s2, expected_result.s3);
255 return TEST_FAIL;
256 }
257 }
258 }
259 x += lws;
260 y += lws;
261 m += 4 * lws;
262 }
263 log_info(" sub_group_%s(%s)... passed\n", operation_names(operation),
264 TypeManager<Ty>::name());
265 return TEST_PASS;
266 }
267 };
268
269 template <typename Ty, BallotOp operation> struct BALLOT_INVERSE
270 {
gen__anon3aaeb11a0111::BALLOT_INVERSE271 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
272 {
273 int gws = test_params.global_workgroup_size;
274 int lws = test_params.local_workgroup_size;
275 int sbs = test_params.subgroup_size;
276 int non_uniform_size = gws % lws;
277 log_info(" sub_group_inverse_ballot...\n");
278 if (non_uniform_size)
279 {
280 log_info(" non uniform work group size mode ON\n");
281 }
282 // no work here
283 }
284
chk__anon3aaeb11a0111::BALLOT_INVERSE285 static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
286 const WorkGroupParams &test_params)
287 {
288 int wi_id, wg_id, sb_id;
289 int gws = test_params.global_workgroup_size;
290 int lws = test_params.local_workgroup_size;
291 int sbs = test_params.subgroup_size;
292 int sb_number = (lws + sbs - 1) / sbs;
293 cl_uint4 expected_result, device_result;
294 int non_uniform_size = gws % lws;
295 int wg_number = gws / lws;
296 int last_subgroup_size = 0;
297 int current_sbs = 0;
298 if (non_uniform_size) wg_number++;
299
300 for (wg_id = 0; wg_id < wg_number; ++wg_id)
301 { // for each work_group
302 if (non_uniform_size && wg_id == wg_number - 1)
303 {
304 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
305 last_subgroup_size);
306 }
307 // Map to array indexed to array indexed by local ID and sub group
308 for (wi_id = 0; wi_id < lws; ++wi_id)
309 { // inside the work_group
310 mx[wi_id] = x[wi_id]; // read host inputs for work_group
311 my[wi_id] = y[wi_id]; // read device outputs for work_group
312 }
313
314 for (sb_id = 0; sb_id < sb_number; ++sb_id)
315 { // for each subgroup
316 int wg_offset = sb_id * sbs;
317 if (last_subgroup_size && sb_id == sb_number - 1)
318 {
319 current_sbs = last_subgroup_size;
320 }
321 else
322 {
323 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
324 }
325 // take index of array where info which work_item will
326 // be broadcast its value is stored
327 int midx = 4 * wg_offset + 2;
328 // take subgroup local id of this work_item
329 // Check result
330 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
331 { // for each subgroup work item
332
333 wi_id & 1 ? expected_result = { 1, 0, 0, 1 }
334 : expected_result = { 1, 0, 0, 2 };
335
336 device_result = my[wg_offset + wi_id];
337 if (!compare(device_result, expected_result))
338 {
339 log_error(
340 "ERROR: sub_group_%s mismatch for local id %d in "
341 "sub group %d in group %d obtained {%d, %d, %d, "
342 "%d}, expected {%d, %d, %d, %d}\n",
343 operation_names(operation), wi_id, sb_id, wg_id,
344 device_result.s0, device_result.s1,
345 device_result.s2, device_result.s3,
346 expected_result.s0, expected_result.s1,
347 expected_result.s2, expected_result.s3);
348 return TEST_FAIL;
349 }
350 }
351 }
352 x += lws;
353 y += lws;
354 m += 4 * lws;
355 }
356
357 log_info(" sub_group_inverse_ballot... passed\n");
358 return TEST_PASS;
359 }
360 };
361
362
363 // Test for bit count/inclusive and exclusive scan/ find lsb msb ballot function
364 template <typename Ty, BallotOp operation> struct BALLOT_COUNT_SCAN_FIND
365 {
gen__anon3aaeb11a0111::BALLOT_COUNT_SCAN_FIND366 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
367 {
368 int wi_id, wg_id, sb_id;
369 int gws = test_params.global_workgroup_size;
370 int lws = test_params.local_workgroup_size;
371 int sbs = test_params.subgroup_size;
372 int sb_number = (lws + sbs - 1) / sbs;
373 int non_uniform_size = gws % lws;
374 int wg_number = gws / lws;
375 int last_subgroup_size = 0;
376 int current_sbs = 0;
377
378 log_info(" sub_group_%s(%s)...\n", operation_names(operation),
379 TypeManager<Ty>::name());
380 if (non_uniform_size)
381 {
382 log_info(" non uniform work group size mode ON\n");
383 wg_number++;
384 }
385 int e;
386 for (wg_id = 0; wg_id < wg_number; ++wg_id)
387 { // for each work_group
388 if (non_uniform_size && wg_id == wg_number - 1)
389 {
390 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
391 last_subgroup_size);
392 }
393 for (sb_id = 0; sb_id < sb_number; ++sb_id)
394 { // for each subgroup
395 int wg_offset = sb_id * sbs;
396 if (last_subgroup_size && sb_id == sb_number - 1)
397 {
398 current_sbs = last_subgroup_size;
399 }
400 else
401 {
402 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
403 }
404 if (operation == BallotOp::ballot_bit_count
405 || operation == BallotOp::ballot_inclusive_scan
406 || operation == BallotOp::ballot_exclusive_scan)
407 {
408 set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
409 }
410 else if (operation == BallotOp::ballot_find_lsb
411 || operation == BallotOp::ballot_find_msb)
412 {
413 // Regarding to the spec, find lsb and find msb result is
414 // undefined behavior if input value is zero, so generate
415 // only non-zero values.
416 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
417 {
418 char x = (genrand_int32(gMTdata)) & 0xff;
419 // undefined behaviour in case of 0;
420 x = x ? x : 1;
421 memset(&t[wg_offset + wi_id], x, sizeof(Ty));
422 }
423 }
424 else
425 {
426 log_error("Unknown operation...");
427 }
428 }
429
430 // Now map into work group using map from device
431 for (wi_id = 0; wi_id < lws; ++wi_id)
432 {
433 x[wi_id] = t[wi_id];
434 }
435
436 x += lws;
437 m += 4 * lws;
438 }
439 }
440
getImportantBits__anon3aaeb11a0111::BALLOT_COUNT_SCAN_FIND441 static bs128 getImportantBits(cl_uint sub_group_local_id,
442 cl_uint sub_group_size)
443 {
444 bs128 mask;
445 if (operation == BallotOp::ballot_bit_count
446 || operation == BallotOp::ballot_find_lsb
447 || operation == BallotOp::ballot_find_msb)
448 {
449 for (cl_uint i = 0; i < sub_group_size; ++i) mask.set(i);
450 }
451 else if (operation == BallotOp::ballot_inclusive_scan
452 || operation == BallotOp::ballot_exclusive_scan)
453 {
454 for (cl_uint i = 0; i <= sub_group_local_id; ++i) mask.set(i);
455 if (operation == BallotOp::ballot_exclusive_scan)
456 mask.reset(sub_group_local_id);
457 }
458 return mask;
459 }
460
chk__anon3aaeb11a0111::BALLOT_COUNT_SCAN_FIND461 static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
462 const WorkGroupParams &test_params)
463 {
464 int wi_id, wg_id, sb_id;
465 int gws = test_params.global_workgroup_size;
466 int lws = test_params.local_workgroup_size;
467 int sbs = test_params.subgroup_size;
468 int sb_number = (lws + sbs - 1) / sbs;
469 int non_uniform_size = gws % lws;
470 int wg_number = gws / lws;
471 wg_number = non_uniform_size ? wg_number + 1 : wg_number;
472 cl_uint4 expected_result, device_result;
473 int last_subgroup_size = 0;
474 int current_sbs = 0;
475
476 for (wg_id = 0; wg_id < wg_number; ++wg_id)
477 { // for each work_group
478 if (non_uniform_size && wg_id == wg_number - 1)
479 {
480 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
481 last_subgroup_size);
482 }
483 // Map to array indexed to array indexed by local ID and sub group
484 for (wi_id = 0; wi_id < lws; ++wi_id)
485 { // inside the work_group
486 // read host inputs for work_group
487 mx[wi_id] = x[wi_id];
488 // read device outputs for work_group
489 my[wi_id] = y[wi_id];
490 }
491
492 for (sb_id = 0; sb_id < sb_number; ++sb_id)
493 { // for each subgroup
494 int wg_offset = sb_id * sbs;
495 if (last_subgroup_size && sb_id == sb_number - 1)
496 {
497 current_sbs = last_subgroup_size;
498 }
499 else
500 {
501 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
502 }
503 // Check result
504 expected_result = { 0, 0, 0, 0 };
505 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
506 { // for subgroup element
507 bs128 bs;
508 // convert cl_uint4 input into std::bitset<128>
509 bs |= bs128(mx[wg_offset + wi_id].s0)
510 | (bs128(mx[wg_offset + wi_id].s1) << 32)
511 | (bs128(mx[wg_offset + wi_id].s2) << 64)
512 | (bs128(mx[wg_offset + wi_id].s3) << 96);
513 bs &= getImportantBits(wi_id, current_sbs);
514 device_result = my[wg_offset + wi_id];
515 if (operation == BallotOp::ballot_inclusive_scan
516 || operation == BallotOp::ballot_exclusive_scan
517 || operation == BallotOp::ballot_bit_count)
518 {
519 expected_result.s0 = bs.count();
520 if (!compare(device_result, expected_result))
521 {
522 log_error("ERROR: sub_group_%s "
523 "mismatch for local id %d in sub group "
524 "%d in group %d obtained {%d, %d, %d, "
525 "%d}, expected {%d, %d, %d, %d}\n",
526 operation_names(operation), wi_id, sb_id,
527 wg_id, device_result.s0, device_result.s1,
528 device_result.s2, device_result.s3,
529 expected_result.s0, expected_result.s1,
530 expected_result.s2, expected_result.s3);
531 return TEST_FAIL;
532 }
533 }
534 else if (operation == BallotOp::ballot_find_lsb)
535 {
536 for (int id = 0; id < current_sbs; ++id)
537 {
538 if (bs.test(id))
539 {
540 expected_result.s0 = id;
541 break;
542 }
543 }
544 if (!compare(device_result, expected_result))
545 {
546 log_error("ERROR: sub_group_ballot_find_lsb "
547 "mismatch for local id %d in sub group "
548 "%d in group %d obtained {%d, %d, %d, "
549 "%d}, expected {%d, %d, %d, %d}\n",
550 wi_id, sb_id, wg_id, device_result.s0,
551 device_result.s1, device_result.s2,
552 device_result.s3, expected_result.s0,
553 expected_result.s1, expected_result.s2,
554 expected_result.s3);
555 return TEST_FAIL;
556 }
557 }
558 else if (operation == BallotOp::ballot_find_msb)
559 {
560 for (int id = current_sbs - 1; id >= 0; --id)
561 {
562 if (bs.test(id))
563 {
564 expected_result.s0 = id;
565 break;
566 }
567 }
568 if (!compare(device_result, expected_result))
569 {
570 log_error("ERROR: sub_group_ballot_find_msb "
571 "mismatch for local id %d in sub group "
572 "%d in group %d obtained {%d, %d, %d, "
573 "%d}, expected {%d, %d, %d, %d}\n",
574 wi_id, sb_id, wg_id, device_result.s0,
575 device_result.s1, device_result.s2,
576 device_result.s3, expected_result.s0,
577 expected_result.s1, expected_result.s2,
578 expected_result.s3);
579 return TEST_FAIL;
580 }
581 }
582 }
583 }
584 x += lws;
585 y += lws;
586 m += 4 * lws;
587 }
588 log_info(" sub_group_ballot_%s(%s)... passed\n",
589 operation_names(operation), TypeManager<Ty>::name());
590 return TEST_PASS;
591 }
592 };
593
594 // test mask functions
595 template <typename Ty, BallotOp operation> struct SMASK
596 {
gen__anon3aaeb11a0111::SMASK597 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
598 {
599 int wi_id, wg_id, l, sb_id;
600 int gws = test_params.global_workgroup_size;
601 int lws = test_params.local_workgroup_size;
602 int sbs = test_params.subgroup_size;
603 int sb_number = (lws + sbs - 1) / sbs;
604 int wg_number = gws / lws;
605 log_info(" get_sub_group_%s_mask...\n", operation_names(operation));
606 for (wg_id = 0; wg_id < wg_number; ++wg_id)
607 { // for each work_group
608 for (sb_id = 0; sb_id < sb_number; ++sb_id)
609 { // for each subgroup
610 int wg_offset = sb_id * sbs;
611 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
612 // Produce expected masks for each work item in the subgroup
613 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
614 {
615 int midx = 4 * wg_offset + 4 * wi_id;
616 cl_uint max_sub_group_size = m[midx + 2];
617 cl_uint4 expected_mask = { 0 };
618 expected_mask = generate_bit_mask(
619 wi_id, operation_names(operation), max_sub_group_size);
620 set_value(t[wg_offset + wi_id], expected_mask);
621 }
622 }
623
624 // Now map into work group using map from device
625 for (wi_id = 0; wi_id < lws; ++wi_id)
626 {
627 x[wi_id] = t[wi_id];
628 }
629 x += lws;
630 m += 4 * lws;
631 }
632 }
633
chk__anon3aaeb11a0111::SMASK634 static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
635 const WorkGroupParams &test_params)
636 {
637 int wi_id, wg_id, sb_id;
638 int gws = test_params.global_workgroup_size;
639 int lws = test_params.local_workgroup_size;
640 int sbs = test_params.subgroup_size;
641 int sb_number = (lws + sbs - 1) / sbs;
642 Ty expected_result, device_result;
643 int wg_number = gws / lws;
644
645 for (wg_id = 0; wg_id < wg_number; ++wg_id)
646 { // for each work_group
647 for (wi_id = 0; wi_id < lws; ++wi_id)
648 { // inside the work_group
649 mx[wi_id] = x[wi_id]; // read host inputs for work_group
650 my[wi_id] = y[wi_id]; // read device outputs for work_group
651 }
652
653 for (sb_id = 0; sb_id < sb_number; ++sb_id)
654 {
655 int wg_offset = sb_id * sbs;
656 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
657
658 // Check result
659 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
660 { // inside the subgroup
661 expected_result =
662 mx[wg_offset + wi_id]; // read host input for subgroup
663 device_result =
664 my[wg_offset
665 + wi_id]; // read device outputs for subgroup
666 if (!compare(device_result, expected_result))
667 {
668 log_error("ERROR: get_sub_group_%s_mask... mismatch "
669 "for local id %d in sub group %d in group "
670 "%d, obtained %d, expected %d\n",
671 operation_names(operation), wi_id, sb_id,
672 wg_id, device_result, expected_result);
673 return TEST_FAIL;
674 }
675 }
676 }
677 x += lws;
678 y += lws;
679 m += 4 * lws;
680 }
681 log_info(" get_sub_group_%s_mask... passed\n",
682 operation_names(operation));
683 return TEST_PASS;
684 }
685 };
686
687 static const char *bcast_non_uniform_source =
688 "__kernel void test_bcast_non_uniform(const __global Type *in, __global "
689 "int4 *xy, __global Type *out)\n"
690 "{\n"
691 " int gid = get_global_id(0);\n"
692 " XY(xy,gid);\n"
693 " Type x = in[gid];\n"
694 " if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n"
695 " out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);\n"
696 " } else {\n"
697 " out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);\n"
698 " }\n"
699 "}\n";
700
701 static const char *bcast_first_source =
702 "__kernel void test_bcast_first(const __global Type *in, __global int4 "
703 "*xy, __global Type *out)\n"
704 "{\n"
705 " int gid = get_global_id(0);\n"
706 " XY(xy,gid);\n"
707 " Type x = in[gid];\n"
708 " if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {\n"
709 " out[gid] = sub_group_broadcast_first(x);\n"
710 " } else {\n"
711 " out[gid] = sub_group_broadcast_first(x);\n"
712 " }\n"
713 "}\n";
714
715 static const char *ballot_bit_count_source =
716 "__kernel void test_sub_group_ballot_bit_count(const __global Type *in, "
717 "__global int4 *xy, __global Type *out)\n"
718 "{\n"
719 " int gid = get_global_id(0);\n"
720 " XY(xy,gid);\n"
721 " Type x = in[gid];\n"
722 " uint4 value = (uint4)(0,0,0,0);\n"
723 " value = (uint4)(sub_group_ballot_bit_count(x),0,0,0);\n"
724 " out[gid] = value;\n"
725 "}\n";
726
727 static const char *ballot_inclusive_scan_source =
728 "__kernel void test_sub_group_ballot_inclusive_scan(const __global Type "
729 "*in, __global int4 *xy, __global Type *out)\n"
730 "{\n"
731 " int gid = get_global_id(0);\n"
732 " XY(xy,gid);\n"
733 " Type x = in[gid];\n"
734 " uint4 value = (uint4)(0,0,0,0);\n"
735 " value = (uint4)(sub_group_ballot_inclusive_scan(x),0,0,0);\n"
736 " out[gid] = value;\n"
737 "}\n";
738
739 static const char *ballot_exclusive_scan_source =
740 "__kernel void test_sub_group_ballot_exclusive_scan(const __global Type "
741 "*in, __global int4 *xy, __global Type *out)\n"
742 "{\n"
743 " int gid = get_global_id(0);\n"
744 " XY(xy,gid);\n"
745 " Type x = in[gid];\n"
746 " uint4 value = (uint4)(0,0,0,0);\n"
747 " value = (uint4)(sub_group_ballot_exclusive_scan(x),0,0,0);\n"
748 " out[gid] = value;\n"
749 "}\n";
750
751 static const char *ballot_find_lsb_source =
752 "__kernel void test_sub_group_ballot_find_lsb(const __global Type *in, "
753 "__global int4 *xy, __global Type *out)\n"
754 "{\n"
755 " int gid = get_global_id(0);\n"
756 " XY(xy,gid);\n"
757 " Type x = in[gid];\n"
758 " uint4 value = (uint4)(0,0,0,0);\n"
759 " value = (uint4)(sub_group_ballot_find_lsb(x),0,0,0);\n"
760 " out[gid] = value;\n"
761 "}\n";
762
763 static const char *ballot_find_msb_source =
764 "__kernel void test_sub_group_ballot_find_msb(const __global Type *in, "
765 "__global int4 *xy, __global Type *out)\n"
766 "{\n"
767 " int gid = get_global_id(0);\n"
768 " XY(xy,gid);\n"
769 " Type x = in[gid];\n"
770 " uint4 value = (uint4)(0,0,0,0);"
771 " value = (uint4)(sub_group_ballot_find_msb(x),0,0,0);"
772 " out[gid] = value ;"
773 "}\n";
774
775 static const char *get_subgroup_ge_mask_source =
776 "__kernel void test_get_sub_group_ge_mask(const __global Type *in, "
777 "__global int4 *xy, __global Type *out)\n"
778 "{\n"
779 " int gid = get_global_id(0);\n"
780 " XY(xy,gid);\n"
781 " xy[gid].z = get_max_sub_group_size();\n"
782 " Type x = in[gid];\n"
783 " uint4 mask = get_sub_group_ge_mask();"
784 " out[gid] = mask;\n"
785 "}\n";
786
787 static const char *get_subgroup_gt_mask_source =
788 "__kernel void test_get_sub_group_gt_mask(const __global Type *in, "
789 "__global int4 *xy, __global Type *out)\n"
790 "{\n"
791 " int gid = get_global_id(0);\n"
792 " XY(xy,gid);\n"
793 " xy[gid].z = get_max_sub_group_size();\n"
794 " Type x = in[gid];\n"
795 " uint4 mask = get_sub_group_gt_mask();"
796 " out[gid] = mask;\n"
797 "}\n";
798
799 static const char *get_subgroup_le_mask_source =
800 "__kernel void test_get_sub_group_le_mask(const __global Type *in, "
801 "__global int4 *xy, __global Type *out)\n"
802 "{\n"
803 " int gid = get_global_id(0);\n"
804 " XY(xy,gid);\n"
805 " xy[gid].z = get_max_sub_group_size();\n"
806 " Type x = in[gid];\n"
807 " uint4 mask = get_sub_group_le_mask();"
808 " out[gid] = mask;\n"
809 "}\n";
810
811 static const char *get_subgroup_lt_mask_source =
812 "__kernel void test_get_sub_group_lt_mask(const __global Type *in, "
813 "__global int4 *xy, __global Type *out)\n"
814 "{\n"
815 " int gid = get_global_id(0);\n"
816 " XY(xy,gid);\n"
817 " xy[gid].z = get_max_sub_group_size();\n"
818 " Type x = in[gid];\n"
819 " uint4 mask = get_sub_group_lt_mask();"
820 " out[gid] = mask;\n"
821 "}\n";
822
823 static const char *get_subgroup_eq_mask_source =
824 "__kernel void test_get_sub_group_eq_mask(const __global Type *in, "
825 "__global int4 *xy, __global Type *out)\n"
826 "{\n"
827 " int gid = get_global_id(0);\n"
828 " XY(xy,gid);\n"
829 " xy[gid].z = get_max_sub_group_size();\n"
830 " Type x = in[gid];\n"
831 " uint4 mask = get_sub_group_eq_mask();"
832 " out[gid] = mask;\n"
833 "}\n";
834
835 static const char *ballot_source =
836 "__kernel void test_sub_group_ballot(const __global Type *in, "
837 "__global int4 *xy, __global Type *out)\n"
838 "{\n"
839 "uint4 full_ballot = sub_group_ballot(1);\n"
840 "uint divergence_mask;\n"
841 "uint4 partial_ballot;\n"
842 "uint gid = get_global_id(0);"
843 "XY(xy,gid);\n"
844 "if (get_sub_group_local_id() & 1) {\n"
845 " divergence_mask = 0xaaaaaaaa;\n"
846 " partial_ballot = sub_group_ballot(1);\n"
847 "} else {\n"
848 " divergence_mask = 0x55555555;\n"
849 " partial_ballot = sub_group_ballot(1);\n"
850 "}\n"
851 " size_t lws = get_local_size(0);\n"
852 "uint4 masked_ballot = full_ballot;\n"
853 "masked_ballot.x &= divergence_mask;\n"
854 "masked_ballot.y &= divergence_mask;\n"
855 "masked_ballot.z &= divergence_mask;\n"
856 "masked_ballot.w &= divergence_mask;\n"
857 "out[gid] = all(masked_ballot == partial_ballot);\n"
858
859 "} \n";
860
861 static const char *ballot_source_inverse =
862 "__kernel void test_sub_group_ballot_inverse(const __global "
863 "Type *in, "
864 "__global int4 *xy, __global Type *out)\n"
865 "{\n"
866 " int gid = get_global_id(0);\n"
867 " XY(xy,gid);\n"
868 " Type x = in[gid];\n"
869 " uint4 value = (uint4)(10,0,0,0);\n"
870 " if (get_sub_group_local_id() & 1) {"
871 " uint4 partial_ballot_mask = "
872 "(uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);"
873 " if (sub_group_inverse_ballot(partial_ballot_mask)) {\n"
874 " value = (uint4)(1,0,0,1);\n"
875 " } else {\n"
876 " value = (uint4)(0,0,0,1);\n"
877 " }\n"
878 " } else {\n"
879 " uint4 partial_ballot_mask = "
880 "(uint4)(0x55555555,0x55555555,0x55555555,0x55555555);"
881 " if (sub_group_inverse_ballot(partial_ballot_mask)) {\n"
882 " value = (uint4)(1,0,0,2);\n"
883 " } else {\n"
884 " value = (uint4)(0,0,0,2);\n"
885 " }\n"
886 " }\n"
887 " out[gid] = value;\n"
888 "}\n";
889
890 static const char *ballot_bit_extract_source =
891 "__kernel void test_sub_group_ballot_bit_extract(const __global Type *in, "
892 "__global int4 *xy, __global Type *out)\n"
893 "{\n"
894 " int gid = get_global_id(0);\n"
895 " XY(xy,gid);\n"
896 " Type x = in[gid];\n"
897 " uint index = xy[gid].z;\n"
898 " uint4 value = (uint4)(10,0,0,0);\n"
899 " if (get_sub_group_local_id() & 1) {"
900 " if (sub_group_ballot_bit_extract(x, xy[gid].z)) {\n"
901 " value = (uint4)(1,0,0,1);\n"
902 " } else {\n"
903 " value = (uint4)(0,0,0,1);\n"
904 " }\n"
905 " } else {\n"
906 " if (sub_group_ballot_bit_extract(x, xy[gid].w)) {\n"
907 " value = (uint4)(1,0,0,2);\n"
908 " } else {\n"
909 " value = (uint4)(0,0,0,2);\n"
910 " }\n"
911 " }\n"
912 " out[gid] = value;\n"
913 "}\n";
914
run_non_uniform_broadcast_for_type(RunTestForType rft)915 template <typename T> int run_non_uniform_broadcast_for_type(RunTestForType rft)
916 {
917 int error =
918 rft.run_impl<T, BC<T, SubgroupsBroadcastOp::non_uniform_broadcast>>(
919 "test_bcast_non_uniform", bcast_non_uniform_source);
920 return error;
921 }
922
923
924 }
925
test_subgroup_functions_ballot(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)926 int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
927 cl_command_queue queue, int num_elements)
928 {
929 std::vector<std::string> required_extensions = { "cl_khr_subgroup_ballot" };
930 constexpr size_t global_work_size = 170;
931 constexpr size_t local_work_size = 64;
932 WorkGroupParams test_params(global_work_size, local_work_size,
933 required_extensions);
934 RunTestForType rft(device, context, queue, num_elements, test_params);
935
936 // non uniform broadcast functions
937 int error = run_non_uniform_broadcast_for_type<cl_int>(rft);
938 error |= run_non_uniform_broadcast_for_type<cl_int2>(rft);
939 error |= run_non_uniform_broadcast_for_type<subgroups::cl_int3>(rft);
940 error |= run_non_uniform_broadcast_for_type<cl_int4>(rft);
941 error |= run_non_uniform_broadcast_for_type<cl_int8>(rft);
942 error |= run_non_uniform_broadcast_for_type<cl_int16>(rft);
943
944 error |= run_non_uniform_broadcast_for_type<cl_uint>(rft);
945 error |= run_non_uniform_broadcast_for_type<cl_uint2>(rft);
946 error |= run_non_uniform_broadcast_for_type<subgroups::cl_uint3>(rft);
947 error |= run_non_uniform_broadcast_for_type<cl_uint4>(rft);
948 error |= run_non_uniform_broadcast_for_type<cl_uint8>(rft);
949 error |= run_non_uniform_broadcast_for_type<cl_uint16>(rft);
950
951 error |= run_non_uniform_broadcast_for_type<cl_char>(rft);
952 error |= run_non_uniform_broadcast_for_type<cl_char2>(rft);
953 error |= run_non_uniform_broadcast_for_type<subgroups::cl_char3>(rft);
954 error |= run_non_uniform_broadcast_for_type<cl_char4>(rft);
955 error |= run_non_uniform_broadcast_for_type<cl_char8>(rft);
956 error |= run_non_uniform_broadcast_for_type<cl_char16>(rft);
957
958 error |= run_non_uniform_broadcast_for_type<cl_uchar>(rft);
959 error |= run_non_uniform_broadcast_for_type<cl_uchar2>(rft);
960 error |= run_non_uniform_broadcast_for_type<subgroups::cl_uchar3>(rft);
961 error |= run_non_uniform_broadcast_for_type<cl_uchar4>(rft);
962 error |= run_non_uniform_broadcast_for_type<cl_uchar8>(rft);
963 error |= run_non_uniform_broadcast_for_type<cl_uchar16>(rft);
964
965 error |= run_non_uniform_broadcast_for_type<cl_short>(rft);
966 error |= run_non_uniform_broadcast_for_type<cl_short2>(rft);
967 error |= run_non_uniform_broadcast_for_type<subgroups::cl_short3>(rft);
968 error |= run_non_uniform_broadcast_for_type<cl_short4>(rft);
969 error |= run_non_uniform_broadcast_for_type<cl_short8>(rft);
970 error |= run_non_uniform_broadcast_for_type<cl_short16>(rft);
971
972 error |= run_non_uniform_broadcast_for_type<cl_ushort>(rft);
973 error |= run_non_uniform_broadcast_for_type<cl_ushort2>(rft);
974 error |= run_non_uniform_broadcast_for_type<subgroups::cl_ushort3>(rft);
975 error |= run_non_uniform_broadcast_for_type<cl_ushort4>(rft);
976 error |= run_non_uniform_broadcast_for_type<cl_ushort8>(rft);
977 error |= run_non_uniform_broadcast_for_type<cl_ushort16>(rft);
978
979 error |= run_non_uniform_broadcast_for_type<cl_long>(rft);
980 error |= run_non_uniform_broadcast_for_type<cl_long2>(rft);
981 error |= run_non_uniform_broadcast_for_type<subgroups::cl_long3>(rft);
982 error |= run_non_uniform_broadcast_for_type<cl_long4>(rft);
983 error |= run_non_uniform_broadcast_for_type<cl_long8>(rft);
984 error |= run_non_uniform_broadcast_for_type<cl_long16>(rft);
985
986 error |= run_non_uniform_broadcast_for_type<cl_ulong>(rft);
987 error |= run_non_uniform_broadcast_for_type<cl_ulong2>(rft);
988 error |= run_non_uniform_broadcast_for_type<subgroups::cl_ulong3>(rft);
989 error |= run_non_uniform_broadcast_for_type<cl_ulong4>(rft);
990 error |= run_non_uniform_broadcast_for_type<cl_ulong8>(rft);
991 error |= run_non_uniform_broadcast_for_type<cl_ulong16>(rft);
992
993 error |= run_non_uniform_broadcast_for_type<cl_float>(rft);
994 error |= run_non_uniform_broadcast_for_type<cl_float2>(rft);
995 error |= run_non_uniform_broadcast_for_type<subgroups::cl_float3>(rft);
996 error |= run_non_uniform_broadcast_for_type<cl_float4>(rft);
997 error |= run_non_uniform_broadcast_for_type<cl_float8>(rft);
998 error |= run_non_uniform_broadcast_for_type<cl_float16>(rft);
999
1000 error |= run_non_uniform_broadcast_for_type<cl_double>(rft);
1001 error |= run_non_uniform_broadcast_for_type<cl_double2>(rft);
1002 error |= run_non_uniform_broadcast_for_type<subgroups::cl_double3>(rft);
1003 error |= run_non_uniform_broadcast_for_type<cl_double4>(rft);
1004 error |= run_non_uniform_broadcast_for_type<cl_double8>(rft);
1005 error |= run_non_uniform_broadcast_for_type<cl_double16>(rft);
1006
1007 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half>(rft);
1008 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half2>(rft);
1009 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half3>(rft);
1010 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half4>(rft);
1011 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half8>(rft);
1012 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half16>(rft);
1013
1014 // broadcast first functions
1015 error |=
1016 rft.run_impl<cl_int, BC<cl_int, SubgroupsBroadcastOp::broadcast_first>>(
1017 "test_bcast_first", bcast_first_source);
1018 error |= rft.run_impl<cl_uint,
1019 BC<cl_uint, SubgroupsBroadcastOp::broadcast_first>>(
1020 "test_bcast_first", bcast_first_source);
1021 error |= rft.run_impl<cl_long,
1022 BC<cl_long, SubgroupsBroadcastOp::broadcast_first>>(
1023 "test_bcast_first", bcast_first_source);
1024 error |= rft.run_impl<cl_ulong,
1025 BC<cl_ulong, SubgroupsBroadcastOp::broadcast_first>>(
1026 "test_bcast_first", bcast_first_source);
1027 error |= rft.run_impl<cl_short,
1028 BC<cl_short, SubgroupsBroadcastOp::broadcast_first>>(
1029 "test_bcast_first", bcast_first_source);
1030 error |= rft.run_impl<cl_ushort,
1031 BC<cl_ushort, SubgroupsBroadcastOp::broadcast_first>>(
1032 "test_bcast_first", bcast_first_source);
1033 error |= rft.run_impl<cl_char,
1034 BC<cl_char, SubgroupsBroadcastOp::broadcast_first>>(
1035 "test_bcast_first", bcast_first_source);
1036 error |= rft.run_impl<cl_uchar,
1037 BC<cl_uchar, SubgroupsBroadcastOp::broadcast_first>>(
1038 "test_bcast_first", bcast_first_source);
1039 error |= rft.run_impl<cl_float,
1040 BC<cl_float, SubgroupsBroadcastOp::broadcast_first>>(
1041 "test_bcast_first", bcast_first_source);
1042 error |= rft.run_impl<cl_double,
1043 BC<cl_double, SubgroupsBroadcastOp::broadcast_first>>(
1044 "test_bcast_first", bcast_first_source);
1045 error |= rft.run_impl<
1046 subgroups::cl_half,
1047 BC<subgroups::cl_half, SubgroupsBroadcastOp::broadcast_first>>(
1048 "test_bcast_first", bcast_first_source);
1049
1050 // mask functions
1051 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::eq_mask>>(
1052 "test_get_sub_group_eq_mask", get_subgroup_eq_mask_source);
1053 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::ge_mask>>(
1054 "test_get_sub_group_ge_mask", get_subgroup_ge_mask_source);
1055 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::gt_mask>>(
1056 "test_get_sub_group_gt_mask", get_subgroup_gt_mask_source);
1057 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::le_mask>>(
1058 "test_get_sub_group_le_mask", get_subgroup_le_mask_source);
1059 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>(
1060 "test_get_sub_group_lt_mask", get_subgroup_lt_mask_source);
1061
1062 // ballot functions
1063 error |= rft.run_impl<cl_uint, BALLOT<cl_uint>>("test_sub_group_ballot",
1064 ballot_source);
1065 error |= rft.run_impl<cl_uint4,
1066 BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
1067 "test_sub_group_ballot_inverse", ballot_source_inverse);
1068 error |= rft.run_impl<
1069 cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>(
1070 "test_sub_group_ballot_bit_extract", ballot_bit_extract_source);
1071 error |= rft.run_impl<
1072 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>(
1073 "test_sub_group_ballot_bit_count", ballot_bit_count_source);
1074 error |= rft.run_impl<
1075 cl_uint4,
1076 BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>(
1077 "test_sub_group_ballot_inclusive_scan", ballot_inclusive_scan_source);
1078 error |= rft.run_impl<
1079 cl_uint4,
1080 BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>(
1081 "test_sub_group_ballot_exclusive_scan", ballot_exclusive_scan_source);
1082 error |= rft.run_impl<
1083 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>(
1084 "test_sub_group_ballot_find_lsb", ballot_find_lsb_source);
1085 error |= rft.run_impl<
1086 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>(
1087 "test_sub_group_ballot_find_msb", ballot_find_msb_source);
1088 return error;
1089 }
1090