1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/cl_operation.h"
17
18 namespace tflite {
19 namespace gpu {
20 namespace cl {
21 namespace {
GetWorkGroupsCount(int grid_dimension,const int3 & grid_size,const int3 & work_group_size,const int3 & work_group_launch_order)22 int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size,
23 const int3& work_group_size,
24 const int3& work_group_launch_order) {
25 int3 work_groups_count;
26 if (grid_dimension == 1) {
27 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
28 work_groups_count.y = 1;
29 work_groups_count.z = 1;
30 } else if (grid_dimension == 2) {
31 int3 wgs;
32 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
33 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
34 work_groups_count.x = wgs[work_group_launch_order[0]];
35 work_groups_count.y = wgs[work_group_launch_order[1]];
36 work_groups_count.z = 1;
37 } else { // grid_dimension == 3
38 int3 wgs;
39 wgs.x = DivideRoundUp(grid_size.x, work_group_size.x);
40 wgs.y = DivideRoundUp(grid_size.y, work_group_size.y);
41 wgs.z = DivideRoundUp(grid_size.z, work_group_size.z);
42 work_groups_count.x = wgs[work_group_launch_order[0]];
43 work_groups_count.y = wgs[work_group_launch_order[1]];
44 work_groups_count.z = wgs[work_group_launch_order[2]];
45 }
46 return work_groups_count;
47 }
48
GetCommonOpenCLDefines(CalculationsPrecision precision)49 std::string GetCommonOpenCLDefines(CalculationsPrecision precision) {
50 std::string result;
51
52 result += "#define FLT16_0123(V) V.s0123\n";
53 result += "#define FLT16_4567(V) V.s4567\n";
54 result += "#define FLT16_89ab(V) V.s89ab\n";
55 result += "#define FLT16_cdef(V) V.scdef\n";
56 result += "#define GLOBAL_ID_0 get_global_id(0)\n";
57 result += "#define GLOBAL_ID_1 get_global_id(1)\n";
58 result += "#define GLOBAL_ID_2 get_global_id(2)\n";
59 result += "#define LOCAL_ID_0 get_local_id(0)\n";
60 result += "#define LOCAL_ID_1 get_local_id(1)\n";
61 result += "#define LOCAL_ID_2 get_local_id(2)\n";
62 result += "#define GROUP_ID_0 get_group_id(0)\n";
63 result += "#define GROUP_ID_1 get_group_id(1)\n";
64 result += "#define GROUP_ID_2 get_group_id(2)\n";
65 result += "#define GROUP_SIZE_0 get_local_size(0)\n";
66 result += "#define GROUP_SIZE_1 get_local_size(1)\n";
67 result += "#define GROUP_SIZE_2 get_local_size(2)\n";
68 result += "#define SUB_GROUP_LOCAL_ID get_sub_group_local_id()\n";
69 result += "#define SUB_GROUP_BROADCAST(V, ID) sub_group_broadcast(V, ID)\n";
70 result += "#define SIMD_LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
71 result += "#define LOCAL_MEM_BARRIER barrier(CLK_LOCAL_MEM_FENCE)\n";
72 result += "#define MAIN_FUNCTION __kernel void main_function\n";
73 result += "#define INIT_FLOAT(value) (float)(value)\n";
74 result += "#define INIT_FLOAT2(value) (float2)(value)\n";
75 result += "#define INIT_FLOAT2v2(v0, v1) (float2)(v0, v1)\n";
76 result += "#define INIT_FLOAT3(value) (float3)(value)\n";
77 result += "#define INIT_FLOAT3v3(v0, v1, v2) (float3)(v0, v1, v2)\n";
78 result += "#define INIT_FLOAT4(value) (float4)(value)\n";
79 result += "#define INIT_FLOAT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
80 result += "#define INIT_INT(value) (int)(value)\n";
81 result += "#define INIT_INT2v2(v0, v1) (int2)(v0, v1)\n";
82 result += "#define INIT_INT4v4(v0, v1, v2, v3) (int4)(v0, v1, v2, v3)\n";
83 result += "#define CONVERT_TO_INT4(value) convert_int4(value)\n";
84 switch (precision) {
85 case CalculationsPrecision::F32:
86 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
87 result += "#define ACCUM_FLT4 float4\n";
88 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
89 result += "#define FLT float\n";
90 result += "#define FLT2 float2\n";
91 result += "#define FLT3 float3\n";
92 result += "#define FLT4 float4\n";
93 result += "#define TO_FLT4 convert_float4\n";
94 result += "#define TO_ACCUM_TYPE convert_float4\n";
95 result += "#define TO_ACCUM_FLT convert_float\n";
96 result += "#define INIT_FLT(value) (float)(value)\n";
97 result += "#define INIT_FLT4(value) (float4)(value)\n";
98 result +=
99 "#define INIT_FLT4v4(v0, v1, v2, v3) (float4)(v0, v1, v2, v3)\n";
100 break;
101 case CalculationsPrecision::F16:
102 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
103 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
104 result += "#define ACCUM_FLT4 half4\n";
105 result += "#define INIT_ACCUM_FLT4(value) (half4)(value)\n";
106 result += "#define FLT half\n";
107 result += "#define FLT2 half2\n";
108 result += "#define FLT3 half3\n";
109 result += "#define FLT4 half4\n";
110 result += "#define TO_FLT4 convert_half4\n";
111 result += "#define TO_ACCUM_TYPE convert_half4\n";
112 result += "#define TO_ACCUM_FLT convert_half\n";
113 result += "#define INIT_FLT(value) (half)(value)\n";
114 result += "#define INIT_FLT4(value) (half4)(value)\n";
115 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
116 break;
117 case CalculationsPrecision::F32_F16:
118 result += "#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable\n";
119 result += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n";
120 result += "#define ACCUM_FLT4 float4\n";
121 result += "#define INIT_ACCUM_FLT4(value) (float4)(value)\n";
122 result += "#define FLT half\n";
123 result += "#define FLT2 half2\n";
124 result += "#define FLT3 half3\n";
125 result += "#define FLT4 half4\n";
126 result += "#define TO_FLT4 convert_half4\n";
127 result += "#define TO_ACCUM_TYPE convert_float4\n";
128 result += "#define TO_ACCUM_FLT convert_float\n";
129 result += "#define INIT_FLT(value) (half)(value)\n";
130 result += "#define INIT_FLT4(value) (half4)(value)\n";
131 result += "#define INIT_FLT4v4(v0, v1, v2, v3) (half4)(v0, v1, v2, v3)\n";
132 break;
133 }
134 return result;
135 }
136 } // namespace
137
ClOperation(ClOperation && operation)138 ClOperation::ClOperation(ClOperation&& operation)
139 : operation_(std::move(operation.operation_)),
140 kernel_(std::move(operation.kernel_)),
141 cl_args_(std::move(operation.cl_args_)) {}
142
operator =(ClOperation && operation)143 ClOperation& ClOperation::operator=(ClOperation&& operation) {
144 if (this != &operation) {
145 operation_ = std::move(operation.operation_);
146 kernel_ = std::move(operation.kernel_);
147 cl_args_ = std::move(operation.cl_args_);
148 }
149 return *this;
150 }
151
AddOperation(ClOperation * operation)152 absl::Status ClOperation::AddOperation(ClOperation* operation) {
153 return operation_->AddOperation(operation->operation_.get());
154 }
155
UpdateParams()156 absl::Status ClOperation::UpdateParams() {
157 for (int i = 0; i < operation_->src_tensors_names_.size(); ++i) {
158 const auto* cl_spatial_tensor =
159 dynamic_cast<const Tensor*>(operation_->src_[i]);
160 if (!cl_spatial_tensor) {
161 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
162 }
163 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->src_tensors_names_[i],
164 cl_spatial_tensor));
165 }
166 for (int i = 0; i < operation_->dst_tensors_names_.size(); ++i) {
167 const auto* cl_spatial_tensor =
168 dynamic_cast<const Tensor*>(operation_->dst_[i]);
169 if (!cl_spatial_tensor) {
170 return absl::InvalidArgumentError("Expected CLSpatialTensor.");
171 }
172 RETURN_IF_ERROR(cl_args_.SetObjectRef(operation_->dst_tensors_names_[i],
173 cl_spatial_tensor));
174 }
175 RETURN_IF_ERROR(operation_->BindArguments(&cl_args_));
176 operation_->grid_size_ = operation_->GetGridSize();
177 operation_->work_groups_count_ = GetWorkGroupsCount(
178 operation_->grid_dimension_, operation_->grid_size_,
179 operation_->work_group_size_, operation_->work_group_launch_order_);
180 return absl::OkStatus();
181 }
182
Compile(const CreationContext & creation_context)183 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
184 operation_->AssembleCode(creation_context.GetGpuInfo());
185 operation_->code_ =
186 GetCommonOpenCLDefines(operation_->definition_.precision) +
187 operation_->code_;
188 RETURN_IF_ERROR(cl_args_.Init(
189 creation_context.GetGpuInfo(),
190 {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
191 creation_context.context, &operation_->args_, &operation_->code_));
192 RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
193 operation_->code_, "main_function", operation_->compiler_options_,
194 *creation_context.context, *creation_context.device, &kernel_));
195 return operation_->PostCompileCheck(creation_context.GetGpuInfo(),
196 kernel_.info_);
197 }
198
CompileDeserialized(const CreationContext & creation_context)199 absl::Status ClOperation::CompileDeserialized(
200 const CreationContext& creation_context) {
201 RETURN_IF_ERROR(cl_args_.Init(creation_context.GetGpuInfo(),
202 &operation_->args_, creation_context.context));
203 return creation_context.cache->GetOrCreateCLKernel(
204 operation_->code_, "main_function", operation_->compiler_options_,
205 *creation_context.context, *creation_context.device, &kernel_);
206 }
207
Tune(TuningType tuning_type,const GpuInfo & gpu_info,ProfilingCommandQueue * profiling_queue)208 absl::Status ClOperation::Tune(TuningType tuning_type, const GpuInfo& gpu_info,
209 ProfilingCommandQueue* profiling_queue) {
210 std::vector<int3> possible_work_groups;
211 operation_->GetPossibleKernelWorkGroups(tuning_type, gpu_info, kernel_.info_,
212 &possible_work_groups);
213 if (possible_work_groups.empty()) {
214 return absl::NotFoundError(
215 "Can not found work_group size to launch kernel");
216 }
217 if (possible_work_groups.size() == 1) {
218 operation_->work_group_size_ = possible_work_groups[0];
219 operation_->work_groups_count_ = GetWorkGroupsCount(
220 operation_->grid_dimension_, operation_->grid_size_,
221 operation_->work_group_size_, operation_->work_group_launch_order_);
222 return absl::OkStatus();
223 } else {
224 std::vector<int3> work_groups_count(possible_work_groups.size());
225 for (int i = 0; i < work_groups_count.size(); ++i) {
226 work_groups_count[i] = GetWorkGroupsCount(
227 operation_->grid_dimension_, operation_->grid_size_,
228 possible_work_groups[i], operation_->work_group_launch_order_);
229 }
230 RETURN_IF_ERROR(cl_args_.Bind(kernel_.kernel()));
231 int best_work_group_index;
232 RETURN_IF_ERROR(profiling_queue->GetBestWorkGroupIndex(
233 kernel_, gpu_info, work_groups_count, possible_work_groups,
234 &best_work_group_index));
235 operation_->work_group_size_ = possible_work_groups[best_work_group_index];
236 operation_->work_groups_count_ = GetWorkGroupsCount(
237 operation_->grid_dimension_, operation_->grid_size_,
238 operation_->work_group_size_, operation_->work_group_launch_order_);
239 return absl::OkStatus();
240 }
241 }
242
243 } // namespace cl
244 } // namespace gpu
245 } // namespace tflite
246