• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17 
18 namespace tflite {
19 namespace gpu {
20 namespace cl {
21 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)22 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
23                         const int3& work_group_size,
24                         const int3& work_group_launch_order) {
25   int3 work_groups_count;
26   if (grid_dimension == 1) {
27     work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
28     work_groups_count.y = 1;
29     work_groups_count.z = 1;
30   } else if (grid_dimension == 2) {
31     int3 wgs;
32     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
33     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
34     work_groups_count.x = wgs[work_group_launch_order[0]];
35     work_groups_count.y = wgs[work_group_launch_order[1]];
36     work_groups_count.z = 1;
37   } else {  // grid_dimension == 3
38     int3 wgs;
39     wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
40     wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
41     wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
42     work_groups_count.x = wgs[work_group_launch_order[0]];
43     work_groups_count.y = wgs[work_group_launch_order[1]];
44     work_groups_count.z = wgs[work_group_launch_order[2]];
45   }
46   return work_groups_count;
47 }
48 
GetCommonOpenCLDefines(CalculationsPrecision precision)49 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
50   std::string result;
51 
52   result += "#define FLT16_0123(V) V.s0123\n";
53   result += "#define FLT16_4567(V) V.s4567\n";
54   result += "#define FLT16_89ab(V) V.s89ab\n";
55   result += "#define FLT16_cdef(V) V.scdef\n";
56   result += "#define GLOBAL_ID_0 get_global_id(0)\n";
57   result += "#define GLOBAL_ID_1 get_global_id(1)\n";
58   result += "#define GLOBAL_ID_2 get_global_id(2)\n";
59   result += "#define LOCAL_ID_0 get_local_id(0)\n";
60   result += "#define LOCAL_ID_1 get_local_id(1)\n";
61   result += "#define LOCAL_ID_2 get_local_id(2)\n";
62   result += "#define GROUP_ID_0 get_group_id(0)\n";
63   result += "#define GROUP_ID_1 get_group_id(1)\n";
64   result += "#define GROUP_ID_2 get_group_id(2)\n";
65   result += "#define GROUP_SIZE_0 get_local_size(0)\n";
66   result += "#define GROUP_SIZE_1 get_local_size(1)\n";
67   result += "#define GROUP_SIZE_2 get_local_size(2)\n";
68   result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
69   result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
70   result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
71   result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
72   result += "#define MAIN_FUNCTION __kernel void main_function\n";
73   result += "#define INIT_FLOAT(value) (float)(value)\n";
74   result += "#define INIT_FLOAT2(value) (float2)(value)\n";
75   result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
76   result += "#define INIT_FLOAT3(value) (float3)(value)\n";
77   result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
78   result += "#define INIT_FLOAT4(value) (float4)(value)\n";
79   result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
80   result += "#define INIT_INT(value) (int)(value)\n";
81   result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
82   result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
83   result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
84   switch (precision) {
85     case CalculationsPrecision::F32:
86       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
87       result += "#define ACCUM_FLT4 float4\n";
88       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
89       result += "#define FLT float\n";
90       result += "#define FLT2 float2\n";
91       result += "#define FLT3 float3\n";
92       result += "#define FLT4 float4\n";
93       result += "#define TO_FLT4 convert_float4\n";
94       result += "#define TO_ACCUM_TYPE convert_float4\n";
95       result += "#define TO_ACCUM_FLT convert_float\n";
96       result += "#define INIT_FLT(value) (float)(value)\n";
97       result += "#define INIT_FLT4(value) (float4)(value)\n";
98       result +=
99           "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
100       break;
101     case CalculationsPrecision::F16:
102       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
103       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
104       result += "#define ACCUM_FLT4 half4\n";
105       result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
106       result += "#define FLT half\n";
107       result += "#define FLT2 half2\n";
108       result += "#define FLT3 half3\n";
109       result += "#define FLT4 half4\n";
110       result += "#define TO_FLT4 convert_half4\n";
111       result += "#define TO_ACCUM_TYPE convert_half4\n";
112       result += "#define TO_ACCUM_FLT convert_half\n";
113       result += "#define INIT_FLT(value) (half)(value)\n";
114       result += "#define INIT_FLT4(value) (half4)(value)\n";
115       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
116       break;
117     case CalculationsPrecision::F32_F16:
118       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
119       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
120       result += "#define ACCUM_FLT4 float4\n";
121       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
122       result += "#define FLT half\n";
123       result += "#define FLT2 half2\n";
124       result += "#define FLT3 half3\n";
125       result += "#define FLT4 half4\n";
126       result += "#define TO_FLT4 convert_half4\n";
127       result += "#define TO_ACCUM_TYPE convert_float4\n";
128       result += "#define TO_ACCUM_FLT convert_float\n";
129       result += "#define INIT_FLT(value) (half)(value)\n";
130       result += "#define INIT_FLT4(value) (half4)(value)\n";
131       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
132       break;
133   }
134   return result;
135 }
136 }  // namespace
137 
ClOperation(ClOperation && operation)138 ClOperation::ClOperation(ClOperation&& operation)
139     : operation_(std::move(operation.operation_)),
140       kernel_(std::move(operation.kernel_)),
141       cl_args_(std::move(operation.cl_args_)) {}
142 
operator =(ClOperation && operation)143 ClOperation& ClOperation::operator=(ClOperation&& operation) {
144   if (this != &operation) {
145     operation_ = std::move(operation.operation_);
146     kernel_ = std::move(operation.kernel_);
147     cl_args_ = std::move(operation.cl_args_);
148   }
149   return *this;
150 }
151 
AddOperation(ClOperation * operation)152 absl::Status ClOperation::AddOperation(ClOperation* operation) {
153   return operation_->AddOperation(operation->operation_.get());
154 }
155 
UpdateParams()156 absl::Status ClOperation::UpdateParams() {
157   for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
158     const auto* cl_spatial_tensor =
159         dynamic_cast<const Tensor*>(operation_->src_[i]);
160     if (!cl_spatial_tensor) {
161       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
162     }
163     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->src_tensors_names_[i],
164                                           cl_spatial_tensor));
165   }
166   for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
167     const auto* cl_spatial_tensor =
168         dynamic_cast<const Tensor*>(operation_->dst_[i]);
169     if (!cl_spatial_tensor) {
170       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
171     }
172     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->dst_tensors_names_[i],
173                                           cl_spatial_tensor));
174   }
175   RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
176   operation_->grid_size_ = operation_->GetGridSize();
177   operation_->work_groups_count_ = GetWorkGroupsCount(
178       operation_->grid_dimension_, operation_->grid_size_,
179       operation_->work_group_size_, operation_->work_group_launch_order_);
180   return absl::OkStatus();
181 }
182 
Compile(const CreationContext & creation_context)183 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
184   operation_->AssembleCode(creation_context.GetGpuInfo());
185   operation_->code_ =
186       GetCommonOpenCLDefines(operation_->definition_.precision) +
187       operation_->code_;
188   RETURN_IF_ERROR(cl_args_.Init(
189       creation_context.GetGpuInfo(),
190       {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
191       creation_context.context, &operation_->args_, &operation_->code_));
192   RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
193       operation_->code_, "main_function", operation_->compiler_options_,
194       *creation_context.context, *creation_context.device, &kernel_));
195   return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
196                                       kernel_.info_);
197 }
198 
CompileDeserialized(const CreationContext & creation_context)199 absl::Status ClOperation::CompileDeserialized(
200     const CreationContext& creation_context) {
201   RETURN_IF_ERROR(cl_args_.Init(creation_context.GetGpuInfo(),
202                                 &operation_->args_, creation_context.context));
203   return creation_context.cache->GetOrCreateCLKernel(
204       operation_->code_, "main_function", operation_->compiler_options_,
205       *creation_context.context, *creation_context.device, &kernel_);
206 }
207 
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)208 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
209                                ProfilingCommandQueue* profiling_queue) {
210   std::vector<int3> possible_work_groups;
211   operation_->GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_.info_,
212                                           &possible_work_groups);
213   if (possible_work_groups.empty()) {
214     return absl::NotFoundError(
215         "Can not found work_group size to launch kernel");
216   }
217   if (possible_work_groups.size() == 1) {
218     operation_->work_group_size_ = possible_work_groups[0];
219     operation_->work_groups_count_ = GetWorkGroupsCount(
220         operation_->grid_dimension_, operation_->grid_size_,
221         operation_->work_group_size_, operation_->work_group_launch_order_);
222     return absl::OkStatus();
223   } else {
224     std::vector<int3> work_groups_count(possible_work_groups.size());
225     for (int i = 0; i < work_groups_count.size(); ++i) {
226       work_groups_count[i] = GetWorkGroupsCount(
227           operation_->grid_dimension_, operation_->grid_size_,
228           possible_work_groups[i], operation_->work_group_launch_order_);
229     }
230     RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
231     int best_work_group_index;
232     RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
233         kernel_, gpu_info, work_groups_count, possible_work_groups,
234         &best_work_group_index));
235     operation_->work_group_size_ = possible_work_groups[best_work_group_index];
236     operation_->work_groups_count_ = GetWorkGroupsCount(
237         operation_->grid_dimension_, operation_->grid_size_,
238         operation_->work_group_size_, operation_->work_group_launch_order_);
239     return absl::OkStatus();
240   }
241 }
242 
243 }  // namespace cl
244 }  // namespace gpu
245 }  // namespace tflite
246