1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <algorithm>
19 #include <cstring>
20
21 #include "tensorflow/lite/kernels/internal/compatibility.h"
22
23 namespace tflite {
24
25 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
26 enum class PaddingType : uint8 { kNone, kSame, kValid };
27
28 struct PaddingValues {
29 int16 width;
30 int16 height;
31 };
32
33 // This enumeration allows for non-default formats for the weights array
34 // of a fully-connected operator, allowing the use of special optimized
35 // runtime paths.
36 enum class FullyConnectedWeightsFormat : uint8 {
37 // Default format (flat 2D layout, the inner contiguous dimension
38 // is input_depth, the outer non-contiguous dimension is output_depth)
39 kDefault,
40 // Summary: optimized layout for fast CPU runtime implementation,
41 // aimed specifically at ARM CPUs at the moment, and specialized for
42 // 8-bit quantized layers.
43 //
44 // The use case we're concerned with here is: 8-bit quantization,
45 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
46 // a key application that drove this), very small batch size (e.g. 1 -- 4).
47 //
48 // Even with 8-bit quantization of weights, the performance of memory
49 // accesses to the weights can become the dominant issue when
50 // the batch size is small, so each weight value is used in only a few
51 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
52 // intensity. The specific issues that arise are of three kinds:
53 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
54 // bound. That's the "good" issue to run into.
55 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
56 // prefetched into the cache by the time we need it.
57 // (3) One may run into cache aliasing: multiple values that are
58 // pre-fetched, alias each other in the L1 cache (which typically
59 // has only 4-way set associativity in ARM CPUs) and thus evict
60 // each other before we get to using them.
61 //
62 // The point of this shuffling is to avoid issues (2) and (3) so that
63 // we get as fast as possible given only the hard constraint (1).
64 // This is achieved by turning the difficulty into a solution: the
65 // difficulty, that each value loaded from memory is used only in
66 // one kernel iteration, making this operation memory-intensive, hints at
67 // the solution, of shuffling the weights so that they are stored in the
68 // exact order as the kernel needs to load them, so that the memory
69 // accesses made by the kernel are trivial. This solves (2) because the
70 // trivial memory access pattern allows the CPU's automatic prefetching
71 // to perform very well (no need even for preload instructions), and this
72 // solves (3) because the values being loaded concurrently are now
73 // contiguous in the address space, thus don't alias each other in the cache.
74 //
75 // On ARM, we typically want our kernel to process a 4x16 block of weights
76 // at a time, because:
77 // - 16 is the number of bytes in a NEON register.
78 // - 4 is how many rows we need to handle concurrently in the kernel in
79 // order to have sufficient mutual independence of instructions to
80 // maximize arithmetic throughput.
81 //
82 // Finally, the 'Int8' part in the name refers to the fact that this
83 // weights format has each weights value encoded as a signed int8 value,
84 // even if the data type of the weights buffer is uint8. This is intended
85 // to save runtime kernels the effort to have to XOR the top bit of these
86 // bytes before using them in signed arithmetic, see this file for more
87 // explanations on the 'signed int8 trick' in matrix multiplication kernels:
88 //
89 // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
90 //
91 kShuffled4x16Int8,
92 };
93
94 // Quantization parameters, determining the mapping of quantized values
95 // to real values (i.e. determining how quantized values are mathematically
96 // interpreted).
97 //
98 // The correspondence is as follows:
99 //
100 // real_value = scale * (quantized_value - zero_point);
101 //
102 // In other words, zero_point designates which quantized value corresponds to
103 // the real 0 value, and scale designates the difference between the real values
104 // corresponding to consecutive quantized values differing by 1.
105 struct QuantizationParams {
106 int32 zero_point = 0;
107 double scale = 0.0;
108 };
109
110 inline bool operator==(const QuantizationParams& qp1,
111 const QuantizationParams& qp2) {
112 return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
113 }
114
115 template <int N>
116 struct Dims {
117 int sizes[N];
118 int strides[N];
119 };
120
121 class RuntimeShape {
122 public:
123 // Shapes with dimensions up to 4 are stored directly in the structure, while
124 // larger shapes are separately allocated.
125 static constexpr int kMaxSmallSize = 4;
126
127 RuntimeShape& operator=(RuntimeShape const&) = delete;
128
RuntimeShape()129 RuntimeShape() : size_(0) {}
130
RuntimeShape(int dimensions_count)131 explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
132 if (dimensions_count > kMaxSmallSize) {
133 #ifdef TF_LITE_STATIC_MEMORY
134 TFLITE_CHECK(false && "No shape resizing supported on this platform");
135 #else // TF_LITE_STATIC_MEMORY
136 dims_pointer_ = new int32[dimensions_count];
137 #endif // TF_LITE_STATIC_MEMORY
138 }
139 }
140
RuntimeShape(int shape_size,int32 value)141 RuntimeShape(int shape_size, int32 value) : size_(0) {
142 Resize(shape_size);
143 for (int i = 0; i < shape_size; ++i) {
144 SetDim(i, value);
145 }
146 }
147
RuntimeShape(int dimensions_count,const int32 * dims_data)148 RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
149 ReplaceWith(dimensions_count, dims_data);
150 }
151
RuntimeShape(const std::initializer_list<int> init_list)152 RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
153 BuildFrom(init_list);
154 }
155
156 // Avoid using this constructor. We should be able to delete it when C++17
157 // rolls out.
RuntimeShape(RuntimeShape const & other)158 RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
159 if (size_ > kMaxSmallSize) {
160 dims_pointer_ = new int32[size_];
161 }
162 std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
163 }
164
165 bool operator==(const RuntimeShape& comp) const {
166 return this->size_ == comp.size_ &&
167 std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
168 }
169
~RuntimeShape()170 ~RuntimeShape() {
171 if (size_ > kMaxSmallSize) {
172 #ifdef TF_LITE_STATIC_MEMORY
173 TFLITE_CHECK(false && "No shape resizing supported on this platform");
174 #else // TF_LITE_STATIC_MEMORY
175 delete[] dims_pointer_;
176 #endif // TF_LITE_STATIC_MEMORY
177 }
178 }
179
DimensionsCount()180 inline int32 DimensionsCount() const { return size_; }
Dims(int i)181 inline int32 Dims(int i) const {
182 TFLITE_DCHECK_GE(i, 0);
183 TFLITE_DCHECK_LT(i, size_);
184 return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
185 }
SetDim(int i,int32 val)186 inline void SetDim(int i, int32 val) {
187 TFLITE_DCHECK_GE(i, 0);
188 TFLITE_DCHECK_LT(i, size_);
189 if (size_ > kMaxSmallSize) {
190 dims_pointer_[i] = val;
191 } else {
192 dims_[i] = val;
193 }
194 }
195
DimsData()196 inline int32* DimsData() {
197 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
198 }
DimsData()199 inline const int32* DimsData() const {
200 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
201 }
202 // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()203 inline const int32* DimsDataUpTo4D() const { return dims_; }
204
Resize(int dimensions_count)205 inline void Resize(int dimensions_count) {
206 if (size_ > kMaxSmallSize) {
207 #ifdef TF_LITE_STATIC_MEMORY
208 TFLITE_CHECK(false && "No shape resizing supported on this platform");
209 #else // TF_LITE_STATIC_MEMORY
210 delete[] dims_pointer_;
211 #endif // TF_LITE_STATIC_MEMORY
212 }
213 size_ = dimensions_count;
214 if (dimensions_count > kMaxSmallSize) {
215 #ifdef TF_LITE_STATIC_MEMORY
216 TFLITE_CHECK(false && "No shape resizing supported on this platform");
217 #else // TF_LITE_STATIC_MEMORY
218 dims_pointer_ = new int32[dimensions_count];
219 #endif // TF_LITE_STATIC_MEMORY
220 }
221 }
222
ReplaceWith(int dimensions_count,const int32 * dims_data)223 inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
224 Resize(dimensions_count);
225 int32* dst_dims = DimsData();
226 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
227 }
228
229 template <typename T>
BuildFrom(const T & src_iterable)230 inline void BuildFrom(const T& src_iterable) {
231 const int dimensions_count =
232 std::distance(src_iterable.begin(), src_iterable.end());
233 Resize(dimensions_count);
234 int32* data = DimsData();
235 for (auto it : src_iterable) {
236 *data = it;
237 ++data;
238 }
239 }
240
241 // This will probably be factored out. Old code made substantial use of 4-D
242 // shapes, and so this function is used to extend smaller shapes. Note that
243 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
244 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
245 // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)246 inline static RuntimeShape ExtendedShape(int new_shape_size,
247 const RuntimeShape& shape) {
248 return RuntimeShape(new_shape_size, shape, 1);
249 }
250
BuildFrom(const std::initializer_list<int> init_list)251 inline void BuildFrom(const std::initializer_list<int> init_list) {
252 BuildFrom<const std::initializer_list<int>>(init_list);
253 }
254
255 // Returns the total count of elements, that is the size when flattened into a
256 // vector.
FlatSize()257 inline int FlatSize() const {
258 int buffer_size = 1;
259 const int* dims_data = DimsData();
260 for (int i = 0; i < size_; i++) {
261 const int dim = dims_data[i];
262 TFLITE_DCHECK_GE(dim, 1);
263 buffer_size *= dim;
264 }
265 return buffer_size;
266 }
267
268 bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
269
270 private:
271 // For use only by ExtendedShape(), written to guarantee (return-value) copy
272 // elision in C++17.
273 // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)274 RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
275 : size_(0) {
276 // If the following check fails, it is likely because a 4D-only kernel is
277 // being used with an array of larger dimension count.
278 TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
279 Resize(new_shape_size);
280 const int size_increase = new_shape_size - shape.DimensionsCount();
281 for (int i = 0; i < size_increase; ++i) {
282 SetDim(i, pad_value);
283 }
284 std::memcpy(DimsData() + size_increase, shape.DimsData(),
285 sizeof(int32) * shape.DimensionsCount());
286 }
287
288 int32 size_;
289 union {
290 int32 dims_[kMaxSmallSize];
291 int32* dims_pointer_;
292 };
293 };
294
295 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)296 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
297 tflite::Dims<4> result;
298 const int dimensions_count = array_shape.DimensionsCount();
299 TFLITE_CHECK_LE(dimensions_count, 4);
300 int cum_prod = 1;
301 for (int i = 0; i < 4; i++) {
302 const int new_dim =
303 (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
304 result.sizes[i] = new_dim;
305 result.strides[i] = cum_prod;
306 cum_prod *= new_dim;
307 }
308 return result;
309 }
310
311 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)312 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
313 return RuntimeShape(
314 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
315 }
316
317 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)318 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
319 if (num_dims == 0) {
320 return false;
321 }
322 TFLITE_DCHECK(dims != nullptr);
323 TFLITE_DCHECK(current != nullptr);
324 int carry = 1;
325 for (int idx = num_dims - 1; idx >= 0; --idx) {
326 int current_val = current[idx] + carry;
327 TFLITE_DCHECK_GE(dims[idx], current_val);
328 if (dims[idx] == current_val) {
329 current[idx] = 0;
330 } else {
331 current[idx] = current_val;
332 carry = 0;
333 break;
334 }
335 }
336 return (carry == 0);
337 }
338
339 // Gets offset of index if reducing on axis. When reducing, the flattened offset
340 // will not change, if the input index changes on the given axis. For example,
341 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
342 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
343 // offset.
344 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)345 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
346 const int* index, const int num_axis,
347 const int* axis) {
348 if (num_dims == 0) {
349 return 0;
350 }
351 TFLITE_DCHECK(dims != nullptr);
352 TFLITE_DCHECK(index != nullptr);
353 size_t offset = 0;
354 for (int idx = 0; idx < num_dims; ++idx) {
355 // if we need to skip this axis
356 bool is_axis = false;
357 if (axis != nullptr) {
358 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
359 if (idx == axis[axis_idx]) {
360 is_axis = true;
361 break;
362 }
363 }
364 }
365 if (!is_axis) {
366 offset = offset * static_cast<size_t>(dims[idx]) +
367 static_cast<size_t>(index[idx]);
368 }
369 }
370 return offset;
371 }
372
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)373 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
374 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
375 const int* dims_data = shape.DimsDataUpTo4D();
376 TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
377 TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
378 TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
379 TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
380 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
381 }
382
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)383 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
384 TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
385 TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
386 TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
387 TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
388 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
389 i3 * dims.strides[3];
390 }
391
Offset(const Dims<4> & dims,int * index)392 inline int Offset(const Dims<4>& dims, int* index) {
393 return Offset(dims, index[0], index[1], index[2], index[3]);
394 }
395
Offset(const RuntimeShape & shape,int * index)396 inline int Offset(const RuntimeShape& shape, int* index) {
397 return Offset(shape, index[0], index[1], index[2], index[3]);
398 }
399
400 // Get array size, DCHECKing that the dim index is in range.
401 //
402 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
403 // already performs this check.
404 template <int N>
ArraySize(const Dims<N> & array,int index)405 int ArraySize(const Dims<N>& array, int index) {
406 TFLITE_DCHECK(index >= 0 && index < N);
407 return array.sizes[index];
408 }
409
410 // Get common array size, DCHECKing that they all agree.
411 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)412 int MatchingArraySize(const ArrayType1& array1, int index1,
413 const ArrayType2& array2, int index2) {
414 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
415 return ArraySize(array1, index1);
416 }
417
418 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)419 int MatchingArraySize(const ArrayType1& array1, int index1,
420 const ArrayType2& array2, int index2, Args... args) {
421 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
422 return MatchingArraySize(array1, index1, args...);
423 }
424
425 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)426 inline int MatchingDim(const RuntimeShape& shape1, int index1,
427 const RuntimeShape& shape2, int index2) {
428 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
429 return shape1.Dims(index1);
430 }
431
432 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)433 int MatchingDim(const RuntimeShape& shape1, int index1,
434 const RuntimeShape& shape2, int index2, Args... args) {
435 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
436 return MatchingDim(shape1, index1, args...);
437 }
438
439 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
440 template <int N>
FlatSize(const Dims<N> & dims)441 inline int FlatSize(const Dims<N>& dims) {
442 int flat_size = 1;
443 for (int i = 0; i < N; ++i) {
444 flat_size *= dims.sizes[i];
445 }
446 return flat_size;
447 }
448
449 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)450 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
451 return FlatSize(dims);
452 }
453
454 // Flat size calculation, checking that dimensions match with one or more other
455 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)456 inline int MatchingFlatSize(const RuntimeShape& shape,
457 const RuntimeShape& check_shape_0) {
458 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
459 const int dims_count = shape.DimensionsCount();
460 for (int i = 0; i < dims_count; ++i) {
461 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
462 }
463 return shape.FlatSize();
464 }
465
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)466 inline int MatchingFlatSize(const RuntimeShape& shape,
467 const RuntimeShape& check_shape_0,
468 const RuntimeShape& check_shape_1) {
469 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
470 const int dims_count = shape.DimensionsCount();
471 for (int i = 0; i < dims_count; ++i) {
472 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
473 }
474 return MatchingFlatSize(shape, check_shape_1);
475 }
476
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)477 inline int MatchingFlatSize(const RuntimeShape& shape,
478 const RuntimeShape& check_shape_0,
479 const RuntimeShape& check_shape_1,
480 const RuntimeShape& check_shape_2) {
481 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
482 const int dims_count = shape.DimensionsCount();
483 for (int i = 0; i < dims_count; ++i) {
484 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
485 }
486 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
487 }
488
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)489 inline int MatchingFlatSize(const RuntimeShape& shape,
490 const RuntimeShape& check_shape_0,
491 const RuntimeShape& check_shape_1,
492 const RuntimeShape& check_shape_2,
493 const RuntimeShape& check_shape_3) {
494 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
495 const int dims_count = shape.DimensionsCount();
496 for (int i = 0; i < dims_count; ++i) {
497 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
498 }
499 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
500 }
501
502 // Flat size calculation, checking that dimensions match with one or more other
503 // arrays.
504 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)505 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
506 for (int i = 0; i < N; ++i) {
507 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
508 }
509 return FlatSize(dims);
510 }
511
512 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)513 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
514 const Dims<N>& check_dims_1) {
515 for (int i = 0; i < N; ++i) {
516 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
517 }
518 return MatchingFlatSize(dims, check_dims_1);
519 }
520
521 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)522 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
523 const Dims<N>& check_dims_1,
524 const Dims<N>& check_dims_2) {
525 for (int i = 0; i < N; ++i) {
526 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
527 }
528 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
529 }
530
531 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)532 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
533 const Dims<N>& check_dims_1,
534 const Dims<N>& check_dims_2,
535 const Dims<N>& check_dims_3) {
536 for (int i = 0; i < N; ++i) {
537 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
538 }
539 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
540 }
541
542 // Data is required to be contiguous, and so many operators can use either the
543 // full array flat size or the flat size with one dimension skipped (commonly
544 // the depth).
545 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)546 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
547 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
548 int flat_size = 1;
549 for (int i = 0; i < N; ++i) {
550 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
551 }
552 return flat_size;
553 }
554
555 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
556 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)557 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
558 const Dims<N>& check_dims_0) {
559 for (int i = 0; i < N; ++i) {
560 if (i != skip_dim) {
561 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
562 }
563 }
564 return FlatSizeSkipDim(dims, skip_dim);
565 }
566
567 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)568 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
569 const Dims<N>& check_dims_0,
570 const Dims<N>& check_dims_1) {
571 for (int i = 0; i < N; ++i) {
572 if (i != skip_dim) {
573 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
574 }
575 }
576 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
577 }
578
579 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)580 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
581 const Dims<N>& check_dims_0,
582 const Dims<N>& check_dims_1,
583 const Dims<N>& check_dims_2) {
584 for (int i = 0; i < N; ++i) {
585 if (i != skip_dim) {
586 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
587 }
588 }
589 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
590 }
591
592 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)593 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
594 const Dims<N>& check_dims_0,
595 const Dims<N>& check_dims_1,
596 const Dims<N>& check_dims_2,
597 const Dims<N>& check_dims_3) {
598 for (int i = 0; i < N; ++i) {
599 if (i != skip_dim) {
600 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
601 }
602 }
603 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
604 check_dims_3);
605 }
606
607 // Data is required to be contiguous, and so many operators can use either the
608 // full array flat size or the flat size with one dimension skipped (commonly
609 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)610 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
611 const int dims_count = shape.DimensionsCount();
612 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
613 const auto* dims_data = shape.DimsData();
614 int flat_size = 1;
615 for (int i = 0; i < dims_count; ++i) {
616 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
617 }
618 return flat_size;
619 }
620
621 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)622 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
623 const RuntimeShape& check_shape_0) {
624 const int dims_count = shape.DimensionsCount();
625 for (int i = 0; i < dims_count; ++i) {
626 if (i != skip_dim) {
627 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
628 }
629 }
630 return FlatSizeSkipDim(shape, skip_dim);
631 }
632
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)633 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
634 const RuntimeShape& check_shape_0,
635 const RuntimeShape& check_shape_1) {
636 const int dims_count = shape.DimensionsCount();
637 for (int i = 0; i < dims_count; ++i) {
638 if (i != skip_dim) {
639 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
640 }
641 }
642 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
643 }
644
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)645 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
646 const RuntimeShape& check_shape_0,
647 const RuntimeShape& check_shape_1,
648 const RuntimeShape& check_shape_2) {
649 const int dims_count = shape.DimensionsCount();
650 for (int i = 0; i < dims_count; ++i) {
651 if (i != skip_dim) {
652 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
653 }
654 }
655 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
656 }
657
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)658 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
659 const RuntimeShape& check_shape_0,
660 const RuntimeShape& check_shape_1,
661 const RuntimeShape& check_shape_2,
662 const RuntimeShape& check_shape_3) {
663 const int dims_count = shape.DimensionsCount();
664 for (int i = 0; i < dims_count; ++i) {
665 if (i != skip_dim) {
666 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
667 }
668 }
669 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
670 check_shape_3);
671 }
672
673 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)674 bool IsPackedWithoutStrides(const Dims<N>& dims) {
675 int expected_stride = 1;
676 for (int d = 0; d < N; d++) {
677 if (dims.strides[d] != expected_stride) return false;
678 expected_stride *= dims.sizes[d];
679 }
680 return true;
681 }
682
683 template <int N>
ComputeStrides(Dims<N> * dims)684 void ComputeStrides(Dims<N>* dims) {
685 dims->strides[0] = 1;
686 for (int d = 1; d < N; d++) {
687 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
688 }
689 }
690
691 enum class BroadcastableOpCategory : uint8 {
692 kNone,
693 kNonBroadcast, // Matching input shapes.
694 kFirstInputBroadcastsFast, // Fivefold nested loops.
695 kSecondInputBroadcastsFast, // Fivefold nested loops.
696 kGenericBroadcast, // Fall-back.
697 };
698
699 struct MinMax {
700 float min;
701 float max;
702 };
703 static_assert(sizeof(MinMax) == 8, "");
704
705 struct ActivationParams {
706 FusedActivationFunctionType activation_type;
707 // uint8, etc, activation params.
708 int32 quantized_activation_min;
709 int32 quantized_activation_max;
710 };
711
712 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
713 // op for pattern-specific optimization.
714 enum class ResizingCategory : uint8 {
715 kNone,
716 kImageStyle, // 4D, operating on inner dimensions, say {0, a, b, 0}.
717 kGenericResize,
718 };
719
720 // For Add, Sub, Mul ops.
721 struct ArithmeticParams {
722 // Shape dependent / common to data / op types.
723 BroadcastableOpCategory broadcast_category;
724 // uint8 inference params.
725 int32 input1_offset;
726 int32 input2_offset;
727 int32 output_offset;
728 int32 output_multiplier;
729 int output_shift;
730 // Add / Sub, not Mul, uint8 inference params.
731 int left_shift;
732 int32 input1_multiplier;
733 int input1_shift;
734 int32 input2_multiplier;
735 int input2_shift;
736 // uint8, etc, activation params.
737 int32 quantized_activation_min;
738 int32 quantized_activation_max;
739 // float activation params.
740 float float_activation_min;
741 float float_activation_max;
742
743 // Processed output dimensions.
744 // Let input "a" be the one that broadcasts in the faster-changing dimension.
745 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
746 // {b0, b1, b2, b3, b4},
747 // broadcast_shape[4] = b0 = a0.
748 // broadcast_shape[3] = b1; a1 = 1.
749 // broadcast_shape[2] = b2 = a2.
750 // broadcast_shape[1] = a3; b3 = 1.
751 // broadcast_shape[0] = b4 = a4.
752 int broadcast_shape[5];
753 };
754
755 struct ConcatenationParams {
756 int8 axis;
757 const int32* input_zeropoint;
758 const float* input_scale;
759 uint16 inputs_count;
760 int32 output_zeropoint;
761 float output_scale;
762 };
763
764 struct ComparisonParams {
765 // uint8 inference params.
766 int left_shift;
767 int32 input1_offset;
768 int32 input1_multiplier;
769 int input1_shift;
770 int32 input2_offset;
771 int32 input2_multiplier;
772 int input2_shift;
773 // Shape dependent / common to inference types.
774 bool is_broadcast;
775 };
776
777 struct ConvParams {
778 PaddingType padding_type;
779 PaddingValues padding_values;
780 // TODO(starka): This was just "stride", so check that width+height is OK.
781 int16 stride_width;
782 int16 stride_height;
783 int16 dilation_width_factor;
784 int16 dilation_height_factor;
785 // uint8 inference params.
786 // TODO(b/65838351): Use smaller types if appropriate.
787 int32 input_offset;
788 int32 weights_offset;
789 int32 output_offset;
790 int32 output_multiplier;
791 int output_shift;
792 // uint8, etc, activation params.
793 int32 quantized_activation_min;
794 int32 quantized_activation_max;
795 // float activation params.
796 float float_activation_min;
797 float float_activation_max;
798 };
799
800 struct DepthToSpaceParams {
801 int32 block_size;
802 };
803
804 struct DepthwiseParams {
805 PaddingType padding_type;
806 PaddingValues padding_values;
807 int16 stride_width;
808 int16 stride_height;
809 int16 dilation_width_factor;
810 int16 dilation_height_factor;
811 int16 depth_multiplier;
812 // uint8 inference params.
813 // TODO(b/65838351): Use smaller types if appropriate.
814 int32 input_offset;
815 int32 weights_offset;
816 int32 output_offset;
817 int32 output_multiplier;
818 int output_shift;
819 // uint8, etc, activation params.
820 int32 quantized_activation_min;
821 int32 quantized_activation_max;
822 // float activation params.
823 float float_activation_min;
824 float float_activation_max;
825 };
826
827 struct DequantizationParams {
828 double scale;
829 int32 zero_point;
830 };
831
832 struct FakeQuantParams {
833 MinMax minmax;
834 int32 num_bits;
835 };
836
837 struct FullyConnectedParams {
838 // uint8 inference params.
839 // TODO(b/65838351): Use smaller types if appropriate.
840 int32 input_offset;
841 int32 weights_offset;
842 int32 output_offset;
843 int32 output_multiplier;
844 int output_shift;
845 // uint8, etc, activation params.
846 int32 quantized_activation_min;
847 int32 quantized_activation_max;
848 // float activation params.
849 float float_activation_min;
850 float float_activation_max;
851 FullyConnectedWeightsFormat weights_format;
852 };
853
854 struct GatherParams {
855 int16 axis;
856 };
857
858 struct L2NormalizationParams {
859 // uint8 inference params.
860 int32 input_zero_point;
861 };
862
863 struct LocalResponseNormalizationParams {
864 int32 range;
865 double bias;
866 double alpha;
867 double beta;
868 };
869
870 struct LogisticParams {
871 // uint8 inference params.
872 int32 input_zero_point;
873 int32 input_range_radius;
874 int32 input_multiplier;
875 int input_left_shift;
876 };
877
878 struct LstmCellParams {
879 int32 weights_zero_point;
880 int32 accum_multiplier;
881 int accum_shift;
882 int state_integer_bits;
883 };
884
885 struct MeanParams {
886 int8 axis_count;
887 int16 axis[4];
888 };
889
890 struct PackParams {
891 int8 axis;
892 const int32* input_zeropoint;
893 const float* input_scale;
894 uint16 inputs_count;
895 int32 output_zeropoint;
896 float output_scale;
897 };
898
899 struct PadParams {
900 int8 left_padding_count;
901 int32 left_padding[4];
902 int8 right_padding_count;
903 int32 right_padding[4];
904 ResizingCategory resizing_category;
905 };
906
907 struct PreluParams {
908 int32 input_offset;
909 int32 alpha_offset;
910 int32 output_offset;
911 int32 output_multiplier;
912 int output_shift;
913 };
914
915 struct PoolParams {
916 FusedActivationFunctionType activation;
917 PaddingType padding_type;
918 PaddingValues padding_values;
919 int stride_height;
920 int stride_width;
921 int filter_height;
922 int filter_width;
923 // uint8, etc, activation params.
924 int32 quantized_activation_min;
925 int32 quantized_activation_max;
926 // float activation params.
927 float float_activation_min;
928 float float_activation_max;
929 };
930
931 struct ReshapeParams {
932 int8 shape_count;
933 int32 shape[4];
934 };
935
936 struct ResizeBilinearParams {
937 bool align_corners;
938 };
939
940 struct ResizeNearestNeighborParams {
941 bool align_corners;
942 };
943
944 struct SliceParams {
945 int8 begin_count;
946 int32 begin[4];
947 int8 size_count;
948 int32 size[4];
949 };
950
951 struct SoftmaxParams {
952 // beta is not really used (not a Tensorflow parameter) and not implemented
953 // for LogSoftmax.
954 double beta;
955 // uint8 inference params. Used even when beta defaults to 1.0.
956 int32 input_multiplier;
957 int32 input_left_shift;
958 // Reverse scaling is only used by LogSoftmax.
959 int32 reverse_scaling_divisor;
960 int32 reverse_scaling_right_shift;
961 int diff_min;
962 };
963
964 struct SpaceToBatchParams {
965 // "Zero" padding for uint8 means padding with the output offset.
966 int32 output_offset;
967 };
968
969 struct SpaceToDepthParams {
970 int32 block_size;
971 };
972
973 struct SplitParams {
974 // Graphs that split into, say, 2000 nodes are encountered. The indices in
975 // OperatorEdges are of type uint16.
976 uint16 num_split;
977 int16 axis;
978 };
979
980 struct SqueezeParams {
981 int8 squeeze_dims_count;
982 int32 squeeze_dims[4];
983 };
984
985 struct StridedSliceParams {
986 int8 start_indices_count;
987 int16 start_indices[4];
988 int8 stop_indices_count;
989 int16 stop_indices[4];
990 int8 strides_count;
991 int16 strides[4];
992
993 int16 begin_mask;
994 int16 ellipsis_mask;
995 int16 end_mask;
996 int16 new_axis_mask;
997 int16 shrink_axis_mask;
998 };
999
1000 struct TanhParams {
1001 int32 input_zero_point;
1002 int32 input_range_radius;
1003 int32 input_multiplier;
1004 int input_left_shift;
1005 };
1006
1007 struct TransposeParams {
1008 int8 perm_count;
1009 int32 perm[4];
1010 };
1011
1012 struct UnpackParams {
1013 uint16 num_split;
1014 int16 axis;
1015 };
1016
1017 struct LeakyReluParams {
1018 float alpha;
1019 };
1020
1021 template <typename P>
SetActivationParams(float min,float max,P * params)1022 inline void SetActivationParams(float min, float max, P* params) {
1023 params->float_activation_min = min;
1024 params->float_activation_max = max;
1025 }
1026
1027 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)1028 inline void SetActivationParams(int32 min, int32 max, P* params) {
1029 params->quantized_activation_min = min;
1030 params->quantized_activation_max = max;
1031 }
1032
1033 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)1034 inline void GetActivationParams(const P& params, int32* min, int32* max) {
1035 *min = params.quantized_activation_min;
1036 *max = params.quantized_activation_max;
1037 }
1038
1039 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1040 inline void GetActivationParams(const P& params, float* min, float* max) {
1041 *min = params.float_activation_min;
1042 *max = params.float_activation_max;
1043 }
1044
1045 } // namespace tflite
1046
1047 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1048