• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2020 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBGROUPCOMMONTEMPLATES_H
17 #define SUBGROUPCOMMONTEMPLATES_H
18 
19 #include "typeWrappers.h"
20 #include <bitset>
21 #include "CL/cl_half.h"
22 #include "subhelpers.h"
23 
24 #include <set>
25 
26 typedef std::bitset<128> bs128;
generate_bit_mask(cl_uint subgroup_local_id,const std::string & mask_type,cl_uint max_sub_group_size)27 static cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
28                                   const std::string &mask_type,
29                                   cl_uint max_sub_group_size)
30 {
31     bs128 mask128;
32     cl_uint4 mask;
33     cl_uint pos = subgroup_local_id;
34     if (mask_type == "eq") mask128.set(pos);
35     if (mask_type == "le" || mask_type == "lt")
36     {
37         for (cl_uint i = 0; i <= pos; i++) mask128.set(i);
38         if (mask_type == "lt") mask128.reset(pos);
39     }
40     if (mask_type == "ge" || mask_type == "gt")
41     {
42         for (cl_uint i = pos; i < max_sub_group_size; i++) mask128.set(i);
43         if (mask_type == "gt") mask128.reset(pos);
44     }
45 
46     // convert std::bitset<128> to uint4
47     auto const uint_mask = bs128{ static_cast<unsigned long>(-1) };
48     mask.s0 = (mask128 & uint_mask).to_ulong();
49     mask128 >>= 32;
50     mask.s1 = (mask128 & uint_mask).to_ulong();
51     mask128 >>= 32;
52     mask.s2 = (mask128 & uint_mask).to_ulong();
53     mask128 >>= 32;
54     mask.s3 = (mask128 & uint_mask).to_ulong();
55 
56     return mask;
57 }
58 
59 // DESCRIPTION :
60 // sub_group_broadcast - each work_item registers it's own value.
61 // All work_items in subgroup takes one value from only one (any) work_item
62 // sub_group_broadcast_first - same as type 0. All work_items in
63 // subgroup takes only one value from only one chosen (the smallest subgroup ID)
64 // work_item
65 // sub_group_non_uniform_broadcast - same as type 0 but
66 // only 4 work_items from subgroup enter the code (are active)
67 template <typename Ty, SubgroupsBroadcastOp operation> struct BC
68 {
genBC69     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
70     {
71         int i, ii, j, k, n;
72         int ng = test_params.global_workgroup_size;
73         int nw = test_params.local_workgroup_size;
74         int ns = test_params.subgroup_size;
75         int nj = (nw + ns - 1) / ns;
76         int d = ns > 100 ? 100 : ns;
77         int non_uniform_size = ng % nw;
78         ng = ng / nw;
79         int last_subgroup_size = 0;
80         ii = 0;
81 
82         log_info("  sub_group_%s(%s)...\n", operation_names(operation),
83                  TypeManager<Ty>::name());
84         if (non_uniform_size)
85         {
86             log_info("  non uniform work group size mode ON\n");
87             ng++;
88         }
89         for (k = 0; k < ng; ++k)
90         { // for each work_group
91             if (non_uniform_size && k == ng - 1)
92             {
93                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
94                                           last_subgroup_size);
95             }
96             for (j = 0; j < nj; ++j)
97             { // for each subgroup
98                 ii = j * ns;
99                 if (last_subgroup_size && j == nj - 1)
100                 {
101                     n = last_subgroup_size;
102                 }
103                 else
104                 {
105                     n = ii + ns > nw ? nw - ii : ns;
106                 }
107                 int bcast_if = 0;
108                 int bcast_elseif = 0;
109                 int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
110                     % (d > n ? n : d);
111                 // l - calculate subgroup local id from which value will be
112                 // broadcasted (one the same value for whole subgroup)
113                 if (operation != SubgroupsBroadcastOp::broadcast)
114                 {
115                     // reduce brodcasting index in case of non_uniform and
116                     // last workgroup last subgroup
117                     if (last_subgroup_size && j == nj - 1
118                         && last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
119                     {
120                         bcast_if = bcast_index % last_subgroup_size;
121                         bcast_elseif = bcast_if;
122                     }
123                     else
124                     {
125                         bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
126                         bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
127                             + bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
128                     }
129                 }
130 
131                 for (i = 0; i < n; ++i)
132                 {
133                     if (operation == SubgroupsBroadcastOp::broadcast)
134                     {
135                         int midx = 4 * ii + 4 * i + 2;
136                         m[midx] = (cl_int)bcast_index;
137                     }
138                     else
139                     {
140                         if (i < NR_OF_ACTIVE_WORK_ITEMS)
141                         {
142                             // index of the third
143                             // element int the vector.
144                             int midx = 4 * ii + 4 * i + 2;
145                             // storing information about
146                             // broadcasting index -
147                             // earlier calculated
148                             m[midx] = (cl_int)bcast_if;
149                         }
150                         else
151                         { // index of the third
152                           // element int the vector.
153                             int midx = 4 * ii + 4 * i + 3;
154                             m[midx] = (cl_int)bcast_elseif;
155                         }
156                     }
157 
158                     // calculate value for broadcasting
159                     cl_ulong number = genrand_int64(gMTdata);
160                     set_value(t[ii + i], number);
161                 }
162             }
163             // Now map into work group using map from device
164             for (j = 0; j < nw; ++j)
165             { // for each element in work_group
166                 // calculate index as number of subgroup
167                 // plus subgroup local id
168                 x[j] = t[j];
169             }
170             x += nw;
171             m += 4 * nw;
172         }
173     }
174 
chkBC175     static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
176                    const WorkGroupParams &test_params)
177     {
178         int ii, i, j, k, l, n;
179         int ng = test_params.global_workgroup_size;
180         int nw = test_params.local_workgroup_size;
181         int ns = test_params.subgroup_size;
182         int nj = (nw + ns - 1) / ns;
183         Ty tr, rr;
184         int non_uniform_size = ng % nw;
185         ng = ng / nw;
186         int last_subgroup_size = 0;
187         if (non_uniform_size) ng++;
188 
189         for (k = 0; k < ng; ++k)
190         { // for each work_group
191             if (non_uniform_size && k == ng - 1)
192             {
193                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
194                                           last_subgroup_size);
195             }
196             for (j = 0; j < nw; ++j)
197             { // inside the work_group
198                 mx[j] = x[j]; // read host inputs for work_group
199                 my[j] = y[j]; // read device outputs for work_group
200             }
201 
202             for (j = 0; j < nj; ++j)
203             { // for each subgroup
204                 ii = j * ns;
205                 if (last_subgroup_size && j == nj - 1)
206                 {
207                     n = last_subgroup_size;
208                 }
209                 else
210                 {
211                     n = ii + ns > nw ? nw - ii : ns;
212                 }
213 
214                 // Check result
215                 if (operation == SubgroupsBroadcastOp::broadcast_first)
216                 {
217                     int lowest_active_id = -1;
218                     for (i = 0; i < n; ++i)
219                     {
220 
221                         lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
222                             ? 0
223                             : NR_OF_ACTIVE_WORK_ITEMS;
224                         //  findout if broadcasted
225                         //  value is the same
226                         tr = mx[ii + lowest_active_id];
227                         //  findout if broadcasted to all
228                         rr = my[ii + i];
229 
230                         if (!compare(rr, tr))
231                         {
232                             log_error(
233                                 "ERROR: sub_group_broadcast_first(%s) "
234                                 "mismatch "
235                                 "for local id %d in sub group %d in group "
236                                 "%d\n",
237                                 TypeManager<Ty>::name(), i, j, k);
238                             return TEST_FAIL;
239                         }
240                     }
241                 }
242                 else
243                 {
244                     for (i = 0; i < n; ++i)
245                     {
246                         if (operation == SubgroupsBroadcastOp::broadcast)
247                         {
248                             int midx = 4 * ii + 4 * i + 2;
249                             l = (int)m[midx];
250                             tr = mx[ii + l];
251                         }
252                         else
253                         {
254                             if (i < NR_OF_ACTIVE_WORK_ITEMS)
255                             { // take index of array where info
256                               // which work_item will be
257                               // broadcast its value is stored
258                                 int midx = 4 * ii + 4 * i + 2;
259                                 // take subgroup local id of
260                                 // this work_item
261                                 l = (int)m[midx];
262                                 // take value generated on host
263                                 // for this work_item
264                                 tr = mx[ii + l];
265                             }
266                             else
267                             {
268                                 int midx = 4 * ii + 4 * i + 3;
269                                 l = (int)m[midx];
270                                 tr = mx[ii + l];
271                             }
272                         }
273                         rr = my[ii + i]; // read device outputs for
274                                          // work_item in the subgroup
275 
276                         if (!compare(rr, tr))
277                         {
278                             log_error("ERROR: sub_group_%s(%s) "
279                                       "mismatch for local id %d in sub "
280                                       "group %d in group %d - got %lu "
281                                       "expected %lu\n",
282                                       operation_names(operation),
283                                       TypeManager<Ty>::name(), i, j, k, rr, tr);
284                             return TEST_FAIL;
285                         }
286                     }
287                 }
288             }
289             x += nw;
290             y += nw;
291             m += 4 * nw;
292         }
293         log_info("  sub_group_%s(%s)... passed\n", operation_names(operation),
294                  TypeManager<Ty>::name());
295         return TEST_PASS;
296     }
297 };
298 
to_float(subgroups::cl_half x)299 static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
300 
to_half(float x)301 static subgroups::cl_half to_half(float x)
302 {
303     subgroups::cl_half value;
304     value.data = cl_half_from_float(x, CL_HALF_RTE);
305     return value;
306 }
307 
308 // for integer types
calculate(Ty a,Ty b,ArithmeticOp operation)309 template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
310 {
311     switch (operation)
312     {
313         case ArithmeticOp::add_: return a + b;
314         case ArithmeticOp::max_: return a > b ? a : b;
315         case ArithmeticOp::min_: return a < b ? a : b;
316         case ArithmeticOp::mul_: return a * b;
317         case ArithmeticOp::and_: return a & b;
318         case ArithmeticOp::or_: return a | b;
319         case ArithmeticOp::xor_: return a ^ b;
320         case ArithmeticOp::logical_and: return a && b;
321         case ArithmeticOp::logical_or: return a || b;
322         case ArithmeticOp::logical_xor: return !a ^ !b;
323         default: log_error("Unknown operation request"); break;
324     }
325     return 0;
326 }
327 // Specialize for floating points.
328 template <>
calculate(cl_double a,cl_double b,ArithmeticOp operation)329 inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
330 {
331     switch (operation)
332     {
333         case ArithmeticOp::add_: {
334             return a + b;
335         }
336         case ArithmeticOp::max_: {
337             return a > b ? a : b;
338         }
339         case ArithmeticOp::min_: {
340             return a < b ? a : b;
341         }
342         case ArithmeticOp::mul_: {
343             return a * b;
344         }
345         default: log_error("Unknown operation request"); break;
346     }
347     return 0;
348 }
349 
350 template <>
calculate(cl_float a,cl_float b,ArithmeticOp operation)351 inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
352 {
353     switch (operation)
354     {
355         case ArithmeticOp::add_: {
356             return a + b;
357         }
358         case ArithmeticOp::max_: {
359             return a > b ? a : b;
360         }
361         case ArithmeticOp::min_: {
362             return a < b ? a : b;
363         }
364         case ArithmeticOp::mul_: {
365             return a * b;
366         }
367         default: log_error("Unknown operation request"); break;
368     }
369     return 0;
370 }
371 
372 template <>
calculate(subgroups::cl_half a,subgroups::cl_half b,ArithmeticOp operation)373 inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
374                                     ArithmeticOp operation)
375 {
376     switch (operation)
377     {
378         case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
379         case ArithmeticOp::max_:
380             return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
381         case ArithmeticOp::min_:
382             return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
383         case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
384         default: log_error("Unknown operation request"); break;
385     }
386     return to_half(0);
387 }
388 
is_floating_point()389 template <typename Ty> bool is_floating_point()
390 {
391     return std::is_floating_point<Ty>::value
392         || std::is_same<Ty, subgroups::cl_half>::value;
393 }
394 
395 template <typename Ty, ArithmeticOp operation>
genrand(Ty * x,Ty * t,cl_int * m,int ns,int nw,int ng)396 void genrand(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
397 {
398     int nj = (nw + ns - 1) / ns;
399 
400     for (int k = 0; k < ng; ++k)
401     {
402         for (int j = 0; j < nj; ++j)
403         {
404             int ii = j * ns;
405             int n = ii + ns > nw ? nw - ii : ns;
406 
407             for (int i = 0; i < n; ++i)
408             {
409                 cl_ulong out_value;
410                 double y;
411                 if (operation == ArithmeticOp::mul_
412                     || operation == ArithmeticOp::add_)
413                 {
414                     // work around to avoid overflow, do not use 0 for
415                     // multiplication
416                     out_value = (genrand_int32(gMTdata) % 4) + 1;
417                 }
418                 else
419                 {
420                     out_value = genrand_int64(gMTdata) % (32 * n);
421                     if ((operation == ArithmeticOp::logical_and
422                          || operation == ArithmeticOp::logical_or
423                          || operation == ArithmeticOp::logical_xor)
424                         && ((out_value >> 32) & 1) == 0)
425                         out_value = 0; // increase probability of false
426                 }
427                 set_value(t[ii + i], out_value);
428             }
429         }
430 
431         // Now map into work group using map from device
432         for (int j = 0; j < nw; ++j)
433         {
434             x[j] = t[j];
435         }
436 
437         x += nw;
438         m += 4 * nw;
439     }
440 }
441 
442 template <typename Ty, ShuffleOp operation> struct SHF
443 {
genSHF444     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
445     {
446         int i, ii, j, k, l, n, delta;
447         int nw = test_params.local_workgroup_size;
448         int ns = test_params.subgroup_size;
449         int ng = test_params.global_workgroup_size;
450         int nj = (nw + ns - 1) / ns;
451         int d = ns > 100 ? 100 : ns;
452         ii = 0;
453         ng = ng / nw;
454         log_info("  sub_group_%s(%s)...\n", operation_names(operation),
455                  TypeManager<Ty>::name());
456         for (k = 0; k < ng; ++k)
457         { // for each work_group
458             for (j = 0; j < nj; ++j)
459             { // for each subgroup
460                 ii = j * ns;
461                 n = ii + ns > nw ? nw - ii : ns;
462                 for (i = 0; i < n; ++i)
463                 {
464                     int midx = 4 * ii + 4 * i + 2;
465                     l = (int)(genrand_int32(gMTdata) & 0x7fffffff)
466                         % (d > n ? n : d);
467                     switch (operation)
468                     {
469                         case ShuffleOp::shuffle:
470                         case ShuffleOp::shuffle_xor:
471                             // storing information about shuffle index
472                             m[midx] = (cl_int)l;
473                             break;
474                         case ShuffleOp::shuffle_up:
475                             delta = l; // calculate delta for shuffle up
476                             if (i - delta < 0)
477                             {
478                                 delta = i;
479                             }
480                             m[midx] = (cl_int)delta;
481                             break;
482                         case ShuffleOp::shuffle_down:
483                             delta = l; // calculate delta for shuffle down
484                             if (i + delta >= n)
485                             {
486                                 delta = n - 1 - i;
487                             }
488                             m[midx] = (cl_int)delta;
489                             break;
490                         default: break;
491                     }
492                     cl_ulong number = genrand_int64(gMTdata);
493                     set_value(t[ii + i], number);
494                 }
495             }
496             // Now map into work group using map from device
497             for (j = 0; j < nw; ++j)
498             { // for each element in work_group
499                 x[j] = t[j];
500             }
501             x += nw;
502             m += 4 * nw;
503         }
504     }
505 
chkSHF506     static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
507                    const WorkGroupParams &test_params)
508     {
509         int ii, i, j, k, l, n;
510         int nw = test_params.local_workgroup_size;
511         int ns = test_params.subgroup_size;
512         int ng = test_params.global_workgroup_size;
513         int nj = (nw + ns - 1) / ns;
514         Ty tr, rr;
515         ng = ng / nw;
516 
517         for (k = 0; k < ng; ++k)
518         { // for each work_group
519             for (j = 0; j < nw; ++j)
520             { // inside the work_group
521                 mx[j] = x[j]; // read host inputs for work_group
522                 my[j] = y[j]; // read device outputs for work_group
523             }
524 
525             for (j = 0; j < nj; ++j)
526             { // for each subgroup
527                 ii = j * ns;
528                 n = ii + ns > nw ? nw - ii : ns;
529 
530                 for (i = 0; i < n; ++i)
531                 { // inside the subgroup
532                   // shuffle index storage
533                     int midx = 4 * ii + 4 * i + 2;
534                     l = (int)m[midx];
535                     rr = my[ii + i];
536                     switch (operation)
537                     {
538                         // shuffle basic - treat l as index
539                         case ShuffleOp::shuffle: tr = mx[ii + l]; break;
540                         // shuffle up - treat l as delta
541                         case ShuffleOp::shuffle_up: tr = mx[ii + i - l]; break;
542                         // shuffle up - treat l as delta
543                         case ShuffleOp::shuffle_down:
544                             tr = mx[ii + i + l];
545                             break;
546                         // shuffle xor - treat l as mask
547                         case ShuffleOp::shuffle_xor:
548                             tr = mx[ii + (i ^ l)];
549                             break;
550                         default: break;
551                     }
552 
553                     if (!compare(rr, tr))
554                     {
555                         log_error("ERROR: sub_group_%s(%s) mismatch for "
556                                   "local id %d in sub group %d in group %d\n",
557                                   operation_names(operation),
558                                   TypeManager<Ty>::name(), i, j, k);
559                         return TEST_FAIL;
560                     }
561                 }
562             }
563             x += nw;
564             y += nw;
565             m += 4 * nw;
566         }
567         log_info("  sub_group_%s(%s)... passed\n", operation_names(operation),
568                  TypeManager<Ty>::name());
569         return TEST_PASS;
570     }
571 };
572 
573 template <typename Ty, ArithmeticOp operation> struct SCEX_NU
574 {
genSCEX_NU575     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
576     {
577         int nw = test_params.local_workgroup_size;
578         int ns = test_params.subgroup_size;
579         int ng = test_params.global_workgroup_size;
580         uint32_t work_items_mask = test_params.work_items_mask;
581         ng = ng / nw;
582         std::string func_name;
583         work_items_mask ? func_name = "sub_group_non_uniform_scan_exclusive"
584                         : func_name = "sub_group_scan_exclusive";
585         log_info("  %s_%s(%s)...\n", func_name.c_str(),
586                  operation_names(operation), TypeManager<Ty>::name());
587         log_info("  test params: global size = %d local size = %d subgroups "
588                  "size = %d work item mask = 0x%x \n",
589                  test_params.global_workgroup_size, nw, ns, work_items_mask);
590         genrand<Ty, operation>(x, t, m, ns, nw, ng);
591     }
592 
chkSCEX_NU593     static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
594                    const WorkGroupParams &test_params)
595     {
596         int ii, i, j, k, n;
597         int nw = test_params.local_workgroup_size;
598         int ns = test_params.subgroup_size;
599         int ng = test_params.global_workgroup_size;
600         uint32_t work_items_mask = test_params.work_items_mask;
601         int nj = (nw + ns - 1) / ns;
602         Ty tr, rr;
603         ng = ng / nw;
604 
605         std::string func_name;
606         work_items_mask ? func_name = "sub_group_non_uniform_scan_exclusive"
607                         : func_name = "sub_group_scan_exclusive";
608 
609         uint32_t use_work_items_mask;
610         // for uniform case take into consideration all workitems
611         use_work_items_mask = !work_items_mask ? 0xFFFFFFFF : work_items_mask;
612         for (k = 0; k < ng; ++k)
613         { // for each work_group
614             // Map to array indexed to array indexed by local ID and sub group
615             for (j = 0; j < nw; ++j)
616             { // inside the work_group
617                 mx[j] = x[j]; // read host inputs for work_group
618                 my[j] = y[j]; // read device outputs for work_group
619             }
620             for (j = 0; j < nj; ++j)
621             {
622                 ii = j * ns;
623                 n = ii + ns > nw ? nw - ii : ns;
624                 std::set<int> active_work_items;
625                 for (i = 0; i < n; ++i)
626                 {
627                     uint32_t check_work_item = 1 << (i % 32);
628                     if (use_work_items_mask & check_work_item)
629                     {
630                         active_work_items.insert(i);
631                     }
632                 }
633                 if (active_work_items.empty())
634                 {
635                     log_info("  No acitve workitems in workgroup id = %d "
636                              "subgroup id = %d - no calculation\n",
637                              k, j);
638                     continue;
639                 }
640                 else if (active_work_items.size() == 1)
641                 {
642                     log_info("  One active workitem in workgroup id = %d "
643                              "subgroup id = %d - no calculation\n",
644                              k, j);
645                     continue;
646                 }
647                 else
648                 {
649                     tr = TypeManager<Ty>::identify_limits(operation);
650                     int idx = 0;
651                     for (const int &active_work_item : active_work_items)
652                     {
653                         rr = my[ii + active_work_item];
654                         if (idx == 0) continue;
655 
656                         if (!compare_ordered(rr, tr))
657                         {
658                             log_error(
659                                 "ERROR: %s_%s(%s) "
660                                 "mismatch for local id %d in sub group %d in "
661                                 "group %d Expected: %d Obtained: %d\n",
662                                 func_name.c_str(), operation_names(operation),
663                                 TypeManager<Ty>::name(), i, j, k, tr, rr);
664                             return TEST_FAIL;
665                         }
666                         tr = calculate<Ty>(tr, mx[ii + active_work_item],
667                                            operation);
668                         idx++;
669                     }
670                 }
671             }
672             x += nw;
673             y += nw;
674             m += 4 * nw;
675         }
676 
677         log_info("  %s_%s(%s)... passed\n", func_name.c_str(),
678                  operation_names(operation), TypeManager<Ty>::name());
679         return TEST_PASS;
680     }
681 };
682 
683 // Test for scan inclusive non uniform functions
684 template <typename Ty, ArithmeticOp operation> struct SCIN_NU
685 {
genSCIN_NU686     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
687     {
688         int nw = test_params.local_workgroup_size;
689         int ns = test_params.subgroup_size;
690         int ng = test_params.global_workgroup_size;
691         uint32_t work_items_mask = test_params.work_items_mask;
692         ng = ng / nw;
693         std::string func_name;
694         work_items_mask ? func_name = "sub_group_non_uniform_scan_inclusive"
695                         : func_name = "sub_group_scan_inclusive";
696 
697         genrand<Ty, operation>(x, t, m, ns, nw, ng);
698         log_info("  %s_%s(%s)...\n", func_name.c_str(),
699                  operation_names(operation), TypeManager<Ty>::name());
700         log_info("  test params: global size = %d local size = %d subgroups "
701                  "size = %d work item mask = 0x%x \n",
702                  test_params.global_workgroup_size, nw, ns, work_items_mask);
703     }
704 
chkSCIN_NU705     static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
706                    const WorkGroupParams &test_params)
707     {
708         int ii, i, j, k, n;
709         int nw = test_params.local_workgroup_size;
710         int ns = test_params.subgroup_size;
711         int ng = test_params.global_workgroup_size;
712         uint32_t work_items_mask = test_params.work_items_mask;
713         int nj = (nw + ns - 1) / ns;
714         Ty tr, rr;
715         ng = ng / nw;
716 
717         std::string func_name;
718         work_items_mask ? func_name = "sub_group_non_uniform_scan_inclusive"
719                         : func_name = "sub_group_scan_inclusive";
720 
721         uint32_t use_work_items_mask;
722         // for uniform case take into consideration all workitems
723         use_work_items_mask = !work_items_mask ? 0xFFFFFFFF : work_items_mask;
724         // std::bitset<32> mask32(use_work_items_mask);
725         // for (int k) mask32.count();
726         for (k = 0; k < ng; ++k)
727         { // for each work_group
728             // Map to array indexed to array indexed by local ID and sub group
729             for (j = 0; j < nw; ++j)
730             { // inside the work_group
731                 mx[j] = x[j]; // read host inputs for work_group
732                 my[j] = y[j]; // read device outputs for work_group
733             }
734             for (j = 0; j < nj; ++j)
735             {
736                 ii = j * ns;
737                 n = ii + ns > nw ? nw - ii : ns;
738                 std::set<int> active_work_items;
739                 int catch_frist_active = -1;
740 
741                 for (i = 0; i < n; ++i)
742                 {
743                     uint32_t check_work_item = 1 << (i % 32);
744                     if (use_work_items_mask & check_work_item)
745                     {
746                         if (catch_frist_active == -1)
747                         {
748                             catch_frist_active = i;
749                         }
750                         active_work_items.insert(i);
751                     }
752                 }
753                 if (active_work_items.empty())
754                 {
755                     log_info("  No acitve workitems in workgroup id = %d "
756                              "subgroup id = %d - no calculation\n",
757                              k, j);
758                     continue;
759                 }
760                 else
761                 {
762                     tr = TypeManager<Ty>::identify_limits(operation);
763                     for (const int &active_work_item : active_work_items)
764                     {
765                         rr = my[ii + active_work_item];
766                         if (active_work_items.size() == 1)
767                         {
768                             tr = mx[ii + catch_frist_active];
769                         }
770                         else
771                         {
772                             tr = calculate<Ty>(tr, mx[ii + active_work_item],
773                                                operation);
774                         }
775                         if (!compare_ordered<Ty>(rr, tr))
776                         {
777                             log_error(
778                                 "ERROR: %s_%s(%s) "
779                                 "mismatch for local id %d in sub group %d "
780                                 "in "
781                                 "group %d Expected: %d Obtained: %d\n",
782                                 func_name.c_str(), operation_names(operation),
783                                 TypeManager<Ty>::name(), active_work_item, j, k,
784                                 tr, rr);
785                             return TEST_FAIL;
786                         }
787                     }
788                 }
789             }
790             x += nw;
791             y += nw;
792             m += 4 * nw;
793         }
794 
795         log_info("  %s_%s(%s)... passed\n", func_name.c_str(),
796                  operation_names(operation), TypeManager<Ty>::name());
797         return TEST_PASS;
798     }
799 };
800 
801 // Test for reduce non uniform functions
802 template <typename Ty, ArithmeticOp operation> struct RED_NU
803 {
804 
genRED_NU805     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
806     {
807         int nw = test_params.local_workgroup_size;
808         int ns = test_params.subgroup_size;
809         int ng = test_params.global_workgroup_size;
810         uint32_t work_items_mask = test_params.work_items_mask;
811         ng = ng / nw;
812         std::string func_name;
813 
814         work_items_mask ? func_name = "sub_group_non_uniform_reduce"
815                         : func_name = "sub_group_reduce";
816         log_info("  %s_%s(%s)...\n", func_name.c_str(),
817                  operation_names(operation), TypeManager<Ty>::name());
818         log_info("  test params: global size = %d local size = %d subgroups "
819                  "size = %d work item mask = 0x%x \n",
820                  test_params.global_workgroup_size, nw, ns, work_items_mask);
821         genrand<Ty, operation>(x, t, m, ns, nw, ng);
822     }
823 
chkRED_NU824     static int chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
825                    const WorkGroupParams &test_params)
826     {
827         int ii, i, j, k, n;
828         int nw = test_params.local_workgroup_size;
829         int ns = test_params.subgroup_size;
830         int ng = test_params.global_workgroup_size;
831         uint32_t work_items_mask = test_params.work_items_mask;
832         int nj = (nw + ns - 1) / ns;
833         ng = ng / nw;
834         Ty tr, rr;
835 
836         std::string func_name;
837         work_items_mask ? func_name = "sub_group_non_uniform_reduce"
838                         : func_name = "sub_group_reduce";
839 
840         for (k = 0; k < ng; ++k)
841         {
842             // Map to array indexed to array indexed by local ID and sub
843             // group
844             for (j = 0; j < nw; ++j)
845             {
846                 mx[j] = x[j];
847                 my[j] = y[j];
848             }
849 
850             uint32_t use_work_items_mask;
851             use_work_items_mask =
852                 !work_items_mask ? 0xFFFFFFFF : work_items_mask;
853 
854             for (j = 0; j < nj; ++j)
855             {
856                 ii = j * ns;
857                 n = ii + ns > nw ? nw - ii : ns;
858                 std::set<int> active_work_items;
859                 int catch_frist_active = -1;
860                 for (i = 0; i < n; ++i)
861                 {
862                     uint32_t check_work_item = 1 << (i % 32);
863                     if (use_work_items_mask & check_work_item)
864                     {
865                         if (catch_frist_active == -1)
866                         {
867                             catch_frist_active = i;
868                             tr = mx[ii + i];
869                             active_work_items.insert(i);
870                             continue;
871                         }
872                         active_work_items.insert(i);
873                         tr = calculate<Ty>(tr, mx[ii + i], operation);
874                     }
875                 }
876 
877                 if (active_work_items.empty())
878                 {
879                     log_info("  No acitve workitems in workgroup id = %d "
880                              "subgroup id = %d - no calculation\n",
881                              k, j);
882                     continue;
883                 }
884 
885                 for (const int &active_work_item : active_work_items)
886                 {
887                     rr = my[ii + active_work_item];
888                     if (!compare_ordered<Ty>(rr, tr))
889                     {
890                         log_error("ERROR: %s_%s(%s) "
891                                   "mismatch for local id %d in sub group %d in "
892                                   "group %d Expected: %d Obtained: %d\n",
893                                   func_name.c_str(), operation_names(operation),
894                                   TypeManager<Ty>::name(), active_work_item, j,
895                                   k, tr, rr);
896                         return TEST_FAIL;
897                     }
898                 }
899             }
900             x += nw;
901             y += nw;
902             m += 4 * nw;
903         }
904 
905         log_info("  %s_%s(%s)... passed\n", func_name.c_str(),
906                  operation_names(operation), TypeManager<Ty>::name());
907         return TEST_PASS;
908     }
909 };
910 
911 #endif
912