• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22 
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 
25 namespace tflite {
26 
27 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
28 enum class PaddingType : uint8 { kNone, kSame, kValid };
29 
30 struct PaddingValues {
31   int16 width;
32   int16 height;
33   // offset is used for calculating "remaining" padding, for example, `width`
34   // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
35   // 1 + 1 = 2.
36   int16 width_offset;
37   // Same as width_offset except it's over the height dimension.
38   int16 height_offset;
39 };
40 
41 // This enumeration allows for non-default formats for the weights array
42 // of a fully-connected operator, allowing the use of special optimized
43 // runtime paths.
44 enum class FullyConnectedWeightsFormat : uint8 {
45   // Default format (flat 2D layout, the inner contiguous dimension
46   // is input_depth, the outer non-contiguous dimension is output_depth)
47   kDefault,
48   // Summary: optimized layout for fast CPU runtime implementation,
49   // aimed specifically at ARM CPUs at the moment, and specialized for
50   // 8-bit quantized layers.
51   //
52   // The use case we're concerned with here is: 8-bit quantization,
53   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
54   // a key application that drove this), very small batch size (e.g. 1 -- 4).
55   //
56   // Even with 8-bit quantization of weights, the performance of memory
57   // accesses to the weights can become the dominant issue when
58   // the batch size is small, so each weight value is used in only a few
59   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
60   // intensity. The specific issues that arise are of three kinds:
61   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
62   //     bound. That's the "good" issue to run into.
63   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
64   //     prefetched into the cache by the time we need it.
65   // (3) One may run into cache aliasing: multiple values that are
66   //     pre-fetched, alias each other in the L1 cache (which typically
67   //     has only 4-way set associativity in ARM CPUs) and thus evict
68   //     each other before we get to using them.
69   //
70   // The point of this shuffling is to avoid issues (2) and (3) so that
71   // we get as fast as possible given only the hard constraint (1).
72   // This is achieved by turning the difficulty into a solution: the
73   // difficulty, that each value loaded from memory is used only in
74   // one kernel iteration, making this operation memory-intensive, hints at
75   // the solution, of shuffling the weights so that they are stored in the
76   // exact order as the kernel needs to load them, so that the memory
77   // accesses made by the kernel are trivial. This solves (2) because the
78   // trivial memory access pattern allows the CPU's automatic prefetching
79   // to perform very well (no need even for preload instructions), and this
80   // solves (3) because the values being loaded concurrently are now
81   // contiguous in the address space, thus don't alias each other in the cache.
82   //
83   // On ARM, we typically want our kernel to process a 4x16 block of weights
84   // at a time, because:
85   //   - 16 is the number of bytes in a NEON register.
86   //   - 4 is how many rows we need to handle concurrently in the kernel in
87   //     order to have sufficient mutual independence of instructions to
88   //     maximize arithmetic throughput.
89   //
90   // Finally, the 'Int8' part in the name refers to the fact that this
91   // weights format has each weights value encoded as a signed int8 value,
92   // even if the data type of the weights buffer is uint8.  This is intended
93   // to save runtime kernels the effort to have to XOR the top bit of these
94   // bytes before using them in signed arithmetic, see this file for more
95   // explanations on the 'signed int8 trick' in matrix multiplication kernels:
96   //
97   //   tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
98   //
99   kShuffled4x16Int8,
100 };
101 
102 // Quantization parameters, determining the mapping of quantized values
103 // to real values (i.e. determining how quantized values are mathematically
104 // interpreted).
105 //
106 // The correspondence is as follows:
107 //
108 //   real_value = scale * (quantized_value - zero_point);
109 //
110 // In other words, zero_point designates which quantized value corresponds to
111 // the real 0 value, and scale designates the difference between the real values
112 // corresponding to consecutive quantized values differing by 1.
113 struct QuantizationParams {
114   int32 zero_point = 0;
115   double scale = 0.0;
116 };
117 
118 inline bool operator==(const QuantizationParams& qp1,
119                        const QuantizationParams& qp2) {
120   return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
121 }
122 
123 template <int N>
124 struct Dims {
125   int sizes[N];
126   int strides[N];
127 };
128 
129 class RuntimeShape {
130  public:
131   // Shapes with dimensions up to 4 are stored directly in the structure, while
132   // larger shapes are separately allocated.
133   static constexpr int kMaxSmallSize = 4;
134 
135   RuntimeShape& operator=(RuntimeShape const&) = delete;
136 
RuntimeShape()137   RuntimeShape() : size_(0) {}
138 
RuntimeShape(int dimensions_count)139   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
140     if (dimensions_count > kMaxSmallSize) {
141 #ifdef TF_LITE_STATIC_MEMORY
142       TFLITE_CHECK(false && "No shape resizing supported on this platform");
143 #else   // TF_LITE_STATIC_MEMORY
144       dims_pointer_ = new int32[dimensions_count];
145 #endif  // TF_LITE_STATIC_MEMORY
146     }
147   }
148 
RuntimeShape(int shape_size,int32 value)149   RuntimeShape(int shape_size, int32 value) : size_(0) {
150     Resize(shape_size);
151     for (int i = 0; i < shape_size; ++i) {
152       SetDim(i, value);
153     }
154   }
155 
RuntimeShape(int dimensions_count,const int32 * dims_data)156   RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
157     ReplaceWith(dimensions_count, dims_data);
158   }
159 
RuntimeShape(const std::initializer_list<int> init_list)160   RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
161     BuildFrom(init_list);
162   }
163 
164   // Avoid using this constructor.  We should be able to delete it when C++17
165   // rolls out.
RuntimeShape(RuntimeShape const & other)166   RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
167     if (size_ > kMaxSmallSize) {
168       dims_pointer_ = new int32[size_];
169     }
170     std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
171   }
172 
173   bool operator==(const RuntimeShape& comp) const {
174     return this->size_ == comp.size_ &&
175            std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
176   }
177 
~RuntimeShape()178   ~RuntimeShape() {
179     if (size_ > kMaxSmallSize) {
180 #ifdef TF_LITE_STATIC_MEMORY
181       TFLITE_CHECK(false && "No shape resizing supported on this platform");
182 #else   // TF_LITE_STATIC_MEMORY
183       delete[] dims_pointer_;
184 #endif  // TF_LITE_STATIC_MEMORY
185     }
186   }
187 
DimensionsCount()188   inline int32 DimensionsCount() const { return size_; }
Dims(int i)189   inline int32 Dims(int i) const {
190     TFLITE_DCHECK_GE(i, 0);
191     TFLITE_DCHECK_LT(i, size_);
192     return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
193   }
SetDim(int i,int32 val)194   inline void SetDim(int i, int32 val) {
195     TFLITE_DCHECK_GE(i, 0);
196     TFLITE_DCHECK_LT(i, size_);
197     if (size_ > kMaxSmallSize) {
198       dims_pointer_[i] = val;
199     } else {
200       dims_[i] = val;
201     }
202   }
203 
DimsData()204   inline int32* DimsData() {
205     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
206   }
DimsData()207   inline const int32* DimsData() const {
208     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
209   }
210   // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()211   inline const int32* DimsDataUpTo4D() const { return dims_; }
212 
Resize(int dimensions_count)213   inline void Resize(int dimensions_count) {
214     if (size_ > kMaxSmallSize) {
215 #ifdef TF_LITE_STATIC_MEMORY
216       TFLITE_CHECK(false && "No shape resizing supported on this platform");
217 #else   // TF_LITE_STATIC_MEMORY
218       delete[] dims_pointer_;
219 #endif  // TF_LITE_STATIC_MEMORY
220     }
221     size_ = dimensions_count;
222     if (dimensions_count > kMaxSmallSize) {
223 #ifdef TF_LITE_STATIC_MEMORY
224       TFLITE_CHECK(false && "No shape resizing supported on this platform");
225 #else   // TF_LITE_STATIC_MEMORY
226       dims_pointer_ = new int32[dimensions_count];
227 #endif  // TF_LITE_STATIC_MEMORY
228     }
229   }
230 
ReplaceWith(int dimensions_count,const int32 * dims_data)231   inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
232     Resize(dimensions_count);
233     int32* dst_dims = DimsData();
234     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
235   }
236 
237   template <typename T>
BuildFrom(const T & src_iterable)238   inline void BuildFrom(const T& src_iterable) {
239     const int dimensions_count =
240         std::distance(src_iterable.begin(), src_iterable.end());
241     Resize(dimensions_count);
242     int32* data = DimsData();
243     for (auto it : src_iterable) {
244       *data = it;
245       ++data;
246     }
247   }
248 
249   // This will probably be factored out. Old code made substantial use of 4-D
250   // shapes, and so this function is used to extend smaller shapes. Note that
251   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
252   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
253   // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)254   inline static RuntimeShape ExtendedShape(int new_shape_size,
255                                            const RuntimeShape& shape) {
256     return RuntimeShape(new_shape_size, shape, 1);
257   }
258 
BuildFrom(const std::initializer_list<int> init_list)259   inline void BuildFrom(const std::initializer_list<int> init_list) {
260     BuildFrom<const std::initializer_list<int>>(init_list);
261   }
262 
263   // Returns the total count of elements, that is the size when flattened into a
264   // vector.
FlatSize()265   inline int FlatSize() const {
266     int buffer_size = 1;
267     const int* dims_data = reinterpret_cast<const int*>(DimsData());
268     for (int i = 0; i < size_; i++) {
269       buffer_size *= dims_data[i];
270     }
271     return buffer_size;
272   }
273 
274   bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
275 
276  private:
277   // For use only by ExtendedShape(), written to guarantee (return-value) copy
278   // elision in C++17.
279   // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)280   RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
281       : size_(0) {
282     // If the following check fails, it is likely because a 4D-only kernel is
283     // being used with an array of larger dimension count.
284     TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
285     Resize(new_shape_size);
286     const int size_increase = new_shape_size - shape.DimensionsCount();
287     for (int i = 0; i < size_increase; ++i) {
288       SetDim(i, pad_value);
289     }
290     std::memcpy(DimsData() + size_increase, shape.DimsData(),
291                 sizeof(int32) * shape.DimensionsCount());
292   }
293 
294   int32 size_;
295   union {
296     int32 dims_[kMaxSmallSize];
297     int32* dims_pointer_;
298   };
299 };
300 
301 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)302 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
303   tflite::Dims<4> result;
304   const int dimensions_count = array_shape.DimensionsCount();
305   TFLITE_CHECK_LE(dimensions_count, 4);
306   int cum_prod = 1;
307   for (int i = 0; i < 4; i++) {
308     const int new_dim =
309         (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
310     result.sizes[i] = new_dim;
311     result.strides[i] = cum_prod;
312     cum_prod *= new_dim;
313   }
314   return result;
315 }
316 
317 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)318 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
319   return RuntimeShape(
320       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
321 }
322 
323 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)324 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
325   if (num_dims == 0) {
326     return false;
327   }
328   TFLITE_DCHECK(dims != nullptr);
329   TFLITE_DCHECK(current != nullptr);
330   int carry = 1;
331   for (int idx = num_dims - 1; idx >= 0; --idx) {
332     int current_val = current[idx] + carry;
333     TFLITE_DCHECK_GE(dims[idx], current_val);
334     if (dims[idx] == current_val) {
335       current[idx] = 0;
336     } else {
337       current[idx] = current_val;
338       carry = 0;
339       break;
340     }
341   }
342   return (carry == 0);
343 }
344 
345 // Gets offset of index if reducing on axis. When reducing, the flattened offset
346 // will not change, if the input index changes on the given axis. For example,
347 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
348 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
349 // offset.
350 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)351 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
352                                   const int* index, const int num_axis,
353                                   const int* axis) {
354   if (num_dims == 0) {
355     return 0;
356   }
357   TFLITE_DCHECK(dims != nullptr);
358   TFLITE_DCHECK(index != nullptr);
359   size_t offset = 0;
360   for (int idx = 0; idx < num_dims; ++idx) {
361     // if we need to skip this axis
362     bool is_axis = false;
363     if (axis != nullptr) {
364       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
365         if (idx == axis[axis_idx]) {
366           is_axis = true;
367           break;
368         }
369       }
370     }
371     if (!is_axis) {
372       offset = offset * static_cast<size_t>(dims[idx]) +
373                static_cast<size_t>(index[idx]);
374     }
375   }
376   return offset;
377 }
378 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)379 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
380   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
381   const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D());
382   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
383   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
384   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
385   TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
386   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
387 }
388 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)389 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
390   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
391   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
392   TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
393   TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
394   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
395          i3 * dims.strides[3];
396 }
397 
Offset(const Dims<4> & dims,int * index)398 inline int Offset(const Dims<4>& dims, int* index) {
399   return Offset(dims, index[0], index[1], index[2], index[3]);
400 }
401 
Offset(const RuntimeShape & shape,int * index)402 inline int Offset(const RuntimeShape& shape, int* index) {
403   return Offset(shape, index[0], index[1], index[2], index[3]);
404 }
405 
406 // Get array size, DCHECKing that the dim index is in range.
407 //
408 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
409 // already performs this check.
410 template <int N>
ArraySize(const Dims<N> & array,int index)411 int ArraySize(const Dims<N>& array, int index) {
412   TFLITE_DCHECK(index >= 0 && index < N);
413   return array.sizes[index];
414 }
415 
416 // Get common array size, DCHECKing that they all agree.
417 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)418 int MatchingArraySize(const ArrayType1& array1, int index1,
419                       const ArrayType2& array2, int index2) {
420   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
421   return ArraySize(array1, index1);
422 }
423 
424 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)425 int MatchingArraySize(const ArrayType1& array1, int index1,
426                       const ArrayType2& array2, int index2, Args... args) {
427   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
428   return MatchingArraySize(array1, index1, args...);
429 }
430 
431 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)432 inline int MatchingDim(const RuntimeShape& shape1, int index1,
433                        const RuntimeShape& shape2, int index2) {
434   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
435   return shape1.Dims(index1);
436 }
437 
438 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)439 int MatchingDim(const RuntimeShape& shape1, int index1,
440                 const RuntimeShape& shape2, int index2, Args... args) {
441   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
442   return MatchingDim(shape1, index1, args...);
443 }
444 
445 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
446 template <int N>
FlatSize(const Dims<N> & dims)447 inline int FlatSize(const Dims<N>& dims) {
448   int flat_size = 1;
449   for (int i = 0; i < N; ++i) {
450     flat_size *= dims.sizes[i];
451   }
452   return flat_size;
453 }
454 
455 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)456 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
457   return FlatSize(dims);
458 }
459 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)460 inline int MatchingElementsSize(const RuntimeShape& shape,
461                                 const RuntimeShape& check_shape_0) {
462   const int size_1 = shape.FlatSize();
463   const int size_2 = check_shape_0.FlatSize();
464   TFLITE_CHECK_EQ(size_1, size_2);
465   return size_1;
466 }
467 
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)468 inline int MatchingElementsSize(const RuntimeShape& shape,
469                                 const RuntimeShape& check_shape_0,
470                                 const RuntimeShape& check_shape_1) {
471   const int size_1 = shape.FlatSize();
472   const int size_2 = check_shape_0.FlatSize();
473   const int size_3 = check_shape_1.FlatSize();
474   TFLITE_CHECK_EQ(size_1, size_2);
475   TFLITE_CHECK_EQ(size_2, size_3);
476   return size_1;
477 }
478 
479 // Flat size calculation, checking that dimensions match with one or more other
480 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)481 inline int MatchingFlatSize(const RuntimeShape& shape,
482                             const RuntimeShape& check_shape_0) {
483   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
484   const int dims_count = shape.DimensionsCount();
485   for (int i = 0; i < dims_count; ++i) {
486     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
487   }
488   return shape.FlatSize();
489 }
490 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)491 inline int MatchingFlatSize(const RuntimeShape& shape,
492                             const RuntimeShape& check_shape_0,
493                             const RuntimeShape& check_shape_1) {
494   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
495   const int dims_count = shape.DimensionsCount();
496   for (int i = 0; i < dims_count; ++i) {
497     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
498   }
499   return MatchingFlatSize(shape, check_shape_1);
500 }
501 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)502 inline int MatchingFlatSize(const RuntimeShape& shape,
503                             const RuntimeShape& check_shape_0,
504                             const RuntimeShape& check_shape_1,
505                             const RuntimeShape& check_shape_2) {
506   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
507   const int dims_count = shape.DimensionsCount();
508   for (int i = 0; i < dims_count; ++i) {
509     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
510   }
511   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
512 }
513 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)514 inline int MatchingFlatSize(const RuntimeShape& shape,
515                             const RuntimeShape& check_shape_0,
516                             const RuntimeShape& check_shape_1,
517                             const RuntimeShape& check_shape_2,
518                             const RuntimeShape& check_shape_3) {
519   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
520   const int dims_count = shape.DimensionsCount();
521   for (int i = 0; i < dims_count; ++i) {
522     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
523   }
524   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
525 }
526 
527 // Flat size calculation, checking that dimensions match with one or more other
528 // arrays.
529 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)530 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
531   for (int i = 0; i < N; ++i) {
532     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
533   }
534   return FlatSize(dims);
535 }
536 
537 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)538 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
539                             const Dims<N>& check_dims_1) {
540   for (int i = 0; i < N; ++i) {
541     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
542   }
543   return MatchingFlatSize(dims, check_dims_1);
544 }
545 
546 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)547 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
548                             const Dims<N>& check_dims_1,
549                             const Dims<N>& check_dims_2) {
550   for (int i = 0; i < N; ++i) {
551     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
552   }
553   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
554 }
555 
556 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)557 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
558                             const Dims<N>& check_dims_1,
559                             const Dims<N>& check_dims_2,
560                             const Dims<N>& check_dims_3) {
561   for (int i = 0; i < N; ++i) {
562     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
563   }
564   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
565 }
566 
567 // Data is required to be contiguous, and so many operators can use either the
568 // full array flat size or the flat size with one dimension skipped (commonly
569 // the depth).
570 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)571 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
572   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
573   int flat_size = 1;
574   for (int i = 0; i < N; ++i) {
575     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
576   }
577   return flat_size;
578 }
579 
580 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
581 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)582 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
583                                    const Dims<N>& check_dims_0) {
584   for (int i = 0; i < N; ++i) {
585     if (i != skip_dim) {
586       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
587     }
588   }
589   return FlatSizeSkipDim(dims, skip_dim);
590 }
591 
592 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)593 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
594                                    const Dims<N>& check_dims_0,
595                                    const Dims<N>& check_dims_1) {
596   for (int i = 0; i < N; ++i) {
597     if (i != skip_dim) {
598       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
599     }
600   }
601   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
602 }
603 
604 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)605 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
606                                    const Dims<N>& check_dims_0,
607                                    const Dims<N>& check_dims_1,
608                                    const Dims<N>& check_dims_2) {
609   for (int i = 0; i < N; ++i) {
610     if (i != skip_dim) {
611       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
612     }
613   }
614   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
615 }
616 
617 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)618 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
619                                    const Dims<N>& check_dims_0,
620                                    const Dims<N>& check_dims_1,
621                                    const Dims<N>& check_dims_2,
622                                    const Dims<N>& check_dims_3) {
623   for (int i = 0; i < N; ++i) {
624     if (i != skip_dim) {
625       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
626     }
627   }
628   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
629                                  check_dims_3);
630 }
631 
632 // Data is required to be contiguous, and so many operators can use either the
633 // full array flat size or the flat size with one dimension skipped (commonly
634 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)635 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
636   const int dims_count = shape.DimensionsCount();
637   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
638   const auto* dims_data = shape.DimsData();
639   int flat_size = 1;
640   for (int i = 0; i < dims_count; ++i) {
641     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
642   }
643   return flat_size;
644 }
645 
646 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)647 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
648                                    const RuntimeShape& check_shape_0) {
649   const int dims_count = shape.DimensionsCount();
650   for (int i = 0; i < dims_count; ++i) {
651     if (i != skip_dim) {
652       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
653     }
654   }
655   return FlatSizeSkipDim(shape, skip_dim);
656 }
657 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)658 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
659                                    const RuntimeShape& check_shape_0,
660                                    const RuntimeShape& check_shape_1) {
661   const int dims_count = shape.DimensionsCount();
662   for (int i = 0; i < dims_count; ++i) {
663     if (i != skip_dim) {
664       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
665     }
666   }
667   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
668 }
669 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)670 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
671                                    const RuntimeShape& check_shape_0,
672                                    const RuntimeShape& check_shape_1,
673                                    const RuntimeShape& check_shape_2) {
674   const int dims_count = shape.DimensionsCount();
675   for (int i = 0; i < dims_count; ++i) {
676     if (i != skip_dim) {
677       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
678     }
679   }
680   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
681 }
682 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)683 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
684                                    const RuntimeShape& check_shape_0,
685                                    const RuntimeShape& check_shape_1,
686                                    const RuntimeShape& check_shape_2,
687                                    const RuntimeShape& check_shape_3) {
688   const int dims_count = shape.DimensionsCount();
689   for (int i = 0; i < dims_count; ++i) {
690     if (i != skip_dim) {
691       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
692     }
693   }
694   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
695                                  check_shape_3);
696 }
697 
698 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)699 bool IsPackedWithoutStrides(const Dims<N>& dims) {
700   int expected_stride = 1;
701   for (int d = 0; d < N; d++) {
702     if (dims.strides[d] != expected_stride) return false;
703     expected_stride *= dims.sizes[d];
704   }
705   return true;
706 }
707 
708 template <int N>
ComputeStrides(Dims<N> * dims)709 void ComputeStrides(Dims<N>* dims) {
710   dims->strides[0] = 1;
711   for (int d = 1; d < N; d++) {
712     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
713   }
714 }
715 
716 enum class BroadcastableOpCategory : uint8 {
717   kNone,
718   kNonBroadcast,               // Matching input shapes.
719   kFirstInputBroadcastsFast,   // Fivefold nested loops.
720   kSecondInputBroadcastsFast,  // Fivefold nested loops.
721   kGenericBroadcast,           // Fall-back.
722 };
723 
724 struct MinMax {
725   float min;
726   float max;
727 };
728 static_assert(sizeof(MinMax) == 8, "");
729 
730 struct ActivationParams {
731   FusedActivationFunctionType activation_type;
732   // uint8, etc, activation params.
733   int32 quantized_activation_min;
734   int32 quantized_activation_max;
735 };
736 
737 struct ReluParams : public ActivationParams {
738   int32 input_offset;
739   int32 output_offset;
740   int32 output_multiplier;
741   int32 output_shift;
742 };
743 
744 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
745 // op for pattern-specific optimization.
746 enum class ResizingCategory : uint8 {
747   kNone,
748   kImageStyle,  // 4D, operating on inner dimensions, say {0, a, b, 0}.
749   kGenericResize,
750 };
751 
752 // For Add, Sub, Mul ops.
753 struct ArithmeticParams {
754   // Shape dependent / common to data / op types.
755   BroadcastableOpCategory broadcast_category;
756   // uint8 inference params.
757   int32 input1_offset;
758   int32 input2_offset;
759   int32 output_offset;
760   int32 output_multiplier;
761   int output_shift;
762   // Add / Sub, not Mul, uint8 inference params.
763   int left_shift;
764   int32 input1_multiplier;
765   int input1_shift;
766   int32 input2_multiplier;
767   int input2_shift;
768   // uint8, etc, activation params.
769   int32 quantized_activation_min;
770   int32 quantized_activation_max;
771   // float activation params.
772   float float_activation_min;
773   float float_activation_max;
774 
775   // Processed output dimensions.
776   // Let input "a" be the one that broadcasts in the faster-changing dimension.
777   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
778   // {b0, b1, b2, b3, b4},
779   // broadcast_shape[4] = b0 = a0.
780   // broadcast_shape[3] = b1; a1 = 1.
781   // broadcast_shape[2] = b2 = a2.
782   // broadcast_shape[1] = a3; b3 = 1.
783   // broadcast_shape[0] = b4 = a4.
784   int broadcast_shape[5];
785 };
786 
787 struct ConcatenationParams {
788   int8 axis;
789   const int32* input_zeropoint;
790   const float* input_scale;
791   uint16 inputs_count;
792   int32 output_zeropoint;
793   float output_scale;
794 };
795 
796 struct ComparisonParams {
797   // uint8 inference params.
798   int left_shift;
799   int32 input1_offset;
800   int32 input1_multiplier;
801   int input1_shift;
802   int32 input2_offset;
803   int32 input2_multiplier;
804   int input2_shift;
805   // Shape dependent / common to inference types.
806   bool is_broadcast;
807 };
808 
809 struct ConvParams {
810   PaddingType padding_type;
811   PaddingValues padding_values;
812   // TODO(starka): This was just "stride", so check that width+height is OK.
813   int16 stride_width;
814   int16 stride_height;
815   int16 dilation_width_factor;
816   int16 dilation_height_factor;
817   // uint8 inference params.
818   // TODO(b/65838351): Use smaller types if appropriate.
819   int32 input_offset;
820   int32 weights_offset;
821   int32 output_offset;
822   int32 output_multiplier;
823   int output_shift;
824   // uint8, etc, activation params.
825   int32 quantized_activation_min;
826   int32 quantized_activation_max;
827   // float activation params.
828   float float_activation_min;
829   float float_activation_max;
830 };
831 
832 struct DepthToSpaceParams {
833   int32 block_size;
834 };
835 
836 struct DepthwiseParams {
837   PaddingType padding_type;
838   PaddingValues padding_values;
839   int16 stride_width;
840   int16 stride_height;
841   int16 dilation_width_factor;
842   int16 dilation_height_factor;
843   int16 depth_multiplier;
844   // uint8 inference params.
845   // TODO(b/65838351): Use smaller types if appropriate.
846   int32 input_offset;
847   int32 weights_offset;
848   int32 output_offset;
849   int32 output_multiplier;
850   int output_shift;
851   // uint8, etc, activation params.
852   int32 quantized_activation_min;
853   int32 quantized_activation_max;
854   // float activation params.
855   float float_activation_min;
856   float float_activation_max;
857   const int32* output_multiplier_per_channel;
858   const int32* output_shift_per_channel;
859 };
860 
861 struct DequantizationParams {
862   double scale;
863   int32 zero_point;
864 };
865 
866 struct FakeQuantParams {
867   MinMax minmax;
868   int32 num_bits;
869 };
870 
871 struct FullyConnectedParams {
872   // uint8 inference params.
873   // TODO(b/65838351): Use smaller types if appropriate.
874   int32 input_offset;
875   int32 weights_offset;
876   int32 output_offset;
877   int32 output_multiplier;
878   int output_shift;
879   // uint8, etc, activation params.
880   int32 quantized_activation_min;
881   int32 quantized_activation_max;
882   // float activation params.
883   float float_activation_min;
884   float float_activation_max;
885   // Mark the operands as cacheable if they are unchanging, e.g. weights.
886   bool lhs_cacheable;
887   bool rhs_cacheable;
888   FullyConnectedWeightsFormat weights_format;
889 };
890 
891 struct GatherParams {
892   int16 axis;
893 };
894 
895 struct L2NormalizationParams {
896   // uint8 inference params.
897   int32 input_zero_point;
898 };
899 
900 struct LocalResponseNormalizationParams {
901   int32 range;
902   double bias;
903   double alpha;
904   double beta;
905 };
906 
907 struct HardSwishParams {
908   // zero_point of the input activations.
909   int16_t input_zero_point;
910   // zero_point of the output activations.
911   int16_t output_zero_point;
912   // 16bit fixed-point component of the multiplier to apply to go from the
913   // "high-res input scale", which is the input scale multiplied by 2^7, to the
914   // "relu-ish scale", which 3.0/32768.
915   // See the implementation of HardSwishPrepare.
916   int16_t reluish_multiplier_fixedpoint_int16;
917   // exponent/bit-shift component of the aforementioned multiplier.
918   int reluish_multiplier_exponent;
919   // 16bit fixed-point component of the multiplier to apply to go from the
920   // "high-res input scale", which is the input scale multiplied by 2^7, to the
921   // output scale.
922   // See the implementation of HardSwishPrepare.
923   int16_t output_multiplier_fixedpoint_int16;
924   // exponent/bit-shift component of the aforementioned multiplier.
925   int output_multiplier_exponent;
926 };
927 
928 struct LogisticParams {
929   // uint8 inference params.
930   int32 input_zero_point;
931   int32 input_range_radius;
932   int32 input_multiplier;
933   int input_left_shift;
934 };
935 
936 struct LstmCellParams {
937   int32 weights_zero_point;
938   int32 accum_multiplier;
939   int accum_shift;
940   int state_integer_bits;
941 };
942 
943 struct MeanParams {
944   int8 axis_count;
945   int16 axis[4];
946 };
947 
948 struct PackParams {
949   int8 axis;
950   const int32* input_zeropoint;
951   const float* input_scale;
952   uint16 inputs_count;
953   int32 output_zeropoint;
954   float output_scale;
955 };
956 
957 struct PadParams {
958   int8 left_padding_count;
959   int32 left_padding[4];
960   int8 right_padding_count;
961   int32 right_padding[4];
962   ResizingCategory resizing_category;
963 };
964 
965 struct PreluParams {
966   int32 input_offset;
967   int32 alpha_offset;
968   int32 output_offset;
969   int32 output_multiplier;
970   int output_shift;
971 };
972 
973 struct PoolParams {
974   FusedActivationFunctionType activation;
975   PaddingType padding_type;
976   PaddingValues padding_values;
977   int stride_height;
978   int stride_width;
979   int filter_height;
980   int filter_width;
981   // uint8, etc, activation params.
982   int32 quantized_activation_min;
983   int32 quantized_activation_max;
984   // float activation params.
985   float float_activation_min;
986   float float_activation_max;
987 };
988 
989 struct ReshapeParams {
990   int8 shape_count;
991   int32 shape[4];
992 };
993 
994 struct ResizeBilinearParams {
995   bool align_corners;
996   // half_pixel_centers assumes pixels are of half the actual dimensions, and
997   // yields more accurate resizes. Corresponds to the same argument for the
998   // original TensorFlow op in TF2.0.
999   bool half_pixel_centers;
1000 };
1001 
1002 struct ResizeNearestNeighborParams {
1003   bool align_corners;
1004 };
1005 
1006 struct SliceParams {
1007   int8 begin_count;
1008   int32 begin[4];
1009   int8 size_count;
1010   int32 size[4];
1011 };
1012 
1013 struct SoftmaxParams {
1014   // beta is not really used (not a Tensorflow parameter) and not implemented
1015   // for LogSoftmax.
1016   double beta;
1017   // uint8 inference params.  Used even when beta defaults to 1.0.
1018   int32 input_multiplier;
1019   int32 input_left_shift;
1020   // Reverse scaling is only used by LogSoftmax.
1021   int32 reverse_scaling_divisor;
1022   int32 reverse_scaling_right_shift;
1023   int diff_min;
1024   int32_t zero_point;
1025   float scale;
1026   float* table;
1027 };
1028 
1029 struct SpaceToBatchParams {
1030   // "Zero" padding for uint8 means padding with the output offset.
1031   int32 output_offset;
1032 };
1033 
1034 struct SpaceToDepthParams {
1035   int32 block_size;
1036 };
1037 
1038 struct SplitParams {
1039   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
1040   // OperatorEdges are of type uint16.
1041   uint16 num_split;
1042   int16 axis;
1043 };
1044 
1045 struct SqueezeParams {
1046   int8 squeeze_dims_count;
1047   int32 squeeze_dims[4];
1048 };
1049 
1050 struct StridedSliceParams {
1051   int8 start_indices_count;
1052   int32 start_indices[4];
1053   int8 stop_indices_count;
1054   int32 stop_indices[4];
1055   int8 strides_count;
1056   int32 strides[4];
1057 
1058   int16 begin_mask;
1059   int16 ellipsis_mask;
1060   int16 end_mask;
1061   int16 new_axis_mask;
1062   int16 shrink_axis_mask;
1063 };
1064 
1065 struct TanhParams {
1066   int32 input_zero_point;
1067   int32 input_range_radius;
1068   int32 input_multiplier;
1069   int input_left_shift;
1070 };
1071 
1072 struct TransposeParams {
1073   int8 perm_count;
1074   int32 perm[4];
1075 };
1076 
1077 struct UnpackParams {
1078   uint16 num_split;
1079   int16 axis;
1080 };
1081 
1082 struct LeakyReluParams {
1083   float alpha;
1084   int32 input_offset;
1085   int32 alpha_offset;
1086   int32 output_offset;
1087   int32 output_multiplier;
1088   int output_shift;
1089 };
1090 
1091 template <typename P>
SetActivationParams(float min,float max,P * params)1092 inline void SetActivationParams(float min, float max, P* params) {
1093   params->float_activation_min = min;
1094   params->float_activation_max = max;
1095 }
1096 
1097 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)1098 inline void SetActivationParams(int32 min, int32 max, P* params) {
1099   params->quantized_activation_min = min;
1100   params->quantized_activation_max = max;
1101 }
1102 
1103 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)1104 inline void GetActivationParams(const P& params, int32* min, int32* max) {
1105   *min = params.quantized_activation_min;
1106   *max = params.quantized_activation_max;
1107 }
1108 
1109 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1110 inline void GetActivationParams(const P& params, float* min, float* max) {
1111   *min = params.float_activation_min;
1112   *max = params.float_activation_max;
1113 }
1114 
1115 }  // namespace tflite
1116 
1117 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1118