• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 #include <string.h>
22 
23 #include <string>
24 
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/lite/c/builtin_op_data.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
29 #include "tensorflow/lite/delegates/gpu/common/model.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
33 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
34 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
35 #include "tensorflow/lite/kernels/internal/types.h"
36 
37 namespace tflite {
38 namespace gpu {
39 
40 absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
41                                     TfLiteNode** tflite_node,
42                                     TfLiteRegistration** registration);
43 
44 DataType ToDataType(TfLiteType type);
45 
46 absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc);
47 
48 absl::Status ExtractAxisFromIndex(const TfLiteTensor& tflite_tensor, int index,
49                                   Axis* axis);
50 
51 absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
52                                             TensorRef<BHWC>* tensor_ref);
53 
54 // Populates quantization parameters for non-constant UInt8/Int8 tensors.
55 // This helps the delegate emulate quantized inference with
56 // QuantizeAndDequantize.
57 absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
58                                  QuantizationParams* quant_params);
59 
60 int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
61                                     const TfLiteNode* tflite_node);
62 
63 int GetNumberOfConstInputsForNode(const TfLiteContext* context,
64                                   const TfLiteNode* tflite_node);
65 
66 absl::Status CheckInputsOutputs(const TfLiteContext* context,
67                                 const TfLiteNode* tflite_node,
68                                 int runtime_inputs, int outputs);
69 
70 absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
71                                       const TfLiteNode* tflite_node,
72                                       int runtime_inputs, int const_inputs,
73                                       int outputs);
74 
75 void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
76                              float* dst);
77 
78 template <typename T>
DequantizeConstantTensor(const TfLiteTensor & tensor,const T * source_data,float * dequantized_data)79 inline void DequantizeConstantTensor(const TfLiteTensor& tensor,
80                                      const T* source_data,
81                                      float* dequantized_data) {
82   TfLiteAffineQuantization* quant_params =
83       static_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
84   if (quant_params->scale->size > 1) {
85     // Tensor is per-channel quantized.
86     PerChannelDequantizationParams op_params;
87     op_params.zero_point = quant_params->zero_point->data;
88     op_params.scale = quant_params->scale->data;
89     op_params.quantized_dimension = quant_params->quantized_dimension;
90     reference_ops::PerChannelDequantize(op_params, GetTensorShape(&tensor),
91                                         source_data, GetTensorShape(&tensor),
92                                         dequantized_data);
93   } else {
94     DequantizationParams op_params;
95     op_params.zero_point = tensor.params.zero_point;
96     op_params.scale = tensor.params.scale;
97     reference_ops::Dequantize(op_params, GetTensorShape(&tensor), source_data,
98                               GetTensorShape(&tensor), dequantized_data);
99   }
100 }
101 
102 template <typename T>
CreateVectorCopyData(const TfLiteTensor & tensor,T * tensor_data)103 absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) {
104   if (tensor.bytes % sizeof(T) != 0) {
105     return absl::InvalidArgumentError(
106         absl::StrCat("Input data size ", tensor.bytes,
107                      " is not aligned to expected type: ", sizeof(T)));
108   }
109   std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes);
110   return absl::OkStatus();
111 }
112 
113 template <>
114 absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
115                                          float* tensor_data);
116 
117 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape);
118 
119 absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions);
120 
121 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape);
122 
123 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape);
124 
125 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape);
126 
127 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape);
128 
129 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape);
130 
131 // If there is fused activation present, then there will be another node created
132 // that will have identical output as the given node. New operation node will
133 // depend on the given node output.
134 absl::Status MaybeFuseActivation(TfLiteFusedActivation fused_activation,
135                                  GraphFloat32* graph, Node* node);
136 
137 }  // namespace gpu
138 }  // namespace tflite
139 
140 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
141