• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
17 
18 #include <cstring>
19 #include <iterator>
20 
21 // TODO: Remove once AOSP has external/absl setup.
22 #if __ANDROID__
23 #define ABSL_DEPRECATED(x)
24 #else
25 #include "absl/base/macros.h"
26 #endif  // __ANDROID__
27 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
28 
29 namespace tflite {
30 
31 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
32 enum class PaddingType : uint8 { kNone, kSame, kValid };
33 
34 struct PaddingValues {
35   int16 width;
36   int16 height;
37 };
38 
39 // This enumeration allows for non-default formats for the weights array
40 // of a fully-connected operator, allowing the use of special optimized
41 // runtime paths.
42 enum class FullyConnectedWeightsFormat : uint8 {
43   // Default format (flat 2D layout, the inner contiguous dimension
44   // is input_depth, the outer non-contiguous dimension is output_depth)
45   kDefault,
46   // Summary: optimized layout for fast CPU runtime implementation,
47   // aimed specifically at ARM CPUs at the moment, and specialized for
48   // 8-bit quantized layers.
49   //
50   // The use case we're concerned with here is: 8-bit quantization,
51   // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
52   // a key application that drove this), very small batch size (e.g. 1 -- 4).
53   //
54   // Even with 8-bit quantization of weights, the performance of memory
55   // accesses to the weights can become the dominant issue when
56   // the batch size is small, so each weight value is used in only a few
57   // arithmetic ops, i.e. the fully-connected node has a low arithmetic
58   // intensity. The specific issues that arise are of three kinds:
59   // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
60   //     bound. That's the "good" issue to run into.
61   // (2) One may run into sub-optimal pre-fetching: the data hasn't been
62   //     prefetched into the cache by the time we need it.
63   // (3) One may run into cache aliasing: multiple values that are
64   //     pre-fetched, alias each other in the L1 cache (which typically
65   //     has only 4-way set associativity in ARM CPUs) and thus evict
66   //     each other before we get to using them.
67   //
68   // The point of this shuffling is to avoid issues (2) and (3) so that
69   // we get as fast as possible given only the hard constraint (1).
70   // This is achieved by turning the difficulty into a solution: the
71   // difficulty, that each value loaded from memory is used only in
72   // one kernel iteration, making this operation memory-intensive, hints at
73   // the solution, of shuffling the weights so that they are stored in the
74   // exact order as the kernel needs to load them, so that the memory
75   // accesses made by the kernel are trivial. This solves (2) because the
76   // trivial memory access pattern allows the CPU's automatic prefetching
77   // to perform very well (no need even for preload instructions), and this
78   // solves (3) because the values being loaded concurrently are now
79   // contiguous in the address space, thus don't alias each other in the cache.
80   //
81   // On ARM, we typically want our kernel to process a 4x16 block of weights
82   // at a time, because:
83   //   - 16 is the number of bytes in a NEON register.
84   //   - 4 is how many rows we need to handle concurrently in the kernel in
85   //     order to have sufficient mutual independence of instructions to
86   //     maximize arithmetic throughput.
87   //
88   // Finally, the 'Int8' part in the name refers to the fact that this
89   // weights format has each weights value encoded as a signed int8 value,
90   // even if the data type of the weights buffer is uint8.  This is intended
91   // to save runtime kernels the effort to have to XOR the top bit of these
92   // bytes before using them in signed arithmetic, see this file for more
93   // explanations on the 'signed int8 trick' in matrix multiplication kernels:
94   //
95   //   tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
96   //
97   kShuffled4x16Int8,
98 };
99 
100 // Quantization parameters, determining the mapping of quantized values
101 // to real values (i.e. determining how quantized values are mathematically
102 // interpreted).
103 //
104 // The correspondence is as follows:
105 //
106 //   real_value = scale * (quantized_value - zero_point);
107 //
108 // In other words, zero_point designates which quantized value corresponds to
109 // the real 0 value, and scale designates the difference between the real values
110 // corresponding to consecutive quantized values differing by 1.
111 struct QuantizationParams {
112   int32 zero_point = 0;
113   double scale = 0.0;
114 };
115 
116 template <int N>
117 struct Dims {
118   int sizes[N];
119   int strides[N];
120 };
121 
122 class RuntimeShape {
123  public:
124   // Shapes with dimensions up to 4 are stored directly in the structure, while
125   // larger shapes are separately allocated.
126   static constexpr int kMaxSmallSize = 4;
127 
128   RuntimeShape& operator=(RuntimeShape const&) = delete;
129 
RuntimeShape()130   RuntimeShape() : size_(0) {}
131 
RuntimeShape(int dimensions_count)132   explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
133     if (dimensions_count > kMaxSmallSize) {
134       dims_pointer_ = new int32[dimensions_count];
135     }
136   }
137 
RuntimeShape(int shape_size,int32 value)138   RuntimeShape(int shape_size, int32 value) : size_(0) {
139     Resize(shape_size);
140     for (int i = 0; i < shape_size; ++i) {
141       SetDim(i, value);
142     }
143   }
144 
RuntimeShape(int dimensions_count,const int32 * dims_data)145   RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
146     ReplaceWith(dimensions_count, dims_data);
147   }
148 
RuntimeShape(const std::initializer_list<int> init_list)149   RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
150     BuildFrom(init_list);
151   }
152 
153   // Avoid using this constructor.  We should be able to delete it when C++17
154   // rolls out.
RuntimeShape(RuntimeShape const & other)155   RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
156     if (size_ > kMaxSmallSize) {
157       dims_pointer_ = new int32[size_];
158     }
159     std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
160   }
161 
162   bool operator==(const RuntimeShape& comp) const {
163     return this->size_ == comp.size_ &&
164            std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
165   }
166 
~RuntimeShape()167   ~RuntimeShape() {
168     if (size_ > kMaxSmallSize) {
169       delete[] dims_pointer_;
170     }
171   }
172 
DimensionsCount()173   inline int32 DimensionsCount() const { return size_; }
Dims(int i)174   inline int32 Dims(int i) const {
175     TFLITE_DCHECK_GE(i, 0);
176     TFLITE_DCHECK_LT(i, size_);
177     return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
178   }
SetDim(int i,int32 val)179   inline void SetDim(int i, int32 val) {
180     TFLITE_DCHECK_GE(i, 0);
181     TFLITE_DCHECK_LT(i, size_);
182     if (size_ > kMaxSmallSize) {
183       dims_pointer_[i] = val;
184     } else {
185       dims_[i] = val;
186     }
187   }
188 
DimsData()189   inline int32* DimsData() {
190     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
191   }
DimsData()192   inline const int32* DimsData() const {
193     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
194   }
195   // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()196   inline const int32* DimsDataUpTo4D() const { return dims_; }
197 
Resize(int dimensions_count)198   inline void Resize(int dimensions_count) {
199     if (size_ > kMaxSmallSize) {
200       delete[] dims_pointer_;
201     }
202     size_ = dimensions_count;
203     if (dimensions_count > kMaxSmallSize) {
204       dims_pointer_ = new int32[dimensions_count];
205     }
206   }
207 
ReplaceWith(int dimensions_count,const int32 * dims_data)208   inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
209     Resize(dimensions_count);
210     int32* dst_dims = DimsData();
211     std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
212   }
213 
214   template <typename T>
BuildFrom(const T & src_iterable)215   inline void BuildFrom(const T& src_iterable) {
216     const int dimensions_count =
217         std::distance(src_iterable.begin(), src_iterable.end());
218     Resize(dimensions_count);
219     int32* data = DimsData();
220     for (auto it : src_iterable) {
221       *data = it;
222       ++data;
223     }
224   }
225 
226   // This will probably be factored out. Old code made substantial use of 4-D
227   // shapes, and so this function is used to extend smaller shapes. Note that
228   // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
229   // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
230   // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)231   inline static RuntimeShape ExtendedShape(int new_shape_size,
232                                            const RuntimeShape& shape) {
233     return RuntimeShape(new_shape_size, shape, 1);
234   }
235 
BuildFrom(const std::initializer_list<int> init_list)236   inline void BuildFrom(const std::initializer_list<int> init_list) {
237     BuildFrom<const std::initializer_list<int>>(init_list);
238   }
239 
240   // Returns the total count of elements, that is the size when flattened into a
241   // vector.
FlatSize()242   inline int FlatSize() const {
243     int buffer_size = 1;
244     const int* dims_data = DimsData();
245     for (int i = 0; i < size_; i++) {
246       const int dim = dims_data[i];
247       TFLITE_DCHECK_GE(dim, 1);
248       buffer_size *= dim;
249     }
250     return buffer_size;
251   }
252 
253   bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
254 
255  private:
256   // For use only by ExtendedShape(), written to guarantee (return-value) copy
257   // elision in C++17.
258   // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)259   RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
260       : size_(0) {
261     TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
262     TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
263     Resize(new_shape_size);
264     const int size_increase = new_shape_size - shape.DimensionsCount();
265     for (int i = 0; i < size_increase; ++i) {
266       SetDim(i, pad_value);
267     }
268     std::memcpy(DimsData() + size_increase, shape.DimsData(),
269                 sizeof(int32) * shape.DimensionsCount());
270   }
271 
272   int32 size_;
273   union {
274     int32 dims_[kMaxSmallSize];
275     int32* dims_pointer_;
276   };
277 };
278 
279 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)280 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
281   tflite::Dims<4> result;
282   const int dimensions_count = array_shape.DimensionsCount();
283   TFLITE_CHECK_LE(dimensions_count, 4);
284   int cum_prod = 1;
285   for (int i = 0; i < 4; i++) {
286     const int new_dim =
287         (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
288     result.sizes[i] = new_dim;
289     result.strides[i] = cum_prod;
290     cum_prod *= new_dim;
291   }
292   return result;
293 }
294 
295 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)296 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
297   return RuntimeShape(
298       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
299 }
300 
301 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)302 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
303   if (num_dims == 0) {
304     return false;
305   }
306   TFLITE_DCHECK(dims != nullptr);
307   TFLITE_DCHECK(current != nullptr);
308   int carry = 1;
309   for (int idx = num_dims - 1; idx >= 0; --idx) {
310     int current_val = current[idx] + carry;
311     TFLITE_DCHECK_GE(dims[idx], current_val);
312     if (dims[idx] == current_val) {
313       current[idx] = 0;
314     } else {
315       current[idx] = current_val;
316       carry = 0;
317       break;
318     }
319   }
320   return (carry == 0);
321 }
322 
323 // Gets offset of index if reducing on axis. When reducing, the flattened offset
324 // will not change, if the input index changes on the given axis. For example,
325 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
326 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
327 // offset.
328 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)329 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
330                                   const int* index, const int num_axis,
331                                   const int* axis) {
332   if (num_dims == 0) {
333     return 0;
334   }
335   TFLITE_DCHECK(dims != nullptr);
336   TFLITE_DCHECK(index != nullptr);
337   size_t offset = 0;
338   for (int idx = 0; idx < num_dims; ++idx) {
339     // if we need to skip this axis
340     bool is_axis = false;
341     if (axis != nullptr) {
342       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
343         if (idx == axis[axis_idx]) {
344           is_axis = true;
345           break;
346         }
347       }
348     }
349     if (!is_axis) {
350       offset = offset * static_cast<size_t>(dims[idx]) +
351                static_cast<size_t>(index[idx]);
352     }
353   }
354   return offset;
355 }
356 
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)357 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
358   TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
359   const int* dims_data = shape.DimsDataUpTo4D();
360   TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
361   TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
362   TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
363   TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
364   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
365 }
366 
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)367 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
368   TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
369   TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
370   TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
371   TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
372   return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
373          i3 * dims.strides[3];
374 }
375 
Offset(const Dims<4> & dims,int * index)376 inline int Offset(const Dims<4>& dims, int* index) {
377   return Offset(dims, index[0], index[1], index[2], index[3]);
378 }
379 
Offset(const RuntimeShape & shape,int * index)380 inline int Offset(const RuntimeShape& shape, int* index) {
381   return Offset(shape, index[0], index[1], index[2], index[3]);
382 }
383 
384 // Get array size, DCHECKing that the dim index is in range.
385 //
386 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
387 // already performs this check.
388 template <int N>
ArraySize(const Dims<N> & array,int index)389 int ArraySize(const Dims<N>& array, int index) {
390   TFLITE_DCHECK(index >= 0 && index < N);
391   return array.sizes[index];
392 }
393 
394 // Get common array size, DCHECKing that they all agree.
395 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)396 int MatchingArraySize(const ArrayType1& array1, int index1,
397                       const ArrayType2& array2, int index2) {
398   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
399   return ArraySize(array1, index1);
400 }
401 
402 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)403 int MatchingArraySize(const ArrayType1& array1, int index1,
404                       const ArrayType2& array2, int index2, Args... args) {
405   TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
406   return MatchingArraySize(array1, index1, args...);
407 }
408 
409 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)410 inline int MatchingDim(const RuntimeShape& shape1, int index1,
411                        const RuntimeShape& shape2, int index2) {
412   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
413   return shape1.Dims(index1);
414 }
415 
416 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)417 int MatchingDim(const RuntimeShape& shape1, int index1,
418                 const RuntimeShape& shape2, int index2, Args... args) {
419   TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
420   return MatchingDim(shape1, index1, args...);
421 }
422 
423 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
424 template <int N>
FlatSize(const Dims<N> & dims)425 inline int FlatSize(const Dims<N>& dims) {
426   int flat_size = 1;
427   for (int i = 0; i < N; ++i) {
428     flat_size *= dims.sizes[i];
429   }
430   return flat_size;
431 }
432 
433 ABSL_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)434 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
435   return FlatSize(dims);
436 }
437 
438 // Flat size calculation, checking that dimensions match with one or more other
439 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)440 inline int MatchingFlatSize(const RuntimeShape& shape,
441                             const RuntimeShape& check_shape_0) {
442   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
443   const int dims_count = shape.DimensionsCount();
444   for (int i = 0; i < dims_count; ++i) {
445     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
446   }
447   return shape.FlatSize();
448 }
449 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)450 inline int MatchingFlatSize(const RuntimeShape& shape,
451                             const RuntimeShape& check_shape_0,
452                             const RuntimeShape& check_shape_1) {
453   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
454   const int dims_count = shape.DimensionsCount();
455   for (int i = 0; i < dims_count; ++i) {
456     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
457   }
458   return MatchingFlatSize(shape, check_shape_1);
459 }
460 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)461 inline int MatchingFlatSize(const RuntimeShape& shape,
462                             const RuntimeShape& check_shape_0,
463                             const RuntimeShape& check_shape_1,
464                             const RuntimeShape& check_shape_2) {
465   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
466   const int dims_count = shape.DimensionsCount();
467   for (int i = 0; i < dims_count; ++i) {
468     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
469   }
470   return MatchingFlatSize(shape, check_shape_1, check_shape_2);
471 }
472 
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)473 inline int MatchingFlatSize(const RuntimeShape& shape,
474                             const RuntimeShape& check_shape_0,
475                             const RuntimeShape& check_shape_1,
476                             const RuntimeShape& check_shape_2,
477                             const RuntimeShape& check_shape_3) {
478   TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
479   const int dims_count = shape.DimensionsCount();
480   for (int i = 0; i < dims_count; ++i) {
481     TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
482   }
483   return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
484 }
485 
486 // Flat size calculation, checking that dimensions match with one or more other
487 // arrays.
488 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)489 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
490   for (int i = 0; i < N; ++i) {
491     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
492   }
493   return FlatSize(dims);
494 }
495 
496 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)497 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
498                             const Dims<N>& check_dims_1) {
499   for (int i = 0; i < N; ++i) {
500     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
501   }
502   return MatchingFlatSize(dims, check_dims_1);
503 }
504 
505 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)506 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
507                             const Dims<N>& check_dims_1,
508                             const Dims<N>& check_dims_2) {
509   for (int i = 0; i < N; ++i) {
510     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
511   }
512   return MatchingFlatSize(dims, check_dims_1, check_dims_2);
513 }
514 
515 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)516 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
517                             const Dims<N>& check_dims_1,
518                             const Dims<N>& check_dims_2,
519                             const Dims<N>& check_dims_3) {
520   for (int i = 0; i < N; ++i) {
521     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
522   }
523   return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
524 }
525 
526 // Data is required to be contiguous, and so many operators can use either the
527 // full array flat size or the flat size with one dimension skipped (commonly
528 // the depth).
529 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)530 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
531   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
532   int flat_size = 1;
533   for (int i = 0; i < N; ++i) {
534     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
535   }
536   return flat_size;
537 }
538 
539 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
540 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)541 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
542                                    const Dims<N>& check_dims_0) {
543   for (int i = 0; i < N; ++i) {
544     if (i != skip_dim) {
545       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
546     }
547   }
548   return FlatSizeSkipDim(dims, skip_dim);
549 }
550 
551 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)552 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
553                                    const Dims<N>& check_dims_0,
554                                    const Dims<N>& check_dims_1) {
555   for (int i = 0; i < N; ++i) {
556     if (i != skip_dim) {
557       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
558     }
559   }
560   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
561 }
562 
563 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)564 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
565                                    const Dims<N>& check_dims_0,
566                                    const Dims<N>& check_dims_1,
567                                    const Dims<N>& check_dims_2) {
568   for (int i = 0; i < N; ++i) {
569     if (i != skip_dim) {
570       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
571     }
572   }
573   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
574 }
575 
576 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)577 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
578                                    const Dims<N>& check_dims_0,
579                                    const Dims<N>& check_dims_1,
580                                    const Dims<N>& check_dims_2,
581                                    const Dims<N>& check_dims_3) {
582   for (int i = 0; i < N; ++i) {
583     if (i != skip_dim) {
584       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
585     }
586   }
587   return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
588                                  check_dims_3);
589 }
590 
591 // Data is required to be contiguous, and so many operators can use either the
592 // full array flat size or the flat size with one dimension skipped (commonly
593 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)594 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
595   const int dims_count = shape.DimensionsCount();
596   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
597   const auto* dims_data = shape.DimsData();
598   int flat_size = 1;
599   for (int i = 0; i < dims_count; ++i) {
600     flat_size *= (i == skip_dim) ? 1 : dims_data[i];
601   }
602   return flat_size;
603 }
604 
605 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)606 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
607                                    const RuntimeShape& check_shape_0) {
608   const int dims_count = shape.DimensionsCount();
609   for (int i = 0; i < dims_count; ++i) {
610     if (i != skip_dim) {
611       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
612     }
613   }
614   return FlatSizeSkipDim(shape, skip_dim);
615 }
616 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)617 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
618                                    const RuntimeShape& check_shape_0,
619                                    const RuntimeShape& check_shape_1) {
620   const int dims_count = shape.DimensionsCount();
621   for (int i = 0; i < dims_count; ++i) {
622     if (i != skip_dim) {
623       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
624     }
625   }
626   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
627 }
628 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)629 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
630                                    const RuntimeShape& check_shape_0,
631                                    const RuntimeShape& check_shape_1,
632                                    const RuntimeShape& check_shape_2) {
633   const int dims_count = shape.DimensionsCount();
634   for (int i = 0; i < dims_count; ++i) {
635     if (i != skip_dim) {
636       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
637     }
638   }
639   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
640 }
641 
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)642 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
643                                    const RuntimeShape& check_shape_0,
644                                    const RuntimeShape& check_shape_1,
645                                    const RuntimeShape& check_shape_2,
646                                    const RuntimeShape& check_shape_3) {
647   const int dims_count = shape.DimensionsCount();
648   for (int i = 0; i < dims_count; ++i) {
649     if (i != skip_dim) {
650       TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
651     }
652   }
653   return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
654                                  check_shape_3);
655 }
656 
657 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)658 bool IsPackedWithoutStrides(const Dims<N>& dims) {
659   int expected_stride = 1;
660   for (int d = 0; d < N; d++) {
661     if (dims.strides[d] != expected_stride) return false;
662     expected_stride *= dims.sizes[d];
663   }
664   return true;
665 }
666 
667 template <int N>
ComputeStrides(Dims<N> * dims)668 void ComputeStrides(Dims<N>* dims) {
669   dims->strides[0] = 1;
670   for (int d = 1; d < N; d++) {
671     dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
672   }
673 }
674 
675 enum class BroadcastableOpCategory : uint8 {
676   kNone,
677   kNonBroadcast,               // Matching input shapes.
678   kFirstInputBroadcastsFast,   // Fivefold nested loops.
679   kSecondInputBroadcastsFast,  // Fivefold nested loops.
680   kGenericBroadcast,           // Fall-back.
681 };
682 
683 struct MinMax {
684   float min;
685   float max;
686 };
687 static_assert(sizeof(MinMax) == 8, "");
688 
689 struct ActivationParams {
690   FusedActivationFunctionType activation_type;
691   // uint8, etc, activation params.
692   int32 quantized_activation_min;
693   int32 quantized_activation_max;
694 };
695 
696 // For Add, Sub, Mul ops.
697 struct ArithmeticParams {
698   // Shape dependent / common to data / op types.
699   BroadcastableOpCategory broadcast_category;
700   // uint8 inference params.
701   int32 input1_offset;
702   int32 input2_offset;
703   int32 output_offset;
704   int32 output_multiplier;
705   int output_shift;
706   // Add / Sub, not Mul, uint8 inference params.
707   int left_shift;
708   int32 input1_multiplier;
709   int input1_shift;
710   int32 input2_multiplier;
711   int input2_shift;
712   // uint8, etc, activation params.
713   int32 quantized_activation_min;
714   int32 quantized_activation_max;
715   // float activation params.
716   float float_activation_min;
717   float float_activation_max;
718 
719   // Processed output dimensions.
720   // Let input "a" be the one that broadcasts in the faster-changing dimension.
721   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
722   // {b0, b1, b2, b3, b4},
723   // broadcast_shape[4] = b0 = a0.
724   // broadcast_shape[3] = b1; a1 = 1.
725   // broadcast_shape[2] = b2 = a2.
726   // broadcast_shape[1] = a3; b3 = 1.
727   // broadcast_shape[0] = b4 = a4.
728   int broadcast_shape[5];
729 };
730 
731 struct ConcatenationParams {
732   int8 axis;
733   const int32* input_zeropoint;
734   const float* input_scale;
735   uint16 inputs_count;
736   int32 output_zeropoint;
737   float output_scale;
738 };
739 
740 struct ComparisonParams {
741   // uint8 inference params.
742   int left_shift;
743   int32 input1_offset;
744   int32 input1_multiplier;
745   int input1_shift;
746   int32 input2_offset;
747   int32 input2_multiplier;
748   int input2_shift;
749   // Shape dependent / common to inference types.
750   bool is_broadcast;
751 };
752 
753 struct ConvParams {
754   PaddingType padding_type;
755   PaddingValues padding_values;
756   // TODO(starka): This was just "stride", so check that width+height is OK.
757   int16 stride_width;
758   int16 stride_height;
759   int16 dilation_width_factor;
760   int16 dilation_height_factor;
761   // uint8 inference params.
762   // TODO(b/65838351): Use smaller types if appropriate.
763   int32 input_offset;
764   int32 weights_offset;
765   int32 output_offset;
766   int32 output_multiplier;
767   int output_shift;
768   // uint8, etc, activation params.
769   int32 quantized_activation_min;
770   int32 quantized_activation_max;
771   // float activation params.
772   float float_activation_min;
773   float float_activation_max;
774 };
775 
776 struct DepthToSpaceParams {
777   int32 block_size;
778 };
779 
780 struct DepthwiseParams {
781   PaddingType padding_type;
782   PaddingValues padding_values;
783   int16 stride_width;
784   int16 stride_height;
785   int16 dilation_width_factor;
786   int16 dilation_height_factor;
787   int16 depth_multiplier;
788   // uint8 inference params.
789   // TODO(b/65838351): Use smaller types if appropriate.
790   int32 input_offset;
791   int32 weights_offset;
792   int32 output_offset;
793   int32 output_multiplier;
794   int output_shift;
795   // uint8, etc, activation params.
796   int32 quantized_activation_min;
797   int32 quantized_activation_max;
798   // float activation params.
799   float float_activation_min;
800   float float_activation_max;
801 };
802 
803 struct DequantizationParams {
804   double scale;
805   int32 zero_point;
806 };
807 
808 struct FakeQuantParams {
809   MinMax minmax;
810   int32 num_bits;
811 };
812 
813 struct FullyConnectedParams {
814   // uint8 inference params.
815   // TODO(b/65838351): Use smaller types if appropriate.
816   int32 input_offset;
817   int32 weights_offset;
818   int32 output_offset;
819   int32 output_multiplier;
820   int output_shift;
821   // uint8, etc, activation params.
822   int32 quantized_activation_min;
823   int32 quantized_activation_max;
824   // float activation params.
825   float float_activation_min;
826   float float_activation_max;
827   FullyConnectedWeightsFormat weights_format;
828 };
829 
830 struct GatherParams {
831   int16 input_rank;
832   int16 axis;
833 };
834 
835 struct L2NormalizationParams {
836   // uint8 inference params.
837   int32 input_zero_point;
838 };
839 
840 struct LocalResponseNormalizationParams {
841   int32 range;
842   double bias;
843   double alpha;
844   double beta;
845 };
846 
847 struct LogisticParams {
848   // uint8 inference params.
849   int32 input_zero_point;
850   int32 input_range_radius;
851   int32 input_multiplier;
852   int input_left_shift;
853 };
854 
855 struct LstmCellParams {
856   int32 weights_zero_point;
857   int32 accum_multiplier;
858   int accum_shift;
859   int state_integer_bits;
860 };
861 
862 struct MeanParams {
863   int8 axis_count;
864   int16 axis[4];
865 };
866 
867 struct PadParams {
868   int8 left_padding_count;
869   int32 left_padding[4];
870   int8 right_padding_count;
871   int32 right_padding[4];
872 };
873 
874 struct PoolParams {
875   FusedActivationFunctionType activation;
876   PaddingType padding_type;
877   PaddingValues padding_values;
878   int stride_height;
879   int stride_width;
880   int filter_height;
881   int filter_width;
882   // uint8, etc, activation params.
883   int32 quantized_activation_min;
884   int32 quantized_activation_max;
885   // float activation params.
886   float float_activation_min;
887   float float_activation_max;
888 };
889 
890 struct ReshapeParams {
891   int8 shape_count;
892   int32 shape[4];
893 };
894 
895 struct ResizeBilinearParams {
896   bool align_corners;
897 };
898 
899 struct SliceParams {
900   int8 begin_count;
901   int32 begin[4];
902   int8 size_count;
903   int32 size[4];
904 };
905 
906 struct SoftmaxParams {
907   // beta is not really used (not a Tensorflow parameter) and not implemented
908   // for LogSoftmax.
909   double beta;
910   // uint8 inference params.  Used even when beta defaults to 1.0.
911   int32 input_multiplier;
912   int32 input_left_shift;
913   // Reverse scaling is only used by LogSoftmax.
914   int32 reverse_scaling_divisor;
915   int32 reverse_scaling_right_shift;
916   int diff_min;
917 };
918 
919 struct SpaceToBatchParams {
920   // "Zero" padding for uint8 means padding with the output offset.
921   int32 output_offset;
922 };
923 
924 struct SpaceToDepthParams {
925   int32 block_size;
926 };
927 
928 struct SplitParams {
929   // Graphs that split into, say, 2000 nodes are encountered.  The indices in
930   // OperatorEdges are of type uint16.
931   uint16 num_split;
932   int16 axis;
933 };
934 
935 struct SqueezeParams {
936   int8 squeeze_dims_count;
937   int32 squeeze_dims[4];
938 };
939 
940 struct StridedSliceParams {
941   int8 start_indices_count;
942   int16 start_indices[4];
943   int8 stop_indices_count;
944   int16 stop_indices[4];
945   int8 strides_count;
946   int16 strides[4];
947 
948   int16 begin_mask;
949   int16 ellipsis_mask;
950   int16 end_mask;
951   int16 new_axis_mask;
952   int16 shrink_axis_mask;
953 };
954 
955 struct TanhParams {
956   int32 input_zero_point;
957   int32 input_range_radius;
958   int32 input_multiplier;
959   int input_left_shift;
960 };
961 
962 struct TransposeParams {
963   int8 perm_count;
964   int32 perm[4];
965 };
966 
967 template <typename P>
SetActivationParams(float min,float max,P * params)968 inline void SetActivationParams(float min, float max, P* params) {
969   params->float_activation_min = min;
970   params->float_activation_max = max;
971 }
972 
973 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)974 inline void SetActivationParams(int32 min, int32 max, P* params) {
975   params->quantized_activation_min = min;
976   params->quantized_activation_max = max;
977 }
978 
979 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)980 inline void GetActivationParams(const P& params, int32* min, int32* max) {
981   *min = params.quantized_activation_min;
982   *max = params.quantized_activation_max;
983 }
984 
985 template <typename P>
GetActivationParams(const P & params,float * min,float * max)986 inline void GetActivationParams(const P& params, float* min, float* max) {
987   *min = params.float_activation_min;
988   *max = params.float_activation_max;
989 }
990 
991 }  // namespace tflite
992 
993 #endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
994