• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2017 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBHELPERS_H
17 #define SUBHELPERS_H
18 
19 #include "testHarness.h"
20 #include "kernelHelpers.h"
21 #include "typeWrappers.h"
22 #include "imageHelpers.h"
23 
24 #include <limits>
25 #include <vector>
26 #include <type_traits>
27 #include <bitset>
28 #include <regex>
29 #include <map>
30 
31 #define NR_OF_ACTIVE_WORK_ITEMS 4
32 
33 extern MTdata gMTdata;
34 typedef std::bitset<128> bs128;
35 extern cl_half_rounding_mode g_rounding_mode;
36 
cl_uint4_to_bs128(cl_uint4 v)37 static bs128 cl_uint4_to_bs128(cl_uint4 v)
38 {
39     return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64)
40         | (bs128(v.s3) << 96);
41 }
42 
bs128_to_cl_uint4(bs128 v)43 static cl_uint4 bs128_to_cl_uint4(bs128 v)
44 {
45     bs128 bs128_ffffffff = 0xffffffffU;
46 
47     cl_uint4 r;
48     r.s0 = ((v >> 0) & bs128_ffffffff).to_ulong();
49     r.s1 = ((v >> 32) & bs128_ffffffff).to_ulong();
50     r.s2 = ((v >> 64) & bs128_ffffffff).to_ulong();
51     r.s3 = ((v >> 96) & bs128_ffffffff).to_ulong();
52 
53     return r;
54 }
55 
56 struct WorkGroupParams
57 {
58 
59     WorkGroupParams(size_t gws, size_t lws, int dm_arg = -1, int cs_arg = -1)
global_workgroup_sizeWorkGroupParams60         : global_workgroup_size(gws), local_workgroup_size(lws),
61           divergence_mask_arg(dm_arg), cluster_size_arg(cs_arg)
62     {
63         subgroup_size = 0;
64         cluster_size = 0;
65         work_items_mask = 0;
66         use_core_subgroups = true;
67         dynsc = 0;
68         load_masks();
69     }
70     size_t global_workgroup_size;
71     size_t local_workgroup_size;
72     size_t subgroup_size;
73     cl_uint cluster_size;
74     bs128 work_items_mask;
75     size_t dynsc;
76     bool use_core_subgroups;
77     std::vector<bs128> all_work_item_masks;
78     int divergence_mask_arg;
79     int cluster_size_arg;
80     void save_kernel_source(const std::string &source, std::string name = "")
81     {
82         if (name == "")
83         {
84             name = "default";
85         }
86         if (kernel_function_name.find(name) != kernel_function_name.end())
87         {
88             log_info("Kernel definition duplication. Source will be "
89                      "overwritten for function name %s\n",
90                      name.c_str());
91         }
92         kernel_function_name[name] = source;
93     };
94     // return specific defined kernel or default.
get_kernel_sourceWorkGroupParams95     std::string get_kernel_source(std::string name)
96     {
97         if (kernel_function_name.find(name) == kernel_function_name.end())
98         {
99             return kernel_function_name["default"];
100         }
101         return kernel_function_name[name];
102     }
103 
104 
105 private:
106     std::map<std::string, std::string> kernel_function_name;
load_masksWorkGroupParams107     void load_masks()
108     {
109         if (divergence_mask_arg != -1)
110         {
111             // 1 in string will be set 1, 0 will be set 0
112             bs128 mask_0xf0f0f0f0("11110000111100001111000011110000"
113                                   "11110000111100001111000011110000"
114                                   "11110000111100001111000011110000"
115                                   "11110000111100001111000011110000",
116                                   128, '0', '1');
117             all_work_item_masks.push_back(mask_0xf0f0f0f0);
118             // 1 in string will be set 0, 0 will be set 1
119             bs128 mask_0x0f0f0f0f("11110000111100001111000011110000"
120                                   "11110000111100001111000011110000"
121                                   "11110000111100001111000011110000"
122                                   "11110000111100001111000011110000",
123                                   128, '1', '0');
124             all_work_item_masks.push_back(mask_0x0f0f0f0f);
125             bs128 mask_0x5555aaaa("10101010101010101010101010101010"
126                                   "10101010101010101010101010101010"
127                                   "10101010101010101010101010101010"
128                                   "10101010101010101010101010101010",
129                                   128, '0', '1');
130             all_work_item_masks.push_back(mask_0x5555aaaa);
131             bs128 mask_0xaaaa5555("10101010101010101010101010101010"
132                                   "10101010101010101010101010101010"
133                                   "10101010101010101010101010101010"
134                                   "10101010101010101010101010101010",
135                                   128, '1', '0');
136             all_work_item_masks.push_back(mask_0xaaaa5555);
137             // 0x0f0ff0f0
138             bs128 mask_0x0f0ff0f0("00001111000011111111000011110000"
139                                   "00001111000011111111000011110000"
140                                   "00001111000011111111000011110000"
141                                   "00001111000011111111000011110000",
142                                   128, '0', '1');
143             all_work_item_masks.push_back(mask_0x0f0ff0f0);
144             // 0xff0000ff
145             bs128 mask_0xff0000ff("11111111000000000000000011111111"
146                                   "11111111000000000000000011111111"
147                                   "11111111000000000000000011111111"
148                                   "11111111000000000000000011111111",
149                                   128, '0', '1');
150             all_work_item_masks.push_back(mask_0xff0000ff);
151             // 0xff00ff00
152             bs128 mask_0xff00ff00("11111111000000001111111100000000"
153                                   "11111111000000001111111100000000"
154                                   "11111111000000001111111100000000"
155                                   "11111111000000001111111100000000",
156                                   128, '0', '1');
157             all_work_item_masks.push_back(mask_0xff00ff00);
158             // 0x00ffff00
159             bs128 mask_0x00ffff00("00000000111111111111111100000000"
160                                   "00000000111111111111111100000000"
161                                   "00000000111111111111111100000000"
162                                   "00000000111111111111111100000000",
163                                   128, '0', '1');
164             all_work_item_masks.push_back(mask_0x00ffff00);
165             // 0x80 1 workitem highest id for 8 subgroup size
166             bs128 mask_0x80808080("10000000100000001000000010000000"
167                                   "10000000100000001000000010000000"
168                                   "10000000100000001000000010000000"
169                                   "10000000100000001000000010000000",
170                                   128, '0', '1');
171 
172             all_work_item_masks.push_back(mask_0x80808080);
173             // 0x8000 1 workitem highest id for 16 subgroup size
174             bs128 mask_0x80008000("10000000000000001000000000000000"
175                                   "10000000000000001000000000000000"
176                                   "10000000000000001000000000000000"
177                                   "10000000000000001000000000000000",
178                                   128, '0', '1');
179             all_work_item_masks.push_back(mask_0x80008000);
180             // 0x80000000 1 workitem highest id for 32 subgroup size
181             bs128 mask_0x80000000("10000000000000000000000000000000"
182                                   "10000000000000000000000000000000"
183                                   "10000000000000000000000000000000"
184                                   "10000000000000000000000000000000",
185                                   128, '0', '1');
186             all_work_item_masks.push_back(mask_0x80000000);
187             // 0x80000000 00000000 1 workitem highest id for 64 subgroup size
188             // 0x80000000 1 workitem highest id for 32 subgroup size
189             bs128 mask_0x8000000000000000("10000000000000000000000000000000"
190                                           "00000000000000000000000000000000"
191                                           "10000000000000000000000000000000"
192                                           "00000000000000000000000000000000",
193                                           128, '0', '1');
194 
195             all_work_item_masks.push_back(mask_0x8000000000000000);
196             // 0x80000000 00000000 00000000 00000000 1 workitem highest id for
197             // 128 subgroup size
198             bs128 mask_0x80000000000000000000000000000000(
199                 "10000000000000000000000000000000"
200                 "00000000000000000000000000000000"
201                 "00000000000000000000000000000000"
202                 "00000000000000000000000000000000",
203                 128, '0', '1');
204             all_work_item_masks.push_back(
205                 mask_0x80000000000000000000000000000000);
206 
207             bs128 mask_0xffffffff("11111111111111111111111111111111"
208                                   "11111111111111111111111111111111"
209                                   "11111111111111111111111111111111"
210                                   "11111111111111111111111111111111",
211                                   128, '0', '1');
212             all_work_item_masks.push_back(mask_0xffffffff);
213         }
214     }
215 };
216 
217 enum class SubgroupsBroadcastOp
218 {
219     broadcast,
220     broadcast_first,
221     non_uniform_broadcast
222 };
223 
224 enum class NonUniformVoteOp
225 {
226     elect,
227     all,
228     any,
229     all_equal
230 };
231 
232 enum class BallotOp
233 {
234     ballot,
235     inverse_ballot,
236     ballot_bit_extract,
237     ballot_bit_count,
238     ballot_inclusive_scan,
239     ballot_exclusive_scan,
240     ballot_find_lsb,
241     ballot_find_msb,
242     eq_mask,
243     ge_mask,
244     gt_mask,
245     le_mask,
246     lt_mask,
247 };
248 
249 enum class ShuffleOp
250 {
251     shuffle,
252     shuffle_up,
253     shuffle_down,
254     shuffle_xor,
255     rotate,
256     clustered_rotate,
257 };
258 
259 enum class ArithmeticOp
260 {
261     add_,
262     max_,
263     min_,
264     mul_,
265     and_,
266     or_,
267     xor_,
268     logical_and,
269     logical_or,
270     logical_xor
271 };
272 
operation_names(ArithmeticOp operation)273 static const char *const operation_names(ArithmeticOp operation)
274 {
275     switch (operation)
276     {
277         case ArithmeticOp::add_: return "add";
278         case ArithmeticOp::max_: return "max";
279         case ArithmeticOp::min_: return "min";
280         case ArithmeticOp::mul_: return "mul";
281         case ArithmeticOp::and_: return "and";
282         case ArithmeticOp::or_: return "or";
283         case ArithmeticOp::xor_: return "xor";
284         case ArithmeticOp::logical_and: return "logical_and";
285         case ArithmeticOp::logical_or: return "logical_or";
286         case ArithmeticOp::logical_xor: return "logical_xor";
287         default: log_error("Unknown operation request\n"); break;
288     }
289     return "";
290 }
291 
operation_names(BallotOp operation)292 static const char *const operation_names(BallotOp operation)
293 {
294     switch (operation)
295     {
296         case BallotOp::ballot: return "ballot";
297         case BallotOp::inverse_ballot: return "inverse_ballot";
298         case BallotOp::ballot_bit_extract: return "bit_extract";
299         case BallotOp::ballot_bit_count: return "bit_count";
300         case BallotOp::ballot_inclusive_scan: return "inclusive_scan";
301         case BallotOp::ballot_exclusive_scan: return "exclusive_scan";
302         case BallotOp::ballot_find_lsb: return "find_lsb";
303         case BallotOp::ballot_find_msb: return "find_msb";
304         case BallotOp::eq_mask: return "eq";
305         case BallotOp::ge_mask: return "ge";
306         case BallotOp::gt_mask: return "gt";
307         case BallotOp::le_mask: return "le";
308         case BallotOp::lt_mask: return "lt";
309         default: log_error("Unknown operation request\n"); break;
310     }
311     return "";
312 }
313 
operation_names(ShuffleOp operation)314 static const char *const operation_names(ShuffleOp operation)
315 {
316     switch (operation)
317     {
318         case ShuffleOp::shuffle: return "shuffle";
319         case ShuffleOp::shuffle_up: return "shuffle_up";
320         case ShuffleOp::shuffle_down: return "shuffle_down";
321         case ShuffleOp::shuffle_xor: return "shuffle_xor";
322         case ShuffleOp::rotate: return "rotate";
323         case ShuffleOp::clustered_rotate: return "clustered_rotate";
324         default: log_error("Unknown operation request\n"); break;
325     }
326     return "";
327 }
328 
operation_names(NonUniformVoteOp operation)329 static const char *const operation_names(NonUniformVoteOp operation)
330 {
331     switch (operation)
332     {
333         case NonUniformVoteOp::all: return "all";
334         case NonUniformVoteOp::all_equal: return "all_equal";
335         case NonUniformVoteOp::any: return "any";
336         case NonUniformVoteOp::elect: return "elect";
337         default: log_error("Unknown operation request\n"); break;
338     }
339     return "";
340 }
341 
operation_names(SubgroupsBroadcastOp operation)342 static const char *const operation_names(SubgroupsBroadcastOp operation)
343 {
344     switch (operation)
345     {
346         case SubgroupsBroadcastOp::broadcast: return "broadcast";
347         case SubgroupsBroadcastOp::broadcast_first: return "broadcast_first";
348         case SubgroupsBroadcastOp::non_uniform_broadcast:
349             return "non_uniform_broadcast";
350         default: log_error("Unknown operation request\n"); break;
351     }
352     return "";
353 }
354 
355 class subgroupsAPI {
356 public:
subgroupsAPI(cl_platform_id platform,bool use_core_subgroups)357     subgroupsAPI(cl_platform_id platform, bool use_core_subgroups)
358     {
359         static_assert(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE
360                           == CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR,
361                       "Enums have to be the same");
362         static_assert(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE
363                           == CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR,
364                       "Enums have to be the same");
365         if (use_core_subgroups)
366         {
367             _clGetKernelSubGroupInfo_ptr = &clGetKernelSubGroupInfo;
368             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfo";
369         }
370         else
371         {
372             _clGetKernelSubGroupInfo_ptr = (clGetKernelSubGroupInfoKHR_fn)
373                 clGetExtensionFunctionAddressForPlatform(
374                     platform, "clGetKernelSubGroupInfoKHR");
375             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfoKHR";
376         }
377     }
clGetKernelSubGroupInfo_ptr()378     clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr()
379     {
380         return _clGetKernelSubGroupInfo_ptr;
381     }
382     const char *clGetKernelSubGroupInfo_name;
383 
384 private:
385     clGetKernelSubGroupInfoKHR_fn _clGetKernelSubGroupInfo_ptr;
386 };
387 
388 // Need to defined custom type for vector size = 3 and half type. This is
389 // because of 3-component types are otherwise indistinguishable from the
390 // 4-component types, and because the half type is indistinguishable from some
391 // other 16-bit type (ushort)
392 namespace subgroups {
393 struct cl_char3
394 {
395     ::cl_char3 data;
396 };
397 struct cl_uchar3
398 {
399     ::cl_uchar3 data;
400 };
401 struct cl_short3
402 {
403     ::cl_short3 data;
404 };
405 struct cl_ushort3
406 {
407     ::cl_ushort3 data;
408 };
409 struct cl_int3
410 {
411     ::cl_int3 data;
412 };
413 struct cl_uint3
414 {
415     ::cl_uint3 data;
416 };
417 struct cl_long3
418 {
419     ::cl_long3 data;
420 };
421 struct cl_ulong3
422 {
423     ::cl_ulong3 data;
424 };
425 struct cl_float3
426 {
427     ::cl_float3 data;
428 };
429 struct cl_double3
430 {
431     ::cl_double3 data;
432 };
433 struct cl_half
434 {
435     ::cl_half data;
436 };
437 struct cl_half2
438 {
439     ::cl_half2 data;
440 };
441 struct cl_half3
442 {
443     ::cl_half3 data;
444 };
445 struct cl_half4
446 {
447     ::cl_half4 data;
448 };
449 struct cl_half8
450 {
451     ::cl_half8 data;
452 };
453 struct cl_half16
454 {
455     ::cl_half16 data;
456 };
457 }
458 
int64_ok(cl_device_id device)459 static bool int64_ok(cl_device_id device)
460 {
461     char profile[128];
462     int error;
463 
464     error = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile),
465                             (void *)&profile, NULL);
466     if (error)
467     {
468         log_info("clGetDeviceInfo failed with CL_DEVICE_PROFILE\n");
469         return false;
470     }
471 
472     if (strcmp(profile, "EMBEDDED_PROFILE") == 0)
473         return is_extension_available(device, "cles_khr_int64");
474 
475     return true;
476 }
477 
double_ok(cl_device_id device)478 static bool double_ok(cl_device_id device)
479 {
480     int error;
481     cl_device_fp_config c;
482     error = clGetDeviceInfo(device, CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(c),
483                             (void *)&c, NULL);
484     if (error)
485     {
486         log_info("clGetDeviceInfo failed with CL_DEVICE_DOUBLE_FP_CONFIG\n");
487         return false;
488     }
489     return c != 0;
490 }
491 
half_ok(cl_device_id device)492 static bool half_ok(cl_device_id device)
493 {
494     int error;
495     cl_device_fp_config c;
496     error = clGetDeviceInfo(device, CL_DEVICE_HALF_FP_CONFIG, sizeof(c),
497                             (void *)&c, NULL);
498     if (error)
499     {
500         log_info("clGetDeviceInfo failed with CL_DEVICE_HALF_FP_CONFIG\n");
501         return false;
502     }
503     return c != 0;
504 }
505 
506 template <typename Ty> struct CommonTypeManager
507 {
508 
nameCommonTypeManager509     static const char *name() { return ""; }
add_typedefCommonTypeManager510     static const char *add_typedef() { return "\n"; }
511     typedef std::false_type is_vector_type;
512     typedef std::false_type is_sb_vector_size3;
513     typedef std::false_type is_sb_vector_type;
514     typedef std::false_type is_sb_scalar_type;
type_supportedCommonTypeManager515     static const bool type_supported(cl_device_id) { return true; }
identify_limitsCommonTypeManager516     static const Ty identify_limits(ArithmeticOp operation)
517     {
518         switch (operation)
519         {
520             case ArithmeticOp::add_: return (Ty)0;
521             case ArithmeticOp::max_: return (std::numeric_limits<Ty>::min)();
522             case ArithmeticOp::min_: return (std::numeric_limits<Ty>::max)();
523             case ArithmeticOp::mul_: return (Ty)1;
524             case ArithmeticOp::and_: return (Ty)~0;
525             case ArithmeticOp::or_: return (Ty)0;
526             case ArithmeticOp::xor_: return (Ty)0;
527             default: log_error("Unknown operation request\n"); break;
528         }
529         return 0;
530     }
531 };
532 
533 template <typename> struct TypeManager;
534 
535 template <> struct TypeManager<cl_int> : public CommonTypeManager<cl_int>
536 {
537     static const char *name() { return "int"; }
538     static const char *add_typedef() { return "typedef int Type;\n"; }
539     static cl_int identify_limits(ArithmeticOp operation)
540     {
541         switch (operation)
542         {
543             case ArithmeticOp::add_: return (cl_int)0;
544             case ArithmeticOp::max_:
545                 return (std::numeric_limits<cl_int>::min)();
546             case ArithmeticOp::min_:
547                 return (std::numeric_limits<cl_int>::max)();
548             case ArithmeticOp::mul_: return (cl_int)1;
549             case ArithmeticOp::and_: return (cl_int)~0;
550             case ArithmeticOp::or_: return (cl_int)0;
551             case ArithmeticOp::xor_: return (cl_int)0;
552             case ArithmeticOp::logical_and: return (cl_int)1;
553             case ArithmeticOp::logical_or: return (cl_int)0;
554             case ArithmeticOp::logical_xor: return (cl_int)0;
555             default: log_error("Unknown operation request\n"); break;
556         }
557         return 0;
558     }
559 };
560 template <> struct TypeManager<cl_int2> : public CommonTypeManager<cl_int2>
561 {
562     static const char *name() { return "int2"; }
563     static const char *add_typedef() { return "typedef int2 Type;\n"; }
564     typedef std::true_type is_vector_type;
565     using scalar_type = cl_int;
566 };
567 template <>
568 struct TypeManager<subgroups::cl_int3>
569     : public CommonTypeManager<subgroups::cl_int3>
570 {
571     static const char *name() { return "int3"; }
572     static const char *add_typedef() { return "typedef int3 Type;\n"; }
573     typedef std::true_type is_sb_vector_size3;
574     using scalar_type = cl_int;
575 };
576 template <> struct TypeManager<cl_int4> : public CommonTypeManager<cl_int4>
577 {
578     static const char *name() { return "int4"; }
579     static const char *add_typedef() { return "typedef int4 Type;\n"; }
580     using scalar_type = cl_int;
581     typedef std::true_type is_vector_type;
582 };
583 template <> struct TypeManager<cl_int8> : public CommonTypeManager<cl_int8>
584 {
585     static const char *name() { return "int8"; }
586     static const char *add_typedef() { return "typedef int8 Type;\n"; }
587     using scalar_type = cl_int;
588     typedef std::true_type is_vector_type;
589 };
590 template <> struct TypeManager<cl_int16> : public CommonTypeManager<cl_int16>
591 {
592     static const char *name() { return "int16"; }
593     static const char *add_typedef() { return "typedef int16 Type;\n"; }
594     using scalar_type = cl_int;
595     typedef std::true_type is_vector_type;
596 };
597 // cl_uint
598 template <> struct TypeManager<cl_uint> : public CommonTypeManager<cl_uint>
599 {
600     static const char *name() { return "uint"; }
601     static const char *add_typedef() { return "typedef uint Type;\n"; }
602 };
603 template <> struct TypeManager<cl_uint2> : public CommonTypeManager<cl_uint2>
604 {
605     static const char *name() { return "uint2"; }
606     static const char *add_typedef() { return "typedef uint2 Type;\n"; }
607     using scalar_type = cl_uint;
608     typedef std::true_type is_vector_type;
609 };
610 template <>
611 struct TypeManager<subgroups::cl_uint3>
612     : public CommonTypeManager<subgroups::cl_uint3>
613 {
614     static const char *name() { return "uint3"; }
615     static const char *add_typedef() { return "typedef uint3 Type;\n"; }
616     typedef std::true_type is_sb_vector_size3;
617     using scalar_type = cl_uint;
618 };
619 template <> struct TypeManager<cl_uint4> : public CommonTypeManager<cl_uint4>
620 {
621     static const char *name() { return "uint4"; }
622     static const char *add_typedef() { return "typedef uint4 Type;\n"; }
623     using scalar_type = cl_uint;
624     typedef std::true_type is_vector_type;
625 };
626 template <> struct TypeManager<cl_uint8> : public CommonTypeManager<cl_uint8>
627 {
628     static const char *name() { return "uint8"; }
629     static const char *add_typedef() { return "typedef uint8 Type;\n"; }
630     using scalar_type = cl_uint;
631     typedef std::true_type is_vector_type;
632 };
633 template <> struct TypeManager<cl_uint16> : public CommonTypeManager<cl_uint16>
634 {
635     static const char *name() { return "uint16"; }
636     static const char *add_typedef() { return "typedef uint16 Type;\n"; }
637     using scalar_type = cl_uint;
638     typedef std::true_type is_vector_type;
639 };
640 // cl_short
641 template <> struct TypeManager<cl_short> : public CommonTypeManager<cl_short>
642 {
643     static const char *name() { return "short"; }
644     static const char *add_typedef() { return "typedef short Type;\n"; }
645 };
646 template <> struct TypeManager<cl_short2> : public CommonTypeManager<cl_short2>
647 {
648     static const char *name() { return "short2"; }
649     static const char *add_typedef() { return "typedef short2 Type;\n"; }
650     using scalar_type = cl_short;
651     typedef std::true_type is_vector_type;
652 };
653 template <>
654 struct TypeManager<subgroups::cl_short3>
655     : public CommonTypeManager<subgroups::cl_short3>
656 {
657     static const char *name() { return "short3"; }
658     static const char *add_typedef() { return "typedef short3 Type;\n"; }
659     typedef std::true_type is_sb_vector_size3;
660     using scalar_type = cl_short;
661 };
662 template <> struct TypeManager<cl_short4> : public CommonTypeManager<cl_short4>
663 {
664     static const char *name() { return "short4"; }
665     static const char *add_typedef() { return "typedef short4 Type;\n"; }
666     using scalar_type = cl_short;
667     typedef std::true_type is_vector_type;
668 };
669 template <> struct TypeManager<cl_short8> : public CommonTypeManager<cl_short8>
670 {
671     static const char *name() { return "short8"; }
672     static const char *add_typedef() { return "typedef short8 Type;\n"; }
673     using scalar_type = cl_short;
674     typedef std::true_type is_vector_type;
675 };
676 template <>
677 struct TypeManager<cl_short16> : public CommonTypeManager<cl_short16>
678 {
679     static const char *name() { return "short16"; }
680     static const char *add_typedef() { return "typedef short16 Type;\n"; }
681     using scalar_type = cl_short;
682     typedef std::true_type is_vector_type;
683 };
684 // cl_ushort
685 template <> struct TypeManager<cl_ushort> : public CommonTypeManager<cl_ushort>
686 {
687     static const char *name() { return "ushort"; }
688     static const char *add_typedef() { return "typedef ushort Type;\n"; }
689 };
690 template <>
691 struct TypeManager<cl_ushort2> : public CommonTypeManager<cl_ushort2>
692 {
693     static const char *name() { return "ushort2"; }
694     static const char *add_typedef() { return "typedef ushort2 Type;\n"; }
695     using scalar_type = cl_ushort;
696     typedef std::true_type is_vector_type;
697 };
698 template <>
699 struct TypeManager<subgroups::cl_ushort3>
700     : public CommonTypeManager<subgroups::cl_ushort3>
701 {
702     static const char *name() { return "ushort3"; }
703     static const char *add_typedef() { return "typedef ushort3 Type;\n"; }
704     typedef std::true_type is_sb_vector_size3;
705     using scalar_type = cl_ushort;
706 };
707 template <>
708 struct TypeManager<cl_ushort4> : public CommonTypeManager<cl_ushort4>
709 {
710     static const char *name() { return "ushort4"; }
711     static const char *add_typedef() { return "typedef ushort4 Type;\n"; }
712     using scalar_type = cl_ushort;
713     typedef std::true_type is_vector_type;
714 };
715 template <>
716 struct TypeManager<cl_ushort8> : public CommonTypeManager<cl_ushort8>
717 {
718     static const char *name() { return "ushort8"; }
719     static const char *add_typedef() { return "typedef ushort8 Type;\n"; }
720     using scalar_type = cl_ushort;
721     typedef std::true_type is_vector_type;
722 };
723 template <>
724 struct TypeManager<cl_ushort16> : public CommonTypeManager<cl_ushort16>
725 {
726     static const char *name() { return "ushort16"; }
727     static const char *add_typedef() { return "typedef ushort16 Type;\n"; }
728     using scalar_type = cl_ushort;
729     typedef std::true_type is_vector_type;
730 };
731 // cl_char
732 template <> struct TypeManager<cl_char> : public CommonTypeManager<cl_char>
733 {
734     static const char *name() { return "char"; }
735     static const char *add_typedef() { return "typedef char Type;\n"; }
736 };
737 template <> struct TypeManager<cl_char2> : public CommonTypeManager<cl_char2>
738 {
739     static const char *name() { return "char2"; }
740     static const char *add_typedef() { return "typedef char2 Type;\n"; }
741     using scalar_type = cl_char;
742     typedef std::true_type is_vector_type;
743 };
744 template <>
745 struct TypeManager<subgroups::cl_char3>
746     : public CommonTypeManager<subgroups::cl_char3>
747 {
748     static const char *name() { return "char3"; }
749     static const char *add_typedef() { return "typedef char3 Type;\n"; }
750     typedef std::true_type is_sb_vector_size3;
751     using scalar_type = cl_char;
752 };
753 template <> struct TypeManager<cl_char4> : public CommonTypeManager<cl_char4>
754 {
755     static const char *name() { return "char4"; }
756     static const char *add_typedef() { return "typedef char4 Type;\n"; }
757     using scalar_type = cl_char;
758     typedef std::true_type is_vector_type;
759 };
760 template <> struct TypeManager<cl_char8> : public CommonTypeManager<cl_char8>
761 {
762     static const char *name() { return "char8"; }
763     static const char *add_typedef() { return "typedef char8 Type;\n"; }
764     using scalar_type = cl_char;
765     typedef std::true_type is_vector_type;
766 };
767 template <> struct TypeManager<cl_char16> : public CommonTypeManager<cl_char16>
768 {
769     static const char *name() { return "char16"; }
770     static const char *add_typedef() { return "typedef char16 Type;\n"; }
771     using scalar_type = cl_char;
772     typedef std::true_type is_vector_type;
773 };
774 // cl_uchar
775 template <> struct TypeManager<cl_uchar> : public CommonTypeManager<cl_uchar>
776 {
777     static const char *name() { return "uchar"; }
778     static const char *add_typedef() { return "typedef uchar Type;\n"; }
779 };
780 template <> struct TypeManager<cl_uchar2> : public CommonTypeManager<cl_uchar2>
781 {
782     static const char *name() { return "uchar2"; }
783     static const char *add_typedef() { return "typedef uchar2 Type;\n"; }
784     using scalar_type = cl_uchar;
785     typedef std::true_type is_vector_type;
786 };
787 template <>
788 struct TypeManager<subgroups::cl_uchar3>
789     : public CommonTypeManager<subgroups::cl_char3>
790 {
791     static const char *name() { return "uchar3"; }
792     static const char *add_typedef() { return "typedef uchar3 Type;\n"; }
793     typedef std::true_type is_sb_vector_size3;
794     using scalar_type = cl_uchar;
795 };
796 template <> struct TypeManager<cl_uchar4> : public CommonTypeManager<cl_uchar4>
797 {
798     static const char *name() { return "uchar4"; }
799     static const char *add_typedef() { return "typedef uchar4 Type;\n"; }
800     using scalar_type = cl_uchar;
801     typedef std::true_type is_vector_type;
802 };
803 template <> struct TypeManager<cl_uchar8> : public CommonTypeManager<cl_uchar8>
804 {
805     static const char *name() { return "uchar8"; }
806     static const char *add_typedef() { return "typedef uchar8 Type;\n"; }
807     using scalar_type = cl_uchar;
808     typedef std::true_type is_vector_type;
809 };
810 template <>
811 struct TypeManager<cl_uchar16> : public CommonTypeManager<cl_uchar16>
812 {
813     static const char *name() { return "uchar16"; }
814     static const char *add_typedef() { return "typedef uchar16 Type;\n"; }
815     using scalar_type = cl_uchar;
816     typedef std::true_type is_vector_type;
817 };
818 // cl_long
819 template <> struct TypeManager<cl_long> : public CommonTypeManager<cl_long>
820 {
821     static const char *name() { return "long"; }
822     static const char *add_typedef() { return "typedef long Type;\n"; }
823     static const bool type_supported(cl_device_id device)
824     {
825         return int64_ok(device);
826     }
827 };
828 template <> struct TypeManager<cl_long2> : public CommonTypeManager<cl_long2>
829 {
830     static const char *name() { return "long2"; }
831     static const char *add_typedef() { return "typedef long2 Type;\n"; }
832     using scalar_type = cl_long;
833     typedef std::true_type is_vector_type;
834     static const bool type_supported(cl_device_id device)
835     {
836         return int64_ok(device);
837     }
838 };
839 template <>
840 struct TypeManager<subgroups::cl_long3>
841     : public CommonTypeManager<subgroups::cl_long3>
842 {
843     static const char *name() { return "long3"; }
844     static const char *add_typedef() { return "typedef long3 Type;\n"; }
845     typedef std::true_type is_sb_vector_size3;
846     using scalar_type = cl_long;
847     static const bool type_supported(cl_device_id device)
848     {
849         return int64_ok(device);
850     }
851 };
852 template <> struct TypeManager<cl_long4> : public CommonTypeManager<cl_long4>
853 {
854     static const char *name() { return "long4"; }
855     static const char *add_typedef() { return "typedef long4 Type;\n"; }
856     using scalar_type = cl_long;
857     typedef std::true_type is_vector_type;
858     static const bool type_supported(cl_device_id device)
859     {
860         return int64_ok(device);
861     }
862 };
863 template <> struct TypeManager<cl_long8> : public CommonTypeManager<cl_long8>
864 {
865     static const char *name() { return "long8"; }
866     static const char *add_typedef() { return "typedef long8 Type;\n"; }
867     using scalar_type = cl_long;
868     typedef std::true_type is_vector_type;
869     static const bool type_supported(cl_device_id device)
870     {
871         return int64_ok(device);
872     }
873 };
874 template <> struct TypeManager<cl_long16> : public CommonTypeManager<cl_long16>
875 {
876     static const char *name() { return "long16"; }
877     static const char *add_typedef() { return "typedef long16 Type;\n"; }
878     using scalar_type = cl_long;
879     typedef std::true_type is_vector_type;
880     static const bool type_supported(cl_device_id device)
881     {
882         return int64_ok(device);
883     }
884 };
885 // cl_ulong
886 template <> struct TypeManager<cl_ulong> : public CommonTypeManager<cl_ulong>
887 {
888     static const char *name() { return "ulong"; }
889     static const char *add_typedef() { return "typedef ulong Type;\n"; }
890     static const bool type_supported(cl_device_id device)
891     {
892         return int64_ok(device);
893     }
894 };
895 template <> struct TypeManager<cl_ulong2> : public CommonTypeManager<cl_ulong2>
896 {
897     static const char *name() { return "ulong2"; }
898     static const char *add_typedef() { return "typedef ulong2 Type;\n"; }
899     using scalar_type = cl_ulong;
900     typedef std::true_type is_vector_type;
901     static const bool type_supported(cl_device_id device)
902     {
903         return int64_ok(device);
904     }
905 };
906 template <>
907 struct TypeManager<subgroups::cl_ulong3>
908     : public CommonTypeManager<subgroups::cl_ulong3>
909 {
910     static const char *name() { return "ulong3"; }
911     static const char *add_typedef() { return "typedef ulong3 Type;\n"; }
912     typedef std::true_type is_sb_vector_size3;
913     using scalar_type = cl_ulong;
914     static const bool type_supported(cl_device_id device)
915     {
916         return int64_ok(device);
917     }
918 };
919 template <> struct TypeManager<cl_ulong4> : public CommonTypeManager<cl_ulong4>
920 {
921     static const char *name() { return "ulong4"; }
922     static const char *add_typedef() { return "typedef ulong4 Type;\n"; }
923     using scalar_type = cl_ulong;
924     typedef std::true_type is_vector_type;
925     static const bool type_supported(cl_device_id device)
926     {
927         return int64_ok(device);
928     }
929 };
930 template <> struct TypeManager<cl_ulong8> : public CommonTypeManager<cl_ulong8>
931 {
932     static const char *name() { return "ulong8"; }
933     static const char *add_typedef() { return "typedef ulong8 Type;\n"; }
934     using scalar_type = cl_ulong;
935     typedef std::true_type is_vector_type;
936     static const bool type_supported(cl_device_id device)
937     {
938         return int64_ok(device);
939     }
940 };
941 template <>
942 struct TypeManager<cl_ulong16> : public CommonTypeManager<cl_ulong16>
943 {
944     static const char *name() { return "ulong16"; }
945     static const char *add_typedef() { return "typedef ulong16 Type;\n"; }
946     using scalar_type = cl_ulong;
947     typedef std::true_type is_vector_type;
948     static const bool type_supported(cl_device_id device)
949     {
950         return int64_ok(device);
951     }
952 };
953 
954 // cl_float
955 template <> struct TypeManager<cl_float> : public CommonTypeManager<cl_float>
956 {
957     static const char *name() { return "float"; }
958     static const char *add_typedef() { return "typedef float Type;\n"; }
959     static cl_float identify_limits(ArithmeticOp operation)
960     {
961         switch (operation)
962         {
963             case ArithmeticOp::add_: return 0.0f;
964             case ArithmeticOp::max_:
965                 return -std::numeric_limits<float>::infinity();
966             case ArithmeticOp::min_:
967                 return std::numeric_limits<float>::infinity();
968             case ArithmeticOp::mul_: return (cl_float)1;
969             default: log_error("Unknown operation request\n"); break;
970         }
971         return 0;
972     }
973 };
974 template <> struct TypeManager<cl_float2> : public CommonTypeManager<cl_float2>
975 {
976     static const char *name() { return "float2"; }
977     static const char *add_typedef() { return "typedef float2 Type;\n"; }
978     using scalar_type = cl_float;
979     typedef std::true_type is_vector_type;
980 };
981 template <>
982 struct TypeManager<subgroups::cl_float3>
983     : public CommonTypeManager<subgroups::cl_float3>
984 {
985     static const char *name() { return "float3"; }
986     static const char *add_typedef() { return "typedef float3 Type;\n"; }
987     typedef std::true_type is_sb_vector_size3;
988     using scalar_type = cl_float;
989 };
990 template <> struct TypeManager<cl_float4> : public CommonTypeManager<cl_float4>
991 {
992     static const char *name() { return "float4"; }
993     static const char *add_typedef() { return "typedef float4 Type;\n"; }
994     using scalar_type = cl_float;
995     typedef std::true_type is_vector_type;
996 };
997 template <> struct TypeManager<cl_float8> : public CommonTypeManager<cl_float8>
998 {
999     static const char *name() { return "float8"; }
1000     static const char *add_typedef() { return "typedef float8 Type;\n"; }
1001     using scalar_type = cl_float;
1002     typedef std::true_type is_vector_type;
1003 };
1004 template <>
1005 struct TypeManager<cl_float16> : public CommonTypeManager<cl_float16>
1006 {
1007     static const char *name() { return "float16"; }
1008     static const char *add_typedef() { return "typedef float16 Type;\n"; }
1009     using scalar_type = cl_float;
1010     typedef std::true_type is_vector_type;
1011 };
1012 
1013 // cl_double
1014 template <> struct TypeManager<cl_double> : public CommonTypeManager<cl_double>
1015 {
1016     static const char *name() { return "double"; }
1017     static const char *add_typedef() { return "typedef double Type;\n"; }
1018     static cl_double identify_limits(ArithmeticOp operation)
1019     {
1020         switch (operation)
1021         {
1022             case ArithmeticOp::add_: return 0.0;
1023             case ArithmeticOp::max_:
1024                 return -std::numeric_limits<double>::infinity();
1025             case ArithmeticOp::min_:
1026                 return std::numeric_limits<double>::infinity();
1027             case ArithmeticOp::mul_: return (cl_double)1;
1028             default: log_error("Unknown operation request\n"); break;
1029         }
1030         return 0;
1031     }
1032     static const bool type_supported(cl_device_id device)
1033     {
1034         return double_ok(device);
1035     }
1036 };
1037 template <>
1038 struct TypeManager<cl_double2> : public CommonTypeManager<cl_double2>
1039 {
1040     static const char *name() { return "double2"; }
1041     static const char *add_typedef() { return "typedef double2 Type;\n"; }
1042     using scalar_type = cl_double;
1043     typedef std::true_type is_vector_type;
1044     static const bool type_supported(cl_device_id device)
1045     {
1046         return double_ok(device);
1047     }
1048 };
1049 template <>
1050 struct TypeManager<subgroups::cl_double3>
1051     : public CommonTypeManager<subgroups::cl_double3>
1052 {
1053     static const char *name() { return "double3"; }
1054     static const char *add_typedef() { return "typedef double3 Type;\n"; }
1055     typedef std::true_type is_sb_vector_size3;
1056     using scalar_type = cl_double;
1057     static const bool type_supported(cl_device_id device)
1058     {
1059         return double_ok(device);
1060     }
1061 };
1062 template <>
1063 struct TypeManager<cl_double4> : public CommonTypeManager<cl_double4>
1064 {
1065     static const char *name() { return "double4"; }
1066     static const char *add_typedef() { return "typedef double4 Type;\n"; }
1067     using scalar_type = cl_double;
1068     typedef std::true_type is_vector_type;
1069     static const bool type_supported(cl_device_id device)
1070     {
1071         return double_ok(device);
1072     }
1073 };
1074 template <>
1075 struct TypeManager<cl_double8> : public CommonTypeManager<cl_double8>
1076 {
1077     static const char *name() { return "double8"; }
1078     static const char *add_typedef() { return "typedef double8 Type;\n"; }
1079     using scalar_type = cl_double;
1080     typedef std::true_type is_vector_type;
1081     static const bool type_supported(cl_device_id device)
1082     {
1083         return double_ok(device);
1084     }
1085 };
1086 template <>
1087 struct TypeManager<cl_double16> : public CommonTypeManager<cl_double16>
1088 {
1089     static const char *name() { return "double16"; }
1090     static const char *add_typedef() { return "typedef double16 Type;\n"; }
1091     using scalar_type = cl_double;
1092     typedef std::true_type is_vector_type;
1093     static const bool type_supported(cl_device_id device)
1094     {
1095         return double_ok(device);
1096     }
1097 };
1098 
1099 // cl_half
1100 template <>
1101 struct TypeManager<subgroups::cl_half>
1102     : public CommonTypeManager<subgroups::cl_half>
1103 {
1104     static const char *name() { return "half"; }
1105     static const char *add_typedef() { return "typedef half Type;\n"; }
1106     typedef std::true_type is_sb_scalar_type;
1107     static subgroups::cl_half identify_limits(ArithmeticOp operation)
1108     {
1109         switch (operation)
1110         {
1111             case ArithmeticOp::add_: return { 0x0000 };
1112             case ArithmeticOp::max_: return { 0xfc00 };
1113             case ArithmeticOp::min_: return { 0x7c00 };
1114             case ArithmeticOp::mul_: return { 0x3c00 };
1115             default: log_error("Unknown operation request\n"); break;
1116         }
1117         return { 0 };
1118     }
1119     static const bool type_supported(cl_device_id device)
1120     {
1121         return half_ok(device);
1122     }
1123 };
1124 template <>
1125 struct TypeManager<subgroups::cl_half2>
1126     : public CommonTypeManager<subgroups::cl_half2>
1127 {
1128     static const char *name() { return "half2"; }
1129     static const char *add_typedef() { return "typedef half2 Type;\n"; }
1130     using scalar_type = subgroups::cl_half;
1131     typedef std::true_type is_sb_vector_type;
1132     static const bool type_supported(cl_device_id device)
1133     {
1134         return half_ok(device);
1135     }
1136 };
1137 template <>
1138 struct TypeManager<subgroups::cl_half3>
1139     : public CommonTypeManager<subgroups::cl_half3>
1140 {
1141     static const char *name() { return "half3"; }
1142     static const char *add_typedef() { return "typedef half3 Type;\n"; }
1143     typedef std::true_type is_sb_vector_size3;
1144     using scalar_type = subgroups::cl_half;
1145 
1146     static const bool type_supported(cl_device_id device)
1147     {
1148         return half_ok(device);
1149     }
1150 };
1151 template <>
1152 struct TypeManager<subgroups::cl_half4>
1153     : public CommonTypeManager<subgroups::cl_half4>
1154 {
1155     static const char *name() { return "half4"; }
1156     static const char *add_typedef() { return "typedef half4 Type;\n"; }
1157     using scalar_type = subgroups::cl_half;
1158     typedef std::true_type is_sb_vector_type;
1159     static const bool type_supported(cl_device_id device)
1160     {
1161         return half_ok(device);
1162     }
1163 };
1164 template <>
1165 struct TypeManager<subgroups::cl_half8>
1166     : public CommonTypeManager<subgroups::cl_half8>
1167 {
1168     static const char *name() { return "half8"; }
1169     static const char *add_typedef() { return "typedef half8 Type;\n"; }
1170     using scalar_type = subgroups::cl_half;
1171     typedef std::true_type is_sb_vector_type;
1172 
1173     static const bool type_supported(cl_device_id device)
1174     {
1175         return half_ok(device);
1176     }
1177 };
1178 template <>
1179 struct TypeManager<subgroups::cl_half16>
1180     : public CommonTypeManager<subgroups::cl_half16>
1181 {
1182     static const char *name() { return "half16"; }
1183     static const char *add_typedef() { return "typedef half16 Type;\n"; }
1184     using scalar_type = subgroups::cl_half;
1185     typedef std::true_type is_sb_vector_type;
1186     static const bool type_supported(cl_device_id device)
1187     {
1188         return half_ok(device);
1189     }
1190 };
1191 
1192 // set scalar value to vector of halfs
1193 template <typename Ty, int N = 0>
1194 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value>::type
1195 set_value(Ty &lhs, const cl_ulong &rhs)
1196 {
1197     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1198     for (auto i = 0; i < size; ++i)
1199     {
1200         lhs.data.s[i] = rhs;
1201     }
1202 }
1203 
1204 
1205 // set scalar value to vector
1206 template <typename Ty>
1207 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1208 set_value(Ty &lhs, const cl_ulong &rhs)
1209 {
1210     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1211     for (auto i = 0; i < size; ++i)
1212     {
1213         lhs.s[i] = rhs;
1214     }
1215 }
1216 
1217 // set vector to vector value
1218 template <typename Ty>
1219 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1220 set_value(Ty &lhs, const Ty &rhs)
1221 {
1222     lhs = rhs;
1223 }
1224 
1225 // set scalar value to vector size 3
1226 template <typename Ty, int N = 0>
1227 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value>::type
1228 set_value(Ty &lhs, const cl_ulong &rhs)
1229 {
1230     for (auto i = 0; i < 3; ++i)
1231     {
1232         lhs.data.s[i] = rhs;
1233     }
1234 }
1235 
1236 // set scalar value to scalar
1237 template <typename Ty>
1238 typename std::enable_if<std::is_scalar<Ty>::value>::type
1239 set_value(Ty &lhs, const cl_ulong &rhs)
1240 {
1241     lhs = static_cast<Ty>(rhs);
1242 }
1243 
1244 // set scalar value to half scalar
1245 template <typename Ty>
1246 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value>::type
1247 set_value(Ty &lhs, const cl_ulong &rhs)
1248 {
1249     lhs.data = cl_half_from_float(static_cast<cl_float>(rhs), g_rounding_mode);
1250 }
1251 
1252 // compare for common vectors
1253 template <typename Ty>
1254 typename std::enable_if<TypeManager<Ty>::is_vector_type::value, bool>::type
1255 compare(const Ty &lhs, const Ty &rhs)
1256 {
1257     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1258     for (auto i = 0; i < size; ++i)
1259     {
1260         if (lhs.s[i] != rhs.s[i])
1261         {
1262             return false;
1263         }
1264     }
1265     return true;
1266 }
1267 
1268 // compare for vectors 3
1269 template <typename Ty>
1270 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value, bool>::type
1271 compare(const Ty &lhs, const Ty &rhs)
1272 {
1273     for (auto i = 0; i < 3; ++i)
1274     {
1275         if (lhs.data.s[i] != rhs.data.s[i])
1276         {
1277             return false;
1278         }
1279     }
1280     return true;
1281 }
1282 
1283 // compare for half vectors
1284 template <typename Ty>
1285 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value, bool>::type
1286 compare(const Ty &lhs, const Ty &rhs)
1287 {
1288     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1289     for (auto i = 0; i < size; ++i)
1290     {
1291         if (lhs.data.s[i] != rhs.data.s[i])
1292         {
1293             return false;
1294         }
1295     }
1296     return true;
1297 }
1298 
1299 // compare for scalars
1300 template <typename Ty>
1301 typename std::enable_if<std::is_scalar<Ty>::value, bool>::type
1302 compare(const Ty &lhs, const Ty &rhs)
1303 {
1304     return lhs == rhs;
1305 }
1306 
1307 // compare for scalar halfs
1308 template <typename Ty>
1309 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value, bool>::type
1310 compare(const Ty &lhs, const Ty &rhs)
1311 {
1312     return lhs.data == rhs.data;
1313 }
1314 
1315 template <typename Ty> inline bool compare_ordered(const Ty &lhs, const Ty &rhs)
1316 {
1317     return lhs == rhs;
1318 }
1319 
1320 template <>
1321 inline bool compare_ordered(const subgroups::cl_half &lhs,
1322                             const subgroups::cl_half &rhs)
1323 {
1324     return cl_half_to_float(lhs.data) == cl_half_to_float(rhs.data);
1325 }
1326 
1327 template <typename Ty>
1328 inline bool compare_ordered(const subgroups::cl_half &lhs, const int &rhs)
1329 {
1330     return cl_half_to_float(lhs.data) == rhs;
1331 }
1332 
1333 template <typename Ty, typename Fns> class KernelExecutor {
1334 public:
1335     KernelExecutor(cl_context c, cl_command_queue q, cl_kernel k, size_t g,
1336                    size_t l, Ty *id, size_t is, Ty *mid, Ty *mod, cl_int *md,
1337                    size_t ms, Ty *od, size_t os, size_t ts = 0)
1338         : context(c), queue(q), kernel(k), global(g), local(l), idata(id),
1339           isize(is), mapin_data(mid), mapout_data(mod), mdata(md), msize(ms),
1340           odata(od), osize(os), tsize(ts)
1341     {
1342         has_status = false;
1343         run_failed = false;
1344     }
1345     cl_context context;
1346     cl_command_queue queue;
1347     cl_kernel kernel;
1348     size_t global;
1349     size_t local;
1350     Ty *idata;
1351     size_t isize;
1352     Ty *mapin_data;
1353     Ty *mapout_data;
1354     cl_int *mdata;
1355     size_t msize;
1356     Ty *odata;
1357     size_t osize;
1358     size_t tsize;
1359     bool run_failed;
1360 
1361 private:
1362     bool has_status;
1363     test_status status;
1364 
1365 public:
1366     // Run a test kernel to compute the result of a built-in on an input
1367     int run()
1368     {
1369         clMemWrapper in;
1370         clMemWrapper xy;
1371         clMemWrapper out;
1372         clMemWrapper tmp;
1373         int error;
1374 
1375         in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
1376         test_error(error, "clCreateBuffer failed");
1377 
1378         xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
1379         test_error(error, "clCreateBuffer failed");
1380 
1381         out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
1382         test_error(error, "clCreateBuffer failed");
1383 
1384         if (tsize)
1385         {
1386             tmp = clCreateBuffer(context,
1387                                  CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
1388                                  tsize, NULL, &error);
1389             test_error(error, "clCreateBuffer failed");
1390         }
1391 
1392         error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
1393         test_error(error, "clSetKernelArg failed");
1394 
1395         error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
1396         test_error(error, "clSetKernelArg failed");
1397 
1398         error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
1399         test_error(error, "clSetKernelArg failed");
1400 
1401         if (tsize)
1402         {
1403             error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
1404             test_error(error, "clSetKernelArg failed");
1405         }
1406 
1407         error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0,
1408                                      NULL, NULL);
1409         test_error(error, "clEnqueueWriteBuffer failed");
1410 
1411         error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1412                                      NULL, NULL);
1413         test_error(error, "clEnqueueWriteBuffer failed");
1414         error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local,
1415                                        0, NULL, NULL);
1416         test_error(error, "clEnqueueNDRangeKernel failed");
1417 
1418         error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1419                                     NULL, NULL);
1420         test_error(error, "clEnqueueReadBuffer failed");
1421 
1422         error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0,
1423                                     NULL, NULL);
1424         test_error(error, "clEnqueueReadBuffer failed");
1425 
1426         error = clFinish(queue);
1427         test_error(error, "clFinish failed");
1428 
1429         return error;
1430     }
1431 
1432 private:
1433     test_status
1434     run_and_check_with_cluster_size(const WorkGroupParams &test_params)
1435     {
1436         cl_int error = run();
1437         if (error != CL_SUCCESS)
1438         {
1439             print_error(error, "Failed to run subgroup test kernel");
1440             status = TEST_FAIL;
1441             run_failed = true;
1442             return status;
1443         }
1444 
1445         test_status tmp_status =
1446             Fns::chk(idata, odata, mapin_data, mapout_data, mdata, test_params);
1447 
1448         if (!has_status || tmp_status == TEST_FAIL
1449             || (tmp_status == TEST_PASS && status != TEST_FAIL))
1450         {
1451             status = tmp_status;
1452             has_status = true;
1453         }
1454 
1455         return status;
1456     }
1457 
1458 public:
1459     test_status run_and_check(WorkGroupParams &test_params)
1460     {
1461         test_status tmp_status = TEST_SKIPPED_ITSELF;
1462 
1463         if (test_params.cluster_size_arg != -1)
1464         {
1465             for (cl_uint cluster_size = 1;
1466                  cluster_size <= test_params.subgroup_size; cluster_size *= 2)
1467             {
1468                 test_params.cluster_size = cluster_size;
1469                 cl_int error =
1470                     clSetKernelArg(kernel, test_params.cluster_size_arg,
1471                                    sizeof(cl_uint), &cluster_size);
1472                 test_error_fail(error, "Unable to set cluster size");
1473 
1474                 tmp_status = run_and_check_with_cluster_size(test_params);
1475 
1476                 if (tmp_status == TEST_FAIL) break;
1477             }
1478         }
1479         else
1480         {
1481             tmp_status = run_and_check_with_cluster_size(test_params);
1482         }
1483 
1484         return tmp_status;
1485     }
1486 };
1487 
1488 // Driver for testing a single built in function
1489 template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
1490 {
1491     static test_status run(cl_device_id device, cl_context context,
1492                            cl_command_queue queue, int num_elements,
1493                            const char *kname, const char *src,
1494                            WorkGroupParams test_params)
1495     {
1496         size_t tmp;
1497         cl_int error;
1498         size_t subgroup_size, num_subgroups;
1499         size_t global = test_params.global_workgroup_size;
1500         size_t local = test_params.local_workgroup_size;
1501         clProgramWrapper program;
1502         clKernelWrapper kernel;
1503         cl_platform_id platform;
1504         std::vector<cl_int> sgmap;
1505         sgmap.resize(4 * global);
1506         std::vector<Ty> mapin;
1507         mapin.resize(local);
1508         std::vector<Ty> mapout;
1509         mapout.resize(local);
1510         std::stringstream kernel_sstr;
1511 
1512         Fns::log_test(test_params, "");
1513 
1514         kernel_sstr << "#define NR_OF_ACTIVE_WORK_ITEMS ";
1515         kernel_sstr << NR_OF_ACTIVE_WORK_ITEMS << "\n";
1516         // Make sure a test of type Ty is supported by the device
1517         if (!TypeManager<Ty>::type_supported(device))
1518         {
1519             log_info("Data type not supported : %s\n", TypeManager<Ty>::name());
1520             return TEST_SKIPPED_ITSELF;
1521         }
1522 
1523         if (strstr(TypeManager<Ty>::name(), "double"))
1524         {
1525             kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n";
1526         }
1527         else if (strstr(TypeManager<Ty>::name(), "half"))
1528         {
1529             kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp16: enable\n";
1530         }
1531 
1532         error = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(platform),
1533                                 (void *)&platform, NULL);
1534         test_error_fail(error, "clGetDeviceInfo failed for CL_DEVICE_PLATFORM");
1535         if (test_params.use_core_subgroups)
1536         {
1537             kernel_sstr
1538                 << "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
1539         }
1540         kernel_sstr << "#define XY(M,I) M[I].x = get_sub_group_local_id(); "
1541                        "M[I].y = get_sub_group_id();\n";
1542         kernel_sstr << TypeManager<Ty>::add_typedef();
1543         kernel_sstr << src;
1544         const std::string &kernel_str = kernel_sstr.str();
1545         const char *kernel_src = kernel_str.c_str();
1546 
1547         error = create_single_kernel_helper(context, &program, &kernel, 1,
1548                                             &kernel_src, kname);
1549         if (error != CL_SUCCESS) return TEST_FAIL;
1550 
1551         // Determine some local dimensions to use for the test.
1552         error = get_max_common_work_group_size(
1553             context, kernel, test_params.global_workgroup_size, &local);
1554         test_error_fail(error, "get_max_common_work_group_size failed");
1555 
1556         // Limit it a bit so we have muliple work groups
1557         // Ideally this will still be large enough to give us multiple
1558         if (local > test_params.local_workgroup_size)
1559             local = test_params.local_workgroup_size;
1560 
1561 
1562         // Get the sub group info
1563         subgroupsAPI subgroupsApiSet(platform, test_params.use_core_subgroups);
1564         clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr =
1565             subgroupsApiSet.clGetKernelSubGroupInfo_ptr();
1566         if (clGetKernelSubGroupInfo_ptr == NULL)
1567         {
1568             log_error("ERROR: %s function not available\n",
1569                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1570             return TEST_FAIL;
1571         }
1572         error = clGetKernelSubGroupInfo_ptr(
1573             kernel, device, CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
1574             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1575         if (error != CL_SUCCESS)
1576         {
1577             log_error("ERROR: %s function error for "
1578                       "CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE\n",
1579                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1580             return TEST_FAIL;
1581         }
1582 
1583         subgroup_size = tmp;
1584 
1585         error = clGetKernelSubGroupInfo_ptr(
1586             kernel, device, CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE,
1587             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1588         if (error != CL_SUCCESS)
1589         {
1590             log_error("ERROR: %s function error for "
1591                       "CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE\n",
1592                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1593             return TEST_FAIL;
1594         }
1595 
1596         num_subgroups = tmp;
1597         // Make sure the number of sub groups is what we expect
1598         if (num_subgroups != (local + subgroup_size - 1) / subgroup_size)
1599         {
1600             log_error("ERROR: unexpected number of subgroups (%zu) returned\n",
1601                       num_subgroups);
1602             return TEST_FAIL;
1603         }
1604 
1605         std::vector<Ty> idata;
1606         std::vector<Ty> odata;
1607         size_t input_array_size = global;
1608         size_t output_array_size = global;
1609         size_t dynscl = test_params.dynsc;
1610 
1611         if (dynscl != 0)
1612         {
1613             input_array_size = global / local * num_subgroups * dynscl;
1614             output_array_size = global / local * dynscl;
1615         }
1616 
1617         idata.resize(input_array_size);
1618         odata.resize(output_array_size);
1619 
1620         if (test_params.divergence_mask_arg != -1)
1621         {
1622             cl_uint4 mask_vector;
1623             mask_vector.x = 0xffffffffU;
1624             mask_vector.y = 0xffffffffU;
1625             mask_vector.z = 0xffffffffU;
1626             mask_vector.w = 0xffffffffU;
1627             error = clSetKernelArg(kernel, test_params.divergence_mask_arg,
1628                                    sizeof(cl_uint4), &mask_vector);
1629             test_error_fail(error, "Unable to set divergence mask argument");
1630         }
1631 
1632         if (test_params.cluster_size_arg != -1)
1633         {
1634             cl_uint dummy_cluster_size = 1;
1635             error = clSetKernelArg(kernel, test_params.cluster_size_arg,
1636                                    sizeof(cl_uint), &dummy_cluster_size);
1637             test_error_fail(error, "Unable to set dummy cluster size");
1638         }
1639 
1640         KernelExecutor<Ty, Fns> executor(
1641             context, queue, kernel, global, local, idata.data(),
1642             input_array_size * sizeof(Ty), mapin.data(), mapout.data(),
1643             sgmap.data(), global * sizeof(cl_int4), odata.data(),
1644             output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
1645 
1646         // Run the kernel once on zeroes to get the map
1647         memset(idata.data(), 0, input_array_size * sizeof(Ty));
1648         error = executor.run();
1649         test_error_fail(error, "Running kernel first time failed");
1650 
1651         // Generate the desired input for the kernel
1652         test_params.subgroup_size = subgroup_size;
1653         Fns::gen(idata.data(), mapin.data(), sgmap.data(), test_params);
1654 
1655         test_status status;
1656 
1657         if (test_params.divergence_mask_arg != -1)
1658         {
1659             for (auto &mask : test_params.all_work_item_masks)
1660             {
1661                 test_params.work_items_mask = mask;
1662                 cl_uint4 mask_vector = bs128_to_cl_uint4(mask);
1663                 clSetKernelArg(kernel, test_params.divergence_mask_arg,
1664                                sizeof(cl_uint4), &mask_vector);
1665 
1666                 status = executor.run_and_check(test_params);
1667 
1668                 if (status == TEST_FAIL) break;
1669             }
1670         }
1671         else
1672         {
1673             status = executor.run_and_check(test_params);
1674         }
1675         // Detailed failure and skip messages should be logged by
1676         // run_and_check.
1677         if (status == TEST_PASS)
1678         {
1679             Fns::log_test(test_params, " passed");
1680         }
1681         else if (!executor.run_failed && status == TEST_FAIL)
1682         {
1683             test_fail("Data verification failed\n");
1684         }
1685         return status;
1686     }
1687 };
1688 
1689 static void set_last_workgroup_params(int non_uniform_size,
1690                                       int &number_of_subgroups,
1691                                       int subgroup_size, int &workgroup_size,
1692                                       int &last_subgroup_size)
1693 {
1694     number_of_subgroups = 1 + non_uniform_size / subgroup_size;
1695     last_subgroup_size = non_uniform_size % subgroup_size;
1696     workgroup_size = non_uniform_size;
1697 }
1698 
1699 template <typename Ty>
1700 static void set_randomdata_for_subgroup(Ty *workgroup, int wg_offset,
1701                                         int current_sbs)
1702 {
1703     int randomize_data = (int)(genrand_int32(gMTdata) % 3);
1704     // Initialize data matrix indexed by local id and sub group id
1705     switch (randomize_data)
1706     {
1707         case 0:
1708             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1709             break;
1710         case 1: {
1711             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1712             int wi_id = (int)(genrand_int32(gMTdata) % (cl_uint)current_sbs);
1713             set_value(workgroup[wg_offset + wi_id], 41);
1714         }
1715         break;
1716         case 2:
1717             memset(&workgroup[wg_offset], 0xff, current_sbs * sizeof(Ty));
1718             break;
1719     }
1720 }
1721 
1722 struct RunTestForType
1723 {
1724     RunTestForType(cl_device_id device, cl_context context,
1725                    cl_command_queue queue, int num_elements,
1726                    WorkGroupParams test_params)
1727         : device_(device), context_(context), queue_(queue),
1728           num_elements_(num_elements), test_params_(test_params)
1729     {}
1730     template <typename T, typename U>
1731     int run_impl(const std::string &function_name)
1732     {
1733         int error = TEST_PASS;
1734         std::string source =
1735             std::regex_replace(test_params_.get_kernel_source(function_name),
1736                                std::regex("\\%s"), function_name);
1737         std::string kernel_name = "test_" + function_name;
1738         error =
1739             test<T, U>::run(device_, context_, queue_, num_elements_,
1740                             kernel_name.c_str(), source.c_str(), test_params_);
1741 
1742         // If we return TEST_SKIPPED_ITSELF here, then an entire suite may be
1743         // reported as having been skipped even if some tests within it
1744         // passed, as the status codes are erroneously ORed together:
1745         return error == TEST_FAIL ? TEST_FAIL : TEST_PASS;
1746     }
1747 
1748 private:
1749     cl_device_id device_;
1750     cl_context context_;
1751     cl_command_queue queue_;
1752     int num_elements_;
1753     WorkGroupParams test_params_;
1754 };
1755 
1756 #endif
1757