• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2022 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "common.h"
18 
19 #include "utility.h" // for sizeNames and sizeValues.
20 
21 #include <sstream>
22 #include <string>
23 
24 namespace {
25 
GetTypeName(ParameterType type)26 const char *GetTypeName(ParameterType type)
27 {
28     switch (type)
29     {
30         case ParameterType::Float: return "float";
31         case ParameterType::Double: return "double";
32     }
33     return nullptr;
34 }
35 
GetUndefValue(ParameterType type)36 const char *GetUndefValue(ParameterType type)
37 {
38     switch (type)
39     {
40         case ParameterType::Float:
41         case ParameterType::Double: return "NAN";
42     }
43     return nullptr;
44 }
45 
EmitDefineType(std::ostringstream & kernel,const char * name,ParameterType type,int vector_size_index)46 void EmitDefineType(std::ostringstream &kernel, const char *name,
47                     ParameterType type, int vector_size_index)
48 {
49     kernel << "#define " << name << " " << GetTypeName(type)
50            << sizeNames[vector_size_index] << '\n';
51     kernel << "#define " << name << "_SCALAR " << GetTypeName(type) << '\n';
52 }
53 
EmitDefineUndef(std::ostringstream & kernel,const char * name,ParameterType type)54 void EmitDefineUndef(std::ostringstream &kernel, const char *name,
55                      ParameterType type)
56 {
57     kernel << "#define " << name << " " << GetUndefValue(type) << '\n';
58 }
59 
EmitEnableExtension(std::ostringstream & kernel,ParameterType type)60 void EmitEnableExtension(std::ostringstream &kernel, ParameterType type)
61 {
62     switch (type)
63     {
64         case ParameterType::Double:
65             kernel << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
66             break;
67 
68         case ParameterType::Float:
69             // No extension required.
70             break;
71     }
72 }
73 
74 } // anonymous namespace
75 
GetKernelName(int vector_size_index)76 std::string GetKernelName(int vector_size_index)
77 {
78     return std::string("math_kernel") + sizeNames[vector_size_index];
79 }
80 
GetTernaryKernel(const std::string & kernel_name,const char * builtin,ParameterType retType,ParameterType type1,ParameterType type2,ParameterType type3,int vector_size_index)81 std::string GetTernaryKernel(const std::string &kernel_name,
82                              const char *builtin, ParameterType retType,
83                              ParameterType type1, ParameterType type2,
84                              ParameterType type3, int vector_size_index)
85 {
86     // To keep the kernel code readable, use macros for types and undef values.
87     std::ostringstream kernel;
88     EmitDefineType(kernel, "RETTYPE", retType, vector_size_index);
89     EmitDefineType(kernel, "TYPE1", type1, vector_size_index);
90     EmitDefineType(kernel, "TYPE2", type2, vector_size_index);
91     EmitDefineType(kernel, "TYPE3", type3, vector_size_index);
92     EmitDefineUndef(kernel, "UNDEF1", type1);
93     EmitDefineUndef(kernel, "UNDEF2", type2);
94     EmitDefineUndef(kernel, "UNDEF3", type3);
95     EmitEnableExtension(kernel, type1);
96 
97     // clang-format off
98     const char *kernel_nonvec3[] = { R"(
99 __kernel void )", kernel_name.c_str(), R"((__global RETTYPE* out,
100                           __global TYPE1* in1,
101                           __global TYPE2* in2,
102                           __global TYPE3* in3)
103 {
104     size_t i = get_global_id(0);
105     out[i] = )", builtin, R"((in1[i], in2[i], in3[i]);
106 }
107 )" };
108 
109     const char *kernel_vec3[] = { R"(
110 __kernel void )", kernel_name.c_str(), R"((__global RETTYPE_SCALAR* out,
111                           __global TYPE1_SCALAR* in1,
112                           __global TYPE2_SCALAR* in2,
113                           __global TYPE3_SCALAR* in3)
114 {
115     size_t i = get_global_id(0);
116 
117     if (i + 1 < get_global_size(0))
118     {
119         TYPE1 a = vload3(0, in1 + 3 * i);
120         TYPE2 b = vload3(0, in2 + 3 * i);
121         TYPE3 c = vload3(0, in3 + 3 * i);
122         RETTYPE res = )", builtin, R"((a, b, c);
123         vstore3(res, 0, out + 3 * i);
124     }
125     else
126     {
127         // Figure out how many elements are left over after
128         // BUFFER_SIZE % (3 * sizeof(type)).
129         // Assume power of two buffer size.
130         size_t parity = i & 1;
131         TYPE1 a = (TYPE1)(UNDEF1, UNDEF1, UNDEF1);
132         TYPE2 b = (TYPE2)(UNDEF2, UNDEF2, UNDEF2);
133         TYPE3 c = (TYPE3)(UNDEF3, UNDEF3, UNDEF3);
134         switch (parity)
135         {
136             case 0:
137                 a.y = in1[3 * i + 1];
138                 b.y = in2[3 * i + 1];
139                 c.y = in3[3 * i + 1];
140                 // fall through
141             case 1:
142                 a.x = in1[3 * i];
143                 b.x = in2[3 * i];
144                 c.x = in3[3 * i];
145                 break;
146         }
147 
148         RETTYPE res = )", builtin, R"((a, b, c);
149 
150         switch (parity)
151         {
152             case 0:
153                 out[3 * i + 1] = res.y;
154                 // fall through
155             case 1:
156                 out[3 * i] = res.x;
157                 break;
158         }
159     }
160 }
161 )" };
162     // clang-format on
163 
164     if (sizeValues[vector_size_index] != 3)
165         for (const auto &chunk : kernel_nonvec3) kernel << chunk;
166     else
167         for (const auto &chunk : kernel_vec3) kernel << chunk;
168 
169     return kernel.str();
170 }
171