• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <algorithm>
19 #include <cstring>
20 
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22 
23 namespace tflite {
24 
25 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
26 enum class PaddingType : uint8 { kNone, kSame, kValid };
27 
28 struct PaddingValues {
29   int16 width;
30   int16 height;
31 };
32 
33 // This enumeration allows for non-default formats for the weights array
34 // of a fully-connected operator, allowing the use of special optimized
35 // runtime paths.
36 enum class FullyConnectedWeightsFormat : uint8 {
37   // Default format (flat 2D layout, the inner contiguous dimension
38   // is input_depth, the outer non-contiguous dimension is output_depth)
39   kDefault,
40   // Summary: optimized layout for fast CPU runtime implementation,
41   // aimed specifically at ARM CPUs at the moment, and specialized for
42   // 8-bit quantized layers.
43   //
44   // The use case we're concerned with here is: 8-bit quantization,
45   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
46   // a key application that drove this), very small batch size (e.g. 1 -- 4).
47   //
48   // Even with 8-bit quantization of weights, the performance of memory
49   // accesses to the weights can become the dominant issue when
50   // the batch size is small, so each weight value is used in only a few
51   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
52   // intensity. The specific issues that arise are of three kinds:
53   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
54   //     bound. That's the "good" issue to run into.
55   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
56   //     prefetched into the cache by the time we need it.
57   // (3) One may run into cache aliasing: multiple values that are
58   //     pre-fetched, alias each other in the L1 cache (which typically
59   //     has only 4-way set associativity in ARM CPUs) and thus evict
60   //     each other before we get to using them.
61   //
62   // The point of this shuffling is to avoid issues (2) and (3) so that
63   // we get as fast as possible given only the hard constraint (1).
64   // This is achieved by turning the difficulty into a solution: the
65   // difficulty, that each value loaded from memory is used only in
66   // one kernel iteration, making this operation memory-intensive, hints at
67   // the solution, of shuffling the weights so that they are stored in the
68   // exact order as the kernel needs to load them, so that the memory
69   // accesses made by the kernel are trivial. This solves (2) because the
70   // trivial memory access pattern allows the CPU's automatic prefetching
71   // to perform very well (no need even for preload instructions), and this
72   // solves (3) because the values being loaded concurrently are now
73   // contiguous in the address space, thus don't alias each other in the cache.
74   //
75   // On ARM, we typically want our kernel to process a 4x16 block of weights
76   // at a time, because:
77   //   - 16 is the number of bytes in a NEON register.
78   //   - 4 is how many rows we need to handle concurrently in the kernel in
79   //     order to have sufficient mutual independence of instructions to
80   //     maximize arithmetic throughput.
81   //
82   // Finally, the 'Int8' part in the name refers to the fact that this
83   // weights format has each weights value encoded as a signed int8 value,
84   // even if the data type of the weights buffer is uint8.  This is intended
85   // to save runtime kernels the effort to have to XOR the top bit of these
86   // bytes before using them in signed arithmetic, see this file for more
87   // explanations on the 'signed int8 trick' in matrix multiplication kernels:
88   //
89   //   tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
90   //
91   kShuffled4x16Int8,
92 };
93 
94 // Quantization parameters, determining the mapping of quantized values
95 // to real values (i.e. determining how quantized values are mathematically
96 // interpreted).
97 //
98 // The correspondence is as follows:
99 //
100 //   real_value = scale * (quantized_value - zero_point);
101 //
102 // In other words, zero_point designates which quantized value corresponds to
103 // the real 0 value, and scale designates the difference between the real values
104 // corresponding to consecutive quantized values differing by 1.
105 struct QuantizationParams {
106   int32 zero_point = 0;
107   double scale = 0.0;
108 };
109 
110 inline bool operator==(const QuantizationParams& qp1,
111                        const QuantizationParams& qp2) {
112   return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
113 }
114 
115 template <int N>
116 struct Dims {
117   int sizes[N];
118   int strides[N];
119 };
120 
121 class RuntimeShape {
122  public:
123   // Shapes with dimensions up to 4 are stored directly in the structure, while
124   // larger shapes are separately allocated.
125   static constexpr int kMaxSmallSize = 4;
126 
127   RuntimeShape& operator=(RuntimeShape const&) = delete;
128 
RuntimeShape()129   RuntimeShape() : size_(0) {}
130 
RuntimeShape(int dimensions_count)131   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
132     if (dimensions_count > kMaxSmallSize) {
133 #ifdef TF_LITE_STATIC_MEMORY
134       TFLITE_CHECK(false && "No shape resizing supported on this platform");
135 #else   // TF_LITE_STATIC_MEMORY
136       dims_pointer_ = new int32[dimensions_count];
137 #endif  // TF_LITE_STATIC_MEMORY
138     }
139   }
140 
RuntimeShape(int shape_size,int32 value)141   RuntimeShape(int shape_size, int32 value) : size_(0) {
142     Resize(shape_size);
143     for (int i = 0; i < shape_size; ++i) {
144       SetDim(i, value);
145     }
146   }
147 
RuntimeShape(int dimensions_count,const int32 * dims_data)148   RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
149     ReplaceWith(dimensions_count, dims_data);
150   }
151 
RuntimeShape(const std::initializer_list<int> init_list)152   RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
153     BuildFrom(init_list);
154   }
155 
156   // Avoid using this constructor.  We should be able to delete it when C++17
157   // rolls out.
RuntimeShape(RuntimeShape const & other)158   RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
159     if (size_ > kMaxSmallSize) {
160       dims_pointer_ = new int32[size_];
161     }
162     std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
163   }
164 
165   bool operator==(const RuntimeShape& comp) const {
166     return this->size_ == comp.size_ &&
167            std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
168   }
169 
~RuntimeShape()170   ~RuntimeShape() {
171     if (size_ > kMaxSmallSize) {
172 #ifdef TF_LITE_STATIC_MEMORY
173       TFLITE_CHECK(false && "No shape resizing supported on this platform");
174 #else   // TF_LITE_STATIC_MEMORY
175       delete[] dims_pointer_;
176 #endif  // TF_LITE_STATIC_MEMORY
177     }
178   }
179 
DimensionsCount()180   inline int32 DimensionsCount() const { return size_; }
Dims(int i)181   inline int32 Dims(int i) const {
182     TFLITE_DCHECK_GE(i, 0);
183     TFLITE_DCHECK_LT(i, size_);
184     return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
185   }
SetDim(int i,int32 val)186   inline void SetDim(int i, int32 val) {
187     TFLITE_DCHECK_GE(i, 0);
188     TFLITE_DCHECK_LT(i, size_);
189     if (size_ > kMaxSmallSize) {
190       dims_pointer_[i] = val;
191     } else {
192       dims_[i] = val;
193     }
194   }
195 
DimsData()196   inline int32* DimsData() {
197     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
198   }
DimsData()199   inline const int32* DimsData() const {
200     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
201   }
202   // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()203   inline const int32* DimsDataUpTo4D() const { return dims_; }
204 
Resize(int dimensions_count)205   inline void Resize(int dimensions_count) {
206     if (size_ > kMaxSmallSize) {
207 #ifdef TF_LITE_STATIC_MEMORY
208       TFLITE_CHECK(false && "No shape resizing supported on this platform");
209 #else   // TF_LITE_STATIC_MEMORY
210       delete[] dims_pointer_;
211 #endif  // TF_LITE_STATIC_MEMORY
212     }
213     size_ = dimensions_count;
214     if (dimensions_count > kMaxSmallSize) {
215 #ifdef TF_LITE_STATIC_MEMORY
216       TFLITE_CHECK(false && "No shape resizing supported on this platform");
217 #else   // TF_LITE_STATIC_MEMORY
218       dims_pointer_ = new int32[dimensions_count];
219 #endif  // TF_LITE_STATIC_MEMORY
220     }
221   }
222 
ReplaceWith(int dimensions_count,const int32 * dims_data)223   inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
224     Resize(dimensions_count);
225     int32* dst_dims = DimsData();
226     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
227   }
228 
229   template <typename T>
BuildFrom(const T & src_iterable)230   inline void BuildFrom(const T& src_iterable) {
231     const int dimensions_count =
232         std::distance(src_iterable.begin(), src_iterable.end());
233     Resize(dimensions_count);
234     int32* data = DimsData();
235     for (auto it : src_iterable) {
236       *data = it;
237       ++data;
238     }
239   }
240 
241   // This will probably be factored out. Old code made substantial use of 4-D
242   // shapes, and so this function is used to extend smaller shapes. Note that
243   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
244   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
245   // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)246   inline static RuntimeShape ExtendedShape(int new_shape_size,
247                                            const RuntimeShape& shape) {
248     return RuntimeShape(new_shape_size, shape, 1);
249   }
250 
BuildFrom(const std::initializer_list<int> init_list)251   inline void BuildFrom(const std::initializer_list<int> init_list) {
252     BuildFrom<const std::initializer_list<int>>(init_list);
253   }
254 
255   // Returns the total count of elements, that is the size when flattened into a
256   // vector.
FlatSize()257   inline int FlatSize() const {
258     int buffer_size = 1;
259     const int* dims_data = DimsData();
260     for (int i = 0; i < size_; i++) {
261       const int dim = dims_data[i];
262       TFLITE_DCHECK_GE(dim, 1);
263       buffer_size *= dim;
264     }
265     return buffer_size;
266   }
267 
268   bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
269 
270  private:
271   // For use only by ExtendedShape(), written to guarantee (return-value) copy
272   // elision in C++17.
273   // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)274   RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
275       : size_(0) {
276     // If the following check fails, it is likely because a 4D-only kernel is
277     // being used with an array of larger dimension count.
278     TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
279     Resize(new_shape_size);
280     const int size_increase = new_shape_size - shape.DimensionsCount();
281     for (int i = 0; i < size_increase; ++i) {
282       SetDim(i, pad_value);
283     }
284     std::memcpy(DimsData() + size_increase, shape.DimsData(),
285                 sizeof(int32) * shape.DimensionsCount());
286   }
287 
288   int32 size_;
289   union {
290     int32 dims_[kMaxSmallSize];
291     int32* dims_pointer_;
292   };
293 };
294 
295 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)296 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
297   tflite::Dims<4> result;
298   const int dimensions_count = array_shape.DimensionsCount();
299   TFLITE_CHECK_LE(dimensions_count, 4);
300   int cum_prod = 1;
301   for (int i = 0; i < 4; i++) {
302     const int new_dim =
303         (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
304     result.sizes[i] = new_dim;
305     result.strides[i] = cum_prod;
306     cum_prod *= new_dim;
307   }
308   return result;
309 }
310 
311 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)312 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
313   return RuntimeShape(
314       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
315 }
316 
317 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)318 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
319   if (num_dims == 0) {
320     return false;
321   }
322   TFLITE_DCHECK(dims != nullptr);
323   TFLITE_DCHECK(current != nullptr);
324   int carry = 1;
325   for (int idx = num_dims - 1; idx >= 0; --idx) {
326     int current_val = current[idx] + carry;
327     TFLITE_DCHECK_GE(dims[idx], current_val);
328     if (dims[idx] == current_val) {
329       current[idx] = 0;
330     } else {
331       current[idx] = current_val;
332       carry = 0;
333       break;
334     }
335   }
336   return (carry == 0);
337 }
338 
339 // Gets offset of index if reducing on axis. When reducing, the flattened offset
340 // will not change, if the input index changes on the given axis. For example,
341 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
342 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
343 // offset.
344 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)345 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
346                                   const int* index, const int num_axis,
347                                   const int* axis) {
348   if (num_dims == 0) {
349     return 0;
350   }
351   TFLITE_DCHECK(dims != nullptr);
352   TFLITE_DCHECK(index != nullptr);
353   size_t offset = 0;
354   for (int idx = 0; idx < num_dims; ++idx) {
355     // if we need to skip this axis
356     bool is_axis = false;
357     if (axis != nullptr) {
358       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
359         if (idx == axis[axis_idx]) {
360           is_axis = true;
361           break;
362         }
363       }
364     }
365     if (!is_axis) {
366       offset = offset * static_cast<size_t>(dims[idx]) +
367                static_cast<size_t>(index[idx]);
368     }
369   }
370   return offset;
371 }
372 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)373 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
374   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
375   const int* dims_data = shape.DimsDataUpTo4D();
376   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
377   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
378   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
379   TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
380   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
381 }
382 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)383 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
384   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
385   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
386   TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
387   TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
388   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
389          i3 * dims.strides[3];
390 }
391 
Offset(const Dims<4> & dims,int * index)392 inline int Offset(const Dims<4>& dims, int* index) {
393   return Offset(dims, index[0], index[1], index[2], index[3]);
394 }
395 
Offset(const RuntimeShape & shape,int * index)396 inline int Offset(const RuntimeShape& shape, int* index) {
397   return Offset(shape, index[0], index[1], index[2], index[3]);
398 }
399 
400 // Get array size, DCHECKing that the dim index is in range.
401 //
402 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
403 // already performs this check.
404 template <int N>
ArraySize(const Dims<N> & array,int index)405 int ArraySize(const Dims<N>& array, int index) {
406   TFLITE_DCHECK(index >= 0 && index < N);
407   return array.sizes[index];
408 }
409 
410 // Get common array size, DCHECKing that they all agree.
411 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)412 int MatchingArraySize(const ArrayType1& array1, int index1,
413                       const ArrayType2& array2, int index2) {
414   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
415   return ArraySize(array1, index1);
416 }
417 
418 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)419 int MatchingArraySize(const ArrayType1& array1, int index1,
420                       const ArrayType2& array2, int index2, Args... args) {
421   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
422   return MatchingArraySize(array1, index1, args...);
423 }
424 
425 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)426 inline int MatchingDim(const RuntimeShape& shape1, int index1,
427                        const RuntimeShape& shape2, int index2) {
428   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
429   return shape1.Dims(index1);
430 }
431 
432 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)433 int MatchingDim(const RuntimeShape& shape1, int index1,
434                 const RuntimeShape& shape2, int index2, Args... args) {
435   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
436   return MatchingDim(shape1, index1, args...);
437 }
438 
439 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
440 template <int N>
FlatSize(const Dims<N> & dims)441 inline int FlatSize(const Dims<N>& dims) {
442   int flat_size = 1;
443   for (int i = 0; i < N; ++i) {
444     flat_size *= dims.sizes[i];
445   }
446   return flat_size;
447 }
448 
449 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)450 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
451   return FlatSize(dims);
452 }
453 
454 // Flat size calculation, checking that dimensions match with one or more other
455 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)456 inline int MatchingFlatSize(const RuntimeShape& shape,
457                             const RuntimeShape& check_shape_0) {
458   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
459   const int dims_count = shape.DimensionsCount();
460   for (int i = 0; i < dims_count; ++i) {
461     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
462   }
463   return shape.FlatSize();
464 }
465 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)466 inline int MatchingFlatSize(const RuntimeShape& shape,
467                             const RuntimeShape& check_shape_0,
468                             const RuntimeShape& check_shape_1) {
469   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
470   const int dims_count = shape.DimensionsCount();
471   for (int i = 0; i < dims_count; ++i) {
472     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
473   }
474   return MatchingFlatSize(shape, check_shape_1);
475 }
476 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)477 inline int MatchingFlatSize(const RuntimeShape& shape,
478                             const RuntimeShape& check_shape_0,
479                             const RuntimeShape& check_shape_1,
480                             const RuntimeShape& check_shape_2) {
481   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
482   const int dims_count = shape.DimensionsCount();
483   for (int i = 0; i < dims_count; ++i) {
484     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
485   }
486   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
487 }
488 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)489 inline int MatchingFlatSize(const RuntimeShape& shape,
490                             const RuntimeShape& check_shape_0,
491                             const RuntimeShape& check_shape_1,
492                             const RuntimeShape& check_shape_2,
493                             const RuntimeShape& check_shape_3) {
494   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
495   const int dims_count = shape.DimensionsCount();
496   for (int i = 0; i < dims_count; ++i) {
497     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
498   }
499   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
500 }
501 
502 // Flat size calculation, checking that dimensions match with one or more other
503 // arrays.
504 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)505 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
506   for (int i = 0; i < N; ++i) {
507     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
508   }
509   return FlatSize(dims);
510 }
511 
512 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)513 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
514                             const Dims<N>& check_dims_1) {
515   for (int i = 0; i < N; ++i) {
516     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
517   }
518   return MatchingFlatSize(dims, check_dims_1);
519 }
520 
521 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)522 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
523                             const Dims<N>& check_dims_1,
524                             const Dims<N>& check_dims_2) {
525   for (int i = 0; i < N; ++i) {
526     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
527   }
528   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
529 }
530 
531 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)532 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
533                             const Dims<N>& check_dims_1,
534                             const Dims<N>& check_dims_2,
535                             const Dims<N>& check_dims_3) {
536   for (int i = 0; i < N; ++i) {
537     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
538   }
539   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
540 }
541 
542 // Data is required to be contiguous, and so many operators can use either the
543 // full array flat size or the flat size with one dimension skipped (commonly
544 // the depth).
545 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)546 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
547   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
548   int flat_size = 1;
549   for (int i = 0; i < N; ++i) {
550     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
551   }
552   return flat_size;
553 }
554 
555 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
556 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)557 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
558                                    const Dims<N>& check_dims_0) {
559   for (int i = 0; i < N; ++i) {
560     if (i != skip_dim) {
561       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
562     }
563   }
564   return FlatSizeSkipDim(dims, skip_dim);
565 }
566 
567 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)568 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
569                                    const Dims<N>& check_dims_0,
570                                    const Dims<N>& check_dims_1) {
571   for (int i = 0; i < N; ++i) {
572     if (i != skip_dim) {
573       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
574     }
575   }
576   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
577 }
578 
579 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)580 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
581                                    const Dims<N>& check_dims_0,
582                                    const Dims<N>& check_dims_1,
583                                    const Dims<N>& check_dims_2) {
584   for (int i = 0; i < N; ++i) {
585     if (i != skip_dim) {
586       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
587     }
588   }
589   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
590 }
591 
592 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)593 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
594                                    const Dims<N>& check_dims_0,
595                                    const Dims<N>& check_dims_1,
596                                    const Dims<N>& check_dims_2,
597                                    const Dims<N>& check_dims_3) {
598   for (int i = 0; i < N; ++i) {
599     if (i != skip_dim) {
600       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
601     }
602   }
603   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
604                                  check_dims_3);
605 }
606 
607 // Data is required to be contiguous, and so many operators can use either the
608 // full array flat size or the flat size with one dimension skipped (commonly
609 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)610 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
611   const int dims_count = shape.DimensionsCount();
612   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
613   const auto* dims_data = shape.DimsData();
614   int flat_size = 1;
615   for (int i = 0; i < dims_count; ++i) {
616     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
617   }
618   return flat_size;
619 }
620 
621 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)622 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
623                                    const RuntimeShape& check_shape_0) {
624   const int dims_count = shape.DimensionsCount();
625   for (int i = 0; i < dims_count; ++i) {
626     if (i != skip_dim) {
627       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
628     }
629   }
630   return FlatSizeSkipDim(shape, skip_dim);
631 }
632 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)633 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
634                                    const RuntimeShape& check_shape_0,
635                                    const RuntimeShape& check_shape_1) {
636   const int dims_count = shape.DimensionsCount();
637   for (int i = 0; i < dims_count; ++i) {
638     if (i != skip_dim) {
639       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
640     }
641   }
642   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
643 }
644 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)645 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
646                                    const RuntimeShape& check_shape_0,
647                                    const RuntimeShape& check_shape_1,
648                                    const RuntimeShape& check_shape_2) {
649   const int dims_count = shape.DimensionsCount();
650   for (int i = 0; i < dims_count; ++i) {
651     if (i != skip_dim) {
652       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
653     }
654   }
655   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
656 }
657 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)658 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
659                                    const RuntimeShape& check_shape_0,
660                                    const RuntimeShape& check_shape_1,
661                                    const RuntimeShape& check_shape_2,
662                                    const RuntimeShape& check_shape_3) {
663   const int dims_count = shape.DimensionsCount();
664   for (int i = 0; i < dims_count; ++i) {
665     if (i != skip_dim) {
666       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
667     }
668   }
669   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
670                                  check_shape_3);
671 }
672 
673 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)674 bool IsPackedWithoutStrides(const Dims<N>& dims) {
675   int expected_stride = 1;
676   for (int d = 0; d < N; d++) {
677     if (dims.strides[d] != expected_stride) return false;
678     expected_stride *= dims.sizes[d];
679   }
680   return true;
681 }
682 
683 template <int N>
ComputeStrides(Dims<N> * dims)684 void ComputeStrides(Dims<N>* dims) {
685   dims->strides[0] = 1;
686   for (int d = 1; d < N; d++) {
687     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
688   }
689 }
690 
691 enum class BroadcastableOpCategory : uint8 {
692   kNone,
693   kNonBroadcast,               // Matching input shapes.
694   kFirstInputBroadcastsFast,   // Fivefold nested loops.
695   kSecondInputBroadcastsFast,  // Fivefold nested loops.
696   kGenericBroadcast,           // Fall-back.
697 };
698 
699 struct MinMax {
700   float min;
701   float max;
702 };
703 static_assert(sizeof(MinMax) == 8, "");
704 
705 struct ActivationParams {
706   FusedActivationFunctionType activation_type;
707   // uint8, etc, activation params.
708   int32 quantized_activation_min;
709   int32 quantized_activation_max;
710 };
711 
712 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
713 // op for pattern-specific optimization.
714 enum class ResizingCategory : uint8 {
715   kNone,
716   kImageStyle,  // 4D, operating on inner dimensions, say {0, a, b, 0}.
717   kGenericResize,
718 };
719 
720 // For Add, Sub, Mul ops.
721 struct ArithmeticParams {
722   // Shape dependent / common to data / op types.
723   BroadcastableOpCategory broadcast_category;
724   // uint8 inference params.
725   int32 input1_offset;
726   int32 input2_offset;
727   int32 output_offset;
728   int32 output_multiplier;
729   int output_shift;
730   // Add / Sub, not Mul, uint8 inference params.
731   int left_shift;
732   int32 input1_multiplier;
733   int input1_shift;
734   int32 input2_multiplier;
735   int input2_shift;
736   // uint8, etc, activation params.
737   int32 quantized_activation_min;
738   int32 quantized_activation_max;
739   // float activation params.
740   float float_activation_min;
741   float float_activation_max;
742 
743   // Processed output dimensions.
744   // Let input "a" be the one that broadcasts in the faster-changing dimension.
745   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
746   // {b0, b1, b2, b3, b4},
747   // broadcast_shape[4] = b0 = a0.
748   // broadcast_shape[3] = b1; a1 = 1.
749   // broadcast_shape[2] = b2 = a2.
750   // broadcast_shape[1] = a3; b3 = 1.
751   // broadcast_shape[0] = b4 = a4.
752   int broadcast_shape[5];
753 };
754 
755 struct ConcatenationParams {
756   int8 axis;
757   const int32* input_zeropoint;
758   const float* input_scale;
759   uint16 inputs_count;
760   int32 output_zeropoint;
761   float output_scale;
762 };
763 
764 struct ComparisonParams {
765   // uint8 inference params.
766   int left_shift;
767   int32 input1_offset;
768   int32 input1_multiplier;
769   int input1_shift;
770   int32 input2_offset;
771   int32 input2_multiplier;
772   int input2_shift;
773   // Shape dependent / common to inference types.
774   bool is_broadcast;
775 };
776 
777 struct ConvParams {
778   PaddingType padding_type;
779   PaddingValues padding_values;
780   // TODO(starka): This was just "stride", so check that width+height is OK.
781   int16 stride_width;
782   int16 stride_height;
783   int16 dilation_width_factor;
784   int16 dilation_height_factor;
785   // uint8 inference params.
786   // TODO(b/65838351): Use smaller types if appropriate.
787   int32 input_offset;
788   int32 weights_offset;
789   int32 output_offset;
790   int32 output_multiplier;
791   int output_shift;
792   // uint8, etc, activation params.
793   int32 quantized_activation_min;
794   int32 quantized_activation_max;
795   // float activation params.
796   float float_activation_min;
797   float float_activation_max;
798 };
799 
800 struct DepthToSpaceParams {
801   int32 block_size;
802 };
803 
804 struct DepthwiseParams {
805   PaddingType padding_type;
806   PaddingValues padding_values;
807   int16 stride_width;
808   int16 stride_height;
809   int16 dilation_width_factor;
810   int16 dilation_height_factor;
811   int16 depth_multiplier;
812   // uint8 inference params.
813   // TODO(b/65838351): Use smaller types if appropriate.
814   int32 input_offset;
815   int32 weights_offset;
816   int32 output_offset;
817   int32 output_multiplier;
818   int output_shift;
819   // uint8, etc, activation params.
820   int32 quantized_activation_min;
821   int32 quantized_activation_max;
822   // float activation params.
823   float float_activation_min;
824   float float_activation_max;
825 };
826 
827 struct DequantizationParams {
828   double scale;
829   int32 zero_point;
830 };
831 
832 struct FakeQuantParams {
833   MinMax minmax;
834   int32 num_bits;
835 };
836 
837 struct FullyConnectedParams {
838   // uint8 inference params.
839   // TODO(b/65838351): Use smaller types if appropriate.
840   int32 input_offset;
841   int32 weights_offset;
842   int32 output_offset;
843   int32 output_multiplier;
844   int output_shift;
845   // uint8, etc, activation params.
846   int32 quantized_activation_min;
847   int32 quantized_activation_max;
848   // float activation params.
849   float float_activation_min;
850   float float_activation_max;
851   FullyConnectedWeightsFormat weights_format;
852 };
853 
854 struct GatherParams {
855   int16 axis;
856 };
857 
858 struct L2NormalizationParams {
859   // uint8 inference params.
860   int32 input_zero_point;
861 };
862 
863 struct LocalResponseNormalizationParams {
864   int32 range;
865   double bias;
866   double alpha;
867   double beta;
868 };
869 
870 struct LogisticParams {
871   // uint8 inference params.
872   int32 input_zero_point;
873   int32 input_range_radius;
874   int32 input_multiplier;
875   int input_left_shift;
876 };
877 
878 struct LstmCellParams {
879   int32 weights_zero_point;
880   int32 accum_multiplier;
881   int accum_shift;
882   int state_integer_bits;
883 };
884 
885 struct MeanParams {
886   int8 axis_count;
887   int16 axis[4];
888 };
889 
890 struct PackParams {
891   int8 axis;
892   const int32* input_zeropoint;
893   const float* input_scale;
894   uint16 inputs_count;
895   int32 output_zeropoint;
896   float output_scale;
897 };
898 
899 struct PadParams {
900   int8 left_padding_count;
901   int32 left_padding[4];
902   int8 right_padding_count;
903   int32 right_padding[4];
904   ResizingCategory resizing_category;
905 };
906 
907 struct PreluParams {
908   int32 input_offset;
909   int32 alpha_offset;
910   int32 output_offset;
911   int32 output_multiplier;
912   int output_shift;
913 };
914 
915 struct PoolParams {
916   FusedActivationFunctionType activation;
917   PaddingType padding_type;
918   PaddingValues padding_values;
919   int stride_height;
920   int stride_width;
921   int filter_height;
922   int filter_width;
923   // uint8, etc, activation params.
924   int32 quantized_activation_min;
925   int32 quantized_activation_max;
926   // float activation params.
927   float float_activation_min;
928   float float_activation_max;
929 };
930 
931 struct ReshapeParams {
932   int8 shape_count;
933   int32 shape[4];
934 };
935 
936 struct ResizeBilinearParams {
937   bool align_corners;
938 };
939 
940 struct ResizeNearestNeighborParams {
941   bool align_corners;
942 };
943 
944 struct SliceParams {
945   int8 begin_count;
946   int32 begin[4];
947   int8 size_count;
948   int32 size[4];
949 };
950 
951 struct SoftmaxParams {
952   // beta is not really used (not a Tensorflow parameter) and not implemented
953   // for LogSoftmax.
954   double beta;
955   // uint8 inference params.  Used even when beta defaults to 1.0.
956   int32 input_multiplier;
957   int32 input_left_shift;
958   // Reverse scaling is only used by LogSoftmax.
959   int32 reverse_scaling_divisor;
960   int32 reverse_scaling_right_shift;
961   int diff_min;
962 };
963 
964 struct SpaceToBatchParams {
965   // "Zero" padding for uint8 means padding with the output offset.
966   int32 output_offset;
967 };
968 
969 struct SpaceToDepthParams {
970   int32 block_size;
971 };
972 
973 struct SplitParams {
974   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
975   // OperatorEdges are of type uint16.
976   uint16 num_split;
977   int16 axis;
978 };
979 
980 struct SqueezeParams {
981   int8 squeeze_dims_count;
982   int32 squeeze_dims[4];
983 };
984 
985 struct StridedSliceParams {
986   int8 start_indices_count;
987   int16 start_indices[4];
988   int8 stop_indices_count;
989   int16 stop_indices[4];
990   int8 strides_count;
991   int16 strides[4];
992 
993   int16 begin_mask;
994   int16 ellipsis_mask;
995   int16 end_mask;
996   int16 new_axis_mask;
997   int16 shrink_axis_mask;
998 };
999 
1000 struct TanhParams {
1001   int32 input_zero_point;
1002   int32 input_range_radius;
1003   int32 input_multiplier;
1004   int input_left_shift;
1005 };
1006 
1007 struct TransposeParams {
1008   int8 perm_count;
1009   int32 perm[4];
1010 };
1011 
1012 struct UnpackParams {
1013   uint16 num_split;
1014   int16 axis;
1015 };
1016 
1017 struct LeakyReluParams {
1018   float alpha;
1019 };
1020 
1021 template <typename P>
SetActivationParams(float min,float max,P * params)1022 inline void SetActivationParams(float min, float max, P* params) {
1023   params->float_activation_min = min;
1024   params->float_activation_max = max;
1025 }
1026 
1027 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)1028 inline void SetActivationParams(int32 min, int32 max, P* params) {
1029   params->quantized_activation_min = min;
1030   params->quantized_activation_max = max;
1031 }
1032 
1033 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)1034 inline void GetActivationParams(const P& params, int32* min, int32* max) {
1035   *min = params.quantized_activation_min;
1036   *max = params.quantized_activation_max;
1037 }
1038 
1039 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1040 inline void GetActivationParams(const P& params, float* min, float* max) {
1041   *min = params.float_activation_min;
1042   *max = params.float_activation_max;
1043 }
1044 
1045 }  // namespace tflite
1046 
1047 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1048