• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2020 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBGROUPCOMMONTEMPLATES_H
17 #define SUBGROUPCOMMONTEMPLATES_H
18 
19 #include "typeWrappers.h"
20 #include "CL/cl_half.h"
21 #include "subhelpers.h"
22 #include <set>
23 #include <algorithm>
24 #include <random>
25 
generate_bit_mask(cl_uint subgroup_local_id,const std::string & mask_type,cl_uint max_sub_group_size)26 static cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
27                                   const std::string &mask_type,
28                                   cl_uint max_sub_group_size)
29 {
30     bs128 mask128;
31     cl_uint4 mask;
32     cl_uint pos = subgroup_local_id;
33     if (mask_type == "eq") mask128.set(pos);
34     if (mask_type == "le" || mask_type == "lt")
35     {
36         for (cl_uint i = 0; i <= pos; i++) mask128.set(i);
37         if (mask_type == "lt") mask128.reset(pos);
38     }
39     if (mask_type == "ge" || mask_type == "gt")
40     {
41         for (cl_uint i = pos; i < max_sub_group_size; i++) mask128.set(i);
42         if (mask_type == "gt") mask128.reset(pos);
43     }
44 
45     // convert std::bitset<128> to uint4
46     auto const uint_mask = bs128{ static_cast<unsigned long>(-1) };
47     mask.s0 = (mask128 & uint_mask).to_ulong();
48     mask128 >>= 32;
49     mask.s1 = (mask128 & uint_mask).to_ulong();
50     mask128 >>= 32;
51     mask.s2 = (mask128 & uint_mask).to_ulong();
52     mask128 >>= 32;
53     mask.s3 = (mask128 & uint_mask).to_ulong();
54 
55     return mask;
56 }
57 
58 // DESCRIPTION :
59 // sub_group_broadcast - each work_item registers it's own value.
60 // All work_items in subgroup takes one value from only one (any) work_item
61 // sub_group_broadcast_first - same as type 0. All work_items in
62 // subgroup takes only one value from only one chosen (the smallest subgroup ID)
63 // work_item
64 // sub_group_non_uniform_broadcast - same as type 0 but
65 // only 4 work_items from subgroup enter the code (are active)
66 template <typename Ty, SubgroupsBroadcastOp operation> struct BC
67 {
log_testBC68     static void log_test(const WorkGroupParams &test_params,
69                          const char *extra_text)
70     {
71         log_info("  sub_group_%s(%s)...%s\n", operation_names(operation),
72                  TypeManager<Ty>::name(), extra_text);
73     }
74 
genBC75     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
76     {
77         int i, ii, j, k, n;
78         int ng = test_params.global_workgroup_size;
79         int nw = test_params.local_workgroup_size;
80         int ns = test_params.subgroup_size;
81         int nj = (nw + ns - 1) / ns;
82         int d = ns > 100 ? 100 : ns;
83         int non_uniform_size = ng % nw;
84         ng = ng / nw;
85         int last_subgroup_size = 0;
86         ii = 0;
87 
88         if (non_uniform_size)
89         {
90             ng++;
91         }
92         for (k = 0; k < ng; ++k)
93         { // for each work_group
94             if (non_uniform_size && k == ng - 1)
95             {
96                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
97                                           last_subgroup_size);
98             }
99             for (j = 0; j < nj; ++j)
100             { // for each subgroup
101                 ii = j * ns;
102                 if (last_subgroup_size && j == nj - 1)
103                 {
104                     n = last_subgroup_size;
105                 }
106                 else
107                 {
108                     n = ii + ns > nw ? nw - ii : ns;
109                 }
110                 int bcast_if = 0;
111                 int bcast_elseif = 0;
112                 int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
113                     % (d > n ? n : d);
114                 // l - calculate subgroup local id from which value will be
115                 // broadcasted (one the same value for whole subgroup)
116                 if (operation != SubgroupsBroadcastOp::broadcast)
117                 {
118                     // reduce brodcasting index in case of non_uniform and
119                     // last workgroup last subgroup
120                     if (last_subgroup_size && j == nj - 1
121                         && last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
122                     {
123                         bcast_if = bcast_index % last_subgroup_size;
124                         bcast_elseif = bcast_if;
125                     }
126                     else
127                     {
128                         bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
129                         bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
130                             + bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
131                     }
132                 }
133 
134                 for (i = 0; i < n; ++i)
135                 {
136                     if (operation == SubgroupsBroadcastOp::broadcast)
137                     {
138                         int midx = 4 * ii + 4 * i + 2;
139                         m[midx] = (cl_int)bcast_index;
140                     }
141                     else
142                     {
143                         if (i < NR_OF_ACTIVE_WORK_ITEMS)
144                         {
145                             // index of the third
146                             // element int the vector.
147                             int midx = 4 * ii + 4 * i + 2;
148                             // storing information about
149                             // broadcasting index -
150                             // earlier calculated
151                             m[midx] = (cl_int)bcast_if;
152                         }
153                         else
154                         { // index of the third
155                           // element int the vector.
156                             int midx = 4 * ii + 4 * i + 3;
157                             m[midx] = (cl_int)bcast_elseif;
158                         }
159                     }
160 
161                     // calculate value for broadcasting
162                     cl_ulong number = genrand_int64(gMTdata);
163                     set_value(t[ii + i], number);
164                 }
165             }
166             // Now map into work group using map from device
167             for (j = 0; j < nw; ++j)
168             { // for each element in work_group
169                 // calculate index as number of subgroup
170                 // plus subgroup local id
171                 x[j] = t[j];
172             }
173             x += nw;
174             m += 4 * nw;
175         }
176     }
177 
chkBC178     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
179                            const WorkGroupParams &test_params)
180     {
181         int ii, i, j, k, l, n;
182         int ng = test_params.global_workgroup_size;
183         int nw = test_params.local_workgroup_size;
184         int ns = test_params.subgroup_size;
185         int nj = (nw + ns - 1) / ns;
186         Ty tr, rr;
187         int non_uniform_size = ng % nw;
188         ng = ng / nw;
189         int last_subgroup_size = 0;
190         if (non_uniform_size) ng++;
191 
192         for (k = 0; k < ng; ++k)
193         { // for each work_group
194             if (non_uniform_size && k == ng - 1)
195             {
196                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
197                                           last_subgroup_size);
198             }
199             for (j = 0; j < nw; ++j)
200             { // inside the work_group
201                 mx[j] = x[j]; // read host inputs for work_group
202                 my[j] = y[j]; // read device outputs for work_group
203             }
204 
205             for (j = 0; j < nj; ++j)
206             { // for each subgroup
207                 ii = j * ns;
208                 if (last_subgroup_size && j == nj - 1)
209                 {
210                     n = last_subgroup_size;
211                 }
212                 else
213                 {
214                     n = ii + ns > nw ? nw - ii : ns;
215                 }
216 
217                 // Check result
218                 if (operation == SubgroupsBroadcastOp::broadcast_first)
219                 {
220                     int lowest_active_id = -1;
221                     for (i = 0; i < n; ++i)
222                     {
223 
224                         lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
225                             ? 0
226                             : NR_OF_ACTIVE_WORK_ITEMS;
227                         //  findout if broadcasted
228                         //  value is the same
229                         tr = mx[ii + lowest_active_id];
230                         //  findout if broadcasted to all
231                         rr = my[ii + i];
232 
233                         if (!compare(rr, tr))
234                         {
235                             log_error(
236                                 "ERROR: sub_group_broadcast_first(%s) "
237                                 "mismatch "
238                                 "for local id %d in sub group %d in group "
239                                 "%d\n",
240                                 TypeManager<Ty>::name(), i, j, k);
241                             return TEST_FAIL;
242                         }
243                     }
244                 }
245                 else
246                 {
247                     for (i = 0; i < n; ++i)
248                     {
249                         if (operation == SubgroupsBroadcastOp::broadcast)
250                         {
251                             int midx = 4 * ii + 4 * i + 2;
252                             l = (int)m[midx];
253                             tr = mx[ii + l];
254                         }
255                         else
256                         {
257                             if (i < NR_OF_ACTIVE_WORK_ITEMS)
258                             { // take index of array where info
259                               // which work_item will be
260                               // broadcast its value is stored
261                                 int midx = 4 * ii + 4 * i + 2;
262                                 // take subgroup local id of
263                                 // this work_item
264                                 l = (int)m[midx];
265                                 // take value generated on host
266                                 // for this work_item
267                                 tr = mx[ii + l];
268                             }
269                             else
270                             {
271                                 int midx = 4 * ii + 4 * i + 3;
272                                 l = (int)m[midx];
273                                 tr = mx[ii + l];
274                             }
275                         }
276                         rr = my[ii + i]; // read device outputs for
277                                          // work_item in the subgroup
278 
279                         if (!compare(rr, tr))
280                         {
281                             log_error("ERROR: sub_group_%s(%s) "
282                                       "mismatch for local id %d in sub "
283                                       "group %d in group %d - got %lu "
284                                       "expected %lu\n",
285                                       operation_names(operation),
286                                       TypeManager<Ty>::name(), i, j, k, rr, tr);
287                             return TEST_FAIL;
288                         }
289                     }
290                 }
291             }
292             x += nw;
293             y += nw;
294             m += 4 * nw;
295         }
296         return TEST_PASS;
297     }
298 };
299 
to_float(subgroups::cl_half x)300 static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
301 
to_half(float x)302 static subgroups::cl_half to_half(float x)
303 {
304     subgroups::cl_half value;
305     value.data = cl_half_from_float(x, g_rounding_mode);
306     return value;
307 }
308 
309 // for integer types
calculate(Ty a,Ty b,ArithmeticOp operation)310 template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
311 {
312     switch (operation)
313     {
314         case ArithmeticOp::add_: return a + b;
315         case ArithmeticOp::max_: return a > b ? a : b;
316         case ArithmeticOp::min_: return a < b ? a : b;
317         case ArithmeticOp::mul_: return a * b;
318         case ArithmeticOp::and_: return a & b;
319         case ArithmeticOp::or_: return a | b;
320         case ArithmeticOp::xor_: return a ^ b;
321         case ArithmeticOp::logical_and: return a && b;
322         case ArithmeticOp::logical_or: return a || b;
323         case ArithmeticOp::logical_xor: return !a ^ !b;
324         default: log_error("Unknown operation request\n"); break;
325     }
326     return 0;
327 }
328 // Specialize for floating points.
329 template <>
calculate(cl_double a,cl_double b,ArithmeticOp operation)330 inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
331 {
332     switch (operation)
333     {
334         case ArithmeticOp::add_: {
335             return a + b;
336         }
337         case ArithmeticOp::max_: {
338             return a > b ? a : b;
339         }
340         case ArithmeticOp::min_: {
341             return a < b ? a : b;
342         }
343         case ArithmeticOp::mul_: {
344             return a * b;
345         }
346         default: log_error("Unknown operation request\n"); break;
347     }
348     return 0;
349 }
350 
351 template <>
calculate(cl_float a,cl_float b,ArithmeticOp operation)352 inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
353 {
354     switch (operation)
355     {
356         case ArithmeticOp::add_: {
357             return a + b;
358         }
359         case ArithmeticOp::max_: {
360             return a > b ? a : b;
361         }
362         case ArithmeticOp::min_: {
363             return a < b ? a : b;
364         }
365         case ArithmeticOp::mul_: {
366             return a * b;
367         }
368         default: log_error("Unknown operation request\n"); break;
369     }
370     return 0;
371 }
372 
373 template <>
calculate(subgroups::cl_half a,subgroups::cl_half b,ArithmeticOp operation)374 inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
375                                     ArithmeticOp operation)
376 {
377     switch (operation)
378     {
379         case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
380         case ArithmeticOp::max_:
381             return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
382         case ArithmeticOp::min_:
383             return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
384         case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
385         default: log_error("Unknown operation request\n"); break;
386     }
387     return to_half(0);
388 }
389 
is_floating_point()390 template <typename Ty> bool is_floating_point()
391 {
392     return std::is_floating_point<Ty>::value
393         || std::is_same<Ty, subgroups::cl_half>::value;
394 }
395 
396 // limit possible input values to avoid arithmetic rounding/overflow issues.
397 // for each subgroup values defined different values
398 // for rest of workitems set 1
399 // shuffle values
fill_and_shuffle_safe_values(std::vector<cl_ulong> & safe_values,int sb_size)400 static void fill_and_shuffle_safe_values(std::vector<cl_ulong> &safe_values,
401                                          int sb_size)
402 {
403     // max product is 720, cl_half has enough precision for it
404     const std::vector<cl_ulong> non_one_values{ 2, 3, 4, 5, 6 };
405 
406     if (sb_size <= non_one_values.size())
407     {
408         safe_values.assign(non_one_values.begin(),
409                            non_one_values.begin() + sb_size);
410     }
411     else
412     {
413         safe_values.assign(sb_size, 1);
414         std::copy(non_one_values.begin(), non_one_values.end(),
415                   safe_values.begin());
416     }
417 
418     std::mt19937 mersenne_twister_engine(10000);
419     std::shuffle(safe_values.begin(), safe_values.end(),
420                  mersenne_twister_engine);
421 };
422 
423 template <typename Ty, ArithmeticOp operation>
generate_inputs(Ty * x,Ty * t,cl_int * m,int ns,int nw,int ng)424 void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
425 {
426     int nj = (nw + ns - 1) / ns;
427 
428     std::vector<cl_ulong> safe_values;
429     if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_)
430     {
431         fill_and_shuffle_safe_values(safe_values, ns);
432     }
433 
434     for (int k = 0; k < ng; ++k)
435     {
436         for (int j = 0; j < nj; ++j)
437         {
438             int ii = j * ns;
439             int n = ii + ns > nw ? nw - ii : ns;
440 
441             for (int i = 0; i < n; ++i)
442             {
443                 cl_ulong out_value;
444                 if (operation == ArithmeticOp::mul_
445                     || operation == ArithmeticOp::add_)
446                 {
447                     out_value = safe_values[i];
448                 }
449                 else
450                 {
451                     out_value = genrand_int64(gMTdata) % (32 * n);
452                     if ((operation == ArithmeticOp::logical_and
453                          || operation == ArithmeticOp::logical_or
454                          || operation == ArithmeticOp::logical_xor)
455                         && ((out_value >> 32) & 1) == 0)
456                         out_value = 0; // increase probability of false
457                 }
458                 set_value(t[ii + i], out_value);
459             }
460         }
461 
462         // Now map into work group using map from device
463         for (int j = 0; j < nw; ++j)
464         {
465             x[j] = t[j];
466         }
467 
468         x += nw;
469         m += 4 * nw;
470     }
471 }
472 
473 template <typename Ty, ShuffleOp operation> struct SHF
474 {
log_testSHF475     static void log_test(const WorkGroupParams &test_params,
476                          const char *extra_text)
477     {
478         log_info("  sub_group_%s(%s)...%s\n", operation_names(operation),
479                  TypeManager<Ty>::name(), extra_text);
480     }
481 
genSHF482     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
483     {
484         int i, ii, j, k, n;
485         cl_uint l;
486         int nw = test_params.local_workgroup_size;
487         int ns = test_params.subgroup_size;
488         int ng = test_params.global_workgroup_size;
489         int nj = (nw + ns - 1) / ns;
490         ii = 0;
491         ng = ng / nw;
492         for (k = 0; k < ng; ++k)
493         { // for each work_group
494             for (j = 0; j < nj; ++j)
495             { // for each subgroup
496                 ii = j * ns;
497                 n = ii + ns > nw ? nw - ii : ns;
498                 for (i = 0; i < n; ++i)
499                 {
500                     int midx = 4 * ii + 4 * i + 2;
501                     l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1)
502                          % (ns * 2 + 1))
503                         - 1;
504                     switch (operation)
505                     {
506                         case ShuffleOp::shuffle:
507                         case ShuffleOp::shuffle_xor:
508                         case ShuffleOp::shuffle_up:
509                         case ShuffleOp::shuffle_down:
510                             // storing information about shuffle index/delta
511                             m[midx] = (cl_int)l;
512                             break;
513                         case ShuffleOp::rotate:
514                         case ShuffleOp::clustered_rotate:
515                             // Storing information about rotate delta.
516                             // The delta must be the same for each thread in
517                             // the subgroup.
518                             if (i == 0)
519                             {
520                                 m[midx] = (cl_int)l;
521                             }
522                             else
523                             {
524                                 m[midx] = m[midx - 4];
525                             }
526                             break;
527                         default: break;
528                     }
529                     cl_ulong number = genrand_int64(gMTdata);
530                     set_value(t[ii + i], number);
531                 }
532             }
533             // Now map into work group using map from device
534             for (j = 0; j < nw; ++j)
535             { // for each element in work_group
536                 x[j] = t[j];
537             }
538             x += nw;
539             m += 4 * nw;
540         }
541     }
542 
chkSHF543     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
544                            const WorkGroupParams &test_params)
545     {
546         int ii, i, j, k, n;
547         cl_uint l;
548         int nw = test_params.local_workgroup_size;
549         int ns = test_params.subgroup_size;
550         int ng = test_params.global_workgroup_size;
551         int nj = (nw + ns - 1) / ns;
552         Ty tr, rr;
553         ng = ng / nw;
554 
555         for (k = 0; k < ng; ++k)
556         { // for each work_group
557             for (j = 0; j < nw; ++j)
558             { // inside the work_group
559                 mx[j] = x[j]; // read host inputs for work_group
560                 my[j] = y[j]; // read device outputs for work_group
561             }
562 
563             for (j = 0; j < nj; ++j)
564             { // for each subgroup
565                 ii = j * ns;
566                 n = ii + ns > nw ? nw - ii : ns;
567 
568                 for (i = 0; i < n; ++i)
569                 { // inside the subgroup
570                   // shuffle index storage
571                     int midx = 4 * ii + 4 * i + 2;
572                     l = m[midx];
573                     rr = my[ii + i];
574                     cl_uint tr_idx;
575                     bool skip = false;
576                     switch (operation)
577                     {
578                         // shuffle basic - treat l as index
579                         case ShuffleOp::shuffle: tr_idx = l; break;
580                         // shuffle xor - treat l as mask
581                         case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break;
582                         // shuffle up - treat l as delta
583                         case ShuffleOp::shuffle_up:
584                             if (l >= ns) skip = true;
585                             tr_idx = i - l;
586                             break;
587                         // shuffle down - treat l as delta
588                         case ShuffleOp::shuffle_down:
589                             if (l >= ns) skip = true;
590                             tr_idx = i + l;
591                             break;
592                         // rotate - treat l as delta
593                         case ShuffleOp::rotate:
594                             tr_idx = (i + l) % test_params.subgroup_size;
595                             break;
596                         case ShuffleOp::clustered_rotate: {
597                             tr_idx = ((i & ~(test_params.cluster_size - 1))
598                                       + ((i + l) % test_params.cluster_size));
599                             break;
600                         }
601                         default: break;
602                     }
603 
604                     if (!skip && tr_idx < n)
605                     {
606                         tr = mx[ii + tr_idx];
607 
608                         if (!compare(rr, tr))
609                         {
610                             log_error("ERROR: sub_group_%s(%s) mismatch for "
611                                       "local id %d in sub group %d in group "
612                                       "%d\n",
613                                       operation_names(operation),
614                                       TypeManager<Ty>::name(), i, j, k);
615                             return TEST_FAIL;
616                         }
617                     }
618                 }
619             }
620             x += nw;
621             y += nw;
622             m += 4 * nw;
623         }
624         return TEST_PASS;
625     }
626 };
627 
628 template <typename Ty, ArithmeticOp operation> struct SCEX_NU
629 {
log_testSCEX_NU630     static void log_test(const WorkGroupParams &test_params,
631                          const char *extra_text)
632     {
633         std::string func_name = (test_params.all_work_item_masks.size() > 0
634                                      ? "sub_group_non_uniform_scan_exclusive"
635                                      : "sub_group_scan_exclusive");
636         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
637                  operation_names(operation), TypeManager<Ty>::name(),
638                  extra_text);
639     }
640 
genSCEX_NU641     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
642     {
643         int nw = test_params.local_workgroup_size;
644         int ns = test_params.subgroup_size;
645         int ng = test_params.global_workgroup_size;
646         ng = ng / nw;
647         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
648     }
649 
chkSCEX_NU650     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
651                            const WorkGroupParams &test_params)
652     {
653         int ii, i, j, k, n;
654         int nw = test_params.local_workgroup_size;
655         int ns = test_params.subgroup_size;
656         int ng = test_params.global_workgroup_size;
657         bs128 work_items_mask = test_params.work_items_mask;
658         int nj = (nw + ns - 1) / ns;
659         Ty tr, rr;
660         ng = ng / nw;
661 
662         std::string func_name = (test_params.all_work_item_masks.size() > 0
663                                      ? "sub_group_non_uniform_scan_exclusive"
664                                      : "sub_group_scan_exclusive");
665 
666         // for uniform case take into consideration all workitems
667         if (!work_items_mask.any())
668         {
669             work_items_mask.set();
670         }
671         for (k = 0; k < ng; ++k)
672         { // for each work_group
673             // Map to array indexed to array indexed by local ID and sub group
674             for (j = 0; j < nw; ++j)
675             { // inside the work_group
676                 mx[j] = x[j]; // read host inputs for work_group
677                 my[j] = y[j]; // read device outputs for work_group
678             }
679             for (j = 0; j < nj; ++j)
680             {
681                 ii = j * ns;
682                 n = ii + ns > nw ? nw - ii : ns;
683                 std::set<int> active_work_items;
684                 for (i = 0; i < n; ++i)
685                 {
686                     if (work_items_mask.test(i))
687                     {
688                         active_work_items.insert(i);
689                     }
690                 }
691                 if (active_work_items.empty())
692                 {
693                     continue;
694                 }
695                 else
696                 {
697                     tr = TypeManager<Ty>::identify_limits(operation);
698                     for (const int &active_work_item : active_work_items)
699                     {
700                         rr = my[ii + active_work_item];
701                         if (!compare_ordered(rr, tr))
702                         {
703                             log_error(
704                                 "ERROR: %s_%s(%s) "
705                                 "mismatch for local id %d in sub group %d in "
706                                 "group %d Expected: %d Obtained: %d\n",
707                                 func_name.c_str(), operation_names(operation),
708                                 TypeManager<Ty>::name(), i, j, k, tr, rr);
709                             return TEST_FAIL;
710                         }
711                         tr = calculate<Ty>(tr, mx[ii + active_work_item],
712                                            operation);
713                     }
714                 }
715             }
716             x += nw;
717             y += nw;
718             m += 4 * nw;
719         }
720 
721         return TEST_PASS;
722     }
723 };
724 
725 // Test for scan inclusive non uniform functions
726 template <typename Ty, ArithmeticOp operation> struct SCIN_NU
727 {
log_testSCIN_NU728     static void log_test(const WorkGroupParams &test_params,
729                          const char *extra_text)
730     {
731         std::string func_name = (test_params.all_work_item_masks.size() > 0
732                                      ? "sub_group_non_uniform_scan_inclusive"
733                                      : "sub_group_scan_inclusive");
734         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
735                  operation_names(operation), TypeManager<Ty>::name(),
736                  extra_text);
737     }
738 
genSCIN_NU739     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
740     {
741         int nw = test_params.local_workgroup_size;
742         int ns = test_params.subgroup_size;
743         int ng = test_params.global_workgroup_size;
744         ng = ng / nw;
745         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
746     }
747 
chkSCIN_NU748     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
749                            const WorkGroupParams &test_params)
750     {
751         int ii, i, j, k, n;
752         int nw = test_params.local_workgroup_size;
753         int ns = test_params.subgroup_size;
754         int ng = test_params.global_workgroup_size;
755         bs128 work_items_mask = test_params.work_items_mask;
756 
757         int nj = (nw + ns - 1) / ns;
758         Ty tr, rr;
759         ng = ng / nw;
760 
761         std::string func_name = (test_params.all_work_item_masks.size() > 0
762                                      ? "sub_group_non_uniform_scan_inclusive"
763                                      : "sub_group_scan_inclusive");
764 
765         // for uniform case take into consideration all workitems
766         if (!work_items_mask.any())
767         {
768             work_items_mask.set();
769         }
770         // std::bitset<32> mask32(use_work_items_mask);
771         // for (int k) mask32.count();
772         for (k = 0; k < ng; ++k)
773         { // for each work_group
774             // Map to array indexed to array indexed by local ID and sub group
775             for (j = 0; j < nw; ++j)
776             { // inside the work_group
777                 mx[j] = x[j]; // read host inputs for work_group
778                 my[j] = y[j]; // read device outputs for work_group
779             }
780             for (j = 0; j < nj; ++j)
781             {
782                 ii = j * ns;
783                 n = ii + ns > nw ? nw - ii : ns;
784                 std::set<int> active_work_items;
785                 int catch_frist_active = -1;
786 
787                 for (i = 0; i < n; ++i)
788                 {
789                     if (work_items_mask.test(i))
790                     {
791                         if (catch_frist_active == -1)
792                         {
793                             catch_frist_active = i;
794                         }
795                         active_work_items.insert(i);
796                     }
797                 }
798                 if (active_work_items.empty())
799                 {
800                     continue;
801                 }
802                 else
803                 {
804                     tr = TypeManager<Ty>::identify_limits(operation);
805                     for (const int &active_work_item : active_work_items)
806                     {
807                         rr = my[ii + active_work_item];
808                         if (active_work_items.size() == 1)
809                         {
810                             tr = mx[ii + catch_frist_active];
811                         }
812                         else
813                         {
814                             tr = calculate<Ty>(tr, mx[ii + active_work_item],
815                                                operation);
816                         }
817                         if (!compare_ordered<Ty>(rr, tr))
818                         {
819                             log_error(
820                                 "ERROR: %s_%s(%s) "
821                                 "mismatch for local id %d in sub group %d "
822                                 "in "
823                                 "group %d Expected: %d Obtained: %d\n",
824                                 func_name.c_str(), operation_names(operation),
825                                 TypeManager<Ty>::name(), active_work_item, j, k,
826                                 tr, rr);
827                             return TEST_FAIL;
828                         }
829                     }
830                 }
831             }
832             x += nw;
833             y += nw;
834             m += 4 * nw;
835         }
836 
837         return TEST_PASS;
838     }
839 };
840 
841 // Test for reduce non uniform functions
842 template <typename Ty, ArithmeticOp operation> struct RED_NU
843 {
log_testRED_NU844     static void log_test(const WorkGroupParams &test_params,
845                          const char *extra_text)
846     {
847         std::string func_name = (test_params.all_work_item_masks.size() > 0
848                                      ? "sub_group_non_uniform_reduce"
849                                      : "sub_group_reduce");
850         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
851                  operation_names(operation), TypeManager<Ty>::name(),
852                  extra_text);
853     }
854 
genRED_NU855     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
856     {
857         int nw = test_params.local_workgroup_size;
858         int ns = test_params.subgroup_size;
859         int ng = test_params.global_workgroup_size;
860         ng = ng / nw;
861         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
862     }
863 
chkRED_NU864     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
865                            const WorkGroupParams &test_params)
866     {
867         int ii, i, j, k, n;
868         int nw = test_params.local_workgroup_size;
869         int ns = test_params.subgroup_size;
870         int ng = test_params.global_workgroup_size;
871         bs128 work_items_mask = test_params.work_items_mask;
872         int nj = (nw + ns - 1) / ns;
873         ng = ng / nw;
874         Ty tr, rr;
875 
876         std::string func_name = (test_params.all_work_item_masks.size() > 0
877                                      ? "sub_group_non_uniform_reduce"
878                                      : "sub_group_reduce");
879 
880         for (k = 0; k < ng; ++k)
881         {
882             // Map to array indexed to array indexed by local ID and sub
883             // group
884             for (j = 0; j < nw; ++j)
885             {
886                 mx[j] = x[j];
887                 my[j] = y[j];
888             }
889 
890             if (!work_items_mask.any())
891             {
892                 work_items_mask.set();
893             }
894 
895             for (j = 0; j < nj; ++j)
896             {
897                 ii = j * ns;
898                 n = ii + ns > nw ? nw - ii : ns;
899                 std::set<int> active_work_items;
900                 int catch_frist_active = -1;
901                 for (i = 0; i < n; ++i)
902                 {
903                     if (work_items_mask.test(i))
904                     {
905                         if (catch_frist_active == -1)
906                         {
907                             catch_frist_active = i;
908                             tr = mx[ii + i];
909                             active_work_items.insert(i);
910                             continue;
911                         }
912                         active_work_items.insert(i);
913                         tr = calculate<Ty>(tr, mx[ii + i], operation);
914                     }
915                 }
916 
917                 if (active_work_items.empty())
918                 {
919                     continue;
920                 }
921 
922                 for (const int &active_work_item : active_work_items)
923                 {
924                     rr = my[ii + active_work_item];
925                     if (!compare_ordered<Ty>(rr, tr))
926                     {
927                         log_error("ERROR: %s_%s(%s) "
928                                   "mismatch for local id %d in sub group %d in "
929                                   "group %d Expected: %d Obtained: %d\n",
930                                   func_name.c_str(), operation_names(operation),
931                                   TypeManager<Ty>::name(), active_work_item, j,
932                                   k, tr, rr);
933                         return TEST_FAIL;
934                     }
935                 }
936             }
937             x += nw;
938             y += nw;
939             m += 4 * nw;
940         }
941 
942         return TEST_PASS;
943     }
944 };
945 
946 #endif
947