• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 #include <string.h>
21 
22 #include <any>
23 #include <limits>
24 #include <string>
25 #include <vector>
26 
27 #include "fp16.h"  // from @FP16
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/lite/c/builtin_op_data.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/context_util.h"
33 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
34 #include "tensorflow/lite/delegates/gpu/common/model.h"
35 #include "tensorflow/lite/delegates/gpu/common/operations.h"
36 #include "tensorflow/lite/delegates/gpu/common/shape.h"
37 #include "tensorflow/lite/delegates/gpu/common/status.h"
38 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
39 #include "tensorflow/lite/kernels/kernel_util.h"
40 
41 namespace tflite {
42 namespace gpu {
43 namespace {
44 
45 // Creates a node that consumes output from the given node. Because output need
46 // to stay the same, newly created node will inherit the output from the given
47 // node, which will in turn get newly created copy of output. This is necessary
48 // to preserve reference consistency if another node was pointing at that
49 // output:
50 //   node(output)
51 // will turn into:
52 //   node(copy(output)) <- passthrough_node(output)
NewPassthroughNode(GraphFloat32 * graph,Node * node,const Value * output,Node ** passthru_node)53 absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
54                                 const Value* output, Node** passthru_node) {
55   *passthru_node = graph->NewNode();
56   // Make copies for every output in the original node.
57   RETURN_IF_ERROR(graph->SetProducer((*passthru_node)->id, output->id));
58   Value* copy_output = graph->NewValue();
59   RETURN_IF_ERROR(graph->SetProducer(node->id, copy_output->id));
60   RETURN_IF_ERROR(graph->AddConsumer((*passthru_node)->id, copy_output->id));
61   copy_output->tensor = output->tensor;
62   copy_output->tensor.ref = -1;
63   return absl::OkStatus();
64 }
65 
66 }  // namespace
67 
GetNodeAndRegistration(TfLiteContext * context,int node_id,TfLiteNode ** tflite_node,TfLiteRegistration ** registration)68 absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
69                                     TfLiteNode** tflite_node,
70                                     TfLiteRegistration** registration) {
71   if (context->GetNodeAndRegistration(context, node_id, tflite_node,
72                                       registration) != kTfLiteOk) {
73     return absl::InvalidArgumentError(absl::StrCat(
74         "Couldn't get node and registration info for op: ", node_id));
75   }
76   return absl::OkStatus();
77 }
78 
ToDataType(TfLiteType type)79 DataType ToDataType(TfLiteType type) {
80   switch (type) {
81     case kTfLiteFloat32:
82       return DataType::FLOAT32;
83     case kTfLiteInt32:
84       return DataType::INT32;
85     case kTfLiteInt64:
86       return DataType::INT64;
87     case kTfLiteInt8:
88       return DataType::INT8;
89     case kTfLiteUInt8:
90       return DataType::UINT8;
91     case kTfLiteBool:
92       return DataType::BOOL;
93     default:
94       return DataType::UNKNOWN;
95   }
96 }
97 
ExtractTensorShape(const TfLiteTensor & tflite_tensor,BHWC * bhwc)98 absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
99   const TfLiteIntArray* dims = tflite_tensor.dims;
100   switch (dims->size) {
101     case 1:
102       // B layout
103       *bhwc = BHWC(dims->data[0], 1, 1, 1);
104       return absl::OkStatus();
105     case 2:
106       // BC layout
107       *bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
108       return absl::OkStatus();
109     case 3:
110       // BWC layout
111       *bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
112       return absl::OkStatus();
113     case 4:
114       // BHWC layout
115       *bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
116       return absl::OkStatus();
117     default:
118       return absl::InvalidArgumentError(absl::StrCat(
119           "Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr",
120           "\" has bad input dims size: ", dims->size, "."));
121   }
122 }
123 
ExtractAxisFromIndex(const TfLiteTensor & tflite_tensor,int index,Axis * axis)124 absl::Status ExtractAxisFromIndex(const TfLiteTensor& tflite_tensor, int index,
125                                   Axis* axis) {
126   const TfLiteIntArray* dims = tflite_tensor.dims;
127   if (index < 0) {
128     index = dims->size + index;
129   }
130   if (index < 0 || index >= dims->size) {
131     return absl::OutOfRangeError("Index for axis out of range");
132   }
133   std::vector<Axis> index_to_axis;
134   switch (dims->size) {
135     case 1:
136       // B layout
137       index_to_axis = {Axis::BATCH};
138       break;
139     case 2:
140       // BC layout
141       index_to_axis = {Axis::BATCH, Axis::CHANNELS};
142       break;
143     case 3:
144       // BWC layout
145       index_to_axis = {Axis::BATCH, Axis::WIDTH, Axis::CHANNELS};
146       break;
147     case 4:
148       // BHWC layout
149       index_to_axis = {Axis::BATCH, Axis::HEIGHT, Axis::WIDTH, Axis::CHANNELS};
150       break;
151     default:
152       return absl::UnavailableError("Unknown layout.");
153   }
154   *axis = index_to_axis[index];
155   return absl::OkStatus();
156 }
157 
ConvertTfLiteTensorToTensorRef(const TfLiteTensor & tflite_tensor,TensorRef<BHWC> * tensor_ref)158 absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
159                                             TensorRef<BHWC>* tensor_ref) {
160   tensor_ref->type = ToDataType(tflite_tensor.type);
161   return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
162 }
163 
PopulateQuantParams(const TfLiteTensor & tensor,QuantizationParams * quant_params)164 absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
165                                  QuantizationParams* quant_params) {
166   const TfLiteQuantization& quant = tensor.quantization;
167   if (quant.type != TfLiteQuantizationType::kTfLiteAffineQuantization) {
168     return absl::InvalidArgumentError(
169         absl::StrCat("Tensor not quantized: ", std::string(tensor.name)));
170   }
171   const TfLiteAffineQuantization* params =
172       static_cast<const TfLiteAffineQuantization*>(quant.params);
173   if (params->scale->size > 1) {
174     return absl::InvalidArgumentError(
175         absl::StrCat("Non-constant per-channel quantized tensor: ",
176                      std::string(tensor.name)));
177   }
178   const float scale = params->scale->data[0];
179   const float zero_point = static_cast<float>(params->zero_point->data[0]);
180 
181   float qmin_value = 0;
182   float qmax_value = 0;
183   if (tensor.type == kTfLiteUInt8) {
184     qmin_value = static_cast<float>(std::numeric_limits<uint8_t>::min());
185     qmax_value = static_cast<float>(std::numeric_limits<uint8_t>::max());
186   } else if (tensor.type == kTfLiteInt8) {
187     qmin_value = static_cast<float>(std::numeric_limits<int8_t>::min());
188     qmax_value = static_cast<float>(std::numeric_limits<int8_t>::max());
189   } else {
190     return absl::InvalidArgumentError(absl::StrCat(
191         "Type invalid for quantized tensor: ", std::string(tensor.name)));
192   }
193   quant_params->min = scale * (static_cast<float>(qmin_value) - zero_point);
194   quant_params->max = scale * (static_cast<float>(qmax_value) - zero_point);
195   quant_params->scale = scale;
196 
197   return absl::OkStatus();
198 }
199 
GetNumberOfRuntimeInputsForNode(const TfLiteContext * context,const TfLiteNode * tflite_node)200 int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
201                                     const TfLiteNode* tflite_node) {
202   int number_of_runtime_inputs = 0;
203   for (int i = 0; i < NumInputs(tflite_node); i++) {
204     const TfLiteTensor* tensor =
205         GetOptionalInputTensor(context, tflite_node, i);
206     if (tensor != nullptr && !IsConstantTensor(tensor)) {
207       number_of_runtime_inputs++;
208     }
209   }
210   return number_of_runtime_inputs;
211 }
212 
GetNumberOfConstInputsForNode(const TfLiteContext * context,const TfLiteNode * tflite_node)213 int GetNumberOfConstInputsForNode(const TfLiteContext* context,
214                                   const TfLiteNode* tflite_node) {
215   return NumInputs(tflite_node) -
216          GetNumberOfRuntimeInputsForNode(context, tflite_node);
217 }
218 
CheckInputsOutputs(const TfLiteContext * context,const TfLiteNode * tflite_node,int runtime_inputs,int outputs)219 absl::Status CheckInputsOutputs(const TfLiteContext* context,
220                                 const TfLiteNode* tflite_node,
221                                 int runtime_inputs, int outputs) {
222   const int runtime_inputs_from_model =
223       GetNumberOfRuntimeInputsForNode(context, tflite_node);
224   if (runtime_inputs_from_model != runtime_inputs) {
225     return absl::InternalError(absl::StrCat(
226         "Expected ", runtime_inputs, " runtime input tensor(s), but node has ",
227         runtime_inputs_from_model, " runtime input(s)."));
228   }
229   const int outputs_from_model = NumOutputs(tflite_node);
230   if (outputs_from_model != outputs) {
231     return absl::InternalError(absl::StrCat("Expected ", outputs,
232                                             " output tensor(s), but node has ",
233                                             outputs_from_model, " output(s)."));
234   }
235   return absl::OkStatus();
236 }
237 
CheckInputsConstsOutputs(const TfLiteContext * context,const TfLiteNode * tflite_node,int runtime_inputs,int const_inputs,int outputs)238 absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
239                                       const TfLiteNode* tflite_node,
240                                       int runtime_inputs, int const_inputs,
241                                       int outputs) {
242   const int const_inputs_from_model =
243       GetNumberOfConstInputsForNode(context, tflite_node);
244   if (const_inputs_from_model != const_inputs) {
245     return absl::InternalError(absl::StrCat(
246         "Expected ", const_inputs, " const input tensor(s), but node has ",
247         const_inputs_from_model, " const input(s)."));
248   }
249   return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
250 }
251 
ConvertFloat16ToFloat32(size_t num_elements,const uint16_t * src,float * dst)252 void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
253                              float* dst) {
254   for (size_t i = 0; i < num_elements; i++) {
255     *dst++ = fp16_ieee_to_fp32_value(*src++);
256   }
257 }
258 
259 template <>
CreateVectorCopyData(const TfLiteTensor & tensor,float * tensor_data)260 absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
261                                          float* tensor_data) {
262   switch (tensor.type) {
263     case kTfLiteFloat32:
264       std::memcpy(tensor_data, tensor.data.f, tensor.bytes);
265       break;
266     case kTfLiteFloat16:
267       ConvertFloat16ToFloat32(
268           NumElements(&tensor),
269           reinterpret_cast<uint16_t const*>(tensor.data.f16), tensor_data);
270       break;
271     case kTfLiteInt8:
272       DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data);
273       break;
274     case kTfLiteUInt8:
275       DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data);
276       break;
277     case kTfLiteInt32:
278       DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data);
279       break;
280     default:
281       return absl::InvalidArgumentError(
282           "Unsupported data type for float32 tensor");
283   }
284   return absl::OkStatus();
285 }
286 
GetDimensionString(const TfLiteIntArray * dimensions)287 const std::string GetDimensionString(const TfLiteIntArray* dimensions) {
288   return absl::StrJoin(TfLiteIntArrayView(dimensions), "x");
289 }
290 
SetAllDimensions(const TfLiteIntArray * dimensions,Scalar * shape)291 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
292   if (dimensions->size < 0) {
293     return absl::InvalidArgumentError("Invalid Scalar dimensions");
294   }
295   for (int i = 0; i < dimensions->size; ++i) {
296     if (dimensions->data[i] != 1) {
297       return absl::InvalidArgumentError(absl::StrCat(
298           GetDimensionString(dimensions), "  cannot be reduced to scalar."));
299     }
300   }
301   shape->v = 1;
302   return absl::OkStatus();
303 }
304 
CheckIfLinearConvertible(const TfLiteIntArray * dimensions)305 absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions) {
306   if (dimensions->size <= 0) {
307     return absl::InvalidArgumentError("Dimension is empty.");
308   }
309   for (int i = 0; i < dimensions->size - 1; ++i) {
310     if (dimensions->data[i] != 1) {
311       return absl::InvalidArgumentError(absl::StrCat(
312           GetDimensionString(dimensions), "  cannot be reduced to linear."));
313     }
314   }
315   return absl::OkStatus();
316 }
317 
SetAllDimensions(const TfLiteIntArray * dimensions,Linear * shape)318 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
319   RETURN_IF_ERROR(CheckIfLinearConvertible(dimensions));
320   shape->v = dimensions->data[dimensions->size - 1];
321   return absl::OkStatus();
322 }
323 
SetAllDimensions(const TfLiteIntArray * dimensions,HWC * shape)324 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
325   if (dimensions->size == 3) {
326     shape->h = dimensions->data[0];
327     shape->w = dimensions->data[1];
328     shape->c = dimensions->data[2];
329     return absl::OkStatus();
330   }
331   if (dimensions->size == 4) {
332     if (dimensions->data[0] != 1) {
333       return absl::UnimplementedError("Batch size is not equal to 1.");
334     }
335     shape->h = dimensions->data[1];
336     shape->w = dimensions->data[2];
337     shape->c = dimensions->data[3];
338     return absl::OkStatus();
339   }
340   return absl::InvalidArgumentError(
341       absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of "
342                    "shape 1xHxWxC but got ",
343                    GetDimensionString(dimensions)));
344 }
345 
SetAllDimensions(const TfLiteIntArray * dimensions,HW * shape)346 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
347   if (dimensions->size != 2) {
348     return absl::InvalidArgumentError(
349         absl::StrCat("Expected a 2D tensor of shape HxW but got ",
350                      GetDimensionString(dimensions)));
351   }
352   shape->h = dimensions->data[0];
353   shape->w = dimensions->data[1];
354   return absl::OkStatus();
355 }
356 
SetAllDimensions(const TfLiteIntArray * dimensions,OHWI * shape)357 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
358   if (dimensions->size != 4) {
359     return absl::InvalidArgumentError(
360         absl::StrCat("Expected a 4D tensor of shape OxHxWxI but got ",
361                      GetDimensionString(dimensions)));
362   }
363   shape->o = dimensions->data[0];
364   shape->h = dimensions->data[1];
365   shape->w = dimensions->data[2];
366   shape->i = dimensions->data[3];
367   return absl::OkStatus();
368 }
369 
SetAllDimensions(const TfLiteIntArray * dimensions,BHWC * shape)370 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
371   if (dimensions->size != 4) {
372     return absl::InvalidArgumentError(
373         absl::StrCat("Expected a 4D tensor of shape BxHxWxC but got ",
374                      GetDimensionString(dimensions)));
375   }
376   shape->b = dimensions->data[0];
377   shape->h = dimensions->data[1];
378   shape->w = dimensions->data[2];
379   shape->c = dimensions->data[3];
380   return absl::OkStatus();
381 }
382 
383 // If there is fused activation present, then there will be another node created
384 // that will have identical output as the given node. New operation node will
385 // depend on the given node output.
MaybeFuseActivation(TfLiteFusedActivation fused_activation,GraphFloat32 * graph,Node * node)386 absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation,
387                                  GraphFloat32* graph, Node* node) {
388   const auto outputs = graph->FindOutputs(node->id);
389   if (outputs.size() != 1) {
390     return absl::InternalError("Number of outputs != 1");
391   }
392   switch (fused_activation) {
393     case kTfLiteActNone:
394       // Nothing to do here
395       return absl::OkStatus();
396     case kTfLiteActRelu:
397     case kTfLiteActReluN1To1:
398     case kTfLiteActRelu6: {
399       ReLUAttributes attr;
400       attr.clip = fused_activation == kTfLiteActRelu
401                       ? 0.0f
402                       : (fused_activation == kTfLiteActReluN1To1 ? 1.0f : 6.0f);
403       Node* activation_node;
404       RETURN_IF_ERROR(
405           NewPassthroughNode(graph, node, outputs[0], &activation_node));
406       activation_node->operation.type = ToString(OperationType::RELU);
407       activation_node->operation.attributes = attr;
408       return absl::OkStatus();
409     }
410     case kTfLiteActTanh: {
411       Node* activation_node;
412       RETURN_IF_ERROR(
413           NewPassthroughNode(graph, node, outputs[0], &activation_node));
414       activation_node->operation.type = ToString(OperationType::TANH);
415       return absl::OkStatus();
416     }
417     case kTfLiteActSigmoid: {
418       Node* activation_node;
419       RETURN_IF_ERROR(
420           NewPassthroughNode(graph, node, outputs[0], &activation_node));
421       activation_node->operation.type = ToString(OperationType::SIGMOID);
422       return absl::OkStatus();
423     } break;
424     default:
425       return absl::NotFoundError(
426           absl::StrCat("Unsupported fused activation: ", fused_activation));
427   }
428 }
429 
430 }  // namespace gpu
431 }  // namespace tflite
432