• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17 
18 #include <string>
19 
20 namespace tflite {
21 namespace gpu {
22 namespace cl {
23 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)24 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
25                         const int3& work_group_size,
26                         const int3& work_group_launch_order) {
27   int3 work_groups_count;
28   if (grid_dimension == 1) {
29     work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
30     work_groups_count.y = 1;
31     work_groups_count.z = 1;
32   } else if (grid_dimension == 2) {
33     int3 wgs;
34     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
35     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
36     work_groups_count.x = wgs[work_group_launch_order[0]];
37     work_groups_count.y = wgs[work_group_launch_order[1]];
38     work_groups_count.z = 1;
39   } else {  // grid_dimension == 3
40     int3 wgs;
41     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
42     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
43     wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
44     work_groups_count.x = wgs[work_group_launch_order[0]];
45     work_groups_count.y = wgs[work_group_launch_order[1]];
46     work_groups_count.z = wgs[work_group_launch_order[2]];
47   }
48   return work_groups_count;
49 }
50 
GetCommonOpenCLDefines(CalculationsPrecision precision)51 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
52   std::string result;
53 
54   result += "#define FLT16_0123(V) V.s0123\n";
55   result += "#define FLT16_4567(V) V.s4567\n";
56   result += "#define FLT16_89ab(V) V.s89ab\n";
57   result += "#define FLT16_cdef(V) V.scdef\n";
58   result += "#define GLOBAL_ID_0 get_global_id(0)\n";
59   result += "#define GLOBAL_ID_1 get_global_id(1)\n";
60   result += "#define GLOBAL_ID_2 get_global_id(2)\n";
61   result += "#define LOCAL_ID_0 get_local_id(0)\n";
62   result += "#define LOCAL_ID_1 get_local_id(1)\n";
63   result += "#define LOCAL_ID_2 get_local_id(2)\n";
64   result += "#define GROUP_ID_0 get_group_id(0)\n";
65   result += "#define GROUP_ID_1 get_group_id(1)\n";
66   result += "#define GROUP_ID_2 get_group_id(2)\n";
67   result += "#define GROUP_SIZE_0 get_local_size(0)\n";
68   result += "#define GROUP_SIZE_1 get_local_size(1)\n";
69   result += "#define GROUP_SIZE_2 get_local_size(2)\n";
70   result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
71   result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
72   result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
73   result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
74   result += "#define MAIN_FUNCTION __kernel void main_function\n";
75   result += "#define INIT_FLOAT(value) (float)(value)\n";
76   result += "#define INIT_FLOAT2(value) (float2)(value)\n";
77   result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
78   result += "#define INIT_FLOAT3(value) (float3)(value)\n";
79   result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
80   result += "#define INIT_FLOAT4(value) (float4)(value)\n";
81   result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
82   result += "#define INIT_INT(value) (int)(value)\n";
83   result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
84   result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
85   result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
86   result +=
87       "#define SELECT_BY_INDEX_FROM_FLT4(value, index) (FLT[4]){(value).x, "
88       "(value).y, (value).z, (value).w}[index]\n";
89   switch (precision) {
90     case CalculationsPrecision::F32:
91       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
92       result += "#define ACCUM_FLT4 float4\n";
93       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
94       result += "#define FLT float\n";
95       result += "#define FLT2 float2\n";
96       result += "#define FLT3 float3\n";
97       result += "#define FLT4 float4\n";
98       result += "#define TO_FLT4 convert_float4\n";
99       result += "#define TO_ACCUM_TYPE convert_float4\n";
100       result += "#define TO_ACCUM_FLT convert_float\n";
101       result += "#define TO_ACCUM_FLT2 convert_float2\n";
102       result += "#define TO_ACCUM_FLT3 convert_float3\n";
103       result += "#define TO_ACCUM_FLT4 convert_float4\n";
104       result += "#define INIT_FLT(value) (float)(value)\n";
105       result += "#define INIT_FLT4(value) (float4)(value)\n";
106       result +=
107           "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
108       break;
109     case CalculationsPrecision::F16:
110       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
111       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
112       result += "#define ACCUM_FLT4 half4\n";
113       result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
114       result += "#define FLT half\n";
115       result += "#define FLT2 half2\n";
116       result += "#define FLT3 half3\n";
117       result += "#define FLT4 half4\n";
118       result += "#define TO_FLT4 convert_half4\n";
119       result += "#define TO_ACCUM_TYPE convert_half4\n";
120       result += "#define TO_ACCUM_FLT convert_half\n";
121       result += "#define TO_ACCUM_FLT2 convert_half2\n";
122       result += "#define TO_ACCUM_FLT3 convert_half3\n";
123       result += "#define TO_ACCUM_FLT4 convert_half4\n";
124       result += "#define INIT_FLT(value) (half)(value)\n";
125       result += "#define INIT_FLT4(value) (half4)(value)\n";
126       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
127       break;
128     case CalculationsPrecision::F32_F16:
129       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
130       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
131       result += "#define ACCUM_FLT4 float4\n";
132       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
133       result += "#define FLT half\n";
134       result += "#define FLT2 half2\n";
135       result += "#define FLT3 half3\n";
136       result += "#define FLT4 half4\n";
137       result += "#define TO_FLT4 convert_half4\n";
138       result += "#define TO_ACCUM_TYPE convert_float4\n";
139       result += "#define TO_ACCUM_FLT convert_float\n";
140       result += "#define TO_ACCUM_FLT2 convert_float2\n";
141       result += "#define TO_ACCUM_FLT3 convert_float3\n";
142       result += "#define TO_ACCUM_FLT4 convert_float4\n";
143       result += "#define INIT_FLT(value) (half)(value)\n";
144       result += "#define INIT_FLT4(value) (half4)(value)\n";
145       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
146       break;
147   }
148   return result;
149 }
150 }  // namespace
151 
AddOperation(ClOperation * operation)152 absl::Status ClOperation::AddOperation(ClOperation* operation) {
153   return operation_->AddOperation(operation->operation_.get());
154 }
155 
UpdateParams()156 absl::Status ClOperation::UpdateParams() {
157   for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
158     const auto* cl_spatial_tensor =
159         dynamic_cast<const Tensor*>(operation_->src_[i]);
160     if (!cl_spatial_tensor) {
161       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
162     }
163     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->src_tensors_names_[i],
164                                           cl_spatial_tensor));
165   }
166   for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
167     const auto* cl_spatial_tensor =
168         dynamic_cast<const Tensor*>(operation_->dst_[i]);
169     if (!cl_spatial_tensor) {
170       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
171     }
172     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->dst_tensors_names_[i],
173                                           cl_spatial_tensor));
174   }
175   RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
176   operation_->grid_size_ = operation_->GetGridSize();
177   operation_->work_groups_count_ = GetWorkGroupsCount(
178       operation_->grid_dimension_, operation_->grid_size_,
179       operation_->work_group_size_, operation_->work_group_launch_order_);
180   return absl::OkStatus();
181 }
182 
Compile(const CreationContext & creation_context)183 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
184   operation_->AssembleCode(creation_context.GetGpuInfo());
185   operation_->code_ =
186       GetCommonOpenCLDefines(operation_->definition_.precision) +
187       operation_->code_;
188   RETURN_IF_ERROR(cl_args_.Init(
189       creation_context.GetGpuInfo(),
190       {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
191       creation_context.context, &operation_->args_, &operation_->code_));
192   RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
193       operation_->code_, "main_function", operation_->compiler_options_,
194       *creation_context.context, *creation_context.device, &kernel_,
195       &kernel_fingerprint_));
196   return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
197                                       kernel_.info_);
198 }
199 
InitFromCache(uint64_t fingerprint,const ProgramCache & program_cache)200 absl::Status ClOperation::InitFromCache(uint64_t fingerprint,
201                                         const ProgramCache& program_cache) {
202   kernel_fingerprint_ = fingerprint;
203   return program_cache.GetKernel(kernel_fingerprint_, "main_function",
204                                  &kernel_);
205 }
206 
RestoreDeserialized(const CreationContext & creation_context)207 absl::Status ClOperation::RestoreDeserialized(
208     const CreationContext& creation_context) {
209   return cl_args_.Init(creation_context.GetGpuInfo(), &operation_->args_,
210                        creation_context.context);
211 }
212 
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)213 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
214                                ProfilingCommandQueue* profiling_queue) {
215   std::vector<int3> possible_work_groups;
216   operation_->GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_.info_,
217                                           &possible_work_groups);
218   if (possible_work_groups.empty()) {
219     return absl::NotFoundError(
220         "Can not found work_group size to launch kernel");
221   }
222   if (possible_work_groups.size() == 1) {
223     operation_->work_group_size_ = possible_work_groups[0];
224     operation_->work_groups_count_ = GetWorkGroupsCount(
225         operation_->grid_dimension_, operation_->grid_size_,
226         operation_->work_group_size_, operation_->work_group_launch_order_);
227     return absl::OkStatus();
228   } else {
229     std::vector<int3> work_groups_count(possible_work_groups.size());
230     for (int i = 0; i < work_groups_count.size(); ++i) {
231       work_groups_count[i] = GetWorkGroupsCount(
232           operation_->grid_dimension_, operation_->grid_size_,
233           possible_work_groups[i], operation_->work_group_launch_order_);
234     }
235     RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
236     int best_work_group_index;
237     RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
238         kernel_, gpu_info, work_groups_count, possible_work_groups,
239         &best_work_group_index));
240     operation_->work_group_size_ = possible_work_groups[best_work_group_index];
241     operation_->work_groups_count_ = GetWorkGroupsCount(
242         operation_->grid_dimension_, operation_->grid_size_,
243         operation_->work_group_size_, operation_->work_group_launch_order_);
244     return absl::OkStatus();
245   }
246 }
247 
248 }  // namespace cl
249 }  // namespace gpu
250 }  // namespace tflite
251