• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
16 #define TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
17 
18 #include <stdint.h>
19 
20 #include <limits>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 
25 namespace tflite {
26 
27 // A fair number of functions in this header have historically been inline.
28 // It is ok to change functions to not be inline if the latency with
29 // benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
30 // made, move the newly non-inlined function declarations to the top of this
31 // header file.
32 
33 // Note: You must check if result is not null:
34 //
35 //   TfLiteTensor* my_tensor = GetInput(context, node, kMyTensorIdx);
36 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
37 //
38 // This is because the index might point to the optional tensor constant
39 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
40 const TfLiteTensor* GetInput(const TfLiteContext* context,
41                              const TfLiteNode* node, int index);
42 
43 // Same as `GetInput` but returns boolean and uses output argument for tensor.
44 //
45 //   TfLiteTensor* my_tensor;
46 //   TF_LITE_ENSURE_OK(context,
47 //                     GetInputSafe(context, node, kMyTensorIdx, &my_tensor));
48 //   // can use my_tensor directly from here onwards, it is not nullptr
49 //
50 // Should be used in cases where the binary size is too large.
51 TfLiteStatus GetInputSafe(const TfLiteContext* context, const TfLiteNode* node,
52                           int index, const TfLiteTensor** tensor);
53 
54 // Note: You must check if result is not null:
55 //
56 //   TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
57 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
58 //
59 // This is because the index might point to the optional tensor constant
60 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
61 TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
62                                int index);
63 
64 // Note: You must check if result is not null:
65 //
66 //   TfLiteTensor* my_tensor = GetOutput(context, node, kMyTensorIdx);
67 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
68 //
69 // This is because the index might point to the optional tensor constant
70 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
71 TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
72                         int index);
73 
74 // Same as `GetOutput` but returns boolean and uses output argument for tensor.
75 //
76 //   TfLiteTensor* my_tensor;
77 //   TF_LITE_ENSURE_OK(context,
78 //                     GetOutputSafe(context, node, kMyTensorIdx, &my_tensor));
79 //   // can use my_tensor directly from here onwards, it is not nullptr
80 //
81 // Should be used in cases where the binary size is too large.
82 TfLiteStatus GetOutputSafe(const TfLiteContext* context, const TfLiteNode* node,
83                            int index, TfLiteTensor** tensor);
84 
85 // Note: You must check if result is not null:
86 //
87 //   TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
88 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
89 //
90 // This is because the index might point to the optional tensor constant
91 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
92 //
93 // Deprecated. GetInput has the same functionality.
94 const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
95                                            const TfLiteNode* node, int index);
96 
97 #ifndef TF_LITE_STATIC_MEMORY
98 // Note: You must check if result is not null:
99 //
100 //   TfLiteTensor* my_tensor = GetTemporary(context, node, kMyTensorIdx);
101 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
102 //
103 // This is because the index might point to the optional tensor constant
104 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
105 TfLiteTensor* GetTemporary(TfLiteContext* context, const TfLiteNode* node,
106                            int index);
107 
108 // Same as `GetTemporary` but returns boolean and uses output argument for
109 // tensor.
110 //
111 //   TfLiteTensor* my_tensor;
112 //   TF_LITE_ENSURE_OK(context,
113 //                     GetTemporarySafe(context, node, kMyTensorIdx,
114 //                     &my_tensor));
115 //   // can use my_tensor directly from here onwards, it is not nullptr
116 //
117 // Should be used in cases where the binary size is too large.
118 TfLiteStatus GetTemporarySafe(const TfLiteContext* context,
119                               const TfLiteNode* node, int index,
120                               TfLiteTensor** tensor);
121 
122 // Note: You must check if result is not null:
123 //
124 //   TfLiteTensor* my_tensor = GetIntermediates(context, node, kMyTensorIdx);
125 //   TF_LITE_ENSURE(context, my_tensor != nullptr);
126 //
127 // This is because the index might point to the optional tensor constant
128 // (kTfLiteOptionalTensor) in which case there is no tensor to return.
129 const TfLiteTensor* GetIntermediates(TfLiteContext* context,
130                                      const TfLiteNode* node, int index);
131 
132 // Same as `GetIntermediates` but returns boolean and uses output argument for
133 // tensor.
134 //
135 //   TfLiteTensor* my_tensor;
136 //   TF_LITE_ENSURE_OK(context,
137 //                     GetIntermediatesSafe(context, node, kMyTensorIdx,
138 //                     &my_tensor));
139 //   // can use my_tensor directly from here onwards, it is not nullptr
140 //
141 // Should be used in cases where the binary size is too large.
142 TfLiteStatus GetIntermediatesSafe(const TfLiteContext* context,
143                                   const TfLiteNode* node, int index,
144                                   TfLiteTensor** tensor);
145 #endif  // TF_LITE_STATIC_MEMORY
146 
NumDimensions(const TfLiteTensor * t)147 inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
SizeOfDimension(const TfLiteTensor * t,int dim)148 inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
149   return t->dims->data[dim];
150 }
151 
NumInputs(const TfLiteNode * node)152 inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
NumOutputs(const TfLiteNode * node)153 inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
154 
155 #ifndef TF_LITE_STATIC_MEMORY
NumIntermediates(const TfLiteNode * node)156 inline int NumIntermediates(const TfLiteNode* node) {
157   return node->intermediates->size;
158 }
159 #endif  // TF_LITE_STATIC_MEMORY
160 
NumElements(const TfLiteIntArray * dims)161 inline int64_t NumElements(const TfLiteIntArray* dims) {
162   int64_t count = 1;
163   for (int i = 0; i < dims->size; ++i) {
164     count *= dims->data[i];
165   }
166   return count;
167 }
168 
NumElements(const TfLiteTensor * t)169 inline int64_t NumElements(const TfLiteTensor* t) {
170   return NumElements(t->dims);
171 }
172 
173 // Determines whether tensor is constant.
174 // TODO(b/138199592): Introduce new query which checks for constant OR
175 // persistent-read-only, which would be useful for most tensor kernels that
176 // are potentially dynamic based on the input tensor value availability at the
177 // time of prepare.
IsConstantTensor(const TfLiteTensor * tensor)178 inline bool IsConstantTensor(const TfLiteTensor* tensor) {
179   return tensor->allocation_type == kTfLiteMmapRo;
180 }
181 
182 // Determines whether tensor is dynamic. Note that a tensor can be non-const and
183 // not dynamic. This function specifically checks for a dynamic tensor.
IsDynamicTensor(const TfLiteTensor * tensor)184 inline bool IsDynamicTensor(const TfLiteTensor* tensor) {
185   return tensor->allocation_type == kTfLiteDynamic;
186 }
187 
188 // Sets tensor to dynamic.
SetTensorToDynamic(TfLiteTensor * tensor)189 inline void SetTensorToDynamic(TfLiteTensor* tensor) {
190   if (tensor->allocation_type != kTfLiteDynamic) {
191     tensor->allocation_type = kTfLiteDynamic;
192     tensor->data.raw = nullptr;
193   }
194 }
195 
196 // Sets tensor to persistent and read-only.
SetTensorToPersistentRo(TfLiteTensor * tensor)197 inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
198   if (tensor->allocation_type != kTfLitePersistentRo) {
199     tensor->allocation_type = kTfLitePersistentRo;
200     tensor->data.raw = nullptr;
201   }
202 }
203 
204 // Determines whether it is a hybrid op - one that has float inputs and
205 // quantized weights.
IsHybridOp(const TfLiteTensor * input,const TfLiteTensor * weight)206 inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
207   return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) &&
208           input->type == kTfLiteFloat32);
209 }
210 
211 // Check dimensionality match and populate OpData for Conv and DepthwiseConv.
212 TfLiteStatus PopulateConvolutionQuantizationParams(
213     TfLiteContext* context, const TfLiteTensor* input,
214     const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
215     const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
216     int32_t* output_activation_min, int32_t* output_activation_max,
217     int32_t* per_channel_multiplier, int* per_channel_shift);
218 
219 TfLiteStatus PopulateConvolutionQuantizationParams(
220     TfLiteContext* context, const TfLiteTensor* input,
221     const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output,
222     const TfLiteFusedActivation& activation, int32_t* multiplier, int* shift,
223     int32_t* output_activation_min, int32_t* output_activation_max,
224     int32_t* per_channel_multiplier, int* per_channel_shift, int num_channels);
225 
226 // Calculates the multiplication factor for a quantized convolution (or
227 // quantized depthwise convolution) involving the given tensors. Returns an
228 // error if the scales of the tensors are not compatible.
229 TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
230                                               const TfLiteTensor* input,
231                                               const TfLiteTensor* filter,
232                                               const TfLiteTensor* bias,
233                                               TfLiteTensor* output,
234                                               double* multiplier);
235 
236 TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
237                                               const TfLiteTensor* input,
238                                               const TfLiteTensor* filter,
239                                               TfLiteTensor* output,
240                                               double* multiplier);
241 
242 // Calculates the useful quantized range of an activation layer given its
243 // activation tensor.
244 TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
245                                                TfLiteFusedActivation activation,
246                                                TfLiteTensor* output,
247                                                int32_t* act_min,
248                                                int32_t* act_max);
249 
250 // Calculates the useful range of an activation layer given its activation
251 // tensor.a
252 template <typename T>
CalculateActivationRange(TfLiteFusedActivation activation,T * activation_min,T * activation_max)253 void CalculateActivationRange(TfLiteFusedActivation activation,
254                               T* activation_min, T* activation_max) {
255   if (activation == kTfLiteActRelu) {
256     *activation_min = 0;
257     *activation_max = std::numeric_limits<T>::max();
258   } else if (activation == kTfLiteActRelu6) {
259     *activation_min = 0;
260     *activation_max = 6;
261   } else if (activation == kTfLiteActReluN1To1) {
262     *activation_min = -1;
263     *activation_max = 1;
264   } else {
265     *activation_min = std::numeric_limits<T>::lowest();
266     *activation_max = std::numeric_limits<T>::max();
267   }
268 }
269 
270 // Return true if the given tensors have the same shape.
271 bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
272 
273 // Calculates the output_shape that is necessary for element-wise operations
274 // with broadcasting involving the two input tensors.
275 TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
276                                         const TfLiteTensor* input1,
277                                         const TfLiteTensor* input2,
278                                         TfLiteIntArray** output_shape);
279 
280 // Calculates the output_shape that is necessary for element-wise operations
281 // with broadcasting involving the three input tensors.
282 TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
283                                         const TfLiteTensor* input1,
284                                         const TfLiteTensor* input2,
285                                         const TfLiteTensor* input3,
286                                         TfLiteIntArray** output_shape);
287 
288 // Return the size of given type in bytes. Return 0 in in case of string.
289 int TfLiteTypeGetSize(TfLiteType type);
290 
291 }  // namespace tflite
292 
293 #endif  // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
294