• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2017 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBHELPERS_H
17 #define SUBHELPERS_H
18 
19 #include "testHarness.h"
20 #include "kernelHelpers.h"
21 #include "typeWrappers.h"
22 #include "imageHelpers.h"
23 
24 #include <limits>
25 #include <vector>
26 #include <type_traits>
27 
28 #define NR_OF_ACTIVE_WORK_ITEMS 4
29 
30 extern MTdata gMTdata;
31 
32 struct WorkGroupParams
33 {
34     WorkGroupParams(size_t gws, size_t lws,
35                     const std::vector<std::string> &req_ext = {},
36                     const std::vector<uint32_t> &all_wim = {})
global_workgroup_sizeWorkGroupParams37         : global_workgroup_size(gws), local_workgroup_size(lws),
38           required_extensions(req_ext), all_work_item_masks(all_wim)
39     {
40         subgroup_size = 0;
41         work_items_mask = 0;
42         use_core_subgroups = true;
43         dynsc = 0;
44     }
45     size_t global_workgroup_size;
46     size_t local_workgroup_size;
47     size_t subgroup_size;
48     uint32_t work_items_mask;
49     int dynsc;
50     bool use_core_subgroups;
51     std::vector<std::string> required_extensions;
52     std::vector<uint32_t> all_work_item_masks;
53 };
54 
55 enum class SubgroupsBroadcastOp
56 {
57     broadcast,
58     broadcast_first,
59     non_uniform_broadcast
60 };
61 
62 enum class NonUniformVoteOp
63 {
64     elect,
65     all,
66     any,
67     all_equal
68 };
69 
70 enum class BallotOp
71 {
72     ballot,
73     inverse_ballot,
74     ballot_bit_extract,
75     ballot_bit_count,
76     ballot_inclusive_scan,
77     ballot_exclusive_scan,
78     ballot_find_lsb,
79     ballot_find_msb,
80     eq_mask,
81     ge_mask,
82     gt_mask,
83     le_mask,
84     lt_mask,
85 };
86 
87 enum class ShuffleOp
88 {
89     shuffle,
90     shuffle_up,
91     shuffle_down,
92     shuffle_xor
93 };
94 
95 enum class ArithmeticOp
96 {
97     add_,
98     max_,
99     min_,
100     mul_,
101     and_,
102     or_,
103     xor_,
104     logical_and,
105     logical_or,
106     logical_xor
107 };
108 
operation_names(ArithmeticOp operation)109 static const char *const operation_names(ArithmeticOp operation)
110 {
111     switch (operation)
112     {
113         case ArithmeticOp::add_: return "add";
114         case ArithmeticOp::max_: return "max";
115         case ArithmeticOp::min_: return "min";
116         case ArithmeticOp::mul_: return "mul";
117         case ArithmeticOp::and_: return "and";
118         case ArithmeticOp::or_: return "or";
119         case ArithmeticOp::xor_: return "xor";
120         case ArithmeticOp::logical_and: return "logical_and";
121         case ArithmeticOp::logical_or: return "logical_or";
122         case ArithmeticOp::logical_xor: return "logical_xor";
123         default: log_error("Unknown operation request"); break;
124     }
125     return "";
126 }
127 
operation_names(BallotOp operation)128 static const char *const operation_names(BallotOp operation)
129 {
130     switch (operation)
131     {
132         case BallotOp::ballot: return "ballot";
133         case BallotOp::inverse_ballot: return "inverse_ballot";
134         case BallotOp::ballot_bit_extract: return "bit_extract";
135         case BallotOp::ballot_bit_count: return "bit_count";
136         case BallotOp::ballot_inclusive_scan: return "inclusive_scan";
137         case BallotOp::ballot_exclusive_scan: return "exclusive_scan";
138         case BallotOp::ballot_find_lsb: return "find_lsb";
139         case BallotOp::ballot_find_msb: return "find_msb";
140         case BallotOp::eq_mask: return "eq";
141         case BallotOp::ge_mask: return "ge";
142         case BallotOp::gt_mask: return "gt";
143         case BallotOp::le_mask: return "le";
144         case BallotOp::lt_mask: return "lt";
145         default: log_error("Unknown operation request"); break;
146     }
147     return "";
148 }
149 
operation_names(ShuffleOp operation)150 static const char *const operation_names(ShuffleOp operation)
151 {
152     switch (operation)
153     {
154         case ShuffleOp::shuffle: return "shuffle";
155         case ShuffleOp::shuffle_up: return "shuffle_up";
156         case ShuffleOp::shuffle_down: return "shuffle_down";
157         case ShuffleOp::shuffle_xor: return "shuffle_xor";
158         default: log_error("Unknown operation request"); break;
159     }
160     return "";
161 }
162 
operation_names(NonUniformVoteOp operation)163 static const char *const operation_names(NonUniformVoteOp operation)
164 {
165     switch (operation)
166     {
167         case NonUniformVoteOp::all: return "all";
168         case NonUniformVoteOp::all_equal: return "all_equal";
169         case NonUniformVoteOp::any: return "any";
170         case NonUniformVoteOp::elect: return "elect";
171         default: log_error("Unknown operation request"); break;
172     }
173     return "";
174 }
175 
operation_names(SubgroupsBroadcastOp operation)176 static const char *const operation_names(SubgroupsBroadcastOp operation)
177 {
178     switch (operation)
179     {
180         case SubgroupsBroadcastOp::broadcast: return "broadcast";
181         case SubgroupsBroadcastOp::broadcast_first: return "broadcast_first";
182         case SubgroupsBroadcastOp::non_uniform_broadcast:
183             return "non_uniform_broadcast";
184         default: log_error("Unknown operation request"); break;
185     }
186     return "";
187 }
188 
189 class subgroupsAPI {
190 public:
subgroupsAPI(cl_platform_id platform,bool use_core_subgroups)191     subgroupsAPI(cl_platform_id platform, bool use_core_subgroups)
192     {
193         static_assert(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE
194                           == CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR,
195                       "Enums have to be the same");
196         static_assert(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE
197                           == CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR,
198                       "Enums have to be the same");
199         if (use_core_subgroups)
200         {
201             _clGetKernelSubGroupInfo_ptr = &clGetKernelSubGroupInfo;
202             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfo";
203         }
204         else
205         {
206             _clGetKernelSubGroupInfo_ptr = (clGetKernelSubGroupInfoKHR_fn)
207                 clGetExtensionFunctionAddressForPlatform(
208                     platform, "clGetKernelSubGroupInfoKHR");
209             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfoKHR";
210         }
211     }
clGetKernelSubGroupInfo_ptr()212     clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr()
213     {
214         return _clGetKernelSubGroupInfo_ptr;
215     }
216     const char *clGetKernelSubGroupInfo_name;
217 
218 private:
219     clGetKernelSubGroupInfoKHR_fn _clGetKernelSubGroupInfo_ptr;
220 };
221 
222 // Need to defined custom type for vector size = 3 and half type. This is
223 // because of 3-component types are otherwise indistinguishable from the
224 // 4-component types, and because the half type is indistinguishable from some
225 // other 16-bit type (ushort)
226 namespace subgroups {
227 struct cl_char3
228 {
229     ::cl_char3 data;
230 };
231 struct cl_uchar3
232 {
233     ::cl_uchar3 data;
234 };
235 struct cl_short3
236 {
237     ::cl_short3 data;
238 };
239 struct cl_ushort3
240 {
241     ::cl_ushort3 data;
242 };
243 struct cl_int3
244 {
245     ::cl_int3 data;
246 };
247 struct cl_uint3
248 {
249     ::cl_uint3 data;
250 };
251 struct cl_long3
252 {
253     ::cl_long3 data;
254 };
255 struct cl_ulong3
256 {
257     ::cl_ulong3 data;
258 };
259 struct cl_float3
260 {
261     ::cl_float3 data;
262 };
263 struct cl_double3
264 {
265     ::cl_double3 data;
266 };
267 struct cl_half
268 {
269     ::cl_half data;
270 };
271 struct cl_half2
272 {
273     ::cl_half2 data;
274 };
275 struct cl_half3
276 {
277     ::cl_half3 data;
278 };
279 struct cl_half4
280 {
281     ::cl_half4 data;
282 };
283 struct cl_half8
284 {
285     ::cl_half8 data;
286 };
287 struct cl_half16
288 {
289     ::cl_half16 data;
290 };
291 }
292 
int64_ok(cl_device_id device)293 static bool int64_ok(cl_device_id device)
294 {
295     char profile[128];
296     int error;
297 
298     error = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile),
299                             (void *)&profile, NULL);
300     if (error)
301     {
302         log_info("clGetDeviceInfo failed with CL_DEVICE_PROFILE\n");
303         return false;
304     }
305 
306     if (strcmp(profile, "EMBEDDED_PROFILE") == 0)
307         return is_extension_available(device, "cles_khr_int64");
308 
309     return true;
310 }
311 
double_ok(cl_device_id device)312 static bool double_ok(cl_device_id device)
313 {
314     int error;
315     cl_device_fp_config c;
316     error = clGetDeviceInfo(device, CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(c),
317                             (void *)&c, NULL);
318     if (error)
319     {
320         log_info("clGetDeviceInfo failed with CL_DEVICE_DOUBLE_FP_CONFIG\n");
321         return false;
322     }
323     return c != 0;
324 }
325 
half_ok(cl_device_id device)326 static bool half_ok(cl_device_id device)
327 {
328     int error;
329     cl_device_fp_config c;
330     error = clGetDeviceInfo(device, CL_DEVICE_HALF_FP_CONFIG, sizeof(c),
331                             (void *)&c, NULL);
332     if (error)
333     {
334         log_info("clGetDeviceInfo failed with CL_DEVICE_HALF_FP_CONFIG\n");
335         return false;
336     }
337     return c != 0;
338 }
339 
340 template <typename Ty> struct CommonTypeManager
341 {
342 
nameCommonTypeManager343     static const char *name() { return ""; }
add_typedefCommonTypeManager344     static const char *add_typedef() { return "\n"; }
345     typedef std::false_type is_vector_type;
346     typedef std::false_type is_sb_vector_size3;
347     typedef std::false_type is_sb_vector_type;
348     typedef std::false_type is_sb_scalar_type;
type_supportedCommonTypeManager349     static const bool type_supported(cl_device_id) { return true; }
identify_limitsCommonTypeManager350     static const Ty identify_limits(ArithmeticOp operation)
351     {
352         switch (operation)
353         {
354             case ArithmeticOp::add_: return (Ty)0;
355             case ArithmeticOp::max_: return (std::numeric_limits<Ty>::min)();
356             case ArithmeticOp::min_: return (std::numeric_limits<Ty>::max)();
357             case ArithmeticOp::mul_: return (Ty)1;
358             case ArithmeticOp::and_: return (Ty)~0;
359             case ArithmeticOp::or_: return (Ty)0;
360             case ArithmeticOp::xor_: return (Ty)0;
361             default: log_error("Unknown operation request"); break;
362         }
363         return 0;
364     }
365 };
366 
367 template <typename> struct TypeManager;
368 
369 template <> struct TypeManager<cl_int> : public CommonTypeManager<cl_int>
370 {
371     static const char *name() { return "int"; }
372     static const char *add_typedef() { return "typedef int Type;\n"; }
373     static cl_int identify_limits(ArithmeticOp operation)
374     {
375         switch (operation)
376         {
377             case ArithmeticOp::add_: return (cl_int)0;
378             case ArithmeticOp::max_:
379                 return (std::numeric_limits<cl_int>::min)();
380             case ArithmeticOp::min_:
381                 return (std::numeric_limits<cl_int>::max)();
382             case ArithmeticOp::mul_: return (cl_int)1;
383             case ArithmeticOp::and_: return (cl_int)~0;
384             case ArithmeticOp::or_: return (cl_int)0;
385             case ArithmeticOp::xor_: return (cl_int)0;
386             case ArithmeticOp::logical_and: return (cl_int)1;
387             case ArithmeticOp::logical_or: return (cl_int)0;
388             case ArithmeticOp::logical_xor: return (cl_int)0;
389             default: log_error("Unknown operation request"); break;
390         }
391         return 0;
392     }
393 };
394 template <> struct TypeManager<cl_int2> : public CommonTypeManager<cl_int2>
395 {
396     static const char *name() { return "int2"; }
397     static const char *add_typedef() { return "typedef int2 Type;\n"; }
398     typedef std::true_type is_vector_type;
399     using scalar_type = cl_int;
400 };
401 template <>
402 struct TypeManager<subgroups::cl_int3>
403     : public CommonTypeManager<subgroups::cl_int3>
404 {
405     static const char *name() { return "int3"; }
406     static const char *add_typedef() { return "typedef int3 Type;\n"; }
407     typedef std::true_type is_sb_vector_size3;
408     using scalar_type = cl_int;
409 };
410 template <> struct TypeManager<cl_int4> : public CommonTypeManager<cl_int4>
411 {
412     static const char *name() { return "int4"; }
413     static const char *add_typedef() { return "typedef int4 Type;\n"; }
414     using scalar_type = cl_int;
415     typedef std::true_type is_vector_type;
416 };
417 template <> struct TypeManager<cl_int8> : public CommonTypeManager<cl_int8>
418 {
419     static const char *name() { return "int8"; }
420     static const char *add_typedef() { return "typedef int8 Type;\n"; }
421     using scalar_type = cl_int;
422     typedef std::true_type is_vector_type;
423 };
424 template <> struct TypeManager<cl_int16> : public CommonTypeManager<cl_int16>
425 {
426     static const char *name() { return "int16"; }
427     static const char *add_typedef() { return "typedef int16 Type;\n"; }
428     using scalar_type = cl_int;
429     typedef std::true_type is_vector_type;
430 };
431 // cl_uint
432 template <> struct TypeManager<cl_uint> : public CommonTypeManager<cl_uint>
433 {
434     static const char *name() { return "uint"; }
435     static const char *add_typedef() { return "typedef uint Type;\n"; }
436 };
437 template <> struct TypeManager<cl_uint2> : public CommonTypeManager<cl_uint2>
438 {
439     static const char *name() { return "uint2"; }
440     static const char *add_typedef() { return "typedef uint2 Type;\n"; }
441     using scalar_type = cl_uint;
442     typedef std::true_type is_vector_type;
443 };
444 template <>
445 struct TypeManager<subgroups::cl_uint3>
446     : public CommonTypeManager<subgroups::cl_uint3>
447 {
448     static const char *name() { return "uint3"; }
449     static const char *add_typedef() { return "typedef uint3 Type;\n"; }
450     typedef std::true_type is_sb_vector_size3;
451     using scalar_type = cl_uint;
452 };
453 template <> struct TypeManager<cl_uint4> : public CommonTypeManager<cl_uint4>
454 {
455     static const char *name() { return "uint4"; }
456     static const char *add_typedef() { return "typedef uint4 Type;\n"; }
457     using scalar_type = cl_uint;
458     typedef std::true_type is_vector_type;
459 };
460 template <> struct TypeManager<cl_uint8> : public CommonTypeManager<cl_uint8>
461 {
462     static const char *name() { return "uint8"; }
463     static const char *add_typedef() { return "typedef uint8 Type;\n"; }
464     using scalar_type = cl_uint;
465     typedef std::true_type is_vector_type;
466 };
467 template <> struct TypeManager<cl_uint16> : public CommonTypeManager<cl_uint16>
468 {
469     static const char *name() { return "uint16"; }
470     static const char *add_typedef() { return "typedef uint16 Type;\n"; }
471     using scalar_type = cl_uint;
472     typedef std::true_type is_vector_type;
473 };
474 // cl_short
475 template <> struct TypeManager<cl_short> : public CommonTypeManager<cl_short>
476 {
477     static const char *name() { return "short"; }
478     static const char *add_typedef() { return "typedef short Type;\n"; }
479 };
480 template <> struct TypeManager<cl_short2> : public CommonTypeManager<cl_short2>
481 {
482     static const char *name() { return "short2"; }
483     static const char *add_typedef() { return "typedef short2 Type;\n"; }
484     using scalar_type = cl_short;
485     typedef std::true_type is_vector_type;
486 };
487 template <>
488 struct TypeManager<subgroups::cl_short3>
489     : public CommonTypeManager<subgroups::cl_short3>
490 {
491     static const char *name() { return "short3"; }
492     static const char *add_typedef() { return "typedef short3 Type;\n"; }
493     typedef std::true_type is_sb_vector_size3;
494     using scalar_type = cl_short;
495 };
496 template <> struct TypeManager<cl_short4> : public CommonTypeManager<cl_short4>
497 {
498     static const char *name() { return "short4"; }
499     static const char *add_typedef() { return "typedef short4 Type;\n"; }
500     using scalar_type = cl_short;
501     typedef std::true_type is_vector_type;
502 };
503 template <> struct TypeManager<cl_short8> : public CommonTypeManager<cl_short8>
504 {
505     static const char *name() { return "short8"; }
506     static const char *add_typedef() { return "typedef short8 Type;\n"; }
507     using scalar_type = cl_short;
508     typedef std::true_type is_vector_type;
509 };
510 template <>
511 struct TypeManager<cl_short16> : public CommonTypeManager<cl_short16>
512 {
513     static const char *name() { return "short16"; }
514     static const char *add_typedef() { return "typedef short16 Type;\n"; }
515     using scalar_type = cl_short;
516     typedef std::true_type is_vector_type;
517 };
518 // cl_ushort
519 template <> struct TypeManager<cl_ushort> : public CommonTypeManager<cl_ushort>
520 {
521     static const char *name() { return "ushort"; }
522     static const char *add_typedef() { return "typedef ushort Type;\n"; }
523 };
524 template <>
525 struct TypeManager<cl_ushort2> : public CommonTypeManager<cl_ushort2>
526 {
527     static const char *name() { return "ushort2"; }
528     static const char *add_typedef() { return "typedef ushort2 Type;\n"; }
529     using scalar_type = cl_ushort;
530     typedef std::true_type is_vector_type;
531 };
532 template <>
533 struct TypeManager<subgroups::cl_ushort3>
534     : public CommonTypeManager<subgroups::cl_ushort3>
535 {
536     static const char *name() { return "ushort3"; }
537     static const char *add_typedef() { return "typedef ushort3 Type;\n"; }
538     typedef std::true_type is_sb_vector_size3;
539     using scalar_type = cl_ushort;
540 };
541 template <>
542 struct TypeManager<cl_ushort4> : public CommonTypeManager<cl_ushort4>
543 {
544     static const char *name() { return "ushort4"; }
545     static const char *add_typedef() { return "typedef ushort4 Type;\n"; }
546     using scalar_type = cl_ushort;
547     typedef std::true_type is_vector_type;
548 };
549 template <>
550 struct TypeManager<cl_ushort8> : public CommonTypeManager<cl_ushort8>
551 {
552     static const char *name() { return "ushort8"; }
553     static const char *add_typedef() { return "typedef ushort8 Type;\n"; }
554     using scalar_type = cl_ushort;
555     typedef std::true_type is_vector_type;
556 };
557 template <>
558 struct TypeManager<cl_ushort16> : public CommonTypeManager<cl_ushort16>
559 {
560     static const char *name() { return "ushort16"; }
561     static const char *add_typedef() { return "typedef ushort16 Type;\n"; }
562     using scalar_type = cl_ushort;
563     typedef std::true_type is_vector_type;
564 };
565 // cl_char
566 template <> struct TypeManager<cl_char> : public CommonTypeManager<cl_char>
567 {
568     static const char *name() { return "char"; }
569     static const char *add_typedef() { return "typedef char Type;\n"; }
570 };
571 template <> struct TypeManager<cl_char2> : public CommonTypeManager<cl_char2>
572 {
573     static const char *name() { return "char2"; }
574     static const char *add_typedef() { return "typedef char2 Type;\n"; }
575     using scalar_type = cl_char;
576     typedef std::true_type is_vector_type;
577 };
578 template <>
579 struct TypeManager<subgroups::cl_char3>
580     : public CommonTypeManager<subgroups::cl_char3>
581 {
582     static const char *name() { return "char3"; }
583     static const char *add_typedef() { return "typedef char3 Type;\n"; }
584     typedef std::true_type is_sb_vector_size3;
585     using scalar_type = cl_char;
586 };
587 template <> struct TypeManager<cl_char4> : public CommonTypeManager<cl_char4>
588 {
589     static const char *name() { return "char4"; }
590     static const char *add_typedef() { return "typedef char4 Type;\n"; }
591     using scalar_type = cl_char;
592     typedef std::true_type is_vector_type;
593 };
594 template <> struct TypeManager<cl_char8> : public CommonTypeManager<cl_char8>
595 {
596     static const char *name() { return "char8"; }
597     static const char *add_typedef() { return "typedef char8 Type;\n"; }
598     using scalar_type = cl_char;
599     typedef std::true_type is_vector_type;
600 };
601 template <> struct TypeManager<cl_char16> : public CommonTypeManager<cl_char16>
602 {
603     static const char *name() { return "char16"; }
604     static const char *add_typedef() { return "typedef char16 Type;\n"; }
605     using scalar_type = cl_char;
606     typedef std::true_type is_vector_type;
607 };
608 // cl_uchar
609 template <> struct TypeManager<cl_uchar> : public CommonTypeManager<cl_uchar>
610 {
611     static const char *name() { return "uchar"; }
612     static const char *add_typedef() { return "typedef uchar Type;\n"; }
613 };
614 template <> struct TypeManager<cl_uchar2> : public CommonTypeManager<cl_uchar2>
615 {
616     static const char *name() { return "uchar2"; }
617     static const char *add_typedef() { return "typedef uchar2 Type;\n"; }
618     using scalar_type = cl_uchar;
619     typedef std::true_type is_vector_type;
620 };
621 template <>
622 struct TypeManager<subgroups::cl_uchar3>
623     : public CommonTypeManager<subgroups::cl_char3>
624 {
625     static const char *name() { return "uchar3"; }
626     static const char *add_typedef() { return "typedef uchar3 Type;\n"; }
627     typedef std::true_type is_sb_vector_size3;
628     using scalar_type = cl_uchar;
629 };
630 template <> struct TypeManager<cl_uchar4> : public CommonTypeManager<cl_uchar4>
631 {
632     static const char *name() { return "uchar4"; }
633     static const char *add_typedef() { return "typedef uchar4 Type;\n"; }
634     using scalar_type = cl_uchar;
635     typedef std::true_type is_vector_type;
636 };
637 template <> struct TypeManager<cl_uchar8> : public CommonTypeManager<cl_uchar8>
638 {
639     static const char *name() { return "uchar8"; }
640     static const char *add_typedef() { return "typedef uchar8 Type;\n"; }
641     using scalar_type = cl_uchar;
642     typedef std::true_type is_vector_type;
643 };
644 template <>
645 struct TypeManager<cl_uchar16> : public CommonTypeManager<cl_uchar16>
646 {
647     static const char *name() { return "uchar16"; }
648     static const char *add_typedef() { return "typedef uchar16 Type;\n"; }
649     using scalar_type = cl_uchar;
650     typedef std::true_type is_vector_type;
651 };
652 // cl_long
653 template <> struct TypeManager<cl_long> : public CommonTypeManager<cl_long>
654 {
655     static const char *name() { return "long"; }
656     static const char *add_typedef() { return "typedef long Type;\n"; }
657     static const bool type_supported(cl_device_id device)
658     {
659         return int64_ok(device);
660     }
661 };
662 template <> struct TypeManager<cl_long2> : public CommonTypeManager<cl_long2>
663 {
664     static const char *name() { return "long2"; }
665     static const char *add_typedef() { return "typedef long2 Type;\n"; }
666     using scalar_type = cl_long;
667     typedef std::true_type is_vector_type;
668     static const bool type_supported(cl_device_id device)
669     {
670         return int64_ok(device);
671     }
672 };
673 template <>
674 struct TypeManager<subgroups::cl_long3>
675     : public CommonTypeManager<subgroups::cl_long3>
676 {
677     static const char *name() { return "long3"; }
678     static const char *add_typedef() { return "typedef long3 Type;\n"; }
679     typedef std::true_type is_sb_vector_size3;
680     using scalar_type = cl_long;
681     static const bool type_supported(cl_device_id device)
682     {
683         return int64_ok(device);
684     }
685 };
686 template <> struct TypeManager<cl_long4> : public CommonTypeManager<cl_long4>
687 {
688     static const char *name() { return "long4"; }
689     static const char *add_typedef() { return "typedef long4 Type;\n"; }
690     using scalar_type = cl_long;
691     typedef std::true_type is_vector_type;
692     static const bool type_supported(cl_device_id device)
693     {
694         return int64_ok(device);
695     }
696 };
697 template <> struct TypeManager<cl_long8> : public CommonTypeManager<cl_long8>
698 {
699     static const char *name() { return "long8"; }
700     static const char *add_typedef() { return "typedef long8 Type;\n"; }
701     using scalar_type = cl_long;
702     typedef std::true_type is_vector_type;
703     static const bool type_supported(cl_device_id device)
704     {
705         return int64_ok(device);
706     }
707 };
708 template <> struct TypeManager<cl_long16> : public CommonTypeManager<cl_long16>
709 {
710     static const char *name() { return "long16"; }
711     static const char *add_typedef() { return "typedef long16 Type;\n"; }
712     using scalar_type = cl_long;
713     typedef std::true_type is_vector_type;
714     static const bool type_supported(cl_device_id device)
715     {
716         return int64_ok(device);
717     }
718 };
719 // cl_ulong
720 template <> struct TypeManager<cl_ulong> : public CommonTypeManager<cl_ulong>
721 {
722     static const char *name() { return "ulong"; }
723     static const char *add_typedef() { return "typedef ulong Type;\n"; }
724     static const bool type_supported(cl_device_id device)
725     {
726         return int64_ok(device);
727     }
728 };
729 template <> struct TypeManager<cl_ulong2> : public CommonTypeManager<cl_ulong2>
730 {
731     static const char *name() { return "ulong2"; }
732     static const char *add_typedef() { return "typedef ulong2 Type;\n"; }
733     using scalar_type = cl_ulong;
734     typedef std::true_type is_vector_type;
735     static const bool type_supported(cl_device_id device)
736     {
737         return int64_ok(device);
738     }
739 };
740 template <>
741 struct TypeManager<subgroups::cl_ulong3>
742     : public CommonTypeManager<subgroups::cl_ulong3>
743 {
744     static const char *name() { return "ulong3"; }
745     static const char *add_typedef() { return "typedef ulong3 Type;\n"; }
746     typedef std::true_type is_sb_vector_size3;
747     using scalar_type = cl_ulong;
748     static const bool type_supported(cl_device_id device)
749     {
750         return int64_ok(device);
751     }
752 };
753 template <> struct TypeManager<cl_ulong4> : public CommonTypeManager<cl_ulong4>
754 {
755     static const char *name() { return "ulong4"; }
756     static const char *add_typedef() { return "typedef ulong4 Type;\n"; }
757     using scalar_type = cl_ulong;
758     typedef std::true_type is_vector_type;
759     static const bool type_supported(cl_device_id device)
760     {
761         return int64_ok(device);
762     }
763 };
764 template <> struct TypeManager<cl_ulong8> : public CommonTypeManager<cl_ulong8>
765 {
766     static const char *name() { return "ulong8"; }
767     static const char *add_typedef() { return "typedef ulong8 Type;\n"; }
768     using scalar_type = cl_ulong;
769     typedef std::true_type is_vector_type;
770     static const bool type_supported(cl_device_id device)
771     {
772         return int64_ok(device);
773     }
774 };
775 template <>
776 struct TypeManager<cl_ulong16> : public CommonTypeManager<cl_ulong16>
777 {
778     static const char *name() { return "ulong16"; }
779     static const char *add_typedef() { return "typedef ulong16 Type;\n"; }
780     using scalar_type = cl_ulong;
781     typedef std::true_type is_vector_type;
782     static const bool type_supported(cl_device_id device)
783     {
784         return int64_ok(device);
785     }
786 };
787 
788 // cl_float
789 template <> struct TypeManager<cl_float> : public CommonTypeManager<cl_float>
790 {
791     static const char *name() { return "float"; }
792     static const char *add_typedef() { return "typedef float Type;\n"; }
793     static cl_float identify_limits(ArithmeticOp operation)
794     {
795         switch (operation)
796         {
797             case ArithmeticOp::add_: return 0.0f;
798             case ArithmeticOp::max_:
799                 return -std::numeric_limits<float>::infinity();
800             case ArithmeticOp::min_:
801                 return std::numeric_limits<float>::infinity();
802             case ArithmeticOp::mul_: return (cl_float)1;
803             default: log_error("Unknown operation request"); break;
804         }
805         return 0;
806     }
807 };
808 template <> struct TypeManager<cl_float2> : public CommonTypeManager<cl_float2>
809 {
810     static const char *name() { return "float2"; }
811     static const char *add_typedef() { return "typedef float2 Type;\n"; }
812     using scalar_type = cl_float;
813     typedef std::true_type is_vector_type;
814 };
815 template <>
816 struct TypeManager<subgroups::cl_float3>
817     : public CommonTypeManager<subgroups::cl_float3>
818 {
819     static const char *name() { return "float3"; }
820     static const char *add_typedef() { return "typedef float3 Type;\n"; }
821     typedef std::true_type is_sb_vector_size3;
822     using scalar_type = cl_float;
823 };
824 template <> struct TypeManager<cl_float4> : public CommonTypeManager<cl_float4>
825 {
826     static const char *name() { return "float4"; }
827     static const char *add_typedef() { return "typedef float4 Type;\n"; }
828     using scalar_type = cl_float;
829     typedef std::true_type is_vector_type;
830 };
831 template <> struct TypeManager<cl_float8> : public CommonTypeManager<cl_float8>
832 {
833     static const char *name() { return "float8"; }
834     static const char *add_typedef() { return "typedef float8 Type;\n"; }
835     using scalar_type = cl_float;
836     typedef std::true_type is_vector_type;
837 };
838 template <>
839 struct TypeManager<cl_float16> : public CommonTypeManager<cl_float16>
840 {
841     static const char *name() { return "float16"; }
842     static const char *add_typedef() { return "typedef float16 Type;\n"; }
843     using scalar_type = cl_float;
844     typedef std::true_type is_vector_type;
845 };
846 
847 // cl_double
848 template <> struct TypeManager<cl_double> : public CommonTypeManager<cl_double>
849 {
850     static const char *name() { return "double"; }
851     static const char *add_typedef() { return "typedef double Type;\n"; }
852     static cl_double identify_limits(ArithmeticOp operation)
853     {
854         switch (operation)
855         {
856             case ArithmeticOp::add_: return 0.0;
857             case ArithmeticOp::max_:
858                 return -std::numeric_limits<double>::infinity();
859             case ArithmeticOp::min_:
860                 return std::numeric_limits<double>::infinity();
861             case ArithmeticOp::mul_: return (cl_double)1;
862             default: log_error("Unknown operation request"); break;
863         }
864         return 0;
865     }
866     static const bool type_supported(cl_device_id device)
867     {
868         return double_ok(device);
869     }
870 };
871 template <>
872 struct TypeManager<cl_double2> : public CommonTypeManager<cl_double2>
873 {
874     static const char *name() { return "double2"; }
875     static const char *add_typedef() { return "typedef double2 Type;\n"; }
876     using scalar_type = cl_double;
877     typedef std::true_type is_vector_type;
878     static const bool type_supported(cl_device_id device)
879     {
880         return double_ok(device);
881     }
882 };
883 template <>
884 struct TypeManager<subgroups::cl_double3>
885     : public CommonTypeManager<subgroups::cl_double3>
886 {
887     static const char *name() { return "double3"; }
888     static const char *add_typedef() { return "typedef double3 Type;\n"; }
889     typedef std::true_type is_sb_vector_size3;
890     using scalar_type = cl_double;
891     static const bool type_supported(cl_device_id device)
892     {
893         return double_ok(device);
894     }
895 };
896 template <>
897 struct TypeManager<cl_double4> : public CommonTypeManager<cl_double4>
898 {
899     static const char *name() { return "double4"; }
900     static const char *add_typedef() { return "typedef double4 Type;\n"; }
901     using scalar_type = cl_double;
902     typedef std::true_type is_vector_type;
903     static const bool type_supported(cl_device_id device)
904     {
905         return double_ok(device);
906     }
907 };
908 template <>
909 struct TypeManager<cl_double8> : public CommonTypeManager<cl_double8>
910 {
911     static const char *name() { return "double8"; }
912     static const char *add_typedef() { return "typedef double8 Type;\n"; }
913     using scalar_type = cl_double;
914     typedef std::true_type is_vector_type;
915     static const bool type_supported(cl_device_id device)
916     {
917         return double_ok(device);
918     }
919 };
920 template <>
921 struct TypeManager<cl_double16> : public CommonTypeManager<cl_double16>
922 {
923     static const char *name() { return "double16"; }
924     static const char *add_typedef() { return "typedef double16 Type;\n"; }
925     using scalar_type = cl_double;
926     typedef std::true_type is_vector_type;
927     static const bool type_supported(cl_device_id device)
928     {
929         return double_ok(device);
930     }
931 };
932 
933 // cl_half
934 template <>
935 struct TypeManager<subgroups::cl_half>
936     : public CommonTypeManager<subgroups::cl_half>
937 {
938     static const char *name() { return "half"; }
939     static const char *add_typedef() { return "typedef half Type;\n"; }
940     typedef std::true_type is_sb_scalar_type;
941     static subgroups::cl_half identify_limits(ArithmeticOp operation)
942     {
943         switch (operation)
944         {
945             case ArithmeticOp::add_: return { 0x0000 };
946             case ArithmeticOp::max_: return { 0xfc00 };
947             case ArithmeticOp::min_: return { 0x7c00 };
948             case ArithmeticOp::mul_: return { 0x3c00 };
949             default: log_error("Unknown operation request"); break;
950         }
951         return { 0 };
952     }
953     static const bool type_supported(cl_device_id device)
954     {
955         return half_ok(device);
956     }
957 };
958 template <>
959 struct TypeManager<subgroups::cl_half2>
960     : public CommonTypeManager<subgroups::cl_half2>
961 {
962     static const char *name() { return "half2"; }
963     static const char *add_typedef() { return "typedef half2 Type;\n"; }
964     using scalar_type = subgroups::cl_half;
965     typedef std::true_type is_sb_vector_type;
966     static const bool type_supported(cl_device_id device)
967     {
968         return half_ok(device);
969     }
970 };
971 template <>
972 struct TypeManager<subgroups::cl_half3>
973     : public CommonTypeManager<subgroups::cl_half3>
974 {
975     static const char *name() { return "half3"; }
976     static const char *add_typedef() { return "typedef half3 Type;\n"; }
977     typedef std::true_type is_sb_vector_size3;
978     using scalar_type = subgroups::cl_half;
979 
980     static const bool type_supported(cl_device_id device)
981     {
982         return half_ok(device);
983     }
984 };
985 template <>
986 struct TypeManager<subgroups::cl_half4>
987     : public CommonTypeManager<subgroups::cl_half4>
988 {
989     static const char *name() { return "half4"; }
990     static const char *add_typedef() { return "typedef half4 Type;\n"; }
991     using scalar_type = subgroups::cl_half;
992     typedef std::true_type is_sb_vector_type;
993     static const bool type_supported(cl_device_id device)
994     {
995         return half_ok(device);
996     }
997 };
998 template <>
999 struct TypeManager<subgroups::cl_half8>
1000     : public CommonTypeManager<subgroups::cl_half8>
1001 {
1002     static const char *name() { return "half8"; }
1003     static const char *add_typedef() { return "typedef half8 Type;\n"; }
1004     using scalar_type = subgroups::cl_half;
1005     typedef std::true_type is_sb_vector_type;
1006 
1007     static const bool type_supported(cl_device_id device)
1008     {
1009         return half_ok(device);
1010     }
1011 };
1012 template <>
1013 struct TypeManager<subgroups::cl_half16>
1014     : public CommonTypeManager<subgroups::cl_half16>
1015 {
1016     static const char *name() { return "half16"; }
1017     static const char *add_typedef() { return "typedef half16 Type;\n"; }
1018     using scalar_type = subgroups::cl_half;
1019     typedef std::true_type is_sb_vector_type;
1020     static const bool type_supported(cl_device_id device)
1021     {
1022         return half_ok(device);
1023     }
1024 };
1025 
1026 // set scalar value to vector of halfs
1027 template <typename Ty, int N = 0>
1028 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value>::type
1029 set_value(Ty &lhs, const cl_ulong &rhs)
1030 {
1031     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1032     for (auto i = 0; i < size; ++i)
1033     {
1034         lhs.data.s[i] = rhs;
1035     }
1036 }
1037 
1038 
1039 // set scalar value to vector
1040 template <typename Ty>
1041 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1042 set_value(Ty &lhs, const cl_ulong &rhs)
1043 {
1044     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1045     for (auto i = 0; i < size; ++i)
1046     {
1047         lhs.s[i] = rhs;
1048     }
1049 }
1050 
1051 // set vector to vector value
1052 template <typename Ty>
1053 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1054 set_value(Ty &lhs, const Ty &rhs)
1055 {
1056     lhs = rhs;
1057 }
1058 
1059 // set scalar value to vector size 3
1060 template <typename Ty, int N = 0>
1061 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value>::type
1062 set_value(Ty &lhs, const cl_ulong &rhs)
1063 {
1064     for (auto i = 0; i < 3; ++i)
1065     {
1066         lhs.data.s[i] = rhs;
1067     }
1068 }
1069 
1070 // set scalar value to scalar
1071 template <typename Ty>
1072 typename std::enable_if<std::is_scalar<Ty>::value>::type
1073 set_value(Ty &lhs, const cl_ulong &rhs)
1074 {
1075     lhs = static_cast<Ty>(rhs);
1076 }
1077 
1078 // set scalar value to half scalar
1079 template <typename Ty>
1080 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value>::type
1081 set_value(Ty &lhs, const cl_ulong &rhs)
1082 {
1083     lhs.data = rhs;
1084 }
1085 
1086 // compare for common vectors
1087 template <typename Ty>
1088 typename std::enable_if<TypeManager<Ty>::is_vector_type::value, bool>::type
1089 compare(const Ty &lhs, const Ty &rhs)
1090 {
1091     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1092     for (auto i = 0; i < size; ++i)
1093     {
1094         if (lhs.s[i] != rhs.s[i])
1095         {
1096             return false;
1097         }
1098     }
1099     return true;
1100 }
1101 
1102 // compare for vectors 3
1103 template <typename Ty>
1104 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value, bool>::type
1105 compare(const Ty &lhs, const Ty &rhs)
1106 {
1107     for (auto i = 0; i < 3; ++i)
1108     {
1109         if (lhs.data.s[i] != rhs.data.s[i])
1110         {
1111             return false;
1112         }
1113     }
1114     return true;
1115 }
1116 
1117 // compare for half vectors
1118 template <typename Ty>
1119 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value, bool>::type
1120 compare(const Ty &lhs, const Ty &rhs)
1121 {
1122     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1123     for (auto i = 0; i < size; ++i)
1124     {
1125         if (lhs.data.s[i] != rhs.data.s[i])
1126         {
1127             return false;
1128         }
1129     }
1130     return true;
1131 }
1132 
1133 // compare for scalars
1134 template <typename Ty>
1135 typename std::enable_if<std::is_scalar<Ty>::value, bool>::type
1136 compare(const Ty &lhs, const Ty &rhs)
1137 {
1138     return lhs == rhs;
1139 }
1140 
1141 // compare for scalar halfs
1142 template <typename Ty>
1143 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value, bool>::type
1144 compare(const Ty &lhs, const Ty &rhs)
1145 {
1146     return lhs.data == rhs.data;
1147 }
1148 
1149 template <typename Ty> inline bool compare_ordered(const Ty &lhs, const Ty &rhs)
1150 {
1151     return lhs == rhs;
1152 }
1153 
1154 template <>
1155 inline bool compare_ordered(const subgroups::cl_half &lhs,
1156                             const subgroups::cl_half &rhs)
1157 {
1158     return cl_half_to_float(lhs.data) == cl_half_to_float(rhs.data);
1159 }
1160 
1161 template <typename Ty>
1162 inline bool compare_ordered(const subgroups::cl_half &lhs, const int &rhs)
1163 {
1164     return cl_half_to_float(lhs.data) == rhs;
1165 }
1166 
1167 // Run a test kernel to compute the result of a built-in on an input
1168 static int run_kernel(cl_context context, cl_command_queue queue,
1169                       cl_kernel kernel, size_t global, size_t local,
1170                       void *idata, size_t isize, void *mdata, size_t msize,
1171                       void *odata, size_t osize, size_t tsize = 0)
1172 {
1173     clMemWrapper in;
1174     clMemWrapper xy;
1175     clMemWrapper out;
1176     clMemWrapper tmp;
1177     int error;
1178 
1179     in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
1180     test_error(error, "clCreateBuffer failed");
1181 
1182     xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
1183     test_error(error, "clCreateBuffer failed");
1184 
1185     out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
1186     test_error(error, "clCreateBuffer failed");
1187 
1188     if (tsize)
1189     {
1190         tmp = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
1191                              tsize, NULL, &error);
1192         test_error(error, "clCreateBuffer failed");
1193     }
1194 
1195     error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
1196     test_error(error, "clSetKernelArg failed");
1197 
1198     error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
1199     test_error(error, "clSetKernelArg failed");
1200 
1201     error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
1202     test_error(error, "clSetKernelArg failed");
1203 
1204     if (tsize)
1205     {
1206         error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
1207         test_error(error, "clSetKernelArg failed");
1208     }
1209 
1210     error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0, NULL,
1211                                  NULL);
1212     test_error(error, "clEnqueueWriteBuffer failed");
1213 
1214     error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0, NULL,
1215                                  NULL);
1216     test_error(error, "clEnqueueWriteBuffer failed");
1217     error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 0,
1218                                    NULL, NULL);
1219     test_error(error, "clEnqueueNDRangeKernel failed");
1220 
1221     error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0, NULL,
1222                                 NULL);
1223     test_error(error, "clEnqueueReadBuffer failed");
1224 
1225     error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0, NULL,
1226                                 NULL);
1227     test_error(error, "clEnqueueReadBuffer failed");
1228 
1229     error = clFinish(queue);
1230     test_error(error, "clFinish failed");
1231 
1232     return error;
1233 }
1234 
1235 // Driver for testing a single built in function
1236 template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
1237 {
1238     static int mrun(cl_device_id device, cl_context context,
1239                     cl_command_queue queue, int num_elements, const char *kname,
1240                     const char *src, WorkGroupParams test_params)
1241     {
1242         int error = TEST_PASS;
1243         for (auto &mask : test_params.all_work_item_masks)
1244         {
1245             test_params.work_items_mask = mask;
1246             error |= run(device, context, queue, num_elements, kname, src,
1247                          test_params);
1248         }
1249         return error;
1250     };
1251     static int run(cl_device_id device, cl_context context,
1252                    cl_command_queue queue, int num_elements, const char *kname,
1253                    const char *src, WorkGroupParams test_params)
1254     {
1255         size_t tmp;
1256         int error;
1257         int subgroup_size, num_subgroups;
1258         size_t realSize;
1259         size_t global = test_params.global_workgroup_size;
1260         size_t local = test_params.local_workgroup_size;
1261         clProgramWrapper program;
1262         clKernelWrapper kernel;
1263         cl_platform_id platform;
1264         std::vector<cl_int> sgmap;
1265         sgmap.resize(4 * global);
1266         std::vector<Ty> mapin;
1267         mapin.resize(local);
1268         std::vector<Ty> mapout;
1269         mapout.resize(local);
1270         std::stringstream kernel_sstr;
1271         if (test_params.work_items_mask != 0)
1272         {
1273             kernel_sstr << "#define WORK_ITEMS_MASK ";
1274             kernel_sstr << "0x" << std::hex << test_params.work_items_mask
1275                         << "\n";
1276         }
1277 
1278 
1279         kernel_sstr << "#define NR_OF_ACTIVE_WORK_ITEMS ";
1280         kernel_sstr << NR_OF_ACTIVE_WORK_ITEMS << "\n";
1281         // Make sure a test of type Ty is supported by the device
1282         if (!TypeManager<Ty>::type_supported(device))
1283         {
1284             log_info("Data type not supported : %s\n", TypeManager<Ty>::name());
1285             return 0;
1286         }
1287         else
1288         {
1289             if (strstr(TypeManager<Ty>::name(), "double"))
1290             {
1291                 kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n";
1292             }
1293             else if (strstr(TypeManager<Ty>::name(), "half"))
1294             {
1295                 kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp16: enable\n";
1296             }
1297         }
1298 
1299         for (std::string extension : test_params.required_extensions)
1300         {
1301             if (!is_extension_available(device, extension.c_str()))
1302             {
1303                 log_info("The extension %s not supported on this device. SKIP "
1304                          "testing - kernel %s data type %s\n",
1305                          extension.c_str(), kname, TypeManager<Ty>::name());
1306                 return TEST_PASS;
1307             }
1308             kernel_sstr << "#pragma OPENCL EXTENSION " + extension
1309                     + ": enable\n";
1310         }
1311 
1312         error = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(platform),
1313                                 (void *)&platform, NULL);
1314         test_error(error, "clGetDeviceInfo failed for CL_DEVICE_PLATFORM");
1315         if (test_params.use_core_subgroups)
1316         {
1317             kernel_sstr
1318                 << "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
1319         }
1320         kernel_sstr << "#define XY(M,I) M[I].x = get_sub_group_local_id(); "
1321                        "M[I].y = get_sub_group_id();\n";
1322         kernel_sstr << TypeManager<Ty>::add_typedef();
1323         kernel_sstr << src;
1324         const std::string &kernel_str = kernel_sstr.str();
1325         const char *kernel_src = kernel_str.c_str();
1326 
1327         error = create_single_kernel_helper(context, &program, &kernel, 1,
1328                                             &kernel_src, kname);
1329         if (error != 0) return error;
1330 
1331         // Determine some local dimensions to use for the test.
1332         error = get_max_common_work_group_size(
1333             context, kernel, test_params.global_workgroup_size, &local);
1334         test_error(error, "get_max_common_work_group_size failed");
1335 
1336         // Limit it a bit so we have muliple work groups
1337         // Ideally this will still be large enough to give us multiple
1338         if (local > test_params.local_workgroup_size)
1339             local = test_params.local_workgroup_size;
1340 
1341 
1342         // Get the sub group info
1343         subgroupsAPI subgroupsApiSet(platform, test_params.use_core_subgroups);
1344         clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr =
1345             subgroupsApiSet.clGetKernelSubGroupInfo_ptr();
1346         if (clGetKernelSubGroupInfo_ptr == NULL)
1347         {
1348             log_error("ERROR: %s function not available",
1349                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1350             return TEST_FAIL;
1351         }
1352         error = clGetKernelSubGroupInfo_ptr(
1353             kernel, device, CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
1354             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1355         if (error != CL_SUCCESS)
1356         {
1357             log_error("ERROR: %s function error for "
1358                       "CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE",
1359                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1360             return TEST_FAIL;
1361         }
1362 
1363         subgroup_size = (int)tmp;
1364 
1365         error = clGetKernelSubGroupInfo_ptr(
1366             kernel, device, CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE,
1367             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1368         if (error != CL_SUCCESS)
1369         {
1370             log_error("ERROR: %s function error for "
1371                       "CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE",
1372                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1373             return TEST_FAIL;
1374         }
1375 
1376         num_subgroups = (int)tmp;
1377         // Make sure the number of sub groups is what we expect
1378         if (num_subgroups != (local + subgroup_size - 1) / subgroup_size)
1379         {
1380             log_error("ERROR: unexpected number of subgroups (%d) returned\n",
1381                       num_subgroups);
1382             return TEST_FAIL;
1383         }
1384 
1385         std::vector<Ty> idata;
1386         std::vector<Ty> odata;
1387         size_t input_array_size = global;
1388         size_t output_array_size = global;
1389         int dynscl = test_params.dynsc;
1390 
1391         if (dynscl != 0)
1392         {
1393             input_array_size =
1394                 (int)global / (int)local * num_subgroups * dynscl;
1395             output_array_size = (int)global / (int)local * dynscl;
1396         }
1397 
1398         idata.resize(input_array_size);
1399         odata.resize(output_array_size);
1400 
1401         // Run the kernel once on zeroes to get the map
1402         memset(idata.data(), 0, input_array_size * sizeof(Ty));
1403         error = run_kernel(context, queue, kernel, global, local, idata.data(),
1404                            input_array_size * sizeof(Ty), sgmap.data(),
1405                            global * sizeof(cl_int4), odata.data(),
1406                            output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
1407         test_error(error, "Running kernel first time failed");
1408 
1409         // Generate the desired input for the kernel
1410 
1411         test_params.subgroup_size = subgroup_size;
1412         Fns::gen(idata.data(), mapin.data(), sgmap.data(), test_params);
1413         error = run_kernel(context, queue, kernel, global, local, idata.data(),
1414                            input_array_size * sizeof(Ty), sgmap.data(),
1415                            global * sizeof(cl_int4), odata.data(),
1416                            output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
1417         test_error(error, "Running kernel second time failed");
1418 
1419         // Check the result
1420         error = Fns::chk(idata.data(), odata.data(), mapin.data(),
1421                          mapout.data(), sgmap.data(), test_params);
1422         test_error(error, "Data verification failed");
1423         return TEST_PASS;
1424     }
1425 };
1426 
1427 static void set_last_workgroup_params(int non_uniform_size,
1428                                       int &number_of_subgroups,
1429                                       int subgroup_size, int &workgroup_size,
1430                                       int &last_subgroup_size)
1431 {
1432     number_of_subgroups = 1 + non_uniform_size / subgroup_size;
1433     last_subgroup_size = non_uniform_size % subgroup_size;
1434     workgroup_size = non_uniform_size;
1435 }
1436 
1437 template <typename Ty>
1438 static void set_randomdata_for_subgroup(Ty *workgroup, int wg_offset,
1439                                         int current_sbs)
1440 {
1441     int randomize_data = (int)(genrand_int32(gMTdata) % 3);
1442     // Initialize data matrix indexed by local id and sub group id
1443     switch (randomize_data)
1444     {
1445         case 0:
1446             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1447             break;
1448         case 1: {
1449             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1450             int wi_id = (int)(genrand_int32(gMTdata) % (cl_uint)current_sbs);
1451             set_value(workgroup[wg_offset + wi_id], 41);
1452         }
1453         break;
1454         case 2:
1455             memset(&workgroup[wg_offset], 0xff, current_sbs * sizeof(Ty));
1456             break;
1457     }
1458 }
1459 
1460 struct RunTestForType
1461 {
1462     RunTestForType(cl_device_id device, cl_context context,
1463                    cl_command_queue queue, int num_elements,
1464                    WorkGroupParams test_params)
1465         : device_(device), context_(context), queue_(queue),
1466           num_elements_(num_elements), test_params_(test_params)
1467     {}
1468     template <typename T, typename U>
1469     int run_impl(const char *kernel_name, const char *source)
1470     {
1471         int error = TEST_PASS;
1472         if (test_params_.all_work_item_masks.size() > 0)
1473         {
1474             error = test<T, U>::mrun(device_, context_, queue_, num_elements_,
1475                                      kernel_name, source, test_params_);
1476         }
1477         else
1478         {
1479             error = test<T, U>::run(device_, context_, queue_, num_elements_,
1480                                     kernel_name, source, test_params_);
1481         }
1482 
1483         return error;
1484     }
1485 
1486 private:
1487     cl_device_id device_;
1488     cl_context context_;
1489     cl_command_queue queue_;
1490     int num_elements_;
1491     WorkGroupParams test_params_;
1492 };
1493 
1494 #endif
1495