• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
25 
26 namespace tflite {
27 
28 enum class FusedActivationFunctionType : uint8_t {
29   kNone,
30   kRelu6,
31   kRelu1,
32   kRelu
33 };
34 enum class PaddingType : uint8_t { kNone, kSame, kValid };
35 
36 struct PaddingValues {
37   int16_t width;
38   int16_t height;
39   // offset is used for calculating "remaining" padding, for example, `width`
40   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
41   // 1 + 1 = 2.
42   int16_t width_offset;
43   // Same as width_offset except it's over the height dimension.
44   int16_t height_offset;
45 };
46 
47 struct Padding3DValues {
48   int16_t width;
49   int16_t height;
50   int16_t depth;
51   // offset is used for calculating "remaining" padding, for example, `width`
52   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
53   // 1 + 1 = 2.
54   int16_t width_offset;
55   // Same as width_offset except it's over the height dimension.
56   int16_t height_offset;
57   // Same as width_offset except it's over the depth dimension.
58   int16_t depth_offset;
59 };
60 
61 // This enumeration allows for non-default formats for the weights array
62 // of a fully-connected operator, allowing the use of special optimized
63 // runtime paths.
64 enum class FullyConnectedWeightsFormat : uint8_t {
65   // Default format (flat 2D layout, the inner contiguous dimension
66   // is input_depth, the outer non-contiguous dimension is output_depth)
67   kDefault,
68   // Summary: optimized layout for fast CPU runtime implementation,
69   // aimed specifically at ARM CPUs at the moment, and specialized for
70   // 8-bit quantized layers.
71   //
72   // The use case we're concerned with here is: 8-bit quantization,
73   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
74   // a key application that drove this), very small batch size (e.g. 1 -- 4).
75   //
76   // Even with 8-bit quantization of weights, the performance of memory
77   // accesses to the weights can become the dominant issue when
78   // the batch size is small, so each weight value is used in only a few
79   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
80   // intensity. The specific issues that arise are of three kinds:
81   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
82   //     bound. That's the "good" issue to run into.
83   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
84   //     prefetched into the cache by the time we need it.
85   // (3) One may run into cache aliasing: multiple values that are
86   //     pre-fetched, alias each other in the L1 cache (which typically
87   //     has only 4-way set associativity in ARM CPUs) and thus evict
88   //     each other before we get to using them.
89   //
90   // The point of this shuffling is to avoid issues (2) and (3) so that
91   // we get as fast as possible given only the hard constraint (1).
92   // This is achieved by turning the difficulty into a solution: the
93   // difficulty, that each value loaded from memory is used only in
94   // one kernel iteration, making this operation memory-intensive, hints at
95   // the solution, of shuffling the weights so that they are stored in the
96   // exact order as the kernel needs to load them, so that the memory
97   // accesses made by the kernel are trivial. This solves (2) because the
98   // trivial memory access pattern allows the CPU's automatic prefetching
99   // to perform very well (no need even for preload instructions), and this
100   // solves (3) because the values being loaded concurrently are now
101   // contiguous in the address space, thus don't alias each other in the cache.
102   //
103   // On ARM, we typically want our kernel to process a 4x16 block of weights
104   // at a time, because:
105   //   - 16 is the number of bytes in a NEON register.
106   //   - 4 is how many rows we need to handle concurrently in the kernel in
107   //     order to have sufficient mutual independence of instructions to
108   //     maximize arithmetic throughput.
109   //
110   // Finally, the 'Int8' part in the name refers to the fact that this
111   // weights format has each weights value encoded as a signed int8_t value,
112   // even if the data type of the weights buffer is uint8_t.  This is intended
113   // to save runtime kernels the effort to have to XOR the top bit of these
114   // bytes before using them in signed arithmetic, see this file for more
115   // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
116   //
117   //   tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
118   //
119   kShuffled4x16Int8,
120 };
121 
122 // Quantization parameters, determining the mapping of quantized values
123 // to real values (i.e. determining how quantized values are mathematically
124 // interpreted).
125 //
126 // The correspondence is as follows:
127 //
128 //   real_value = scale * (quantized_value - zero_point);
129 //
130 // In other words, zero_point designates which quantized value corresponds to
131 // the real 0 value, and scale designates the difference between the real values
132 // corresponding to consecutive quantized values differing by 1.
133 struct QuantizationParams {
134   int32_t zero_point = 0;
135   double scale = 0.0;
136 };
137 
138 inline bool operator==(const QuantizationParams& qp1,
139                        const QuantizationParams& qp2) {
140   return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
141 }
142 
143 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)144 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
145   if (num_dims == 0) {
146     return false;
147   }
148   TFLITE_DCHECK(dims != nullptr);
149   TFLITE_DCHECK(current != nullptr);
150   int carry = 1;
151   for (int idx = num_dims - 1; idx >= 0; --idx) {
152     int current_val = current[idx] + carry;
153     TFLITE_DCHECK_GE(dims[idx], current_val);
154     if (dims[idx] == current_val) {
155       current[idx] = 0;
156     } else {
157       current[idx] = current_val;
158       carry = 0;
159       break;
160     }
161   }
162   return (carry == 0);
163 }
164 
165 // Gets offset of index if reducing on axis. When reducing, the flattened offset
166 // will not change, if the input index changes on the given axis. For example,
167 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
168 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
169 // offset.
170 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)171 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
172                                   const int* index, const int num_axis,
173                                   const int* axis) {
174   if (num_dims == 0) {
175     return 0;
176   }
177   TFLITE_DCHECK(dims != nullptr);
178   TFLITE_DCHECK(index != nullptr);
179   size_t offset = 0;
180   for (int idx = 0; idx < num_dims; ++idx) {
181     // if we need to skip this axis
182     bool is_axis = false;
183     if (axis != nullptr) {
184       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
185         if (idx == axis[axis_idx]) {
186           is_axis = true;
187           break;
188         }
189       }
190     }
191     if (!is_axis) {
192       offset = offset * static_cast<size_t>(dims[idx]) +
193                static_cast<size_t>(index[idx]);
194     }
195   }
196   return offset;
197 }
198 
199 // Since tensors with '0' in their shape are valid in TF, these offset functions
200 // allow that as long as the corresponding index is also 0. It is upto the
201 // calling ops to ensure that they perform verification checks on tensor shapes
202 // if they don't support a particular behavior.
203 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)204 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
205   TFLITE_DCHECK((i0 == 0 && dims.sizes[0] == 0) ||
206                 (i0 >= 0 && i0 < dims.sizes[0]));
207   TFLITE_DCHECK((i1 == 0 && dims.sizes[1] == 0) ||
208                 (i1 >= 0 && i1 < dims.sizes[1]));
209   TFLITE_DCHECK((i2 == 0 && dims.sizes[2] == 0) ||
210                 (i2 >= 0 && i2 < dims.sizes[2]));
211   TFLITE_DCHECK((i3 == 0 && dims.sizes[3] == 0) ||
212                 (i3 >= 0 && i3 < dims.sizes[3]));
213   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
214          i3 * dims.strides[3];
215 }
216 
Offset(const Dims<4> & dims,int * index)217 inline int Offset(const Dims<4>& dims, int* index) {
218   return Offset(dims, index[0], index[1], index[2], index[3]);
219 }
220 
221 // Get array size, DCHECKing that the dim index is in range.
222 //
223 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
224 // already performs this check.
225 template <int N>
ArraySize(const Dims<N> & array,int index)226 int ArraySize(const Dims<N>& array, int index) {
227   TFLITE_DCHECK(index >= 0 && index < N);
228   return array.sizes[index];
229 }
230 
231 // Get common array size, DCHECKing that they all agree.
232 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)233 int MatchingArraySize(const ArrayType1& array1, int index1,
234                       const ArrayType2& array2, int index2) {
235   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
236   return ArraySize(array1, index1);
237 }
238 
239 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)240 int MatchingArraySize(const ArrayType1& array1, int index1,
241                       const ArrayType2& array2, int index2, Args... args) {
242   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
243   return MatchingArraySize(array1, index1, args...);
244 }
245 
246 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)247 inline int MatchingDim(const RuntimeShape& shape1, int index1,
248                        const RuntimeShape& shape2, int index2) {
249   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
250   return std::min(shape1.Dims(index1), shape2.Dims(index2));
251 }
252 
253 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)254 int MatchingDim(const RuntimeShape& shape1, int index1,
255                 const RuntimeShape& shape2, int index2, Args... args) {
256   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
257   return MatchingDim(shape1, index1, args...);
258 }
259 
260 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
261 template <int N>
FlatSize(const Dims<N> & dims)262 inline int FlatSize(const Dims<N>& dims) {
263   int flat_size = 1;
264   for (int i = 0; i < N; ++i) {
265     flat_size *= dims.sizes[i];
266   }
267   return flat_size;
268 }
269 
270 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)271 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
272   return FlatSize(dims);
273 }
274 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)275 inline int MatchingElementsSize(const RuntimeShape& shape,
276                                 const RuntimeShape& check_shape_0) {
277   const int size_1 = shape.FlatSize();
278   const int size_2 = check_shape_0.FlatSize();
279   TFLITE_CHECK_EQ(size_1, size_2);
280   return size_1;
281 }
282 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)283 inline int MatchingElementsSize(const RuntimeShape& shape,
284                                 const RuntimeShape& check_shape_0,
285                                 const RuntimeShape& check_shape_1) {
286   const int size_1 = shape.FlatSize();
287   const int size_2 = check_shape_0.FlatSize();
288   const int size_3 = check_shape_1.FlatSize();
289   TFLITE_CHECK_EQ(size_1, size_2);
290   TFLITE_CHECK_EQ(size_2, size_3);
291   return size_1;
292 }
293 
294 // Flat size calculation, checking that dimensions match with one or more other
295 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)296 inline int MatchingFlatSize(const RuntimeShape& shape,
297                             const RuntimeShape& check_shape_0) {
298   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
299   const int dims_count = shape.DimensionsCount();
300   for (int i = 0; i < dims_count; ++i) {
301     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
302   }
303   return shape.FlatSize();
304 }
305 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)306 inline int MatchingFlatSize(const RuntimeShape& shape,
307                             const RuntimeShape& check_shape_0,
308                             const RuntimeShape& check_shape_1) {
309   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
310   const int dims_count = shape.DimensionsCount();
311   for (int i = 0; i < dims_count; ++i) {
312     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
313   }
314   return MatchingFlatSize(shape, check_shape_1);
315 }
316 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)317 inline int MatchingFlatSize(const RuntimeShape& shape,
318                             const RuntimeShape& check_shape_0,
319                             const RuntimeShape& check_shape_1,
320                             const RuntimeShape& check_shape_2) {
321   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
322   const int dims_count = shape.DimensionsCount();
323   for (int i = 0; i < dims_count; ++i) {
324     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
325   }
326   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
327 }
328 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)329 inline int MatchingFlatSize(const RuntimeShape& shape,
330                             const RuntimeShape& check_shape_0,
331                             const RuntimeShape& check_shape_1,
332                             const RuntimeShape& check_shape_2,
333                             const RuntimeShape& check_shape_3) {
334   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
335   const int dims_count = shape.DimensionsCount();
336   for (int i = 0; i < dims_count; ++i) {
337     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
338   }
339   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
340 }
341 
342 // Flat size calculation, checking that dimensions match with one or more other
343 // arrays.
344 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)345 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
346   for (int i = 0; i < N; ++i) {
347     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
348   }
349   return FlatSize(dims);
350 }
351 
352 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)353 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
354                             const Dims<N>& check_dims_1) {
355   for (int i = 0; i < N; ++i) {
356     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
357   }
358   return MatchingFlatSize(dims, check_dims_1);
359 }
360 
361 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)362 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
363                             const Dims<N>& check_dims_1,
364                             const Dims<N>& check_dims_2) {
365   for (int i = 0; i < N; ++i) {
366     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
367   }
368   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
369 }
370 
371 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)372 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
373                             const Dims<N>& check_dims_1,
374                             const Dims<N>& check_dims_2,
375                             const Dims<N>& check_dims_3) {
376   for (int i = 0; i < N; ++i) {
377     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
378   }
379   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
380 }
381 
382 // Flat size calculation, checking if their extended shapes match.
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)383 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
384                                          const RuntimeShape& check_shape_0) {
385   const int shape_dims = shape.DimensionsCount();
386   const int check_shape_0_dims = check_shape_0.DimensionsCount();
387   const int min_dims = std::min(shape_dims, check_shape_0_dims);
388 
389   for (int i = 0; i < min_dims; ++i) {
390     TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i),
391                      check_shape_0.Dims(check_shape_0_dims - 1 - i));
392   }
393   for (int i = min_dims; i < shape_dims; ++i) {
394     TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i), 1);
395   }
396   for (int i = min_dims; i < check_shape_0_dims; ++i) {
397     TFLITE_DCHECK_EQ(check_shape_0.Dims(check_shape_0_dims - 1 - i), 1);
398   }
399   return shape.FlatSize();
400 }
401 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)402 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
403                                          const RuntimeShape& check_shape_0,
404                                          const RuntimeShape& check_shape_1) {
405   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
406   TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1),
407                    flat_size);
408   return flat_size;
409 }
410 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)411 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
412                                          const RuntimeShape& check_shape_0,
413                                          const RuntimeShape& check_shape_1,
414                                          const RuntimeShape& check_shape_2) {
415   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
416   TFLITE_DCHECK_EQ(
417       MatchingExtendedShapeFlatSize(shape, check_shape_1, check_shape_2),
418       flat_size);
419   return flat_size;
420 }
421 
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)422 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
423                                          const RuntimeShape& check_shape_0,
424                                          const RuntimeShape& check_shape_1,
425                                          const RuntimeShape& check_shape_2,
426                                          const RuntimeShape& check_shape_3) {
427   const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
428   TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1,
429                                                  check_shape_2, check_shape_3),
430                    flat_size);
431   return flat_size;
432 }
433 
434 // Data is required to be contiguous, and so many operators can use either the
435 // full array flat size or the flat size with one dimension skipped (commonly
436 // the depth).
437 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)438 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
439   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
440   int flat_size = 1;
441   for (int i = 0; i < N; ++i) {
442     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
443   }
444   return flat_size;
445 }
446 
447 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
448 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)449 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
450                                    const Dims<N>& check_dims_0) {
451   for (int i = 0; i < N; ++i) {
452     if (i != skip_dim) {
453       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
454     }
455   }
456   return FlatSizeSkipDim(dims, skip_dim);
457 }
458 
459 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)460 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
461                                    const Dims<N>& check_dims_0,
462                                    const Dims<N>& check_dims_1) {
463   for (int i = 0; i < N; ++i) {
464     if (i != skip_dim) {
465       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
466     }
467   }
468   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
469 }
470 
471 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)472 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
473                                    const Dims<N>& check_dims_0,
474                                    const Dims<N>& check_dims_1,
475                                    const Dims<N>& check_dims_2) {
476   for (int i = 0; i < N; ++i) {
477     if (i != skip_dim) {
478       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
479     }
480   }
481   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
482 }
483 
484 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)485 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
486                                    const Dims<N>& check_dims_0,
487                                    const Dims<N>& check_dims_1,
488                                    const Dims<N>& check_dims_2,
489                                    const Dims<N>& check_dims_3) {
490   for (int i = 0; i < N; ++i) {
491     if (i != skip_dim) {
492       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
493     }
494   }
495   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
496                                  check_dims_3);
497 }
498 
499 // Data is required to be contiguous, and so many operators can use either the
500 // full array flat size or the flat size with one dimension skipped (commonly
501 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)502 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
503   const int dims_count = shape.DimensionsCount();
504   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
505   const auto* dims_data = shape.DimsData();
506   int flat_size = 1;
507   for (int i = 0; i < dims_count; ++i) {
508     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
509   }
510   return flat_size;
511 }
512 
513 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)514 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
515                                    const RuntimeShape& check_shape_0) {
516   const int dims_count = shape.DimensionsCount();
517   for (int i = 0; i < dims_count; ++i) {
518     if (i != skip_dim) {
519       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
520     }
521   }
522   return FlatSizeSkipDim(shape, skip_dim);
523 }
524 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)525 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
526                                    const RuntimeShape& check_shape_0,
527                                    const RuntimeShape& check_shape_1) {
528   const int dims_count = shape.DimensionsCount();
529   for (int i = 0; i < dims_count; ++i) {
530     if (i != skip_dim) {
531       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
532     }
533   }
534   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
535 }
536 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)537 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
538                                    const RuntimeShape& check_shape_0,
539                                    const RuntimeShape& check_shape_1,
540                                    const RuntimeShape& check_shape_2) {
541   const int dims_count = shape.DimensionsCount();
542   for (int i = 0; i < dims_count; ++i) {
543     if (i != skip_dim) {
544       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
545     }
546   }
547   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
548 }
549 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)550 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
551                                    const RuntimeShape& check_shape_0,
552                                    const RuntimeShape& check_shape_1,
553                                    const RuntimeShape& check_shape_2,
554                                    const RuntimeShape& check_shape_3) {
555   const int dims_count = shape.DimensionsCount();
556   for (int i = 0; i < dims_count; ++i) {
557     if (i != skip_dim) {
558       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
559     }
560   }
561   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
562                                  check_shape_3);
563 }
564 
565 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)566 bool IsPackedWithoutStrides(const Dims<N>& dims) {
567   int expected_stride = 1;
568   for (int d = 0; d < N; d++) {
569     if (dims.strides[d] != expected_stride) return false;
570     expected_stride *= dims.sizes[d];
571   }
572   return true;
573 }
574 
575 template <int N>
ComputeStrides(Dims<N> * dims)576 void ComputeStrides(Dims<N>* dims) {
577   dims->strides[0] = 1;
578   for (int d = 1; d < N; d++) {
579     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
580   }
581 }
582 
583 enum class BroadcastableOpCategory : uint8_t {
584   kNone,
585   kNonBroadcast,               // Matching input shapes.
586   kFirstInputBroadcastsFast,   // Fivefold nested loops.
587   kSecondInputBroadcastsFast,  // Fivefold nested loops.
588   kGenericBroadcast,           // Fall-back.
589 };
590 
591 struct MinMax {
592   float min;
593   float max;
594 };
595 static_assert(sizeof(MinMax) == 8, "");
596 
597 struct ActivationParams {
598   FusedActivationFunctionType activation_type;
599   // uint8_t, etc, activation params.
600   int32_t quantized_activation_min;
601   int32_t quantized_activation_max;
602 };
603 
604 struct ReluParams : public ActivationParams {
605   int32_t input_offset;
606   int32_t output_offset;
607   int32_t output_multiplier;
608   int output_shift;
609 };
610 
611 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
612 // op for pattern-specific optimization.
613 enum class ResizingCategory : uint8_t {
614   kNone,
615   kImageStyle,  // 4D, operating on inner dimensions, say {0, a, b, 0}.
616   kGenericResize,
617 };
618 
619 // For Add, Sub, Mul ops.
620 struct ArithmeticParams {
621   // Shape dependent / common to data / op types.
622   BroadcastableOpCategory broadcast_category;
623   // uint8_t inference params.
624   int32_t input1_offset;
625   int32_t input2_offset;
626   int32_t output_offset;
627   int32_t output_multiplier;
628   int output_shift;
629   // Add / Sub, not Mul, uint8_t inference params.
630   int left_shift;
631   int32_t input1_multiplier;
632   int input1_shift;
633   int32_t input2_multiplier;
634   int input2_shift;
635 
636   // TODO(b/158622529): Union the following activation params.
637   // uint8_t, etc, activation params.
638   int32_t quantized_activation_min;
639   int32_t quantized_activation_max;
640   // float activation params.
641   float float_activation_min;
642   float float_activation_max;
643   // int64_t activation params.
644   int64_t int64_activation_min;
645   int64_t int64_activation_max;
646 
647   // Processed output dimensions.
648   // Let input "a" be the one that broadcasts in the faster-changing dimension.
649   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
650   // {b0, b1, b2, b3, b4},
651   // broadcast_shape[4] = b0 = a0.
652   // broadcast_shape[3] = b1; a1 = 1.
653   // broadcast_shape[2] = b2 = a2.
654   // broadcast_shape[1] = a3; b3 = 1.
655   // broadcast_shape[0] = b4 = a4.
656   int broadcast_shape[5];
657 };
658 
659 struct ConcatenationParams {
660   int8_t axis;
661   const int32_t* input_zeropoint;
662   const float* input_scale;
663   uint16_t inputs_count;
664   int32_t output_zeropoint;
665   float output_scale;
666 };
667 
668 struct ComparisonParams {
669   // uint8_t inference params.
670   int left_shift;
671   int32_t input1_offset;
672   int32_t input1_multiplier;
673   int input1_shift;
674   int32_t input2_offset;
675   int32_t input2_multiplier;
676   int input2_shift;
677   // Shape dependent / common to inference types.
678   bool is_broadcast;
679 };
680 
681 struct ConvParams {
682   PaddingType padding_type;
683   PaddingValues padding_values;
684   // TODO(starka): This was just "stride", so check that width+height is OK.
685   int16_t stride_width;
686   int16_t stride_height;
687   int16_t dilation_width_factor;
688   int16_t dilation_height_factor;
689   // uint8_t inference params.
690   // TODO(b/65838351): Use smaller types if appropriate.
691   int32_t input_offset;
692   int32_t weights_offset;
693   int32_t output_offset;
694   int32_t output_multiplier;
695   int output_shift;
696   // uint8_t, etc, activation params.
697   int32_t quantized_activation_min;
698   int32_t quantized_activation_max;
699   // float activation params.
700   float float_activation_min;
701   float float_activation_max;
702 };
703 
704 struct Conv3DParams {
705   Padding3DValues padding_values;
706   int stride_width;
707   int stride_height;
708   int stride_depth;
709   int dilation_width;
710   int dilation_height;
711   int dilation_depth;
712   // float activation params.
713   float float_activation_min;
714   float float_activation_max;
715 };
716 
717 typedef Conv3DParams Conv3DTransposeParams;
718 
719 struct DepthToSpaceParams {
720   int32_t block_size;
721 };
722 
723 struct DepthwiseParams {
724   PaddingType padding_type;
725   PaddingValues padding_values;
726   int16_t stride_width;
727   int16_t stride_height;
728   int16_t dilation_width_factor;
729   int16_t dilation_height_factor;
730   int16_t depth_multiplier;
731   // uint8_t inference params.
732   // TODO(b/65838351): Use smaller types if appropriate.
733   int32_t input_offset;
734   int32_t weights_offset;
735   int32_t output_offset;
736   int32_t output_multiplier;
737   int output_shift;
738   // uint8_t, etc, activation params.
739   int32_t quantized_activation_min;
740   int32_t quantized_activation_max;
741   // float activation params.
742   float float_activation_min;
743   float float_activation_max;
744   const int32_t* output_multiplier_per_channel;
745   const int32_t* output_shift_per_channel;
746 };
747 
748 struct DequantizationParams {
749   double scale;
750   int32_t zero_point;
751 };
752 
753 struct PerChannelDequantizationParams {
754   const float* scale;
755   const int32_t* zero_point;
756   int32_t quantized_dimension;
757 };
758 
759 struct FakeQuantParams {
760   MinMax minmax;
761   int32_t num_bits;
762 };
763 
764 struct FullyConnectedParams {
765   // uint8_t inference params.
766   // TODO(b/65838351): Use smaller types if appropriate.
767   int32_t input_offset;
768   int32_t weights_offset;
769   int32_t output_offset;
770   int32_t output_multiplier;
771   int output_shift;
772   // uint8_t, etc, activation params.
773   int32_t quantized_activation_min;
774   int32_t quantized_activation_max;
775   // float activation params.
776   float float_activation_min;
777   float float_activation_max;
778   // Mark the operands as cacheable if they are unchanging, e.g. weights.
779   bool lhs_cacheable;
780   bool rhs_cacheable;
781   FullyConnectedWeightsFormat weights_format;
782 };
783 
784 struct GatherParams {
785   int16_t axis;
786   int16_t batch_dims;
787 };
788 
789 struct L2NormalizationParams {
790   // uint8_t inference params.
791   int32_t input_zero_point;
792 };
793 
794 struct LocalResponseNormalizationParams {
795   int32_t range;
796   double bias;
797   double alpha;
798   double beta;
799 };
800 
801 struct HardSwishParams {
802   // zero_point of the input activations.
803   int16_t input_zero_point;
804   // zero_point of the output activations.
805   int16_t output_zero_point;
806   // 16bit fixed-point component of the multiplier to apply to go from the
807   // "high-res input scale", which is the input scale multiplied by 2^7, to the
808   // "relu-ish scale", which 3.0/32768.
809   // See the implementation of HardSwishPrepare.
810   int16_t reluish_multiplier_fixedpoint_int16;
811   // exponent/bit-shift component of the aforementioned multiplier.
812   int reluish_multiplier_exponent;
813   // 16bit fixed-point component of the multiplier to apply to go from the
814   // "high-res input scale", which is the input scale multiplied by 2^7, to the
815   // output scale.
816   // See the implementation of HardSwishPrepare.
817   int16_t output_multiplier_fixedpoint_int16;
818   // exponent/bit-shift component of the aforementioned multiplier.
819   int output_multiplier_exponent;
820 };
821 
822 struct LogisticParams {
823   // uint8_t inference params.
824   int32_t input_zero_point;
825   int32_t input_range_radius;
826   int32_t input_multiplier;
827   int input_left_shift;
828 };
829 
830 struct LstmCellParams {
831   int32_t weights_zero_point;
832   int32_t accum_multiplier;
833   int accum_shift;
834   int state_integer_bits;
835 };
836 
837 struct MeanParams {
838   int8_t axis_count;
839   int16_t axis[4];
840 };
841 
842 struct PackParams {
843   int8_t axis;
844   const int32_t* input_zeropoint;
845   const float* input_scale;
846   uint16_t inputs_count;
847   int32_t output_zeropoint;
848   float output_scale;
849 };
850 
851 struct PadParams {
852   int8_t left_padding_count;
853   int32_t left_padding[5];
854   int8_t right_padding_count;
855   int32_t right_padding[5];
856   ResizingCategory resizing_category;
857 };
858 
859 struct PreluParams {
860   int32_t input_offset;
861   int32_t alpha_offset;
862   int32_t output_offset;
863   int32_t output_multiplier_1;
864   int output_shift_1;
865   int32_t output_multiplier_2;
866   int output_shift_2;
867 };
868 
869 struct PoolParams {
870   FusedActivationFunctionType activation;
871   PaddingType padding_type;
872   PaddingValues padding_values;
873   int stride_height;
874   int stride_width;
875   int filter_height;
876   int filter_width;
877   // uint8_t, etc, activation params.
878   int32_t quantized_activation_min;
879   int32_t quantized_activation_max;
880   // float activation params.
881   float float_activation_min;
882   float float_activation_max;
883 };
884 
885 struct ReshapeParams {
886   int8_t shape_count;
887   int32_t shape[4];
888 };
889 
890 struct ResizeBilinearParams {
891   bool align_corners;
892   // half_pixel_centers assumes pixels are of half the actual dimensions, and
893   // yields more accurate resizes. Corresponds to the same argument for the
894   // original TensorFlow op in TF2.0.
895   bool half_pixel_centers;
896 };
897 
898 struct ResizeNearestNeighborParams {
899   bool align_corners;
900   bool half_pixel_centers;
901 };
902 
903 struct SliceParams {
904   int8_t begin_count;
905   int32_t begin[5];
906   int8_t size_count;
907   int32_t size[5];
908 };
909 
910 struct SoftmaxParams {
911   // beta is not really used (not a Tensorflow parameter) and not implemented
912   // for LogSoftmax.
913   double beta;
914   // uint8_t inference params.  Used even when beta defaults to 1.0.
915   int32_t input_multiplier;
916   int32_t input_left_shift;
917   // Reverse scaling is only used by LogSoftmax.
918   int32_t reverse_scaling_divisor;
919   int32_t reverse_scaling_right_shift;
920   int diff_min;
921   int32_t zero_point;
922   float scale;
923   float* table;
924   // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
925   int16_t* exp_lut;
926   // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
927   int16_t* one_over_one_plus_x_lut;
928   uint8_t* uint8_table1;
929   uint8_t* uint8_table2;
930 };
931 
932 struct SpaceToBatchParams {
933   // "Zero" padding for uint8_t means padding with the output offset.
934   int32_t output_offset;
935 };
936 
937 struct SpaceToDepthParams {
938   int32_t block_size;
939 };
940 
941 struct SplitParams {
942   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
943   // OperatorEdges are of type uint16_t.
944   uint16_t num_split;
945   int16_t axis;
946 };
947 
948 struct SqueezeParams {
949   int8_t squeeze_dims_count;
950   int32_t squeeze_dims[4];
951 };
952 
953 struct StridedSliceParams {
954   int8_t start_indices_count;
955   int32_t start_indices[5];
956   int8_t stop_indices_count;
957   int32_t stop_indices[5];
958   int8_t strides_count;
959   int32_t strides[5];
960 
961   int16_t begin_mask;
962   int16_t ellipsis_mask;
963   int16_t end_mask;
964   int16_t new_axis_mask;
965   int16_t shrink_axis_mask;
966 };
967 
968 struct TanhParams {
969   int32_t input_zero_point;
970   int32_t input_range_radius;
971   int32_t input_multiplier;
972   int input_left_shift;
973 };
974 
975 struct TransposeParams {
976   int8_t perm_count;
977   int32_t perm[5];
978 };
979 
980 struct UnpackParams {
981   uint16_t num_split;
982   int16_t axis;
983 };
984 
985 struct LeakyReluParams {
986   float alpha;
987   int32_t input_offset;
988   int32_t output_offset;
989   int32_t output_multiplier_alpha;
990   int32_t output_shift_alpha;
991   int32_t output_multiplier_identity;
992   int32_t output_shift_identity;
993 };
994 
995 template <typename P>
SetActivationParams(float min,float max,P * params)996 inline void SetActivationParams(float min, float max, P* params) {
997   params->float_activation_min = min;
998   params->float_activation_max = max;
999 }
1000 
1001 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1002 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1003   params->quantized_activation_min = min;
1004   params->quantized_activation_max = max;
1005 }
1006 
1007 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1008 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1009   params->int64_activation_min = min;
1010   params->int64_activation_max = max;
1011 }
1012 
1013 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1014 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1015   *min = params.quantized_activation_min;
1016   *max = params.quantized_activation_max;
1017 }
1018 
1019 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1020 inline void GetActivationParams(const P& params, float* min, float* max) {
1021   *min = params.float_activation_min;
1022   *max = params.float_activation_max;
1023 }
1024 
1025 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1026 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1027   *min = params.int64_activation_min;
1028   *max = params.int64_activation_max;
1029 }
1030 
1031 // Type trait to check of given type has size smaller than 4 bytes.
1032 template <typename T>
1033 struct is_small_integer
1034     : public std::integral_constant<bool,
1035                                     std::is_same<T, int8_t>::value ||
1036                                         std::is_same<T, uint8_t>::value ||
1037                                         std::is_same<T, int16_t>::value ||
1038                                         std::is_same<T, uint16_t>::value> {};
1039 
1040 // Type trait to check of given type is int32 or int64.
1041 template <typename T>
1042 struct is_int32_or_int64
1043     : public std::integral_constant<bool, std::is_same<T, int32_t>::value ||
1044                                               std::is_same<T, int64_t>::value> {
1045 };
1046 
1047 }  // namespace tflite
1048 
1049 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1050