• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17 
18 #include <string>
19 
20 namespace tflite {
21 namespace gpu {
22 namespace cl {
23 namespace {
GetCommonOpenCLDefines(CalculationsPrecision precision)24 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
25   std::string result;
26 
27   result += "#define FLT16_0123(V) V.s0123\n";
28   result += "#define FLT16_4567(V) V.s4567\n";
29   result += "#define FLT16_89ab(V) V.s89ab\n";
30   result += "#define FLT16_cdef(V) V.scdef\n";
31   result += "#define GLOBAL_ID_0 get_global_id(0)\n";
32   result += "#define GLOBAL_ID_1 get_global_id(1)\n";
33   result += "#define GLOBAL_ID_2 get_global_id(2)\n";
34   result += "#define LOCAL_ID_0 get_local_id(0)\n";
35   result += "#define LOCAL_ID_1 get_local_id(1)\n";
36   result += "#define LOCAL_ID_2 get_local_id(2)\n";
37   result += "#define GROUP_ID_0 get_group_id(0)\n";
38   result += "#define GROUP_ID_1 get_group_id(1)\n";
39   result += "#define GROUP_ID_2 get_group_id(2)\n";
40   result += "#define GROUP_SIZE_0 get_local_size(0)\n";
41   result += "#define GROUP_SIZE_1 get_local_size(1)\n";
42   result += "#define GROUP_SIZE_2 get_local_size(2)\n";
43   result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
44   result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
45   result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
46   result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
47   result += "#define MAIN_FUNCTION __kernel void main_function\n";
48   result += "#define INIT_FLOAT(value) (float)(value)\n";
49   result += "#define INIT_FLOAT2(value) (float2)(value)\n";
50   result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
51   result += "#define INIT_FLOAT3(value) (float3)(value)\n";
52   result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
53   result += "#define INIT_FLOAT4(value) (float4)(value)\n";
54   result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
55   result += "#define INIT_INT(value) (int)(value)\n";
56   result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
57   result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
58   result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
59   switch (precision) {
60     case CalculationsPrecision::F32:
61       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
62       result += "#define ACCUM_FLT4 float4\n";
63       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
64       result += "#define FLT float\n";
65       result += "#define FLT2 float2\n";
66       result += "#define FLT3 float3\n";
67       result += "#define FLT4 float4\n";
68       result += "#define TO_FLT4 convert_float4\n";
69       result += "#define TO_ACCUM_TYPE convert_float4\n";
70       result += "#define TO_ACCUM_FLT convert_float\n";
71       result += "#define TO_ACCUM_FLT2 convert_float2\n";
72       result += "#define TO_ACCUM_FLT3 convert_float3\n";
73       result += "#define TO_ACCUM_FLT4 convert_float4\n";
74       result += "#define INIT_FLT(value) (float)(value)\n";
75       result += "#define INIT_FLT4(value) (float4)(value)\n";
76       result +=
77           "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
78       break;
79     case CalculationsPrecision::F16:
80       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
81       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
82       result += "#define ACCUM_FLT4 half4\n";
83       result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
84       result += "#define FLT half\n";
85       result += "#define FLT2 half2\n";
86       result += "#define FLT3 half3\n";
87       result += "#define FLT4 half4\n";
88       result += "#define TO_FLT4 convert_half4\n";
89       result += "#define TO_ACCUM_TYPE convert_half4\n";
90       result += "#define TO_ACCUM_FLT convert_half\n";
91       result += "#define TO_ACCUM_FLT2 convert_half2\n";
92       result += "#define TO_ACCUM_FLT3 convert_half3\n";
93       result += "#define TO_ACCUM_FLT4 convert_half4\n";
94       result += "#define INIT_FLT(value) (half)(value)\n";
95       result += "#define INIT_FLT4(value) (half4)(value)\n";
96       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
97       break;
98     case CalculationsPrecision::F32_F16:
99       result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
100       result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
101       result += "#define ACCUM_FLT4 float4\n";
102       result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
103       result += "#define FLT half\n";
104       result += "#define FLT2 half2\n";
105       result += "#define FLT3 half3\n";
106       result += "#define FLT4 half4\n";
107       result += "#define TO_FLT4 convert_half4\n";
108       result += "#define TO_ACCUM_TYPE convert_float4\n";
109       result += "#define TO_ACCUM_FLT convert_float\n";
110       result += "#define TO_ACCUM_FLT2 convert_float2\n";
111       result += "#define TO_ACCUM_FLT3 convert_float3\n";
112       result += "#define TO_ACCUM_FLT4 convert_float4\n";
113       result += "#define INIT_FLT(value) (half)(value)\n";
114       result += "#define INIT_FLT4(value) (half4)(value)\n";
115       result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
116       break;
117   }
118   result += "#define bool2 uchar2\n";
119   result += "#define bool3 uchar3\n";
120   result += "#define bool4 uchar4\n";
121   return result;
122 }
123 }  // namespace
124 
UpdateParams()125 absl::Status ClOperation::UpdateParams() {
126   for (int i = 0; i < operation_->GetSrcTensorsNames().size(); ++i) {
127     const auto* cl_spatial_tensor =
128         dynamic_cast<const Tensor*>(operation_->GetSrcTensors()[i]);
129     if (!cl_spatial_tensor) {
130       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
131     }
132     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->GetSrcTensorsNames()[i],
133                                           cl_spatial_tensor));
134   }
135   for (int i = 0; i < operation_->GetDstTensorsNames().size(); ++i) {
136     const auto* cl_spatial_tensor =
137         dynamic_cast<const Tensor*>(operation_->GetDstTensors()[i]);
138     if (!cl_spatial_tensor) {
139       return absl::InvalidArgumentError("Expected CLSpatialTensor.");
140     }
141     RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->GetDstTensorsNames()[i],
142                                           cl_spatial_tensor));
143   }
144   RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
145   operation_->RecalculateGridSize();
146   operation_->RecalculateWorkGroupsCount();
147   return absl::OkStatus();
148 }
149 
SetSrcTensor(int index,Tensor * tensor)150 absl::Status ClOperation::SetSrcTensor(int index, Tensor* tensor) {
151   operation_->SetSrc(tensor, index);
152   return cl_args_.SetObjectRef(operation_->GetSrcTensorsNames()[index], tensor);
153 }
154 
SetDstTensor(int index,Tensor * tensor)155 absl::Status ClOperation::SetDstTensor(int index, Tensor* tensor) {
156   operation_->SetDst(tensor, index);
157   return cl_args_.SetObjectRef(operation_->GetDstTensorsNames()[index], tensor);
158 }
159 
Compile(const CreationContext & creation_context)160 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
161   operation_->code_ =
162       GetCommonOpenCLDefines(operation_->GetDefinition().precision) +
163       operation_->code_;
164   RETURN_IF_ERROR(cl_args_.Init(
165       creation_context.GetGpuInfo(),
166       creation_context.context, &operation_->args_, &operation_->code_));
167   operation_->args_.ReleaseCPURepresentation();
168   RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
169       operation_->code_, "main_function", operation_->compiler_options_,
170       *creation_context.context, *creation_context.device, &kernel_,
171       &kernel_fingerprint_));
172   return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
173                                       kernel_.info_);
174 }
175 
RestoreDeserialized(const ProgramCache & program_cache,uint64_t fingerprint,const GpuInfo & gpu_info,const int3 & work_group_size,CLContext * context)176 absl::Status ClOperation::RestoreDeserialized(const ProgramCache& program_cache,
177                                               uint64_t fingerprint,
178                                               const GpuInfo& gpu_info,
179                                               const int3& work_group_size,
180                                               CLContext* context) {
181   kernel_fingerprint_ = fingerprint;
182   RETURN_IF_ERROR(
183       program_cache.GetKernel(kernel_fingerprint_, "main_function", &kernel_));
184   operation_->work_group_size_ = work_group_size;
185   operation_->RecalculateWorkGroupsCount();
186   RETURN_IF_ERROR(cl_args_.Init(gpu_info, &operation_->args_, context));
187   operation_->args_.ReleaseCPURepresentation();
188   return absl::OkStatus();
189 }
190 
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)191 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
192                                ProfilingCommandQueue* profiling_queue) {
193   std::vector<GPUOperation::DispatchInfo> possible_dispatches;
194   operation_->GetPossibleDispatches(tuning_type, gpu_info, kernel_.info_,
195                                     &possible_dispatches);
196   if (possible_dispatches.empty()) {
197     return absl::NotFoundError("No dispatch parameters to launch kernel");
198   }
199   if (possible_dispatches.size() == 1) {
200     operation_->work_group_size_ = possible_dispatches[0].work_group_size;
201     operation_->RecalculateWorkGroupsCount();
202     return absl::OkStatus();
203   } else {
204     std::vector<int3> work_group_sizes(possible_dispatches.size());
205     std::vector<int3> work_groups_counts(possible_dispatches.size());
206     for (int i = 0; i < possible_dispatches.size(); ++i) {
207       work_group_sizes[i] = possible_dispatches[i].work_group_size;
208       work_groups_counts[i] = possible_dispatches[i].work_groups_count;
209     }
210     RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
211     int best_work_group_index;
212     RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
213         kernel_, gpu_info, work_groups_counts, work_group_sizes,
214         &best_work_group_index));
215     operation_->work_group_size_ = work_group_sizes[best_work_group_index];
216     operation_->RecalculateWorkGroupsCount();
217     return absl::OkStatus();
218   }
219 }
220 
221 }  // namespace cl
222 }  // namespace gpu
223 }  // namespace tflite
224