1 //
2 // Copyright (c) 2022 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "common.h"
18
19 #include "utility.h" // for sizeNames and sizeValues.
20
21 #include <sstream>
22 #include <string>
23
24 namespace {
25
GetTypeName(ParameterType type)26 const char *GetTypeName(ParameterType type)
27 {
28 switch (type)
29 {
30 case ParameterType::Float: return "float";
31 case ParameterType::Double: return "double";
32 }
33 return nullptr;
34 }
35
GetUndefValue(ParameterType type)36 const char *GetUndefValue(ParameterType type)
37 {
38 switch (type)
39 {
40 case ParameterType::Float:
41 case ParameterType::Double: return "NAN";
42 }
43 return nullptr;
44 }
45
EmitDefineType(std::ostringstream & kernel,const char * name,ParameterType type,int vector_size_index)46 void EmitDefineType(std::ostringstream &kernel, const char *name,
47 ParameterType type, int vector_size_index)
48 {
49 kernel << "#define " << name << " " << GetTypeName(type)
50 << sizeNames[vector_size_index] << '\n';
51 kernel << "#define " << name << "_SCALAR " << GetTypeName(type) << '\n';
52 }
53
EmitDefineUndef(std::ostringstream & kernel,const char * name,ParameterType type)54 void EmitDefineUndef(std::ostringstream &kernel, const char *name,
55 ParameterType type)
56 {
57 kernel << "#define " << name << " " << GetUndefValue(type) << '\n';
58 }
59
EmitEnableExtension(std::ostringstream & kernel,ParameterType type)60 void EmitEnableExtension(std::ostringstream &kernel, ParameterType type)
61 {
62 switch (type)
63 {
64 case ParameterType::Double:
65 kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
66 break;
67
68 case ParameterType::Float:
69 // No extension required.
70 break;
71 }
72 }
73
74 } // anonymous namespace
75
GetKernelName(int vector_size_index)76 std::string GetKernelName(int vector_size_index)
77 {
78 return std::string("math_kernel") + sizeNames[vector_size_index];
79 }
80
GetTernaryKernel(const std::string & kernel_name,const char * builtin,ParameterType retType,ParameterType type1,ParameterType type2,ParameterType type3,int vector_size_index)81 std::string GetTernaryKernel(const std::string &kernel_name,
82 const char *builtin, ParameterType retType,
83 ParameterType type1, ParameterType type2,
84 ParameterType type3, int vector_size_index)
85 {
86 // To keep the kernel code readable, use macros for types and undef values.
87 std::ostringstream kernel;
88 EmitDefineType(kernel, "RETTYPE", retType, vector_size_index);
89 EmitDefineType(kernel, "TYPE1", type1, vector_size_index);
90 EmitDefineType(kernel, "TYPE2", type2, vector_size_index);
91 EmitDefineType(kernel, "TYPE3", type3, vector_size_index);
92 EmitDefineUndef(kernel, "UNDEF1", type1);
93 EmitDefineUndef(kernel, "UNDEF2", type2);
94 EmitDefineUndef(kernel, "UNDEF3", type3);
95 EmitEnableExtension(kernel, type1);
96
97 // clang-format off
98 const char *kernel_nonvec3[] = { R"(
99 __kernel void )", kernel_name.c_str(), R"((__global RETTYPE* out,
100 __global TYPE1* in1,
101 __global TYPE2* in2,
102 __global TYPE3* in3)
103 {
104 size_t i = get_global_id(0);
105 out[i] = )", builtin, R"((in1[i], in2[i], in3[i]);
106 }
107 )" };
108
109 const char *kernel_vec3[] = { R"(
110 __kernel void )", kernel_name.c_str(), R"((__global RETTYPE_SCALAR* out,
111 __global TYPE1_SCALAR* in1,
112 __global TYPE2_SCALAR* in2,
113 __global TYPE3_SCALAR* in3)
114 {
115 size_t i = get_global_id(0);
116
117 if (i + 1 < get_global_size(0))
118 {
119 TYPE1 a = vload3(0, in1 + 3 * i);
120 TYPE2 b = vload3(0, in2 + 3 * i);
121 TYPE3 c = vload3(0, in3 + 3 * i);
122 RETTYPE res = )", builtin, R"((a, b, c);
123 vstore3(res, 0, out + 3 * i);
124 }
125 else
126 {
127 // Figure out how many elements are left over after
128 // BUFFER_SIZE % (3 * sizeof(type)).
129 // Assume power of two buffer size.
130 size_t parity = i & 1;
131 TYPE1 a = (TYPE1)(UNDEF1, UNDEF1, UNDEF1);
132 TYPE2 b = (TYPE2)(UNDEF2, UNDEF2, UNDEF2);
133 TYPE3 c = (TYPE3)(UNDEF3, UNDEF3, UNDEF3);
134 switch (parity)
135 {
136 case 0:
137 a.y = in1[3 * i + 1];
138 b.y = in2[3 * i + 1];
139 c.y = in3[3 * i + 1];
140 // fall through
141 case 1:
142 a.x = in1[3 * i];
143 b.x = in2[3 * i];
144 c.x = in3[3 * i];
145 break;
146 }
147
148 RETTYPE res = )", builtin, R"((a, b, c);
149
150 switch (parity)
151 {
152 case 0:
153 out[3 * i + 1] = res.y;
154 // fall through
155 case 1:
156 out[3 * i] = res.x;
157 break;
158 }
159 }
160 }
161 )" };
162 // clang-format on
163
164 if (sizeValues[vector_size_index] != 3)
165 for (const auto &chunk : kernel_nonvec3) kernel << chunk;
166 else
167 for (const auto &chunk : kernel_vec3) kernel << chunk;
168
169 return kernel.str();
170 }
171