1 //
2 // Copyright (c) 2020 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBGROUPCOMMONTEMPLATES_H
17 #define SUBGROUPCOMMONTEMPLATES_H
18
19 #include "typeWrappers.h"
20 #include "CL/cl_half.h"
21 #include "subhelpers.h"
22 #include <set>
23 #include <algorithm>
24 #include <random>
25
generate_bit_mask(cl_uint subgroup_local_id,const std::string & mask_type,cl_uint max_sub_group_size)26 static cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
27 const std::string &mask_type,
28 cl_uint max_sub_group_size)
29 {
30 bs128 mask128;
31 cl_uint4 mask;
32 cl_uint pos = subgroup_local_id;
33 if (mask_type == "eq") mask128.set(pos);
34 if (mask_type == "le" || mask_type == "lt")
35 {
36 for (cl_uint i = 0; i <= pos; i++) mask128.set(i);
37 if (mask_type == "lt") mask128.reset(pos);
38 }
39 if (mask_type == "ge" || mask_type == "gt")
40 {
41 for (cl_uint i = pos; i < max_sub_group_size; i++) mask128.set(i);
42 if (mask_type == "gt") mask128.reset(pos);
43 }
44
45 // convert std::bitset<128> to uint4
46 auto const uint_mask = bs128{ static_cast<unsigned long>(-1) };
47 mask.s0 = (mask128 & uint_mask).to_ulong();
48 mask128 >>= 32;
49 mask.s1 = (mask128 & uint_mask).to_ulong();
50 mask128 >>= 32;
51 mask.s2 = (mask128 & uint_mask).to_ulong();
52 mask128 >>= 32;
53 mask.s3 = (mask128 & uint_mask).to_ulong();
54
55 return mask;
56 }
57
58 // DESCRIPTION :
59 // sub_group_broadcast - each work_item registers it's own value.
60 // All work_items in subgroup takes one value from only one (any) work_item
61 // sub_group_broadcast_first - same as type 0. All work_items in
62 // subgroup takes only one value from only one chosen (the smallest subgroup ID)
63 // work_item
64 // sub_group_non_uniform_broadcast - same as type 0 but
65 // only 4 work_items from subgroup enter the code (are active)
66 template <typename Ty, SubgroupsBroadcastOp operation> struct BC
67 {
log_testBC68 static void log_test(const WorkGroupParams &test_params,
69 const char *extra_text)
70 {
71 log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
72 TypeManager<Ty>::name(), extra_text);
73 }
74
genBC75 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
76 {
77 int i, ii, j, k, n;
78 int ng = test_params.global_workgroup_size;
79 int nw = test_params.local_workgroup_size;
80 int ns = test_params.subgroup_size;
81 int nj = (nw + ns - 1) / ns;
82 int d = ns > 100 ? 100 : ns;
83 int non_uniform_size = ng % nw;
84 ng = ng / nw;
85 int last_subgroup_size = 0;
86 ii = 0;
87
88 if (non_uniform_size)
89 {
90 ng++;
91 }
92 for (k = 0; k < ng; ++k)
93 { // for each work_group
94 if (non_uniform_size && k == ng - 1)
95 {
96 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
97 last_subgroup_size);
98 }
99 for (j = 0; j < nj; ++j)
100 { // for each subgroup
101 ii = j * ns;
102 if (last_subgroup_size && j == nj - 1)
103 {
104 n = last_subgroup_size;
105 }
106 else
107 {
108 n = ii + ns > nw ? nw - ii : ns;
109 }
110 int bcast_if = 0;
111 int bcast_elseif = 0;
112 int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
113 % (d > n ? n : d);
114 // l - calculate subgroup local id from which value will be
115 // broadcasted (one the same value for whole subgroup)
116 if (operation != SubgroupsBroadcastOp::broadcast)
117 {
118 // reduce brodcasting index in case of non_uniform and
119 // last workgroup last subgroup
120 if (last_subgroup_size && j == nj - 1
121 && last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
122 {
123 bcast_if = bcast_index % last_subgroup_size;
124 bcast_elseif = bcast_if;
125 }
126 else
127 {
128 bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
129 bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
130 + bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
131 }
132 }
133
134 for (i = 0; i < n; ++i)
135 {
136 if (operation == SubgroupsBroadcastOp::broadcast)
137 {
138 int midx = 4 * ii + 4 * i + 2;
139 m[midx] = (cl_int)bcast_index;
140 }
141 else
142 {
143 if (i < NR_OF_ACTIVE_WORK_ITEMS)
144 {
145 // index of the third
146 // element int the vector.
147 int midx = 4 * ii + 4 * i + 2;
148 // storing information about
149 // broadcasting index -
150 // earlier calculated
151 m[midx] = (cl_int)bcast_if;
152 }
153 else
154 { // index of the third
155 // element int the vector.
156 int midx = 4 * ii + 4 * i + 3;
157 m[midx] = (cl_int)bcast_elseif;
158 }
159 }
160
161 // calculate value for broadcasting
162 cl_ulong number = genrand_int64(gMTdata);
163 set_value(t[ii + i], number);
164 }
165 }
166 // Now map into work group using map from device
167 for (j = 0; j < nw; ++j)
168 { // for each element in work_group
169 // calculate index as number of subgroup
170 // plus subgroup local id
171 x[j] = t[j];
172 }
173 x += nw;
174 m += 4 * nw;
175 }
176 }
177
chkBC178 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
179 const WorkGroupParams &test_params)
180 {
181 int ii, i, j, k, l, n;
182 int ng = test_params.global_workgroup_size;
183 int nw = test_params.local_workgroup_size;
184 int ns = test_params.subgroup_size;
185 int nj = (nw + ns - 1) / ns;
186 Ty tr, rr;
187 int non_uniform_size = ng % nw;
188 ng = ng / nw;
189 int last_subgroup_size = 0;
190 if (non_uniform_size) ng++;
191
192 for (k = 0; k < ng; ++k)
193 { // for each work_group
194 if (non_uniform_size && k == ng - 1)
195 {
196 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
197 last_subgroup_size);
198 }
199 for (j = 0; j < nw; ++j)
200 { // inside the work_group
201 mx[j] = x[j]; // read host inputs for work_group
202 my[j] = y[j]; // read device outputs for work_group
203 }
204
205 for (j = 0; j < nj; ++j)
206 { // for each subgroup
207 ii = j * ns;
208 if (last_subgroup_size && j == nj - 1)
209 {
210 n = last_subgroup_size;
211 }
212 else
213 {
214 n = ii + ns > nw ? nw - ii : ns;
215 }
216
217 // Check result
218 if (operation == SubgroupsBroadcastOp::broadcast_first)
219 {
220 int lowest_active_id = -1;
221 for (i = 0; i < n; ++i)
222 {
223
224 lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
225 ? 0
226 : NR_OF_ACTIVE_WORK_ITEMS;
227 // findout if broadcasted
228 // value is the same
229 tr = mx[ii + lowest_active_id];
230 // findout if broadcasted to all
231 rr = my[ii + i];
232
233 if (!compare(rr, tr))
234 {
235 log_error(
236 "ERROR: sub_group_broadcast_first(%s) "
237 "mismatch "
238 "for local id %d in sub group %d in group "
239 "%d\n",
240 TypeManager<Ty>::name(), i, j, k);
241 return TEST_FAIL;
242 }
243 }
244 }
245 else
246 {
247 for (i = 0; i < n; ++i)
248 {
249 if (operation == SubgroupsBroadcastOp::broadcast)
250 {
251 int midx = 4 * ii + 4 * i + 2;
252 l = (int)m[midx];
253 tr = mx[ii + l];
254 }
255 else
256 {
257 if (i < NR_OF_ACTIVE_WORK_ITEMS)
258 { // take index of array where info
259 // which work_item will be
260 // broadcast its value is stored
261 int midx = 4 * ii + 4 * i + 2;
262 // take subgroup local id of
263 // this work_item
264 l = (int)m[midx];
265 // take value generated on host
266 // for this work_item
267 tr = mx[ii + l];
268 }
269 else
270 {
271 int midx = 4 * ii + 4 * i + 3;
272 l = (int)m[midx];
273 tr = mx[ii + l];
274 }
275 }
276 rr = my[ii + i]; // read device outputs for
277 // work_item in the subgroup
278
279 if (!compare(rr, tr))
280 {
281 log_error("ERROR: sub_group_%s(%s) "
282 "mismatch for local id %d in sub "
283 "group %d in group %d - got %lu "
284 "expected %lu\n",
285 operation_names(operation),
286 TypeManager<Ty>::name(), i, j, k, rr, tr);
287 return TEST_FAIL;
288 }
289 }
290 }
291 }
292 x += nw;
293 y += nw;
294 m += 4 * nw;
295 }
296 return TEST_PASS;
297 }
298 };
299
to_float(subgroups::cl_half x)300 static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
301
to_half(float x)302 static subgroups::cl_half to_half(float x)
303 {
304 subgroups::cl_half value;
305 value.data = cl_half_from_float(x, g_rounding_mode);
306 return value;
307 }
308
309 // for integer types
calculate(Ty a,Ty b,ArithmeticOp operation)310 template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
311 {
312 switch (operation)
313 {
314 case ArithmeticOp::add_: return a + b;
315 case ArithmeticOp::max_: return a > b ? a : b;
316 case ArithmeticOp::min_: return a < b ? a : b;
317 case ArithmeticOp::mul_: return a * b;
318 case ArithmeticOp::and_: return a & b;
319 case ArithmeticOp::or_: return a | b;
320 case ArithmeticOp::xor_: return a ^ b;
321 case ArithmeticOp::logical_and: return a && b;
322 case ArithmeticOp::logical_or: return a || b;
323 case ArithmeticOp::logical_xor: return !a ^ !b;
324 default: log_error("Unknown operation request\n"); break;
325 }
326 return 0;
327 }
328 // Specialize for floating points.
329 template <>
calculate(cl_double a,cl_double b,ArithmeticOp operation)330 inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
331 {
332 switch (operation)
333 {
334 case ArithmeticOp::add_: {
335 return a + b;
336 }
337 case ArithmeticOp::max_: {
338 return a > b ? a : b;
339 }
340 case ArithmeticOp::min_: {
341 return a < b ? a : b;
342 }
343 case ArithmeticOp::mul_: {
344 return a * b;
345 }
346 default: log_error("Unknown operation request\n"); break;
347 }
348 return 0;
349 }
350
351 template <>
calculate(cl_float a,cl_float b,ArithmeticOp operation)352 inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
353 {
354 switch (operation)
355 {
356 case ArithmeticOp::add_: {
357 return a + b;
358 }
359 case ArithmeticOp::max_: {
360 return a > b ? a : b;
361 }
362 case ArithmeticOp::min_: {
363 return a < b ? a : b;
364 }
365 case ArithmeticOp::mul_: {
366 return a * b;
367 }
368 default: log_error("Unknown operation request\n"); break;
369 }
370 return 0;
371 }
372
373 template <>
calculate(subgroups::cl_half a,subgroups::cl_half b,ArithmeticOp operation)374 inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
375 ArithmeticOp operation)
376 {
377 switch (operation)
378 {
379 case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
380 case ArithmeticOp::max_:
381 return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
382 case ArithmeticOp::min_:
383 return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
384 case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
385 default: log_error("Unknown operation request\n"); break;
386 }
387 return to_half(0);
388 }
389
is_floating_point()390 template <typename Ty> bool is_floating_point()
391 {
392 return std::is_floating_point<Ty>::value
393 || std::is_same<Ty, subgroups::cl_half>::value;
394 }
395
396 // limit possible input values to avoid arithmetic rounding/overflow issues.
397 // for each subgroup values defined different values
398 // for rest of workitems set 1
399 // shuffle values
fill_and_shuffle_safe_values(std::vector<cl_ulong> & safe_values,int sb_size)400 static void fill_and_shuffle_safe_values(std::vector<cl_ulong> &safe_values,
401 int sb_size)
402 {
403 // max product is 720, cl_half has enough precision for it
404 const std::vector<cl_ulong> non_one_values{ 2, 3, 4, 5, 6 };
405
406 if (sb_size <= non_one_values.size())
407 {
408 safe_values.assign(non_one_values.begin(),
409 non_one_values.begin() + sb_size);
410 }
411 else
412 {
413 safe_values.assign(sb_size, 1);
414 std::copy(non_one_values.begin(), non_one_values.end(),
415 safe_values.begin());
416 }
417
418 std::mt19937 mersenne_twister_engine(10000);
419 std::shuffle(safe_values.begin(), safe_values.end(),
420 mersenne_twister_engine);
421 };
422
423 template <typename Ty, ArithmeticOp operation>
generate_inputs(Ty * x,Ty * t,cl_int * m,int ns,int nw,int ng)424 void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
425 {
426 int nj = (nw + ns - 1) / ns;
427
428 std::vector<cl_ulong> safe_values;
429 if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_)
430 {
431 fill_and_shuffle_safe_values(safe_values, ns);
432 }
433
434 for (int k = 0; k < ng; ++k)
435 {
436 for (int j = 0; j < nj; ++j)
437 {
438 int ii = j * ns;
439 int n = ii + ns > nw ? nw - ii : ns;
440
441 for (int i = 0; i < n; ++i)
442 {
443 cl_ulong out_value;
444 if (operation == ArithmeticOp::mul_
445 || operation == ArithmeticOp::add_)
446 {
447 out_value = safe_values[i];
448 }
449 else
450 {
451 out_value = genrand_int64(gMTdata) % (32 * n);
452 if ((operation == ArithmeticOp::logical_and
453 || operation == ArithmeticOp::logical_or
454 || operation == ArithmeticOp::logical_xor)
455 && ((out_value >> 32) & 1) == 0)
456 out_value = 0; // increase probability of false
457 }
458 set_value(t[ii + i], out_value);
459 }
460 }
461
462 // Now map into work group using map from device
463 for (int j = 0; j < nw; ++j)
464 {
465 x[j] = t[j];
466 }
467
468 x += nw;
469 m += 4 * nw;
470 }
471 }
472
473 template <typename Ty, ShuffleOp operation> struct SHF
474 {
log_testSHF475 static void log_test(const WorkGroupParams &test_params,
476 const char *extra_text)
477 {
478 log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
479 TypeManager<Ty>::name(), extra_text);
480 }
481
genSHF482 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
483 {
484 int i, ii, j, k, n;
485 cl_uint l;
486 int nw = test_params.local_workgroup_size;
487 int ns = test_params.subgroup_size;
488 int ng = test_params.global_workgroup_size;
489 int nj = (nw + ns - 1) / ns;
490 ii = 0;
491 ng = ng / nw;
492 for (k = 0; k < ng; ++k)
493 { // for each work_group
494 for (j = 0; j < nj; ++j)
495 { // for each subgroup
496 ii = j * ns;
497 n = ii + ns > nw ? nw - ii : ns;
498 for (i = 0; i < n; ++i)
499 {
500 int midx = 4 * ii + 4 * i + 2;
501 l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1)
502 % (ns * 2 + 1))
503 - 1;
504 switch (operation)
505 {
506 case ShuffleOp::shuffle:
507 case ShuffleOp::shuffle_xor:
508 case ShuffleOp::shuffle_up:
509 case ShuffleOp::shuffle_down:
510 // storing information about shuffle index/delta
511 m[midx] = (cl_int)l;
512 break;
513 case ShuffleOp::rotate:
514 case ShuffleOp::clustered_rotate:
515 // Storing information about rotate delta.
516 // The delta must be the same for each thread in
517 // the subgroup.
518 if (i == 0)
519 {
520 m[midx] = (cl_int)l;
521 }
522 else
523 {
524 m[midx] = m[midx - 4];
525 }
526 break;
527 default: break;
528 }
529 cl_ulong number = genrand_int64(gMTdata);
530 set_value(t[ii + i], number);
531 }
532 }
533 // Now map into work group using map from device
534 for (j = 0; j < nw; ++j)
535 { // for each element in work_group
536 x[j] = t[j];
537 }
538 x += nw;
539 m += 4 * nw;
540 }
541 }
542
chkSHF543 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
544 const WorkGroupParams &test_params)
545 {
546 int ii, i, j, k, n;
547 cl_uint l;
548 int nw = test_params.local_workgroup_size;
549 int ns = test_params.subgroup_size;
550 int ng = test_params.global_workgroup_size;
551 int nj = (nw + ns - 1) / ns;
552 Ty tr, rr;
553 ng = ng / nw;
554
555 for (k = 0; k < ng; ++k)
556 { // for each work_group
557 for (j = 0; j < nw; ++j)
558 { // inside the work_group
559 mx[j] = x[j]; // read host inputs for work_group
560 my[j] = y[j]; // read device outputs for work_group
561 }
562
563 for (j = 0; j < nj; ++j)
564 { // for each subgroup
565 ii = j * ns;
566 n = ii + ns > nw ? nw - ii : ns;
567
568 for (i = 0; i < n; ++i)
569 { // inside the subgroup
570 // shuffle index storage
571 int midx = 4 * ii + 4 * i + 2;
572 l = m[midx];
573 rr = my[ii + i];
574 cl_uint tr_idx;
575 bool skip = false;
576 switch (operation)
577 {
578 // shuffle basic - treat l as index
579 case ShuffleOp::shuffle: tr_idx = l; break;
580 // shuffle xor - treat l as mask
581 case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break;
582 // shuffle up - treat l as delta
583 case ShuffleOp::shuffle_up:
584 if (l >= ns) skip = true;
585 tr_idx = i - l;
586 break;
587 // shuffle down - treat l as delta
588 case ShuffleOp::shuffle_down:
589 if (l >= ns) skip = true;
590 tr_idx = i + l;
591 break;
592 // rotate - treat l as delta
593 case ShuffleOp::rotate:
594 tr_idx = (i + l) % test_params.subgroup_size;
595 break;
596 case ShuffleOp::clustered_rotate: {
597 tr_idx = ((i & ~(test_params.cluster_size - 1))
598 + ((i + l) % test_params.cluster_size));
599 break;
600 }
601 default: break;
602 }
603
604 if (!skip && tr_idx < n)
605 {
606 tr = mx[ii + tr_idx];
607
608 if (!compare(rr, tr))
609 {
610 log_error("ERROR: sub_group_%s(%s) mismatch for "
611 "local id %d in sub group %d in group "
612 "%d\n",
613 operation_names(operation),
614 TypeManager<Ty>::name(), i, j, k);
615 return TEST_FAIL;
616 }
617 }
618 }
619 }
620 x += nw;
621 y += nw;
622 m += 4 * nw;
623 }
624 return TEST_PASS;
625 }
626 };
627
628 template <typename Ty, ArithmeticOp operation> struct SCEX_NU
629 {
log_testSCEX_NU630 static void log_test(const WorkGroupParams &test_params,
631 const char *extra_text)
632 {
633 std::string func_name = (test_params.all_work_item_masks.size() > 0
634 ? "sub_group_non_uniform_scan_exclusive"
635 : "sub_group_scan_exclusive");
636 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
637 operation_names(operation), TypeManager<Ty>::name(),
638 extra_text);
639 }
640
genSCEX_NU641 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
642 {
643 int nw = test_params.local_workgroup_size;
644 int ns = test_params.subgroup_size;
645 int ng = test_params.global_workgroup_size;
646 ng = ng / nw;
647 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
648 }
649
chkSCEX_NU650 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
651 const WorkGroupParams &test_params)
652 {
653 int ii, i, j, k, n;
654 int nw = test_params.local_workgroup_size;
655 int ns = test_params.subgroup_size;
656 int ng = test_params.global_workgroup_size;
657 bs128 work_items_mask = test_params.work_items_mask;
658 int nj = (nw + ns - 1) / ns;
659 Ty tr, rr;
660 ng = ng / nw;
661
662 std::string func_name = (test_params.all_work_item_masks.size() > 0
663 ? "sub_group_non_uniform_scan_exclusive"
664 : "sub_group_scan_exclusive");
665
666 // for uniform case take into consideration all workitems
667 if (!work_items_mask.any())
668 {
669 work_items_mask.set();
670 }
671 for (k = 0; k < ng; ++k)
672 { // for each work_group
673 // Map to array indexed to array indexed by local ID and sub group
674 for (j = 0; j < nw; ++j)
675 { // inside the work_group
676 mx[j] = x[j]; // read host inputs for work_group
677 my[j] = y[j]; // read device outputs for work_group
678 }
679 for (j = 0; j < nj; ++j)
680 {
681 ii = j * ns;
682 n = ii + ns > nw ? nw - ii : ns;
683 std::set<int> active_work_items;
684 for (i = 0; i < n; ++i)
685 {
686 if (work_items_mask.test(i))
687 {
688 active_work_items.insert(i);
689 }
690 }
691 if (active_work_items.empty())
692 {
693 continue;
694 }
695 else
696 {
697 tr = TypeManager<Ty>::identify_limits(operation);
698 for (const int &active_work_item : active_work_items)
699 {
700 rr = my[ii + active_work_item];
701 if (!compare_ordered(rr, tr))
702 {
703 log_error(
704 "ERROR: %s_%s(%s) "
705 "mismatch for local id %d in sub group %d in "
706 "group %d Expected: %d Obtained: %d\n",
707 func_name.c_str(), operation_names(operation),
708 TypeManager<Ty>::name(), i, j, k, tr, rr);
709 return TEST_FAIL;
710 }
711 tr = calculate<Ty>(tr, mx[ii + active_work_item],
712 operation);
713 }
714 }
715 }
716 x += nw;
717 y += nw;
718 m += 4 * nw;
719 }
720
721 return TEST_PASS;
722 }
723 };
724
725 // Test for scan inclusive non uniform functions
726 template <typename Ty, ArithmeticOp operation> struct SCIN_NU
727 {
log_testSCIN_NU728 static void log_test(const WorkGroupParams &test_params,
729 const char *extra_text)
730 {
731 std::string func_name = (test_params.all_work_item_masks.size() > 0
732 ? "sub_group_non_uniform_scan_inclusive"
733 : "sub_group_scan_inclusive");
734 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
735 operation_names(operation), TypeManager<Ty>::name(),
736 extra_text);
737 }
738
genSCIN_NU739 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
740 {
741 int nw = test_params.local_workgroup_size;
742 int ns = test_params.subgroup_size;
743 int ng = test_params.global_workgroup_size;
744 ng = ng / nw;
745 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
746 }
747
chkSCIN_NU748 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
749 const WorkGroupParams &test_params)
750 {
751 int ii, i, j, k, n;
752 int nw = test_params.local_workgroup_size;
753 int ns = test_params.subgroup_size;
754 int ng = test_params.global_workgroup_size;
755 bs128 work_items_mask = test_params.work_items_mask;
756
757 int nj = (nw + ns - 1) / ns;
758 Ty tr, rr;
759 ng = ng / nw;
760
761 std::string func_name = (test_params.all_work_item_masks.size() > 0
762 ? "sub_group_non_uniform_scan_inclusive"
763 : "sub_group_scan_inclusive");
764
765 // for uniform case take into consideration all workitems
766 if (!work_items_mask.any())
767 {
768 work_items_mask.set();
769 }
770 // std::bitset<32> mask32(use_work_items_mask);
771 // for (int k) mask32.count();
772 for (k = 0; k < ng; ++k)
773 { // for each work_group
774 // Map to array indexed to array indexed by local ID and sub group
775 for (j = 0; j < nw; ++j)
776 { // inside the work_group
777 mx[j] = x[j]; // read host inputs for work_group
778 my[j] = y[j]; // read device outputs for work_group
779 }
780 for (j = 0; j < nj; ++j)
781 {
782 ii = j * ns;
783 n = ii + ns > nw ? nw - ii : ns;
784 std::set<int> active_work_items;
785 int catch_frist_active = -1;
786
787 for (i = 0; i < n; ++i)
788 {
789 if (work_items_mask.test(i))
790 {
791 if (catch_frist_active == -1)
792 {
793 catch_frist_active = i;
794 }
795 active_work_items.insert(i);
796 }
797 }
798 if (active_work_items.empty())
799 {
800 continue;
801 }
802 else
803 {
804 tr = TypeManager<Ty>::identify_limits(operation);
805 for (const int &active_work_item : active_work_items)
806 {
807 rr = my[ii + active_work_item];
808 if (active_work_items.size() == 1)
809 {
810 tr = mx[ii + catch_frist_active];
811 }
812 else
813 {
814 tr = calculate<Ty>(tr, mx[ii + active_work_item],
815 operation);
816 }
817 if (!compare_ordered<Ty>(rr, tr))
818 {
819 log_error(
820 "ERROR: %s_%s(%s) "
821 "mismatch for local id %d in sub group %d "
822 "in "
823 "group %d Expected: %d Obtained: %d\n",
824 func_name.c_str(), operation_names(operation),
825 TypeManager<Ty>::name(), active_work_item, j, k,
826 tr, rr);
827 return TEST_FAIL;
828 }
829 }
830 }
831 }
832 x += nw;
833 y += nw;
834 m += 4 * nw;
835 }
836
837 return TEST_PASS;
838 }
839 };
840
841 // Test for reduce non uniform functions
842 template <typename Ty, ArithmeticOp operation> struct RED_NU
843 {
log_testRED_NU844 static void log_test(const WorkGroupParams &test_params,
845 const char *extra_text)
846 {
847 std::string func_name = (test_params.all_work_item_masks.size() > 0
848 ? "sub_group_non_uniform_reduce"
849 : "sub_group_reduce");
850 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
851 operation_names(operation), TypeManager<Ty>::name(),
852 extra_text);
853 }
854
genRED_NU855 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
856 {
857 int nw = test_params.local_workgroup_size;
858 int ns = test_params.subgroup_size;
859 int ng = test_params.global_workgroup_size;
860 ng = ng / nw;
861 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
862 }
863
chkRED_NU864 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
865 const WorkGroupParams &test_params)
866 {
867 int ii, i, j, k, n;
868 int nw = test_params.local_workgroup_size;
869 int ns = test_params.subgroup_size;
870 int ng = test_params.global_workgroup_size;
871 bs128 work_items_mask = test_params.work_items_mask;
872 int nj = (nw + ns - 1) / ns;
873 ng = ng / nw;
874 Ty tr, rr;
875
876 std::string func_name = (test_params.all_work_item_masks.size() > 0
877 ? "sub_group_non_uniform_reduce"
878 : "sub_group_reduce");
879
880 for (k = 0; k < ng; ++k)
881 {
882 // Map to array indexed to array indexed by local ID and sub
883 // group
884 for (j = 0; j < nw; ++j)
885 {
886 mx[j] = x[j];
887 my[j] = y[j];
888 }
889
890 if (!work_items_mask.any())
891 {
892 work_items_mask.set();
893 }
894
895 for (j = 0; j < nj; ++j)
896 {
897 ii = j * ns;
898 n = ii + ns > nw ? nw - ii : ns;
899 std::set<int> active_work_items;
900 int catch_frist_active = -1;
901 for (i = 0; i < n; ++i)
902 {
903 if (work_items_mask.test(i))
904 {
905 if (catch_frist_active == -1)
906 {
907 catch_frist_active = i;
908 tr = mx[ii + i];
909 active_work_items.insert(i);
910 continue;
911 }
912 active_work_items.insert(i);
913 tr = calculate<Ty>(tr, mx[ii + i], operation);
914 }
915 }
916
917 if (active_work_items.empty())
918 {
919 continue;
920 }
921
922 for (const int &active_work_item : active_work_items)
923 {
924 rr = my[ii + active_work_item];
925 if (!compare_ordered<Ty>(rr, tr))
926 {
927 log_error("ERROR: %s_%s(%s) "
928 "mismatch for local id %d in sub group %d in "
929 "group %d Expected: %d Obtained: %d\n",
930 func_name.c_str(), operation_names(operation),
931 TypeManager<Ty>::name(), active_work_item, j,
932 k, tr, rr);
933 return TEST_FAIL;
934 }
935 }
936 }
937 x += nw;
938 y += nw;
939 m += 4 * nw;
940 }
941
942 return TEST_PASS;
943 }
944 };
945
946 #endif
947