1 //
2 // Copyright (c) 2017 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBHELPERS_H
17 #define SUBHELPERS_H
18
19 #include "testHarness.h"
20 #include "kernelHelpers.h"
21 #include "typeWrappers.h"
22 #include "imageHelpers.h"
23
24 #include <limits>
25 #include <vector>
26 #include <type_traits>
27 #include <bitset>
28 #include <regex>
29 #include <map>
30
31 #define NR_OF_ACTIVE_WORK_ITEMS 4
32
33 extern MTdata gMTdata;
34 typedef std::bitset<128> bs128;
35 extern cl_half_rounding_mode g_rounding_mode;
36
cl_uint4_to_bs128(cl_uint4 v)37 static bs128 cl_uint4_to_bs128(cl_uint4 v)
38 {
39 return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64)
40 | (bs128(v.s3) << 96);
41 }
42
bs128_to_cl_uint4(bs128 v)43 static cl_uint4 bs128_to_cl_uint4(bs128 v)
44 {
45 bs128 bs128_ffffffff = 0xffffffffU;
46
47 cl_uint4 r;
48 r.s0 = ((v >> 0) & bs128_ffffffff).to_ulong();
49 r.s1 = ((v >> 32) & bs128_ffffffff).to_ulong();
50 r.s2 = ((v >> 64) & bs128_ffffffff).to_ulong();
51 r.s3 = ((v >> 96) & bs128_ffffffff).to_ulong();
52
53 return r;
54 }
55
56 struct WorkGroupParams
57 {
58
59 WorkGroupParams(size_t gws, size_t lws, int dm_arg = -1, int cs_arg = -1)
global_workgroup_sizeWorkGroupParams60 : global_workgroup_size(gws), local_workgroup_size(lws),
61 divergence_mask_arg(dm_arg), cluster_size_arg(cs_arg)
62 {
63 subgroup_size = 0;
64 cluster_size = 0;
65 work_items_mask = 0;
66 use_core_subgroups = true;
67 dynsc = 0;
68 load_masks();
69 }
70 size_t global_workgroup_size;
71 size_t local_workgroup_size;
72 size_t subgroup_size;
73 cl_uint cluster_size;
74 bs128 work_items_mask;
75 size_t dynsc;
76 bool use_core_subgroups;
77 std::vector<bs128> all_work_item_masks;
78 int divergence_mask_arg;
79 int cluster_size_arg;
80 void save_kernel_source(const std::string &source, std::string name = "")
81 {
82 if (name == "")
83 {
84 name = "default";
85 }
86 if (kernel_function_name.find(name) != kernel_function_name.end())
87 {
88 log_info("Kernel definition duplication. Source will be "
89 "overwritten for function name %s\n",
90 name.c_str());
91 }
92 kernel_function_name[name] = source;
93 };
94 // return specific defined kernel or default.
get_kernel_sourceWorkGroupParams95 std::string get_kernel_source(std::string name)
96 {
97 if (kernel_function_name.find(name) == kernel_function_name.end())
98 {
99 return kernel_function_name["default"];
100 }
101 return kernel_function_name[name];
102 }
103
104
105 private:
106 std::map<std::string, std::string> kernel_function_name;
load_masksWorkGroupParams107 void load_masks()
108 {
109 if (divergence_mask_arg != -1)
110 {
111 // 1 in string will be set 1, 0 will be set 0
112 bs128 mask_0xf0f0f0f0("11110000111100001111000011110000"
113 "11110000111100001111000011110000"
114 "11110000111100001111000011110000"
115 "11110000111100001111000011110000",
116 128, '0', '1');
117 all_work_item_masks.push_back(mask_0xf0f0f0f0);
118 // 1 in string will be set 0, 0 will be set 1
119 bs128 mask_0x0f0f0f0f("11110000111100001111000011110000"
120 "11110000111100001111000011110000"
121 "11110000111100001111000011110000"
122 "11110000111100001111000011110000",
123 128, '1', '0');
124 all_work_item_masks.push_back(mask_0x0f0f0f0f);
125 bs128 mask_0x5555aaaa("10101010101010101010101010101010"
126 "10101010101010101010101010101010"
127 "10101010101010101010101010101010"
128 "10101010101010101010101010101010",
129 128, '0', '1');
130 all_work_item_masks.push_back(mask_0x5555aaaa);
131 bs128 mask_0xaaaa5555("10101010101010101010101010101010"
132 "10101010101010101010101010101010"
133 "10101010101010101010101010101010"
134 "10101010101010101010101010101010",
135 128, '1', '0');
136 all_work_item_masks.push_back(mask_0xaaaa5555);
137 // 0x0f0ff0f0
138 bs128 mask_0x0f0ff0f0("00001111000011111111000011110000"
139 "00001111000011111111000011110000"
140 "00001111000011111111000011110000"
141 "00001111000011111111000011110000",
142 128, '0', '1');
143 all_work_item_masks.push_back(mask_0x0f0ff0f0);
144 // 0xff0000ff
145 bs128 mask_0xff0000ff("11111111000000000000000011111111"
146 "11111111000000000000000011111111"
147 "11111111000000000000000011111111"
148 "11111111000000000000000011111111",
149 128, '0', '1');
150 all_work_item_masks.push_back(mask_0xff0000ff);
151 // 0xff00ff00
152 bs128 mask_0xff00ff00("11111111000000001111111100000000"
153 "11111111000000001111111100000000"
154 "11111111000000001111111100000000"
155 "11111111000000001111111100000000",
156 128, '0', '1');
157 all_work_item_masks.push_back(mask_0xff00ff00);
158 // 0x00ffff00
159 bs128 mask_0x00ffff00("00000000111111111111111100000000"
160 "00000000111111111111111100000000"
161 "00000000111111111111111100000000"
162 "00000000111111111111111100000000",
163 128, '0', '1');
164 all_work_item_masks.push_back(mask_0x00ffff00);
165 // 0x80 1 workitem highest id for 8 subgroup size
166 bs128 mask_0x80808080("10000000100000001000000010000000"
167 "10000000100000001000000010000000"
168 "10000000100000001000000010000000"
169 "10000000100000001000000010000000",
170 128, '0', '1');
171
172 all_work_item_masks.push_back(mask_0x80808080);
173 // 0x8000 1 workitem highest id for 16 subgroup size
174 bs128 mask_0x80008000("10000000000000001000000000000000"
175 "10000000000000001000000000000000"
176 "10000000000000001000000000000000"
177 "10000000000000001000000000000000",
178 128, '0', '1');
179 all_work_item_masks.push_back(mask_0x80008000);
180 // 0x80000000 1 workitem highest id for 32 subgroup size
181 bs128 mask_0x80000000("10000000000000000000000000000000"
182 "10000000000000000000000000000000"
183 "10000000000000000000000000000000"
184 "10000000000000000000000000000000",
185 128, '0', '1');
186 all_work_item_masks.push_back(mask_0x80000000);
187 // 0x80000000 00000000 1 workitem highest id for 64 subgroup size
188 // 0x80000000 1 workitem highest id for 32 subgroup size
189 bs128 mask_0x8000000000000000("10000000000000000000000000000000"
190 "00000000000000000000000000000000"
191 "10000000000000000000000000000000"
192 "00000000000000000000000000000000",
193 128, '0', '1');
194
195 all_work_item_masks.push_back(mask_0x8000000000000000);
196 // 0x80000000 00000000 00000000 00000000 1 workitem highest id for
197 // 128 subgroup size
198 bs128 mask_0x80000000000000000000000000000000(
199 "10000000000000000000000000000000"
200 "00000000000000000000000000000000"
201 "00000000000000000000000000000000"
202 "00000000000000000000000000000000",
203 128, '0', '1');
204 all_work_item_masks.push_back(
205 mask_0x80000000000000000000000000000000);
206
207 bs128 mask_0xffffffff("11111111111111111111111111111111"
208 "11111111111111111111111111111111"
209 "11111111111111111111111111111111"
210 "11111111111111111111111111111111",
211 128, '0', '1');
212 all_work_item_masks.push_back(mask_0xffffffff);
213 }
214 }
215 };
216
217 enum class SubgroupsBroadcastOp
218 {
219 broadcast,
220 broadcast_first,
221 non_uniform_broadcast
222 };
223
224 enum class NonUniformVoteOp
225 {
226 elect,
227 all,
228 any,
229 all_equal
230 };
231
232 enum class BallotOp
233 {
234 ballot,
235 inverse_ballot,
236 ballot_bit_extract,
237 ballot_bit_count,
238 ballot_inclusive_scan,
239 ballot_exclusive_scan,
240 ballot_find_lsb,
241 ballot_find_msb,
242 eq_mask,
243 ge_mask,
244 gt_mask,
245 le_mask,
246 lt_mask,
247 };
248
249 enum class ShuffleOp
250 {
251 shuffle,
252 shuffle_up,
253 shuffle_down,
254 shuffle_xor,
255 rotate,
256 clustered_rotate,
257 };
258
259 enum class ArithmeticOp
260 {
261 add_,
262 max_,
263 min_,
264 mul_,
265 and_,
266 or_,
267 xor_,
268 logical_and,
269 logical_or,
270 logical_xor
271 };
272
operation_names(ArithmeticOp operation)273 static const char *const operation_names(ArithmeticOp operation)
274 {
275 switch (operation)
276 {
277 case ArithmeticOp::add_: return "add";
278 case ArithmeticOp::max_: return "max";
279 case ArithmeticOp::min_: return "min";
280 case ArithmeticOp::mul_: return "mul";
281 case ArithmeticOp::and_: return "and";
282 case ArithmeticOp::or_: return "or";
283 case ArithmeticOp::xor_: return "xor";
284 case ArithmeticOp::logical_and: return "logical_and";
285 case ArithmeticOp::logical_or: return "logical_or";
286 case ArithmeticOp::logical_xor: return "logical_xor";
287 default: log_error("Unknown operation request\n"); break;
288 }
289 return "";
290 }
291
operation_names(BallotOp operation)292 static const char *const operation_names(BallotOp operation)
293 {
294 switch (operation)
295 {
296 case BallotOp::ballot: return "ballot";
297 case BallotOp::inverse_ballot: return "inverse_ballot";
298 case BallotOp::ballot_bit_extract: return "bit_extract";
299 case BallotOp::ballot_bit_count: return "bit_count";
300 case BallotOp::ballot_inclusive_scan: return "inclusive_scan";
301 case BallotOp::ballot_exclusive_scan: return "exclusive_scan";
302 case BallotOp::ballot_find_lsb: return "find_lsb";
303 case BallotOp::ballot_find_msb: return "find_msb";
304 case BallotOp::eq_mask: return "eq";
305 case BallotOp::ge_mask: return "ge";
306 case BallotOp::gt_mask: return "gt";
307 case BallotOp::le_mask: return "le";
308 case BallotOp::lt_mask: return "lt";
309 default: log_error("Unknown operation request\n"); break;
310 }
311 return "";
312 }
313
operation_names(ShuffleOp operation)314 static const char *const operation_names(ShuffleOp operation)
315 {
316 switch (operation)
317 {
318 case ShuffleOp::shuffle: return "shuffle";
319 case ShuffleOp::shuffle_up: return "shuffle_up";
320 case ShuffleOp::shuffle_down: return "shuffle_down";
321 case ShuffleOp::shuffle_xor: return "shuffle_xor";
322 case ShuffleOp::rotate: return "rotate";
323 case ShuffleOp::clustered_rotate: return "clustered_rotate";
324 default: log_error("Unknown operation request\n"); break;
325 }
326 return "";
327 }
328
operation_names(NonUniformVoteOp operation)329 static const char *const operation_names(NonUniformVoteOp operation)
330 {
331 switch (operation)
332 {
333 case NonUniformVoteOp::all: return "all";
334 case NonUniformVoteOp::all_equal: return "all_equal";
335 case NonUniformVoteOp::any: return "any";
336 case NonUniformVoteOp::elect: return "elect";
337 default: log_error("Unknown operation request\n"); break;
338 }
339 return "";
340 }
341
operation_names(SubgroupsBroadcastOp operation)342 static const char *const operation_names(SubgroupsBroadcastOp operation)
343 {
344 switch (operation)
345 {
346 case SubgroupsBroadcastOp::broadcast: return "broadcast";
347 case SubgroupsBroadcastOp::broadcast_first: return "broadcast_first";
348 case SubgroupsBroadcastOp::non_uniform_broadcast:
349 return "non_uniform_broadcast";
350 default: log_error("Unknown operation request\n"); break;
351 }
352 return "";
353 }
354
355 class subgroupsAPI {
356 public:
subgroupsAPI(cl_platform_id platform,bool use_core_subgroups)357 subgroupsAPI(cl_platform_id platform, bool use_core_subgroups)
358 {
359 static_assert(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE
360 == CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR,
361 "Enums have to be the same");
362 static_assert(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE
363 == CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR,
364 "Enums have to be the same");
365 if (use_core_subgroups)
366 {
367 _clGetKernelSubGroupInfo_ptr = &clGetKernelSubGroupInfo;
368 clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfo";
369 }
370 else
371 {
372 _clGetKernelSubGroupInfo_ptr = (clGetKernelSubGroupInfoKHR_fn)
373 clGetExtensionFunctionAddressForPlatform(
374 platform, "clGetKernelSubGroupInfoKHR");
375 clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfoKHR";
376 }
377 }
clGetKernelSubGroupInfo_ptr()378 clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr()
379 {
380 return _clGetKernelSubGroupInfo_ptr;
381 }
382 const char *clGetKernelSubGroupInfo_name;
383
384 private:
385 clGetKernelSubGroupInfoKHR_fn _clGetKernelSubGroupInfo_ptr;
386 };
387
388 // Need to defined custom type for vector size = 3 and half type. This is
389 // because of 3-component types are otherwise indistinguishable from the
390 // 4-component types, and because the half type is indistinguishable from some
391 // other 16-bit type (ushort)
392 namespace subgroups {
393 struct cl_char3
394 {
395 ::cl_char3 data;
396 };
397 struct cl_uchar3
398 {
399 ::cl_uchar3 data;
400 };
401 struct cl_short3
402 {
403 ::cl_short3 data;
404 };
405 struct cl_ushort3
406 {
407 ::cl_ushort3 data;
408 };
409 struct cl_int3
410 {
411 ::cl_int3 data;
412 };
413 struct cl_uint3
414 {
415 ::cl_uint3 data;
416 };
417 struct cl_long3
418 {
419 ::cl_long3 data;
420 };
421 struct cl_ulong3
422 {
423 ::cl_ulong3 data;
424 };
425 struct cl_float3
426 {
427 ::cl_float3 data;
428 };
429 struct cl_double3
430 {
431 ::cl_double3 data;
432 };
433 struct cl_half
434 {
435 ::cl_half data;
436 };
437 struct cl_half2
438 {
439 ::cl_half2 data;
440 };
441 struct cl_half3
442 {
443 ::cl_half3 data;
444 };
445 struct cl_half4
446 {
447 ::cl_half4 data;
448 };
449 struct cl_half8
450 {
451 ::cl_half8 data;
452 };
453 struct cl_half16
454 {
455 ::cl_half16 data;
456 };
457 }
458
int64_ok(cl_device_id device)459 static bool int64_ok(cl_device_id device)
460 {
461 char profile[128];
462 int error;
463
464 error = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile),
465 (void *)&profile, NULL);
466 if (error)
467 {
468 log_info("clGetDeviceInfo failed with CL_DEVICE_PROFILE\n");
469 return false;
470 }
471
472 if (strcmp(profile, "EMBEDDED_PROFILE") == 0)
473 return is_extension_available(device, "cles_khr_int64");
474
475 return true;
476 }
477
double_ok(cl_device_id device)478 static bool double_ok(cl_device_id device)
479 {
480 int error;
481 cl_device_fp_config c;
482 error = clGetDeviceInfo(device, CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(c),
483 (void *)&c, NULL);
484 if (error)
485 {
486 log_info("clGetDeviceInfo failed with CL_DEVICE_DOUBLE_FP_CONFIG\n");
487 return false;
488 }
489 return c != 0;
490 }
491
half_ok(cl_device_id device)492 static bool half_ok(cl_device_id device)
493 {
494 int error;
495 cl_device_fp_config c;
496 error = clGetDeviceInfo(device, CL_DEVICE_HALF_FP_CONFIG, sizeof(c),
497 (void *)&c, NULL);
498 if (error)
499 {
500 log_info("clGetDeviceInfo failed with CL_DEVICE_HALF_FP_CONFIG\n");
501 return false;
502 }
503 return c != 0;
504 }
505
506 template <typename Ty> struct CommonTypeManager
507 {
508
nameCommonTypeManager509 static const char *name() { return ""; }
add_typedefCommonTypeManager510 static const char *add_typedef() { return "\n"; }
511 typedef std::false_type is_vector_type;
512 typedef std::false_type is_sb_vector_size3;
513 typedef std::false_type is_sb_vector_type;
514 typedef std::false_type is_sb_scalar_type;
type_supportedCommonTypeManager515 static const bool type_supported(cl_device_id) { return true; }
identify_limitsCommonTypeManager516 static const Ty identify_limits(ArithmeticOp operation)
517 {
518 switch (operation)
519 {
520 case ArithmeticOp::add_: return (Ty)0;
521 case ArithmeticOp::max_: return (std::numeric_limits<Ty>::min)();
522 case ArithmeticOp::min_: return (std::numeric_limits<Ty>::max)();
523 case ArithmeticOp::mul_: return (Ty)1;
524 case ArithmeticOp::and_: return (Ty)~0;
525 case ArithmeticOp::or_: return (Ty)0;
526 case ArithmeticOp::xor_: return (Ty)0;
527 default: log_error("Unknown operation request\n"); break;
528 }
529 return 0;
530 }
531 };
532
533 template <typename> struct TypeManager;
534
535 template <> struct TypeManager<cl_int> : public CommonTypeManager<cl_int>
536 {
537 static const char *name() { return "int"; }
538 static const char *add_typedef() { return "typedef int Type;\n"; }
539 static cl_int identify_limits(ArithmeticOp operation)
540 {
541 switch (operation)
542 {
543 case ArithmeticOp::add_: return (cl_int)0;
544 case ArithmeticOp::max_:
545 return (std::numeric_limits<cl_int>::min)();
546 case ArithmeticOp::min_:
547 return (std::numeric_limits<cl_int>::max)();
548 case ArithmeticOp::mul_: return (cl_int)1;
549 case ArithmeticOp::and_: return (cl_int)~0;
550 case ArithmeticOp::or_: return (cl_int)0;
551 case ArithmeticOp::xor_: return (cl_int)0;
552 case ArithmeticOp::logical_and: return (cl_int)1;
553 case ArithmeticOp::logical_or: return (cl_int)0;
554 case ArithmeticOp::logical_xor: return (cl_int)0;
555 default: log_error("Unknown operation request\n"); break;
556 }
557 return 0;
558 }
559 };
560 template <> struct TypeManager<cl_int2> : public CommonTypeManager<cl_int2>
561 {
562 static const char *name() { return "int2"; }
563 static const char *add_typedef() { return "typedef int2 Type;\n"; }
564 typedef std::true_type is_vector_type;
565 using scalar_type = cl_int;
566 };
567 template <>
568 struct TypeManager<subgroups::cl_int3>
569 : public CommonTypeManager<subgroups::cl_int3>
570 {
571 static const char *name() { return "int3"; }
572 static const char *add_typedef() { return "typedef int3 Type;\n"; }
573 typedef std::true_type is_sb_vector_size3;
574 using scalar_type = cl_int;
575 };
576 template <> struct TypeManager<cl_int4> : public CommonTypeManager<cl_int4>
577 {
578 static const char *name() { return "int4"; }
579 static const char *add_typedef() { return "typedef int4 Type;\n"; }
580 using scalar_type = cl_int;
581 typedef std::true_type is_vector_type;
582 };
583 template <> struct TypeManager<cl_int8> : public CommonTypeManager<cl_int8>
584 {
585 static const char *name() { return "int8"; }
586 static const char *add_typedef() { return "typedef int8 Type;\n"; }
587 using scalar_type = cl_int;
588 typedef std::true_type is_vector_type;
589 };
590 template <> struct TypeManager<cl_int16> : public CommonTypeManager<cl_int16>
591 {
592 static const char *name() { return "int16"; }
593 static const char *add_typedef() { return "typedef int16 Type;\n"; }
594 using scalar_type = cl_int;
595 typedef std::true_type is_vector_type;
596 };
597 // cl_uint
598 template <> struct TypeManager<cl_uint> : public CommonTypeManager<cl_uint>
599 {
600 static const char *name() { return "uint"; }
601 static const char *add_typedef() { return "typedef uint Type;\n"; }
602 };
603 template <> struct TypeManager<cl_uint2> : public CommonTypeManager<cl_uint2>
604 {
605 static const char *name() { return "uint2"; }
606 static const char *add_typedef() { return "typedef uint2 Type;\n"; }
607 using scalar_type = cl_uint;
608 typedef std::true_type is_vector_type;
609 };
610 template <>
611 struct TypeManager<subgroups::cl_uint3>
612 : public CommonTypeManager<subgroups::cl_uint3>
613 {
614 static const char *name() { return "uint3"; }
615 static const char *add_typedef() { return "typedef uint3 Type;\n"; }
616 typedef std::true_type is_sb_vector_size3;
617 using scalar_type = cl_uint;
618 };
619 template <> struct TypeManager<cl_uint4> : public CommonTypeManager<cl_uint4>
620 {
621 static const char *name() { return "uint4"; }
622 static const char *add_typedef() { return "typedef uint4 Type;\n"; }
623 using scalar_type = cl_uint;
624 typedef std::true_type is_vector_type;
625 };
626 template <> struct TypeManager<cl_uint8> : public CommonTypeManager<cl_uint8>
627 {
628 static const char *name() { return "uint8"; }
629 static const char *add_typedef() { return "typedef uint8 Type;\n"; }
630 using scalar_type = cl_uint;
631 typedef std::true_type is_vector_type;
632 };
633 template <> struct TypeManager<cl_uint16> : public CommonTypeManager<cl_uint16>
634 {
635 static const char *name() { return "uint16"; }
636 static const char *add_typedef() { return "typedef uint16 Type;\n"; }
637 using scalar_type = cl_uint;
638 typedef std::true_type is_vector_type;
639 };
640 // cl_short
641 template <> struct TypeManager<cl_short> : public CommonTypeManager<cl_short>
642 {
643 static const char *name() { return "short"; }
644 static const char *add_typedef() { return "typedef short Type;\n"; }
645 };
646 template <> struct TypeManager<cl_short2> : public CommonTypeManager<cl_short2>
647 {
648 static const char *name() { return "short2"; }
649 static const char *add_typedef() { return "typedef short2 Type;\n"; }
650 using scalar_type = cl_short;
651 typedef std::true_type is_vector_type;
652 };
653 template <>
654 struct TypeManager<subgroups::cl_short3>
655 : public CommonTypeManager<subgroups::cl_short3>
656 {
657 static const char *name() { return "short3"; }
658 static const char *add_typedef() { return "typedef short3 Type;\n"; }
659 typedef std::true_type is_sb_vector_size3;
660 using scalar_type = cl_short;
661 };
662 template <> struct TypeManager<cl_short4> : public CommonTypeManager<cl_short4>
663 {
664 static const char *name() { return "short4"; }
665 static const char *add_typedef() { return "typedef short4 Type;\n"; }
666 using scalar_type = cl_short;
667 typedef std::true_type is_vector_type;
668 };
669 template <> struct TypeManager<cl_short8> : public CommonTypeManager<cl_short8>
670 {
671 static const char *name() { return "short8"; }
672 static const char *add_typedef() { return "typedef short8 Type;\n"; }
673 using scalar_type = cl_short;
674 typedef std::true_type is_vector_type;
675 };
676 template <>
677 struct TypeManager<cl_short16> : public CommonTypeManager<cl_short16>
678 {
679 static const char *name() { return "short16"; }
680 static const char *add_typedef() { return "typedef short16 Type;\n"; }
681 using scalar_type = cl_short;
682 typedef std::true_type is_vector_type;
683 };
684 // cl_ushort
685 template <> struct TypeManager<cl_ushort> : public CommonTypeManager<cl_ushort>
686 {
687 static const char *name() { return "ushort"; }
688 static const char *add_typedef() { return "typedef ushort Type;\n"; }
689 };
690 template <>
691 struct TypeManager<cl_ushort2> : public CommonTypeManager<cl_ushort2>
692 {
693 static const char *name() { return "ushort2"; }
694 static const char *add_typedef() { return "typedef ushort2 Type;\n"; }
695 using scalar_type = cl_ushort;
696 typedef std::true_type is_vector_type;
697 };
698 template <>
699 struct TypeManager<subgroups::cl_ushort3>
700 : public CommonTypeManager<subgroups::cl_ushort3>
701 {
702 static const char *name() { return "ushort3"; }
703 static const char *add_typedef() { return "typedef ushort3 Type;\n"; }
704 typedef std::true_type is_sb_vector_size3;
705 using scalar_type = cl_ushort;
706 };
707 template <>
708 struct TypeManager<cl_ushort4> : public CommonTypeManager<cl_ushort4>
709 {
710 static const char *name() { return "ushort4"; }
711 static const char *add_typedef() { return "typedef ushort4 Type;\n"; }
712 using scalar_type = cl_ushort;
713 typedef std::true_type is_vector_type;
714 };
715 template <>
716 struct TypeManager<cl_ushort8> : public CommonTypeManager<cl_ushort8>
717 {
718 static const char *name() { return "ushort8"; }
719 static const char *add_typedef() { return "typedef ushort8 Type;\n"; }
720 using scalar_type = cl_ushort;
721 typedef std::true_type is_vector_type;
722 };
723 template <>
724 struct TypeManager<cl_ushort16> : public CommonTypeManager<cl_ushort16>
725 {
726 static const char *name() { return "ushort16"; }
727 static const char *add_typedef() { return "typedef ushort16 Type;\n"; }
728 using scalar_type = cl_ushort;
729 typedef std::true_type is_vector_type;
730 };
731 // cl_char
732 template <> struct TypeManager<cl_char> : public CommonTypeManager<cl_char>
733 {
734 static const char *name() { return "char"; }
735 static const char *add_typedef() { return "typedef char Type;\n"; }
736 };
737 template <> struct TypeManager<cl_char2> : public CommonTypeManager<cl_char2>
738 {
739 static const char *name() { return "char2"; }
740 static const char *add_typedef() { return "typedef char2 Type;\n"; }
741 using scalar_type = cl_char;
742 typedef std::true_type is_vector_type;
743 };
744 template <>
745 struct TypeManager<subgroups::cl_char3>
746 : public CommonTypeManager<subgroups::cl_char3>
747 {
748 static const char *name() { return "char3"; }
749 static const char *add_typedef() { return "typedef char3 Type;\n"; }
750 typedef std::true_type is_sb_vector_size3;
751 using scalar_type = cl_char;
752 };
753 template <> struct TypeManager<cl_char4> : public CommonTypeManager<cl_char4>
754 {
755 static const char *name() { return "char4"; }
756 static const char *add_typedef() { return "typedef char4 Type;\n"; }
757 using scalar_type = cl_char;
758 typedef std::true_type is_vector_type;
759 };
760 template <> struct TypeManager<cl_char8> : public CommonTypeManager<cl_char8>
761 {
762 static const char *name() { return "char8"; }
763 static const char *add_typedef() { return "typedef char8 Type;\n"; }
764 using scalar_type = cl_char;
765 typedef std::true_type is_vector_type;
766 };
767 template <> struct TypeManager<cl_char16> : public CommonTypeManager<cl_char16>
768 {
769 static const char *name() { return "char16"; }
770 static const char *add_typedef() { return "typedef char16 Type;\n"; }
771 using scalar_type = cl_char;
772 typedef std::true_type is_vector_type;
773 };
774 // cl_uchar
775 template <> struct TypeManager<cl_uchar> : public CommonTypeManager<cl_uchar>
776 {
777 static const char *name() { return "uchar"; }
778 static const char *add_typedef() { return "typedef uchar Type;\n"; }
779 };
780 template <> struct TypeManager<cl_uchar2> : public CommonTypeManager<cl_uchar2>
781 {
782 static const char *name() { return "uchar2"; }
783 static const char *add_typedef() { return "typedef uchar2 Type;\n"; }
784 using scalar_type = cl_uchar;
785 typedef std::true_type is_vector_type;
786 };
787 template <>
788 struct TypeManager<subgroups::cl_uchar3>
789 : public CommonTypeManager<subgroups::cl_char3>
790 {
791 static const char *name() { return "uchar3"; }
792 static const char *add_typedef() { return "typedef uchar3 Type;\n"; }
793 typedef std::true_type is_sb_vector_size3;
794 using scalar_type = cl_uchar;
795 };
796 template <> struct TypeManager<cl_uchar4> : public CommonTypeManager<cl_uchar4>
797 {
798 static const char *name() { return "uchar4"; }
799 static const char *add_typedef() { return "typedef uchar4 Type;\n"; }
800 using scalar_type = cl_uchar;
801 typedef std::true_type is_vector_type;
802 };
803 template <> struct TypeManager<cl_uchar8> : public CommonTypeManager<cl_uchar8>
804 {
805 static const char *name() { return "uchar8"; }
806 static const char *add_typedef() { return "typedef uchar8 Type;\n"; }
807 using scalar_type = cl_uchar;
808 typedef std::true_type is_vector_type;
809 };
810 template <>
811 struct TypeManager<cl_uchar16> : public CommonTypeManager<cl_uchar16>
812 {
813 static const char *name() { return "uchar16"; }
814 static const char *add_typedef() { return "typedef uchar16 Type;\n"; }
815 using scalar_type = cl_uchar;
816 typedef std::true_type is_vector_type;
817 };
818 // cl_long
819 template <> struct TypeManager<cl_long> : public CommonTypeManager<cl_long>
820 {
821 static const char *name() { return "long"; }
822 static const char *add_typedef() { return "typedef long Type;\n"; }
823 static const bool type_supported(cl_device_id device)
824 {
825 return int64_ok(device);
826 }
827 };
828 template <> struct TypeManager<cl_long2> : public CommonTypeManager<cl_long2>
829 {
830 static const char *name() { return "long2"; }
831 static const char *add_typedef() { return "typedef long2 Type;\n"; }
832 using scalar_type = cl_long;
833 typedef std::true_type is_vector_type;
834 static const bool type_supported(cl_device_id device)
835 {
836 return int64_ok(device);
837 }
838 };
839 template <>
840 struct TypeManager<subgroups::cl_long3>
841 : public CommonTypeManager<subgroups::cl_long3>
842 {
843 static const char *name() { return "long3"; }
844 static const char *add_typedef() { return "typedef long3 Type;\n"; }
845 typedef std::true_type is_sb_vector_size3;
846 using scalar_type = cl_long;
847 static const bool type_supported(cl_device_id device)
848 {
849 return int64_ok(device);
850 }
851 };
852 template <> struct TypeManager<cl_long4> : public CommonTypeManager<cl_long4>
853 {
854 static const char *name() { return "long4"; }
855 static const char *add_typedef() { return "typedef long4 Type;\n"; }
856 using scalar_type = cl_long;
857 typedef std::true_type is_vector_type;
858 static const bool type_supported(cl_device_id device)
859 {
860 return int64_ok(device);
861 }
862 };
863 template <> struct TypeManager<cl_long8> : public CommonTypeManager<cl_long8>
864 {
865 static const char *name() { return "long8"; }
866 static const char *add_typedef() { return "typedef long8 Type;\n"; }
867 using scalar_type = cl_long;
868 typedef std::true_type is_vector_type;
869 static const bool type_supported(cl_device_id device)
870 {
871 return int64_ok(device);
872 }
873 };
874 template <> struct TypeManager<cl_long16> : public CommonTypeManager<cl_long16>
875 {
876 static const char *name() { return "long16"; }
877 static const char *add_typedef() { return "typedef long16 Type;\n"; }
878 using scalar_type = cl_long;
879 typedef std::true_type is_vector_type;
880 static const bool type_supported(cl_device_id device)
881 {
882 return int64_ok(device);
883 }
884 };
885 // cl_ulong
886 template <> struct TypeManager<cl_ulong> : public CommonTypeManager<cl_ulong>
887 {
888 static const char *name() { return "ulong"; }
889 static const char *add_typedef() { return "typedef ulong Type;\n"; }
890 static const bool type_supported(cl_device_id device)
891 {
892 return int64_ok(device);
893 }
894 };
895 template <> struct TypeManager<cl_ulong2> : public CommonTypeManager<cl_ulong2>
896 {
897 static const char *name() { return "ulong2"; }
898 static const char *add_typedef() { return "typedef ulong2 Type;\n"; }
899 using scalar_type = cl_ulong;
900 typedef std::true_type is_vector_type;
901 static const bool type_supported(cl_device_id device)
902 {
903 return int64_ok(device);
904 }
905 };
906 template <>
907 struct TypeManager<subgroups::cl_ulong3>
908 : public CommonTypeManager<subgroups::cl_ulong3>
909 {
910 static const char *name() { return "ulong3"; }
911 static const char *add_typedef() { return "typedef ulong3 Type;\n"; }
912 typedef std::true_type is_sb_vector_size3;
913 using scalar_type = cl_ulong;
914 static const bool type_supported(cl_device_id device)
915 {
916 return int64_ok(device);
917 }
918 };
919 template <> struct TypeManager<cl_ulong4> : public CommonTypeManager<cl_ulong4>
920 {
921 static const char *name() { return "ulong4"; }
922 static const char *add_typedef() { return "typedef ulong4 Type;\n"; }
923 using scalar_type = cl_ulong;
924 typedef std::true_type is_vector_type;
925 static const bool type_supported(cl_device_id device)
926 {
927 return int64_ok(device);
928 }
929 };
930 template <> struct TypeManager<cl_ulong8> : public CommonTypeManager<cl_ulong8>
931 {
932 static const char *name() { return "ulong8"; }
933 static const char *add_typedef() { return "typedef ulong8 Type;\n"; }
934 using scalar_type = cl_ulong;
935 typedef std::true_type is_vector_type;
936 static const bool type_supported(cl_device_id device)
937 {
938 return int64_ok(device);
939 }
940 };
941 template <>
942 struct TypeManager<cl_ulong16> : public CommonTypeManager<cl_ulong16>
943 {
944 static const char *name() { return "ulong16"; }
945 static const char *add_typedef() { return "typedef ulong16 Type;\n"; }
946 using scalar_type = cl_ulong;
947 typedef std::true_type is_vector_type;
948 static const bool type_supported(cl_device_id device)
949 {
950 return int64_ok(device);
951 }
952 };
953
954 // cl_float
955 template <> struct TypeManager<cl_float> : public CommonTypeManager<cl_float>
956 {
957 static const char *name() { return "float"; }
958 static const char *add_typedef() { return "typedef float Type;\n"; }
959 static cl_float identify_limits(ArithmeticOp operation)
960 {
961 switch (operation)
962 {
963 case ArithmeticOp::add_: return 0.0f;
964 case ArithmeticOp::max_:
965 return -std::numeric_limits<float>::infinity();
966 case ArithmeticOp::min_:
967 return std::numeric_limits<float>::infinity();
968 case ArithmeticOp::mul_: return (cl_float)1;
969 default: log_error("Unknown operation request\n"); break;
970 }
971 return 0;
972 }
973 };
974 template <> struct TypeManager<cl_float2> : public CommonTypeManager<cl_float2>
975 {
976 static const char *name() { return "float2"; }
977 static const char *add_typedef() { return "typedef float2 Type;\n"; }
978 using scalar_type = cl_float;
979 typedef std::true_type is_vector_type;
980 };
981 template <>
982 struct TypeManager<subgroups::cl_float3>
983 : public CommonTypeManager<subgroups::cl_float3>
984 {
985 static const char *name() { return "float3"; }
986 static const char *add_typedef() { return "typedef float3 Type;\n"; }
987 typedef std::true_type is_sb_vector_size3;
988 using scalar_type = cl_float;
989 };
990 template <> struct TypeManager<cl_float4> : public CommonTypeManager<cl_float4>
991 {
992 static const char *name() { return "float4"; }
993 static const char *add_typedef() { return "typedef float4 Type;\n"; }
994 using scalar_type = cl_float;
995 typedef std::true_type is_vector_type;
996 };
997 template <> struct TypeManager<cl_float8> : public CommonTypeManager<cl_float8>
998 {
999 static const char *name() { return "float8"; }
1000 static const char *add_typedef() { return "typedef float8 Type;\n"; }
1001 using scalar_type = cl_float;
1002 typedef std::true_type is_vector_type;
1003 };
1004 template <>
1005 struct TypeManager<cl_float16> : public CommonTypeManager<cl_float16>
1006 {
1007 static const char *name() { return "float16"; }
1008 static const char *add_typedef() { return "typedef float16 Type;\n"; }
1009 using scalar_type = cl_float;
1010 typedef std::true_type is_vector_type;
1011 };
1012
1013 // cl_double
1014 template <> struct TypeManager<cl_double> : public CommonTypeManager<cl_double>
1015 {
1016 static const char *name() { return "double"; }
1017 static const char *add_typedef() { return "typedef double Type;\n"; }
1018 static cl_double identify_limits(ArithmeticOp operation)
1019 {
1020 switch (operation)
1021 {
1022 case ArithmeticOp::add_: return 0.0;
1023 case ArithmeticOp::max_:
1024 return -std::numeric_limits<double>::infinity();
1025 case ArithmeticOp::min_:
1026 return std::numeric_limits<double>::infinity();
1027 case ArithmeticOp::mul_: return (cl_double)1;
1028 default: log_error("Unknown operation request\n"); break;
1029 }
1030 return 0;
1031 }
1032 static const bool type_supported(cl_device_id device)
1033 {
1034 return double_ok(device);
1035 }
1036 };
1037 template <>
1038 struct TypeManager<cl_double2> : public CommonTypeManager<cl_double2>
1039 {
1040 static const char *name() { return "double2"; }
1041 static const char *add_typedef() { return "typedef double2 Type;\n"; }
1042 using scalar_type = cl_double;
1043 typedef std::true_type is_vector_type;
1044 static const bool type_supported(cl_device_id device)
1045 {
1046 return double_ok(device);
1047 }
1048 };
1049 template <>
1050 struct TypeManager<subgroups::cl_double3>
1051 : public CommonTypeManager<subgroups::cl_double3>
1052 {
1053 static const char *name() { return "double3"; }
1054 static const char *add_typedef() { return "typedef double3 Type;\n"; }
1055 typedef std::true_type is_sb_vector_size3;
1056 using scalar_type = cl_double;
1057 static const bool type_supported(cl_device_id device)
1058 {
1059 return double_ok(device);
1060 }
1061 };
1062 template <>
1063 struct TypeManager<cl_double4> : public CommonTypeManager<cl_double4>
1064 {
1065 static const char *name() { return "double4"; }
1066 static const char *add_typedef() { return "typedef double4 Type;\n"; }
1067 using scalar_type = cl_double;
1068 typedef std::true_type is_vector_type;
1069 static const bool type_supported(cl_device_id device)
1070 {
1071 return double_ok(device);
1072 }
1073 };
1074 template <>
1075 struct TypeManager<cl_double8> : public CommonTypeManager<cl_double8>
1076 {
1077 static const char *name() { return "double8"; }
1078 static const char *add_typedef() { return "typedef double8 Type;\n"; }
1079 using scalar_type = cl_double;
1080 typedef std::true_type is_vector_type;
1081 static const bool type_supported(cl_device_id device)
1082 {
1083 return double_ok(device);
1084 }
1085 };
1086 template <>
1087 struct TypeManager<cl_double16> : public CommonTypeManager<cl_double16>
1088 {
1089 static const char *name() { return "double16"; }
1090 static const char *add_typedef() { return "typedef double16 Type;\n"; }
1091 using scalar_type = cl_double;
1092 typedef std::true_type is_vector_type;
1093 static const bool type_supported(cl_device_id device)
1094 {
1095 return double_ok(device);
1096 }
1097 };
1098
1099 // cl_half
1100 template <>
1101 struct TypeManager<subgroups::cl_half>
1102 : public CommonTypeManager<subgroups::cl_half>
1103 {
1104 static const char *name() { return "half"; }
1105 static const char *add_typedef() { return "typedef half Type;\n"; }
1106 typedef std::true_type is_sb_scalar_type;
1107 static subgroups::cl_half identify_limits(ArithmeticOp operation)
1108 {
1109 switch (operation)
1110 {
1111 case ArithmeticOp::add_: return { 0x0000 };
1112 case ArithmeticOp::max_: return { 0xfc00 };
1113 case ArithmeticOp::min_: return { 0x7c00 };
1114 case ArithmeticOp::mul_: return { 0x3c00 };
1115 default: log_error("Unknown operation request\n"); break;
1116 }
1117 return { 0 };
1118 }
1119 static const bool type_supported(cl_device_id device)
1120 {
1121 return half_ok(device);
1122 }
1123 };
1124 template <>
1125 struct TypeManager<subgroups::cl_half2>
1126 : public CommonTypeManager<subgroups::cl_half2>
1127 {
1128 static const char *name() { return "half2"; }
1129 static const char *add_typedef() { return "typedef half2 Type;\n"; }
1130 using scalar_type = subgroups::cl_half;
1131 typedef std::true_type is_sb_vector_type;
1132 static const bool type_supported(cl_device_id device)
1133 {
1134 return half_ok(device);
1135 }
1136 };
1137 template <>
1138 struct TypeManager<subgroups::cl_half3>
1139 : public CommonTypeManager<subgroups::cl_half3>
1140 {
1141 static const char *name() { return "half3"; }
1142 static const char *add_typedef() { return "typedef half3 Type;\n"; }
1143 typedef std::true_type is_sb_vector_size3;
1144 using scalar_type = subgroups::cl_half;
1145
1146 static const bool type_supported(cl_device_id device)
1147 {
1148 return half_ok(device);
1149 }
1150 };
1151 template <>
1152 struct TypeManager<subgroups::cl_half4>
1153 : public CommonTypeManager<subgroups::cl_half4>
1154 {
1155 static const char *name() { return "half4"; }
1156 static const char *add_typedef() { return "typedef half4 Type;\n"; }
1157 using scalar_type = subgroups::cl_half;
1158 typedef std::true_type is_sb_vector_type;
1159 static const bool type_supported(cl_device_id device)
1160 {
1161 return half_ok(device);
1162 }
1163 };
1164 template <>
1165 struct TypeManager<subgroups::cl_half8>
1166 : public CommonTypeManager<subgroups::cl_half8>
1167 {
1168 static const char *name() { return "half8"; }
1169 static const char *add_typedef() { return "typedef half8 Type;\n"; }
1170 using scalar_type = subgroups::cl_half;
1171 typedef std::true_type is_sb_vector_type;
1172
1173 static const bool type_supported(cl_device_id device)
1174 {
1175 return half_ok(device);
1176 }
1177 };
1178 template <>
1179 struct TypeManager<subgroups::cl_half16>
1180 : public CommonTypeManager<subgroups::cl_half16>
1181 {
1182 static const char *name() { return "half16"; }
1183 static const char *add_typedef() { return "typedef half16 Type;\n"; }
1184 using scalar_type = subgroups::cl_half;
1185 typedef std::true_type is_sb_vector_type;
1186 static const bool type_supported(cl_device_id device)
1187 {
1188 return half_ok(device);
1189 }
1190 };
1191
1192 // set scalar value to vector of halfs
1193 template <typename Ty, int N = 0>
1194 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value>::type
1195 set_value(Ty &lhs, const cl_ulong &rhs)
1196 {
1197 const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1198 for (auto i = 0; i < size; ++i)
1199 {
1200 lhs.data.s[i] = rhs;
1201 }
1202 }
1203
1204
1205 // set scalar value to vector
1206 template <typename Ty>
1207 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1208 set_value(Ty &lhs, const cl_ulong &rhs)
1209 {
1210 const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1211 for (auto i = 0; i < size; ++i)
1212 {
1213 lhs.s[i] = rhs;
1214 }
1215 }
1216
1217 // set vector to vector value
1218 template <typename Ty>
1219 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1220 set_value(Ty &lhs, const Ty &rhs)
1221 {
1222 lhs = rhs;
1223 }
1224
1225 // set scalar value to vector size 3
1226 template <typename Ty, int N = 0>
1227 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value>::type
1228 set_value(Ty &lhs, const cl_ulong &rhs)
1229 {
1230 for (auto i = 0; i < 3; ++i)
1231 {
1232 lhs.data.s[i] = rhs;
1233 }
1234 }
1235
1236 // set scalar value to scalar
1237 template <typename Ty>
1238 typename std::enable_if<std::is_scalar<Ty>::value>::type
1239 set_value(Ty &lhs, const cl_ulong &rhs)
1240 {
1241 lhs = static_cast<Ty>(rhs);
1242 }
1243
1244 // set scalar value to half scalar
1245 template <typename Ty>
1246 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value>::type
1247 set_value(Ty &lhs, const cl_ulong &rhs)
1248 {
1249 lhs.data = cl_half_from_float(static_cast<cl_float>(rhs), g_rounding_mode);
1250 }
1251
1252 // compare for common vectors
1253 template <typename Ty>
1254 typename std::enable_if<TypeManager<Ty>::is_vector_type::value, bool>::type
1255 compare(const Ty &lhs, const Ty &rhs)
1256 {
1257 const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1258 for (auto i = 0; i < size; ++i)
1259 {
1260 if (lhs.s[i] != rhs.s[i])
1261 {
1262 return false;
1263 }
1264 }
1265 return true;
1266 }
1267
1268 // compare for vectors 3
1269 template <typename Ty>
1270 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value, bool>::type
1271 compare(const Ty &lhs, const Ty &rhs)
1272 {
1273 for (auto i = 0; i < 3; ++i)
1274 {
1275 if (lhs.data.s[i] != rhs.data.s[i])
1276 {
1277 return false;
1278 }
1279 }
1280 return true;
1281 }
1282
1283 // compare for half vectors
1284 template <typename Ty>
1285 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value, bool>::type
1286 compare(const Ty &lhs, const Ty &rhs)
1287 {
1288 const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1289 for (auto i = 0; i < size; ++i)
1290 {
1291 if (lhs.data.s[i] != rhs.data.s[i])
1292 {
1293 return false;
1294 }
1295 }
1296 return true;
1297 }
1298
1299 // compare for scalars
1300 template <typename Ty>
1301 typename std::enable_if<std::is_scalar<Ty>::value, bool>::type
1302 compare(const Ty &lhs, const Ty &rhs)
1303 {
1304 return lhs == rhs;
1305 }
1306
1307 // compare for scalar halfs
1308 template <typename Ty>
1309 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value, bool>::type
1310 compare(const Ty &lhs, const Ty &rhs)
1311 {
1312 return lhs.data == rhs.data;
1313 }
1314
1315 template <typename Ty> inline bool compare_ordered(const Ty &lhs, const Ty &rhs)
1316 {
1317 return lhs == rhs;
1318 }
1319
1320 template <>
1321 inline bool compare_ordered(const subgroups::cl_half &lhs,
1322 const subgroups::cl_half &rhs)
1323 {
1324 return cl_half_to_float(lhs.data) == cl_half_to_float(rhs.data);
1325 }
1326
1327 template <typename Ty>
1328 inline bool compare_ordered(const subgroups::cl_half &lhs, const int &rhs)
1329 {
1330 return cl_half_to_float(lhs.data) == rhs;
1331 }
1332
1333 template <typename Ty, typename Fns> class KernelExecutor {
1334 public:
1335 KernelExecutor(cl_context c, cl_command_queue q, cl_kernel k, size_t g,
1336 size_t l, Ty *id, size_t is, Ty *mid, Ty *mod, cl_int *md,
1337 size_t ms, Ty *od, size_t os, size_t ts = 0)
1338 : context(c), queue(q), kernel(k), global(g), local(l), idata(id),
1339 isize(is), mapin_data(mid), mapout_data(mod), mdata(md), msize(ms),
1340 odata(od), osize(os), tsize(ts)
1341 {
1342 has_status = false;
1343 run_failed = false;
1344 }
1345 cl_context context;
1346 cl_command_queue queue;
1347 cl_kernel kernel;
1348 size_t global;
1349 size_t local;
1350 Ty *idata;
1351 size_t isize;
1352 Ty *mapin_data;
1353 Ty *mapout_data;
1354 cl_int *mdata;
1355 size_t msize;
1356 Ty *odata;
1357 size_t osize;
1358 size_t tsize;
1359 bool run_failed;
1360
1361 private:
1362 bool has_status;
1363 test_status status;
1364
1365 public:
1366 // Run a test kernel to compute the result of a built-in on an input
1367 int run()
1368 {
1369 clMemWrapper in;
1370 clMemWrapper xy;
1371 clMemWrapper out;
1372 clMemWrapper tmp;
1373 int error;
1374
1375 in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
1376 test_error(error, "clCreateBuffer failed");
1377
1378 xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
1379 test_error(error, "clCreateBuffer failed");
1380
1381 out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
1382 test_error(error, "clCreateBuffer failed");
1383
1384 if (tsize)
1385 {
1386 tmp = clCreateBuffer(context,
1387 CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
1388 tsize, NULL, &error);
1389 test_error(error, "clCreateBuffer failed");
1390 }
1391
1392 error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
1393 test_error(error, "clSetKernelArg failed");
1394
1395 error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
1396 test_error(error, "clSetKernelArg failed");
1397
1398 error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
1399 test_error(error, "clSetKernelArg failed");
1400
1401 if (tsize)
1402 {
1403 error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
1404 test_error(error, "clSetKernelArg failed");
1405 }
1406
1407 error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0,
1408 NULL, NULL);
1409 test_error(error, "clEnqueueWriteBuffer failed");
1410
1411 error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1412 NULL, NULL);
1413 test_error(error, "clEnqueueWriteBuffer failed");
1414 error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local,
1415 0, NULL, NULL);
1416 test_error(error, "clEnqueueNDRangeKernel failed");
1417
1418 error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1419 NULL, NULL);
1420 test_error(error, "clEnqueueReadBuffer failed");
1421
1422 error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0,
1423 NULL, NULL);
1424 test_error(error, "clEnqueueReadBuffer failed");
1425
1426 error = clFinish(queue);
1427 test_error(error, "clFinish failed");
1428
1429 return error;
1430 }
1431
1432 private:
1433 test_status
1434 run_and_check_with_cluster_size(const WorkGroupParams &test_params)
1435 {
1436 cl_int error = run();
1437 if (error != CL_SUCCESS)
1438 {
1439 print_error(error, "Failed to run subgroup test kernel");
1440 status = TEST_FAIL;
1441 run_failed = true;
1442 return status;
1443 }
1444
1445 test_status tmp_status =
1446 Fns::chk(idata, odata, mapin_data, mapout_data, mdata, test_params);
1447
1448 if (!has_status || tmp_status == TEST_FAIL
1449 || (tmp_status == TEST_PASS && status != TEST_FAIL))
1450 {
1451 status = tmp_status;
1452 has_status = true;
1453 }
1454
1455 return status;
1456 }
1457
1458 public:
1459 test_status run_and_check(WorkGroupParams &test_params)
1460 {
1461 test_status tmp_status = TEST_SKIPPED_ITSELF;
1462
1463 if (test_params.cluster_size_arg != -1)
1464 {
1465 for (cl_uint cluster_size = 1;
1466 cluster_size <= test_params.subgroup_size; cluster_size *= 2)
1467 {
1468 test_params.cluster_size = cluster_size;
1469 cl_int error =
1470 clSetKernelArg(kernel, test_params.cluster_size_arg,
1471 sizeof(cl_uint), &cluster_size);
1472 test_error_fail(error, "Unable to set cluster size");
1473
1474 tmp_status = run_and_check_with_cluster_size(test_params);
1475
1476 if (tmp_status == TEST_FAIL) break;
1477 }
1478 }
1479 else
1480 {
1481 tmp_status = run_and_check_with_cluster_size(test_params);
1482 }
1483
1484 return tmp_status;
1485 }
1486 };
1487
1488 // Driver for testing a single built in function
1489 template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
1490 {
1491 static test_status run(cl_device_id device, cl_context context,
1492 cl_command_queue queue, int num_elements,
1493 const char *kname, const char *src,
1494 WorkGroupParams test_params)
1495 {
1496 size_t tmp;
1497 cl_int error;
1498 size_t subgroup_size, num_subgroups;
1499 size_t global = test_params.global_workgroup_size;
1500 size_t local = test_params.local_workgroup_size;
1501 clProgramWrapper program;
1502 clKernelWrapper kernel;
1503 cl_platform_id platform;
1504 std::vector<cl_int> sgmap;
1505 sgmap.resize(4 * global);
1506 std::vector<Ty> mapin;
1507 mapin.resize(local);
1508 std::vector<Ty> mapout;
1509 mapout.resize(local);
1510 std::stringstream kernel_sstr;
1511
1512 Fns::log_test(test_params, "");
1513
1514 kernel_sstr << "#define NR_OF_ACTIVE_WORK_ITEMS ";
1515 kernel_sstr << NR_OF_ACTIVE_WORK_ITEMS << "\n";
1516 // Make sure a test of type Ty is supported by the device
1517 if (!TypeManager<Ty>::type_supported(device))
1518 {
1519 log_info("Data type not supported : %s\n", TypeManager<Ty>::name());
1520 return TEST_SKIPPED_ITSELF;
1521 }
1522
1523 if (strstr(TypeManager<Ty>::name(), "double"))
1524 {
1525 kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n";
1526 }
1527 else if (strstr(TypeManager<Ty>::name(), "half"))
1528 {
1529 kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp16: enable\n";
1530 }
1531
1532 error = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(platform),
1533 (void *)&platform, NULL);
1534 test_error_fail(error, "clGetDeviceInfo failed for CL_DEVICE_PLATFORM");
1535 if (test_params.use_core_subgroups)
1536 {
1537 kernel_sstr
1538 << "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
1539 }
1540 kernel_sstr << "#define XY(M,I) M[I].x = get_sub_group_local_id(); "
1541 "M[I].y = get_sub_group_id();\n";
1542 kernel_sstr << TypeManager<Ty>::add_typedef();
1543 kernel_sstr << src;
1544 const std::string &kernel_str = kernel_sstr.str();
1545 const char *kernel_src = kernel_str.c_str();
1546
1547 error = create_single_kernel_helper(context, &program, &kernel, 1,
1548 &kernel_src, kname);
1549 if (error != CL_SUCCESS) return TEST_FAIL;
1550
1551 // Determine some local dimensions to use for the test.
1552 error = get_max_common_work_group_size(
1553 context, kernel, test_params.global_workgroup_size, &local);
1554 test_error_fail(error, "get_max_common_work_group_size failed");
1555
1556 // Limit it a bit so we have muliple work groups
1557 // Ideally this will still be large enough to give us multiple
1558 if (local > test_params.local_workgroup_size)
1559 local = test_params.local_workgroup_size;
1560
1561
1562 // Get the sub group info
1563 subgroupsAPI subgroupsApiSet(platform, test_params.use_core_subgroups);
1564 clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr =
1565 subgroupsApiSet.clGetKernelSubGroupInfo_ptr();
1566 if (clGetKernelSubGroupInfo_ptr == NULL)
1567 {
1568 log_error("ERROR: %s function not available\n",
1569 subgroupsApiSet.clGetKernelSubGroupInfo_name);
1570 return TEST_FAIL;
1571 }
1572 error = clGetKernelSubGroupInfo_ptr(
1573 kernel, device, CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
1574 sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1575 if (error != CL_SUCCESS)
1576 {
1577 log_error("ERROR: %s function error for "
1578 "CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE\n",
1579 subgroupsApiSet.clGetKernelSubGroupInfo_name);
1580 return TEST_FAIL;
1581 }
1582
1583 subgroup_size = tmp;
1584
1585 error = clGetKernelSubGroupInfo_ptr(
1586 kernel, device, CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE,
1587 sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1588 if (error != CL_SUCCESS)
1589 {
1590 log_error("ERROR: %s function error for "
1591 "CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE\n",
1592 subgroupsApiSet.clGetKernelSubGroupInfo_name);
1593 return TEST_FAIL;
1594 }
1595
1596 num_subgroups = tmp;
1597 // Make sure the number of sub groups is what we expect
1598 if (num_subgroups != (local + subgroup_size - 1) / subgroup_size)
1599 {
1600 log_error("ERROR: unexpected number of subgroups (%zu) returned\n",
1601 num_subgroups);
1602 return TEST_FAIL;
1603 }
1604
1605 std::vector<Ty> idata;
1606 std::vector<Ty> odata;
1607 size_t input_array_size = global;
1608 size_t output_array_size = global;
1609 size_t dynscl = test_params.dynsc;
1610
1611 if (dynscl != 0)
1612 {
1613 input_array_size = global / local * num_subgroups * dynscl;
1614 output_array_size = global / local * dynscl;
1615 }
1616
1617 idata.resize(input_array_size);
1618 odata.resize(output_array_size);
1619
1620 if (test_params.divergence_mask_arg != -1)
1621 {
1622 cl_uint4 mask_vector;
1623 mask_vector.x = 0xffffffffU;
1624 mask_vector.y = 0xffffffffU;
1625 mask_vector.z = 0xffffffffU;
1626 mask_vector.w = 0xffffffffU;
1627 error = clSetKernelArg(kernel, test_params.divergence_mask_arg,
1628 sizeof(cl_uint4), &mask_vector);
1629 test_error_fail(error, "Unable to set divergence mask argument");
1630 }
1631
1632 if (test_params.cluster_size_arg != -1)
1633 {
1634 cl_uint dummy_cluster_size = 1;
1635 error = clSetKernelArg(kernel, test_params.cluster_size_arg,
1636 sizeof(cl_uint), &dummy_cluster_size);
1637 test_error_fail(error, "Unable to set dummy cluster size");
1638 }
1639
1640 KernelExecutor<Ty, Fns> executor(
1641 context, queue, kernel, global, local, idata.data(),
1642 input_array_size * sizeof(Ty), mapin.data(), mapout.data(),
1643 sgmap.data(), global * sizeof(cl_int4), odata.data(),
1644 output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
1645
1646 // Run the kernel once on zeroes to get the map
1647 memset(idata.data(), 0, input_array_size * sizeof(Ty));
1648 error = executor.run();
1649 test_error_fail(error, "Running kernel first time failed");
1650
1651 // Generate the desired input for the kernel
1652 test_params.subgroup_size = subgroup_size;
1653 Fns::gen(idata.data(), mapin.data(), sgmap.data(), test_params);
1654
1655 test_status status;
1656
1657 if (test_params.divergence_mask_arg != -1)
1658 {
1659 for (auto &mask : test_params.all_work_item_masks)
1660 {
1661 test_params.work_items_mask = mask;
1662 cl_uint4 mask_vector = bs128_to_cl_uint4(mask);
1663 clSetKernelArg(kernel, test_params.divergence_mask_arg,
1664 sizeof(cl_uint4), &mask_vector);
1665
1666 status = executor.run_and_check(test_params);
1667
1668 if (status == TEST_FAIL) break;
1669 }
1670 }
1671 else
1672 {
1673 status = executor.run_and_check(test_params);
1674 }
1675 // Detailed failure and skip messages should be logged by
1676 // run_and_check.
1677 if (status == TEST_PASS)
1678 {
1679 Fns::log_test(test_params, " passed");
1680 }
1681 else if (!executor.run_failed && status == TEST_FAIL)
1682 {
1683 test_fail("Data verification failed\n");
1684 }
1685 return status;
1686 }
1687 };
1688
1689 static void set_last_workgroup_params(int non_uniform_size,
1690 int &number_of_subgroups,
1691 int subgroup_size, int &workgroup_size,
1692 int &last_subgroup_size)
1693 {
1694 number_of_subgroups = 1 + non_uniform_size / subgroup_size;
1695 last_subgroup_size = non_uniform_size % subgroup_size;
1696 workgroup_size = non_uniform_size;
1697 }
1698
1699 template <typename Ty>
1700 static void set_randomdata_for_subgroup(Ty *workgroup, int wg_offset,
1701 int current_sbs)
1702 {
1703 int randomize_data = (int)(genrand_int32(gMTdata) % 3);
1704 // Initialize data matrix indexed by local id and sub group id
1705 switch (randomize_data)
1706 {
1707 case 0:
1708 memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1709 break;
1710 case 1: {
1711 memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1712 int wi_id = (int)(genrand_int32(gMTdata) % (cl_uint)current_sbs);
1713 set_value(workgroup[wg_offset + wi_id], 41);
1714 }
1715 break;
1716 case 2:
1717 memset(&workgroup[wg_offset], 0xff, current_sbs * sizeof(Ty));
1718 break;
1719 }
1720 }
1721
1722 struct RunTestForType
1723 {
1724 RunTestForType(cl_device_id device, cl_context context,
1725 cl_command_queue queue, int num_elements,
1726 WorkGroupParams test_params)
1727 : device_(device), context_(context), queue_(queue),
1728 num_elements_(num_elements), test_params_(test_params)
1729 {}
1730 template <typename T, typename U>
1731 int run_impl(const std::string &function_name)
1732 {
1733 int error = TEST_PASS;
1734 std::string source =
1735 std::regex_replace(test_params_.get_kernel_source(function_name),
1736 std::regex("\\%s"), function_name);
1737 std::string kernel_name = "test_" + function_name;
1738 error =
1739 test<T, U>::run(device_, context_, queue_, num_elements_,
1740 kernel_name.c_str(), source.c_str(), test_params_);
1741
1742 // If we return TEST_SKIPPED_ITSELF here, then an entire suite may be
1743 // reported as having been skipped even if some tests within it
1744 // passed, as the status codes are erroneously ORed together:
1745 return error == TEST_FAIL ? TEST_FAIL : TEST_PASS;
1746 }
1747
1748 private:
1749 cl_device_id device_;
1750 cl_context context_;
1751 cl_command_queue queue_;
1752 int num_elements_;
1753 WorkGroupParams test_params_;
1754 };
1755
1756 #endif
1757