1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17
18 #include <string>
19
20 namespace tflite {
21 namespace gpu {
22 namespace cl {
23 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)24 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
25 const int3& work_group_size,
26 const int3& work_group_launch_order) {
27 int3 work_groups_count;
28 if (grid_dimension == 1) {
29 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
30 work_groups_count.y = 1;
31 work_groups_count.z = 1;
32 } else if (grid_dimension == 2) {
33 int3 wgs;
34 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
35 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
36 work_groups_count.x = wgs[work_group_launch_order[0]];
37 work_groups_count.y = wgs[work_group_launch_order[1]];
38 work_groups_count.z = 1;
39 } else { // grid_dimension == 3
40 int3 wgs;
41 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
42 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
43 wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
44 work_groups_count.x = wgs[work_group_launch_order[0]];
45 work_groups_count.y = wgs[work_group_launch_order[1]];
46 work_groups_count.z = wgs[work_group_launch_order[2]];
47 }
48 return work_groups_count;
49 }
50
GetCommonOpenCLDefines(CalculationsPrecision precision)51 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
52 std::string result;
53
54 result += "#define FLT16_0123(V) V.s0123\n";
55 result += "#define FLT16_4567(V) V.s4567\n";
56 result += "#define FLT16_89ab(V) V.s89ab\n";
57 result += "#define FLT16_cdef(V) V.scdef\n";
58 result += "#define GLOBAL_ID_0 get_global_id(0)\n";
59 result += "#define GLOBAL_ID_1 get_global_id(1)\n";
60 result += "#define GLOBAL_ID_2 get_global_id(2)\n";
61 result += "#define LOCAL_ID_0 get_local_id(0)\n";
62 result += "#define LOCAL_ID_1 get_local_id(1)\n";
63 result += "#define LOCAL_ID_2 get_local_id(2)\n";
64 result += "#define GROUP_ID_0 get_group_id(0)\n";
65 result += "#define GROUP_ID_1 get_group_id(1)\n";
66 result += "#define GROUP_ID_2 get_group_id(2)\n";
67 result += "#define GROUP_SIZE_0 get_local_size(0)\n";
68 result += "#define GROUP_SIZE_1 get_local_size(1)\n";
69 result += "#define GROUP_SIZE_2 get_local_size(2)\n";
70 result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
71 result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
72 result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
73 result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
74 result += "#define MAIN_FUNCTION __kernel void main_function\n";
75 result += "#define INIT_FLOAT(value) (float)(value)\n";
76 result += "#define INIT_FLOAT2(value) (float2)(value)\n";
77 result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
78 result += "#define INIT_FLOAT3(value) (float3)(value)\n";
79 result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
80 result += "#define INIT_FLOAT4(value) (float4)(value)\n";
81 result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
82 result += "#define INIT_INT(value) (int)(value)\n";
83 result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
84 result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
85 result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
86 result +=
87 "#define SELECT_BY_INDEX_FROM_FLT4(value, index) (FLT[4]){(value).x, "
88 "(value).y, (value).z, (value).w}[index]\n";
89 switch (precision) {
90 case CalculationsPrecision::F32:
91 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
92 result += "#define ACCUM_FLT4 float4\n";
93 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
94 result += "#define FLT float\n";
95 result += "#define FLT2 float2\n";
96 result += "#define FLT3 float3\n";
97 result += "#define FLT4 float4\n";
98 result += "#define TO_FLT4 convert_float4\n";
99 result += "#define TO_ACCUM_TYPE convert_float4\n";
100 result += "#define TO_ACCUM_FLT convert_float\n";
101 result += "#define TO_ACCUM_FLT2 convert_float2\n";
102 result += "#define TO_ACCUM_FLT3 convert_float3\n";
103 result += "#define TO_ACCUM_FLT4 convert_float4\n";
104 result += "#define INIT_FLT(value) (float)(value)\n";
105 result += "#define INIT_FLT4(value) (float4)(value)\n";
106 result +=
107 "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
108 break;
109 case CalculationsPrecision::F16:
110 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
111 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
112 result += "#define ACCUM_FLT4 half4\n";
113 result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
114 result += "#define FLT half\n";
115 result += "#define FLT2 half2\n";
116 result += "#define FLT3 half3\n";
117 result += "#define FLT4 half4\n";
118 result += "#define TO_FLT4 convert_half4\n";
119 result += "#define TO_ACCUM_TYPE convert_half4\n";
120 result += "#define TO_ACCUM_FLT convert_half\n";
121 result += "#define TO_ACCUM_FLT2 convert_half2\n";
122 result += "#define TO_ACCUM_FLT3 convert_half3\n";
123 result += "#define TO_ACCUM_FLT4 convert_half4\n";
124 result += "#define INIT_FLT(value) (half)(value)\n";
125 result += "#define INIT_FLT4(value) (half4)(value)\n";
126 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
127 break;
128 case CalculationsPrecision::F32_F16:
129 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
130 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
131 result += "#define ACCUM_FLT4 float4\n";
132 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
133 result += "#define FLT half\n";
134 result += "#define FLT2 half2\n";
135 result += "#define FLT3 half3\n";
136 result += "#define FLT4 half4\n";
137 result += "#define TO_FLT4 convert_half4\n";
138 result += "#define TO_ACCUM_TYPE convert_float4\n";
139 result += "#define TO_ACCUM_FLT convert_float\n";
140 result += "#define TO_ACCUM_FLT2 convert_float2\n";
141 result += "#define TO_ACCUM_FLT3 convert_float3\n";
142 result += "#define TO_ACCUM_FLT4 convert_float4\n";
143 result += "#define INIT_FLT(value) (half)(value)\n";
144 result += "#define INIT_FLT4(value) (half4)(value)\n";
145 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
146 break;
147 }
148 return result;
149 }
150 } // namespace
151
AddOperation(ClOperation * operation)152 absl::Status ClOperation::AddOperation(ClOperation* operation) {
153 return operation_->AddOperation(operation->operation_.get());
154 }
155
UpdateParams()156 absl::Status ClOperation::UpdateParams() {
157 for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
158 const auto* cl_spatial_tensor =
159 dynamic_cast<const Tensor*>(operation_->src_[i]);
160 if (!cl_spatial_tensor) {
161 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
162 }
163 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->src_tensors_names_[i],
164 cl_spatial_tensor));
165 }
166 for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
167 const auto* cl_spatial_tensor =
168 dynamic_cast<const Tensor*>(operation_->dst_[i]);
169 if (!cl_spatial_tensor) {
170 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
171 }
172 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->dst_tensors_names_[i],
173 cl_spatial_tensor));
174 }
175 RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
176 operation_->grid_size_ = operation_->GetGridSize();
177 operation_->work_groups_count_ = GetWorkGroupsCount(
178 operation_->grid_dimension_, operation_->grid_size_,
179 operation_->work_group_size_, operation_->work_group_launch_order_);
180 return absl::OkStatus();
181 }
182
Compile(const CreationContext & creation_context)183 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
184 operation_->AssembleCode(creation_context.GetGpuInfo());
185 operation_->code_ =
186 GetCommonOpenCLDefines(operation_->definition_.precision) +
187 operation_->code_;
188 RETURN_IF_ERROR(cl_args_.Init(
189 creation_context.GetGpuInfo(),
190 {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
191 creation_context.context, &operation_->args_, &operation_->code_));
192 RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
193 operation_->code_, "main_function", operation_->compiler_options_,
194 *creation_context.context, *creation_context.device, &kernel_,
195 &kernel_fingerprint_));
196 return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
197 kernel_.info_);
198 }
199
InitFromCache(uint64_t fingerprint,const ProgramCache & program_cache)200 absl::Status ClOperation::InitFromCache(uint64_t fingerprint,
201 const ProgramCache& program_cache) {
202 kernel_fingerprint_ = fingerprint;
203 return program_cache.GetKernel(kernel_fingerprint_, "main_function",
204 &kernel_);
205 }
206
RestoreDeserialized(const CreationContext & creation_context)207 absl::Status ClOperation::RestoreDeserialized(
208 const CreationContext& creation_context) {
209 return cl_args_.Init(creation_context.GetGpuInfo(), &operation_->args_,
210 creation_context.context);
211 }
212
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)213 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
214 ProfilingCommandQueue* profiling_queue) {
215 std::vector<int3> possible_work_groups;
216 operation_->GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_.info_,
217 &possible_work_groups);
218 if (possible_work_groups.empty()) {
219 return absl::NotFoundError(
220 "Can not found work_group size to launch kernel");
221 }
222 if (possible_work_groups.size() == 1) {
223 operation_->work_group_size_ = possible_work_groups[0];
224 operation_->work_groups_count_ = GetWorkGroupsCount(
225 operation_->grid_dimension_, operation_->grid_size_,
226 operation_->work_group_size_, operation_->work_group_launch_order_);
227 return absl::OkStatus();
228 } else {
229 std::vector<int3> work_groups_count(possible_work_groups.size());
230 for (int i = 0; i < work_groups_count.size(); ++i) {
231 work_groups_count[i] = GetWorkGroupsCount(
232 operation_->grid_dimension_, operation_->grid_size_,
233 possible_work_groups[i], operation_->work_group_launch_order_);
234 }
235 RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
236 int best_work_group_index;
237 RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
238 kernel_, gpu_info, work_groups_count, possible_work_groups,
239 &best_work_group_index));
240 operation_->work_group_size_ = possible_work_groups[best_work_group_index];
241 operation_->work_groups_count_ = GetWorkGroupsCount(
242 operation_->grid_dimension_, operation_->grid_size_,
243 operation_->work_group_size_, operation_->work_group_launch_order_);
244 return absl::OkStatus();
245 }
246 }
247
248 } // namespace cl
249 } // namespace gpu
250 } // namespace tflite
251