1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20 #include <string.h>
21
22 #include <any>
23 #include <limits>
24 #include <string>
25 #include <vector>
26
27 #include "fp16.h" // from @FP16
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/lite/c/builtin_op_data.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/context_util.h"
33 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
34 #include "tensorflow/lite/delegates/gpu/common/model.h"
35 #include "tensorflow/lite/delegates/gpu/common/operations.h"
36 #include "tensorflow/lite/delegates/gpu/common/shape.h"
37 #include "tensorflow/lite/delegates/gpu/common/status.h"
38 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
39 #include "tensorflow/lite/kernels/kernel_util.h"
40
41 namespace tflite {
42 namespace gpu {
43 namespace {
44
45 // Creates a node that consumes output from the given node. Because output need
46 // to stay the same, newly created node will inherit the output from the given
47 // node, which will in turn get newly created copy of output. This is necessary
48 // to preserve reference consistency if another node was pointing at that
49 // output:
50 // node(output)
51 // will turn into:
52 // node(copy(output)) <- passthrough_node(output)
NewPassthroughNode(GraphFloat32 * graph,Node * node,const Value * output,Node ** passthru_node)53 absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
54 const Value* output, Node** passthru_node) {
55 *passthru_node = graph->NewNode();
56 // Make copies for every output in the original node.
57 RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
58 Value* copy_output = graph->NewValue();
59 RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
60 RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
61 copy_output->tensor = output->tensor;
62 copy_output->tensor.ref = -1;
63 return absl::OkStatus();
64 }
65
66 } // namespace
67
GetNodeAndRegistration(TfLiteContext * context,int node_id,TfLiteNode ** tflite_node,TfLiteRegistration ** registration)68 absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
69 TfLiteNode** tflite_node,
70 TfLiteRegistration** registration) {
71 if (context->GetNodeAndRegistration(context, node_id, tflite_node,
72 registration) != kTfLiteOk) {
73 return absl::InvalidArgumentError(absl::StrCat(
74 "Couldn't get node and registration info for op: ", node_id));
75 }
76 return absl::OkStatus();
77 }
78
ToDataType(TfLiteType type)79 DataType ToDataType(TfLiteType type) {
80 switch (type) {
81 case kTfLiteFloat32:
82 return DataType::FLOAT32;
83 case kTfLiteInt32:
84 return DataType::INT32;
85 case kTfLiteInt64:
86 return DataType::INT64;
87 case kTfLiteInt8:
88 return DataType::INT8;
89 case kTfLiteUInt8:
90 return DataType::UINT8;
91 case kTfLiteBool:
92 return DataType::BOOL;
93 default:
94 return DataType::UNKNOWN;
95 }
96 }
97
ExtractTensorShape(const TfLiteTensor & tflite_tensor,BHWC * bhwc)98 absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
99 const TfLiteIntArray* dims = tflite_tensor.dims;
100 switch (dims->size) {
101 case 1:
102 // B layout
103 *bhwc = BHWC(dims->data[0], 1, 1, 1);
104 return absl::OkStatus();
105 case 2:
106 // BC layout
107 *bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
108 return absl::OkStatus();
109 case 3:
110 // BWC layout
111 *bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
112 return absl::OkStatus();
113 case 4:
114 // BHWC layout
115 *bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
116 return absl::OkStatus();
117 default:
118 return absl::InvalidArgumentError(absl::StrCat(
119 "Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr",
120 "\" has bad input dims size: ", dims->size, "."));
121 }
122 }
123
ExtractAxisFromIndex(const TfLiteTensor & tflite_tensor,int index,Axis * axis)124 absl::Status ExtractAxisFromIndex(const TfLiteTensor& tflite_tensor, int index,
125 Axis* axis) {
126 const TfLiteIntArray* dims = tflite_tensor.dims;
127 if (index < 0) {
128 index = dims->size + index;
129 }
130 if (index < 0 || index >= dims->size) {
131 return absl::OutOfRangeError("Index for axis out of range");
132 }
133 std::vector<Axis> index_to_axis;
134 switch (dims->size) {
135 case 1:
136 // B layout
137 index_to_axis = {Axis::BATCH};
138 break;
139 case 2:
140 // BC layout
141 index_to_axis = {Axis::BATCH, Axis::CHANNELS};
142 break;
143 case 3:
144 // BWC layout
145 index_to_axis = {Axis::BATCH, Axis::WIDTH, Axis::CHANNELS};
146 break;
147 case 4:
148 // BHWC layout
149 index_to_axis = {Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS};
150 break;
151 default:
152 return absl::UnavailableError("Unknown layout.");
153 }
154 *axis = index_to_axis[index];
155 return absl::OkStatus();
156 }
157
ConvertTfLiteTensorToTensorRef(const TfLiteTensor & tflite_tensor,TensorRef<BHWC> * tensor_ref)158 absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
159 TensorRef<BHWC>* tensor_ref) {
160 tensor_ref->type = ToDataType(tflite_tensor.type);
161 return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
162 }
163
PopulateQuantParams(const TfLiteTensor & tensor,QuantizationParams * quant_params)164 absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
165 QuantizationParams* quant_params) {
166 const TfLiteQuantization& quant = tensor.quantization;
167 if (quant.type != TfLiteQuantizationType::kTfLiteAffineQuantization) {
168 return absl::InvalidArgumentError(
169 absl::StrCat("Tensor not quantized: ", std::string(tensor.name)));
170 }
171 const TfLiteAffineQuantization* params =
172 static_cast<const TfLiteAffineQuantization*>(quant.params);
173 if (params->scale->size > 1) {
174 return absl::InvalidArgumentError(
175 absl::StrCat("Non-constant per-channel quantized tensor: ",
176 std::string(tensor.name)));
177 }
178 const float scale = params->scale->data[0];
179 const float zero_point = static_cast<float>(params->zero_point->data[0]);
180
181 float qmin_value = 0;
182 float qmax_value = 0;
183 if (tensor.type == kTfLiteUInt8) {
184 qmin_value = static_cast<float>(std::numeric_limits<uint8_t>::min());
185 qmax_value = static_cast<float>(std::numeric_limits<uint8_t>::max());
186 } else if (tensor.type == kTfLiteInt8) {
187 qmin_value = static_cast<float>(std::numeric_limits<int8_t>::min());
188 qmax_value = static_cast<float>(std::numeric_limits<int8_t>::max());
189 } else {
190 return absl::InvalidArgumentError(absl::StrCat(
191 "Type invalid for quantized tensor: ", std::string(tensor.name)));
192 }
193 quant_params->min = scale * (static_cast<float>(qmin_value) - zero_point);
194 quant_params->max = scale * (static_cast<float>(qmax_value) - zero_point);
195 quant_params->scale = scale;
196
197 return absl::OkStatus();
198 }
199
GetNumberOfRuntimeInputsForNode(const TfLiteContext * context,const TfLiteNode * tflite_node)200 int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
201 const TfLiteNode* tflite_node) {
202 int number_of_runtime_inputs = 0;
203 for (int i = 0; i < NumInputs(tflite_node); i++) {
204 const TfLiteTensor* tensor =
205 GetOptionalInputTensor(context, tflite_node, i);
206 if (tensor != nullptr && !IsConstantTensor(tensor)) {
207 number_of_runtime_inputs++;
208 }
209 }
210 return number_of_runtime_inputs;
211 }
212
GetNumberOfConstInputsForNode(const TfLiteContext * context,const TfLiteNode * tflite_node)213 int GetNumberOfConstInputsForNode(const TfLiteContext* context,
214 const TfLiteNode* tflite_node) {
215 return NumInputs(tflite_node) -
216 GetNumberOfRuntimeInputsForNode(context, tflite_node);
217 }
218
CheckInputsOutputs(const TfLiteContext * context,const TfLiteNode * tflite_node,int runtime_inputs,int outputs)219 absl::Status CheckInputsOutputs(const TfLiteContext* context,
220 const TfLiteNode* tflite_node,
221 int runtime_inputs, int outputs) {
222 const int runtime_inputs_from_model =
223 GetNumberOfRuntimeInputsForNode(context, tflite_node);
224 if (runtime_inputs_from_model != runtime_inputs) {
225 return absl::InternalError(absl::StrCat(
226 "Expected ", runtime_inputs, " runtime input tensor(s), but node has ",
227 runtime_inputs_from_model, " runtime input(s)."));
228 }
229 const int outputs_from_model = NumOutputs(tflite_node);
230 if (outputs_from_model != outputs) {
231 return absl::InternalError(absl::StrCat("Expected ", outputs,
232 " output tensor(s), but node has ",
233 outputs_from_model, " output(s)."));
234 }
235 return absl::OkStatus();
236 }
237
CheckInputsConstsOutputs(const TfLiteContext * context,const TfLiteNode * tflite_node,int runtime_inputs,int const_inputs,int outputs)238 absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
239 const TfLiteNode* tflite_node,
240 int runtime_inputs, int const_inputs,
241 int outputs) {
242 const int const_inputs_from_model =
243 GetNumberOfConstInputsForNode(context, tflite_node);
244 if (const_inputs_from_model != const_inputs) {
245 return absl::InternalError(absl::StrCat(
246 "Expected ", const_inputs, " const input tensor(s), but node has ",
247 const_inputs_from_model, " const input(s)."));
248 }
249 return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
250 }
251
ConvertFloat16ToFloat32(size_t num_elements,const uint16_t * src,float * dst)252 void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
253 float* dst) {
254 for (size_t i = 0; i < num_elements; i++) {
255 *dst++ = fp16_ieee_to_fp32_value(*src++);
256 }
257 }
258
259 template <>
CreateVectorCopyData(const TfLiteTensor & tensor,float * tensor_data)260 absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
261 float* tensor_data) {
262 switch (tensor.type) {
263 case kTfLiteFloat32:
264 std::memcpy(tensor_data, tensor.data.f, tensor.bytes);
265 break;
266 case kTfLiteFloat16:
267 ConvertFloat16ToFloat32(
268 NumElements(&tensor),
269 reinterpret_cast<uint16_t const*>(tensor.data.f16), tensor_data);
270 break;
271 case kTfLiteInt8:
272 DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data);
273 break;
274 case kTfLiteUInt8:
275 DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data);
276 break;
277 case kTfLiteInt32:
278 DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data);
279 break;
280 default:
281 return absl::InvalidArgumentError(
282 "Unsupported data type for float32 tensor");
283 }
284 return absl::OkStatus();
285 }
286
GetDimensionString(const TfLiteIntArray * dimensions)287 const std::string GetDimensionString(const TfLiteIntArray* dimensions) {
288 return absl::StrJoin(TfLiteIntArrayView(dimensions), "x");
289 }
290
SetAllDimensions(const TfLiteIntArray * dimensions,Scalar * shape)291 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
292 if (dimensions->size < 0) {
293 return absl::InvalidArgumentError("Invalid Scalar dimensions");
294 }
295 for (int i = 0; i < dimensions->size; ++i) {
296 if (dimensions->data[i] != 1) {
297 return absl::InvalidArgumentError(absl::StrCat(
298 GetDimensionString(dimensions), " cannot be reduced to scalar."));
299 }
300 }
301 shape->v = 1;
302 return absl::OkStatus();
303 }
304
CheckIfLinearConvertible(const TfLiteIntArray * dimensions)305 absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions) {
306 if (dimensions->size <= 0) {
307 return absl::InvalidArgumentError("Dimension is empty.");
308 }
309 for (int i = 0; i < dimensions->size - 1; ++i) {
310 if (dimensions->data[i] != 1) {
311 return absl::InvalidArgumentError(absl::StrCat(
312 GetDimensionString(dimensions), " cannot be reduced to linear."));
313 }
314 }
315 return absl::OkStatus();
316 }
317
SetAllDimensions(const TfLiteIntArray * dimensions,Linear * shape)318 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
319 RETURN_IF_ERROR(CheckIfLinearConvertible(dimensions));
320 shape->v = dimensions->data[dimensions->size - 1];
321 return absl::OkStatus();
322 }
323
SetAllDimensions(const TfLiteIntArray * dimensions,HWC * shape)324 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
325 if (dimensions->size == 3) {
326 shape->h = dimensions->data[0];
327 shape->w = dimensions->data[1];
328 shape->c = dimensions->data[2];
329 return absl::OkStatus();
330 }
331 if (dimensions->size == 4) {
332 if (dimensions->data[0] != 1) {
333 return absl::UnimplementedError("Batch size is not equal to 1.");
334 }
335 shape->h = dimensions->data[1];
336 shape->w = dimensions->data[2];
337 shape->c = dimensions->data[3];
338 return absl::OkStatus();
339 }
340 return absl::InvalidArgumentError(
341 absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of "
342 "shape 1xHxWxC but got ",
343 GetDimensionString(dimensions)));
344 }
345
SetAllDimensions(const TfLiteIntArray * dimensions,HW * shape)346 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
347 if (dimensions->size != 2) {
348 return absl::InvalidArgumentError(
349 absl::StrCat("Expected a 2D tensor of shape HxW but got ",
350 GetDimensionString(dimensions)));
351 }
352 shape->h = dimensions->data[0];
353 shape->w = dimensions->data[1];
354 return absl::OkStatus();
355 }
356
SetAllDimensions(const TfLiteIntArray * dimensions,OHWI * shape)357 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
358 if (dimensions->size != 4) {
359 return absl::InvalidArgumentError(
360 absl::StrCat("Expected a 4D tensor of shape OxHxWxI but got ",
361 GetDimensionString(dimensions)));
362 }
363 shape->o = dimensions->data[0];
364 shape->h = dimensions->data[1];
365 shape->w = dimensions->data[2];
366 shape->i = dimensions->data[3];
367 return absl::OkStatus();
368 }
369
SetAllDimensions(const TfLiteIntArray * dimensions,BHWC * shape)370 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
371 if (dimensions->size != 4) {
372 return absl::InvalidArgumentError(
373 absl::StrCat("Expected a 4D tensor of shape BxHxWxC but got ",
374 GetDimensionString(dimensions)));
375 }
376 shape->b = dimensions->data[0];
377 shape->h = dimensions->data[1];
378 shape->w = dimensions->data[2];
379 shape->c = dimensions->data[3];
380 return absl::OkStatus();
381 }
382
383 // If there is fused activation present, then there will be another node created
384 // that will have identical output as the given node. New operation node will
385 // depend on the given node output.
MaybeFuseActivation(TfLiteFusedActivation fused_activation,GraphFloat32 * graph,Node * node)386 absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation,
387 GraphFloat32* graph, Node* node) {
388 const auto outputs = graph->FindOutputs(node->id);
389 if (outputs.size() != 1) {
390 return absl::InternalError("Number of outputs != 1");
391 }
392 switch (fused_activation) {
393 case kTfLiteActNone:
394 // Nothing to do here
395 return absl::OkStatus();
396 case kTfLiteActRelu:
397 case kTfLiteActReluN1To1:
398 case kTfLiteActRelu6: {
399 ReLUAttributes attr;
400 attr.clip = fused_activation == kTfLiteActRelu
401 ? 0.0f
402 : (fused_activation == kTfLiteActReluN1To1 ? 1.0f : 6.0f);
403 Node* activation_node;
404 RETURN_IF_ERROR(
405 NewPassthroughNode(graph, node, outputs[0], &activation_node));
406 activation_node->operation.type = ToString(OperationType::RELU);
407 activation_node->operation.attributes = attr;
408 return absl::OkStatus();
409 }
410 case kTfLiteActTanh: {
411 Node* activation_node;
412 RETURN_IF_ERROR(
413 NewPassthroughNode(graph, node, outputs[0], &activation_node));
414 activation_node->operation.type = ToString(OperationType::TANH);
415 return absl::OkStatus();
416 }
417 case kTfLiteActSigmoid: {
418 Node* activation_node;
419 RETURN_IF_ERROR(
420 NewPassthroughNode(graph, node, outputs[0], &activation_node));
421 activation_node->operation.type = ToString(OperationType::SIGMOID);
422 return absl::OkStatus();
423 } break;
424 default:
425 return absl::NotFoundError(
426 absl::StrCat("Unsupported fused activation: ", fused_activation));
427 }
428 }
429
430 } // namespace gpu
431 } // namespace tflite
432