1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
20 #include "tensorflow/lite/delegates/gpu/cl/util.h"
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22
23 namespace tflite {
24 namespace gpu {
25 namespace cl {
26 namespace {
27
GetKernelMaxWorkGroupSize(cl_kernel kernel,cl_device_id device_id,int * result)28 absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
29 int* result) {
30 size_t max_work_group_size;
31 cl_int error_code =
32 clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
33 sizeof(size_t), &max_work_group_size, nullptr);
34 if (error_code != CL_SUCCESS) {
35 return absl::UnknownError(
36 absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
37 CLErrorCodeToString(error_code)));
38 }
39 *result = static_cast<int>(max_work_group_size);
40 return absl::OkStatus();
41 }
42
GetKernelPrivateMemorySize(cl_kernel kernel,cl_device_id device_id,int * result)43 absl::Status GetKernelPrivateMemorySize(cl_kernel kernel,
44 cl_device_id device_id, int* result) {
45 cl_ulong private_mem_size;
46 cl_int error_code =
47 clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
48 sizeof(cl_ulong), &private_mem_size, nullptr);
49 if (error_code != CL_SUCCESS) {
50 return absl::UnknownError(
51 absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
52 CLErrorCodeToString(error_code)));
53 }
54 *result = static_cast<int>(private_mem_size);
55 return absl::OkStatus();
56 }
57
58 } // namespace
59
CLKernel(CLKernel && kernel)60 CLKernel::CLKernel(CLKernel&& kernel)
61 : info_(kernel.info_),
62 binding_counter_(kernel.binding_counter_),
63 function_name_(std::move(kernel.function_name_)),
64 program_(kernel.program_),
65 kernel_(kernel.kernel_) {
66 kernel.kernel_ = nullptr;
67 }
68
operator =(CLKernel && kernel)69 CLKernel& CLKernel::operator=(CLKernel&& kernel) {
70 if (this != &kernel) {
71 Release();
72 std::swap(info_, kernel.info_);
73 std::swap(binding_counter_, kernel.binding_counter_);
74 function_name_ = std::move(kernel.function_name_);
75 std::swap(program_, kernel.program_);
76 std::swap(kernel_, kernel.kernel_);
77 }
78 return *this;
79 }
80
~CLKernel()81 CLKernel::~CLKernel() { Release(); }
82
ReInit() const83 absl::Status CLKernel::ReInit() const {
84 clReleaseKernel(kernel_);
85 cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
86 int error_code;
87 *kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
88 if (!kernel_ || error_code != CL_SUCCESS) {
89 *kern_ptr = nullptr;
90 return absl::UnknownError(absl::StrCat("Failed to create ", function_name_,
91 CLErrorCodeToString(error_code)));
92 }
93 return absl::OkStatus();
94 }
95
Release()96 void CLKernel::Release() {
97 if (kernel_) {
98 clReleaseKernel(kernel_);
99 clReleaseProgram(program_);
100 kernel_ = nullptr;
101 }
102 }
103
CreateFromProgram(const CLProgram & program,const std::string & function_name)104 absl::Status CLKernel::CreateFromProgram(const CLProgram& program,
105 const std::string& function_name) {
106 int error_code;
107 function_name_ = function_name;
108 kernel_ =
109 clCreateKernel(program.program(), function_name.c_str(), &error_code);
110 if (!kernel_ || error_code != CL_SUCCESS) {
111 kernel_ = nullptr;
112 return absl::UnknownError(absl::StrCat("Failed to create ", function_name,
113 CLErrorCodeToString(error_code)));
114 }
115
116 program_ = program.program();
117 clRetainProgram(program_);
118
119 RETURN_IF_ERROR(GetKernelPrivateMemorySize(kernel_, program.GetDeviceId(),
120 &info_.private_memory_size));
121 RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
122 &info_.max_work_group_size));
123 return absl::OkStatus();
124 }
125
SetMemory(int index,cl_mem memory)126 absl::Status CLKernel::SetMemory(int index, cl_mem memory) {
127 return SetBytes(index, &memory, sizeof(cl_mem));
128 }
129
SetMemoryAuto(cl_mem memory)130 absl::Status CLKernel::SetMemoryAuto(cl_mem memory) {
131 return SetBytesAuto(&memory, sizeof(cl_mem));
132 }
133
SetBytes(int index,const void * ptr,int length) const134 absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
135 const int error_code = clSetKernelArg(kernel_, index, length, ptr);
136 if (error_code != CL_SUCCESS) {
137 return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
138 CLErrorCodeToString(error_code)));
139 }
140 return absl::OkStatus();
141 }
142
SetBytesAuto(const void * ptr,int length)143 absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) {
144 const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
145 if (error_code != CL_SUCCESS) {
146 return absl::UnknownError(absl::StrCat(
147 "Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
148 "(at index - ", binding_counter_, ")"));
149 }
150 binding_counter_++;
151 return absl::OkStatus();
152 }
153
154 } // namespace cl
155 } // namespace gpu
156 } // namespace tflite
157