1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17
18 #include <string>
19
20 namespace tflite {
21 namespace gpu {
22 namespace cl {
23 namespace {
GetCommonOpenCLDefines(CalculationsPrecision precision)24 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
25 std::string result;
26
27 result += "#define FLT16_0123(V) V.s0123\n";
28 result += "#define FLT16_4567(V) V.s4567\n";
29 result += "#define FLT16_89ab(V) V.s89ab\n";
30 result += "#define FLT16_cdef(V) V.scdef\n";
31 result += "#define GLOBAL_ID_0 get_global_id(0)\n";
32 result += "#define GLOBAL_ID_1 get_global_id(1)\n";
33 result += "#define GLOBAL_ID_2 get_global_id(2)\n";
34 result += "#define LOCAL_ID_0 get_local_id(0)\n";
35 result += "#define LOCAL_ID_1 get_local_id(1)\n";
36 result += "#define LOCAL_ID_2 get_local_id(2)\n";
37 result += "#define GROUP_ID_0 get_group_id(0)\n";
38 result += "#define GROUP_ID_1 get_group_id(1)\n";
39 result += "#define GROUP_ID_2 get_group_id(2)\n";
40 result += "#define GROUP_SIZE_0 get_local_size(0)\n";
41 result += "#define GROUP_SIZE_1 get_local_size(1)\n";
42 result += "#define GROUP_SIZE_2 get_local_size(2)\n";
43 result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
44 result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
45 result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
46 result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
47 result += "#define MAIN_FUNCTION __kernel void main_function\n";
48 result += "#define INIT_FLOAT(value) (float)(value)\n";
49 result += "#define INIT_FLOAT2(value) (float2)(value)\n";
50 result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
51 result += "#define INIT_FLOAT3(value) (float3)(value)\n";
52 result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
53 result += "#define INIT_FLOAT4(value) (float4)(value)\n";
54 result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
55 result += "#define INIT_INT(value) (int)(value)\n";
56 result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
57 result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
58 result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
59 switch (precision) {
60 case CalculationsPrecision::F32:
61 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
62 result += "#define ACCUM_FLT4 float4\n";
63 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
64 result += "#define FLT float\n";
65 result += "#define FLT2 float2\n";
66 result += "#define FLT3 float3\n";
67 result += "#define FLT4 float4\n";
68 result += "#define TO_FLT4 convert_float4\n";
69 result += "#define TO_ACCUM_TYPE convert_float4\n";
70 result += "#define TO_ACCUM_FLT convert_float\n";
71 result += "#define TO_ACCUM_FLT2 convert_float2\n";
72 result += "#define TO_ACCUM_FLT3 convert_float3\n";
73 result += "#define TO_ACCUM_FLT4 convert_float4\n";
74 result += "#define INIT_FLT(value) (float)(value)\n";
75 result += "#define INIT_FLT4(value) (float4)(value)\n";
76 result +=
77 "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
78 break;
79 case CalculationsPrecision::F16:
80 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
81 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
82 result += "#define ACCUM_FLT4 half4\n";
83 result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
84 result += "#define FLT half\n";
85 result += "#define FLT2 half2\n";
86 result += "#define FLT3 half3\n";
87 result += "#define FLT4 half4\n";
88 result += "#define TO_FLT4 convert_half4\n";
89 result += "#define TO_ACCUM_TYPE convert_half4\n";
90 result += "#define TO_ACCUM_FLT convert_half\n";
91 result += "#define TO_ACCUM_FLT2 convert_half2\n";
92 result += "#define TO_ACCUM_FLT3 convert_half3\n";
93 result += "#define TO_ACCUM_FLT4 convert_half4\n";
94 result += "#define INIT_FLT(value) (half)(value)\n";
95 result += "#define INIT_FLT4(value) (half4)(value)\n";
96 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
97 break;
98 case CalculationsPrecision::F32_F16:
99 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
100 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
101 result += "#define ACCUM_FLT4 float4\n";
102 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
103 result += "#define FLT half\n";
104 result += "#define FLT2 half2\n";
105 result += "#define FLT3 half3\n";
106 result += "#define FLT4 half4\n";
107 result += "#define TO_FLT4 convert_half4\n";
108 result += "#define TO_ACCUM_TYPE convert_float4\n";
109 result += "#define TO_ACCUM_FLT convert_float\n";
110 result += "#define TO_ACCUM_FLT2 convert_float2\n";
111 result += "#define TO_ACCUM_FLT3 convert_float3\n";
112 result += "#define TO_ACCUM_FLT4 convert_float4\n";
113 result += "#define INIT_FLT(value) (half)(value)\n";
114 result += "#define INIT_FLT4(value) (half4)(value)\n";
115 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
116 break;
117 }
118 result += "#define bool2 uchar2\n";
119 result += "#define bool3 uchar3\n";
120 result += "#define bool4 uchar4\n";
121 return result;
122 }
123 } // namespace
124
UpdateParams()125 absl::Status ClOperation::UpdateParams() {
126 for (int i = 0; i < operation_->GetSrcTensorsNames().size(); ++i) {
127 const auto* cl_spatial_tensor =
128 dynamic_cast<const Tensor*>(operation_->GetSrcTensors()[i]);
129 if (!cl_spatial_tensor) {
130 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
131 }
132 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->GetSrcTensorsNames()[i],
133 cl_spatial_tensor));
134 }
135 for (int i = 0; i < operation_->GetDstTensorsNames().size(); ++i) {
136 const auto* cl_spatial_tensor =
137 dynamic_cast<const Tensor*>(operation_->GetDstTensors()[i]);
138 if (!cl_spatial_tensor) {
139 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
140 }
141 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->GetDstTensorsNames()[i],
142 cl_spatial_tensor));
143 }
144 RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
145 operation_->RecalculateGridSize();
146 operation_->RecalculateWorkGroupsCount();
147 return absl::OkStatus();
148 }
149
SetSrcTensor(int index,Tensor * tensor)150 absl::Status ClOperation::SetSrcTensor(int index, Tensor* tensor) {
151 operation_->SetSrc(tensor, index);
152 return cl_args_.SetObjectRef(operation_->GetSrcTensorsNames()[index], tensor);
153 }
154
SetDstTensor(int index,Tensor * tensor)155 absl::Status ClOperation::SetDstTensor(int index, Tensor* tensor) {
156 operation_->SetDst(tensor, index);
157 return cl_args_.SetObjectRef(operation_->GetDstTensorsNames()[index], tensor);
158 }
159
Compile(const CreationContext & creation_context)160 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
161 operation_->code_ =
162 GetCommonOpenCLDefines(operation_->GetDefinition().precision) +
163 operation_->code_;
164 RETURN_IF_ERROR(cl_args_.Init(
165 creation_context.GetGpuInfo(),
166 creation_context.context, &operation_->args_, &operation_->code_));
167 operation_->args_.ReleaseCPURepresentation();
168 RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
169 operation_->code_, "main_function", operation_->compiler_options_,
170 *creation_context.context, *creation_context.device, &kernel_,
171 &kernel_fingerprint_));
172 return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
173 kernel_.info_);
174 }
175
RestoreDeserialized(const ProgramCache & program_cache,uint64_t fingerprint,const GpuInfo & gpu_info,const int3 & work_group_size,CLContext * context)176 absl::Status ClOperation::RestoreDeserialized(const ProgramCache& program_cache,
177 uint64_t fingerprint,
178 const GpuInfo& gpu_info,
179 const int3& work_group_size,
180 CLContext* context) {
181 kernel_fingerprint_ = fingerprint;
182 RETURN_IF_ERROR(
183 program_cache.GetKernel(kernel_fingerprint_, "main_function", &kernel_));
184 operation_->work_group_size_ = work_group_size;
185 operation_->RecalculateWorkGroupsCount();
186 RETURN_IF_ERROR(cl_args_.Init(gpu_info, &operation_->args_, context));
187 operation_->args_.ReleaseCPURepresentation();
188 return absl::OkStatus();
189 }
190
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)191 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
192 ProfilingCommandQueue* profiling_queue) {
193 std::vector<GPUOperation::DispatchInfo> possible_dispatches;
194 operation_->GetPossibleDispatches(tuning_type, gpu_info, kernel_.info_,
195 &possible_dispatches);
196 if (possible_dispatches.empty()) {
197 return absl::NotFoundError("No dispatch parameters to launch kernel");
198 }
199 if (possible_dispatches.size() == 1) {
200 operation_->work_group_size_ = possible_dispatches[0].work_group_size;
201 operation_->RecalculateWorkGroupsCount();
202 return absl::OkStatus();
203 } else {
204 std::vector<int3> work_group_sizes(possible_dispatches.size());
205 std::vector<int3> work_groups_counts(possible_dispatches.size());
206 for (int i = 0; i < possible_dispatches.size(); ++i) {
207 work_group_sizes[i] = possible_dispatches[i].work_group_size;
208 work_groups_counts[i] = possible_dispatches[i].work_groups_count;
209 }
210 RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
211 int best_work_group_index;
212 RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
213 kernel_, gpu_info, work_groups_counts, work_group_sizes,
214 &best_work_group_index));
215 operation_->work_group_size_ = work_group_sizes[best_work_group_index];
216 operation_->RecalculateWorkGroupsCount();
217 return absl::OkStatus();
218 }
219 }
220
221 } // namespace cl
222 } // namespace gpu
223 } // namespace tflite
224