• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2017 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 
18 const char *kernel_call_kernel_code[] = {
19     "void test_function_to_call(__global int *output, __global int *input, int where);\n"
20     "\n"
21     "__kernel void test_kernel_to_call(__global int *output, __global int *input, int where) \n"
22     "{\n"
23     "  int b;\n"
24     "  if (where == 0) {\n"
25     "    output[get_global_id(0)] = 0;\n"
26     "  }\n"
27     "  for (b=0; b<where; b++)\n"
28     "    output[get_global_id(0)] += input[b];  \n"
29     "}\n"
30     "\n"
31     "__kernel void test_call_kernel(__global int *src, __global int *dst, int times) \n"
32     "{\n"
33     "  int tid = get_global_id(0);\n"
34     "  int a;\n"
35     "  dst[tid] = 1;\n"
36     "  for (a=0; a<times; a++)\n"
37     "    test_kernel_to_call(dst, src, tid);\n"
38     "}\n"
39     "void test_function_to_call(__global int *output, __global int *input, int where) \n"
40     "{\n"
41     "  int b;\n"
42     "  if (where == 0) {\n"
43     "    output[get_global_id(0)] = 0;\n"
44     "  }\n"
45     "  for (b=0; b<where; b++)\n"
46     "    output[get_global_id(0)] += input[b];  \n"
47     "}\n"
48     "\n"
49     "__kernel void test_call_function(__global int *src, __global int *dst, int times) \n"
50     "{\n"
51     "  int tid = get_global_id(0);\n"
52     "  int a;\n"
53     "  dst[tid] = 1;\n"
54     "  for (a=0; a<times; a++)\n"
55     "    test_function_to_call(dst, src, tid);\n"
56     "}\n"
57 };
58 
59 
60 
test_kernel_call_kernel_function(cl_device_id deviceID,cl_context context,cl_command_queue queue,int num_elements)61 int test_kernel_call_kernel_function(cl_device_id deviceID, cl_context context, cl_command_queue queue, int num_elements)
62 {
63     num_elements = 256;
64 
65     int error, errors = 0;
66     clProgramWrapper program;
67     clKernelWrapper kernel1, kernel2, kernel_to_call;
68     clMemWrapper    streams[2];
69 
70     size_t    threads[] = {num_elements,1,1};
71     cl_int *input, *output, *expected;
72     cl_int times = 4;
73     int pass = 0;
74 
75     input = (cl_int*)malloc(sizeof(cl_int)*num_elements);
76     output = (cl_int*)malloc(sizeof(cl_int)*num_elements);
77     expected = (cl_int*)malloc(sizeof(cl_int)*num_elements);
78 
79     for (int i=0; i<num_elements; i++) {
80         input[i] = i;
81         output[i] = i;
82         expected[i] = output[i];
83     }
84     // Calculate the expected results
85     for (int tid=0; tid<num_elements; tid++) {
86         expected[tid] = 1;
87         for (int a=0; a<times; a++) {
88             int where = tid;
89             if (where == 0)
90                 expected[tid] = 0;
91             for (int b=0; b<where; b++) {
92                 expected[tid] += input[b];
93             }
94         }
95     }
96 
97     // Test kernel calling a kernel
98     log_info("Testing kernel calling kernel...\n");
99     // Create the kernel
100     if( create_single_kernel_helper( context, &program, &kernel1, 1, kernel_call_kernel_code, "test_call_kernel" ) != 0 )
101     {
102         return -1;
103     }
104 
105     kernel_to_call = clCreateKernel(program, "test_kernel_to_call", &error);
106     test_error(error, "clCreateKernel failed");
107 
108     /* Create some I/O streams */
109     streams[0] = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,  sizeof(cl_int)*num_elements, input, &error);
110     test_error( error, "clCreateBuffer failed" );
111     streams[1] = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,  sizeof(cl_int)*num_elements, output, &error);
112     test_error( error, "clCreateBuffer failed" );
113 
114     error = clSetKernelArg(kernel1, 0, sizeof( streams[0] ), &streams[0]);
115     test_error( error, "clSetKernelArg failed" );
116     error = clSetKernelArg(kernel1, 1, sizeof( streams[1] ), &streams[1]);
117     test_error( error, "clSetKernelArg failed" );
118     error = clSetKernelArg(kernel1, 2, sizeof( times ), &times);
119     test_error( error, "clSetKernelArg failed" );
120 
121     error = clEnqueueNDRangeKernel( queue, kernel1, 1, NULL, threads, NULL, 0, NULL, NULL );
122     test_error( error, "clEnqueueNDRangeKernel failed" );
123 
124     error = clEnqueueReadBuffer( queue, streams[1], CL_TRUE, 0, sizeof(cl_int)*num_elements, output, 0, NULL, NULL );
125     test_error( error, "clEnqueueReadBuffer failed" );
126 
127     // Compare the results
128     pass = 1;
129     for (int i=0; i<num_elements; i++) {
130         if (output[i] != expected[i]) {
131             if (errors > 10)
132                 continue;
133             if (errors == 10) {
134                 log_error("Suppressing further results...\n");
135                 continue;
136             }
137             log_error("Results do not match: output[%d]=%d != expected[%d]=%d\n", i, output[i], i, expected[i]);
138             errors++;
139             pass = 0;
140         }
141     }
142     if (pass) log_info("Passed kernel calling kernel...\n");
143 
144 
145 
146     // Test kernel calling a function
147     log_info("Testing kernel calling function...\n");
148     // Reset the inputs
149     for (int i=0; i<num_elements; i++) {
150         input[i] = i;
151         output[i] = i;
152     }
153     error = clEnqueueWriteBuffer(queue, streams[0], CL_TRUE, 0, sizeof(cl_int)*num_elements, input, 0, NULL, NULL);
154     test_error(error, "clEnqueueWriteBuffer failed");
155     error = clEnqueueWriteBuffer(queue, streams[1], CL_TRUE, 0, sizeof(cl_int)*num_elements, output, 0, NULL, NULL);
156     test_error(error, "clEnqueueWriteBuffer failed");
157 
158     kernel2 = clCreateKernel(program, "test_call_function", &error);
159     test_error(error, "clCreateKernel failed");
160 
161     error = clSetKernelArg(kernel2, 0, sizeof( streams[0] ), &streams[0]);
162     test_error( error, "clSetKernelArg failed" );
163     error = clSetKernelArg(kernel2, 1, sizeof( streams[1] ), &streams[1]);
164     test_error( error, "clSetKernelArg failed" );
165     error = clSetKernelArg(kernel2, 2, sizeof( times ), &times);
166     test_error( error, "clSetKernelArg failed" );
167 
168     error = clEnqueueNDRangeKernel( queue, kernel2, 1, NULL, threads, NULL, 0, NULL, NULL );
169     test_error( error, "clEnqueueNDRangeKernel failed" );
170 
171     error = clEnqueueReadBuffer( queue, streams[1], CL_TRUE, 0, sizeof(cl_int)*num_elements, output, 0, NULL, NULL );
172     test_error( error, "clEnqueueReadBuffer failed" );
173 
174     // Compare the results
175     pass = 1;
176     for (int i=0; i<num_elements; i++) {
177         if (output[i] != expected[i]) {
178             if (errors > 10)
179                 continue;
180             if (errors > 10) {
181                 log_error("Suppressing further results...\n");
182                 continue;
183             }
184             log_error("Results do not match: output[%d]=%d != expected[%d]=%d\n", i, output[i], i, expected[i]);
185             errors++;
186             pass = 0;
187         }
188     }
189     if (pass) log_info("Passed kernel calling function...\n");
190 
191 
192     // Test calling the kernel we called from another kernel
193     log_info("Testing calling the kernel we called from another kernel before...\n");
194     // Reset the inputs
195     for (int i=0; i<num_elements; i++) {
196         input[i] = i;
197         output[i] = i;
198         expected[i] = output[i];
199     }
200     error = clEnqueueWriteBuffer(queue, streams[0], CL_TRUE, 0, sizeof(cl_int)*num_elements, input, 0, NULL, NULL);
201     test_error(error, "clEnqueueWriteBuffer failed");
202     error = clEnqueueWriteBuffer(queue, streams[1], CL_TRUE, 0, sizeof(cl_int)*num_elements, output, 0, NULL, NULL);
203     test_error(error, "clEnqueueWriteBuffer failed");
204 
205     // Calculate the expected results
206     int where = times;
207     for (int tid=0; tid<num_elements; tid++) {
208         if (where == 0)
209             expected[tid] = 0;
210         for (int b=0; b<where; b++) {
211             expected[tid] += input[b];
212         }
213     }
214 
215 
216     error = clSetKernelArg(kernel_to_call, 0, sizeof( streams[1] ), &streams[1]);
217     test_error( error, "clSetKernelArg failed" );
218     error = clSetKernelArg(kernel_to_call, 1, sizeof( streams[0] ), &streams[0]);
219     test_error( error, "clSetKernelArg failed" );
220     error = clSetKernelArg(kernel_to_call, 2, sizeof( times ), &times);
221     test_error( error, "clSetKernelArg failed" );
222 
223     error = clEnqueueNDRangeKernel( queue, kernel_to_call, 1, NULL, threads, NULL, 0, NULL, NULL );
224     test_error( error, "clEnqueueNDRangeKernel failed" );
225 
226     error = clEnqueueReadBuffer( queue, streams[1], CL_TRUE, 0, sizeof(cl_int)*num_elements, output, 0, NULL, NULL );
227     test_error( error, "clEnqueueReadBuffer failed" );
228 
229     // Compare the results
230     pass = 1;
231     for (int i=0; i<num_elements; i++) {
232         if (output[i] != expected[i]) {
233             if (errors > 10)
234                 continue;
235             if (errors > 10) {
236                 log_error("Suppressing further results...\n");
237                 continue;
238             }
239             log_error("Results do not match: output[%d]=%d != expected[%d]=%d\n", i, output[i], i, expected[i]);
240             errors++;
241             pass = 0;
242         }
243     }
244     if (pass) log_info("Passed calling the kernel we called from another kernel before...\n");
245 
246     free( input );
247     free( output );
248     free( expected );
249 
250     return errors;
251 }
252 
253 
254