1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24
25 namespace tflite {
26
27 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
28 enum class PaddingType : uint8 { kNone, kSame, kValid };
29
30 struct PaddingValues {
31 int16 width;
32 int16 height;
33 // offset is used for calculating "remaining" padding, for example, `width`
34 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
35 // 1 + 1 = 2.
36 int16 width_offset;
37 // Same as width_offset except it's over the height dimension.
38 int16 height_offset;
39 };
40
41 // This enumeration allows for non-default formats for the weights array
42 // of a fully-connected operator, allowing the use of special optimized
43 // runtime paths.
44 enum class FullyConnectedWeightsFormat : uint8 {
45 // Default format (flat 2D layout, the inner contiguous dimension
46 // is input_depth, the outer non-contiguous dimension is output_depth)
47 kDefault,
48 // Summary: optimized layout for fast CPU runtime implementation,
49 // aimed specifically at ARM CPUs at the moment, and specialized for
50 // 8-bit quantized layers.
51 //
52 // The use case we're concerned with here is: 8-bit quantization,
53 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
54 // a key application that drove this), very small batch size (e.g. 1 -- 4).
55 //
56 // Even with 8-bit quantization of weights, the performance of memory
57 // accesses to the weights can become the dominant issue when
58 // the batch size is small, so each weight value is used in only a few
59 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
60 // intensity. The specific issues that arise are of three kinds:
61 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
62 // bound. That's the "good" issue to run into.
63 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
64 // prefetched into the cache by the time we need it.
65 // (3) One may run into cache aliasing: multiple values that are
66 // pre-fetched, alias each other in the L1 cache (which typically
67 // has only 4-way set associativity in ARM CPUs) and thus evict
68 // each other before we get to using them.
69 //
70 // The point of this shuffling is to avoid issues (2) and (3) so that
71 // we get as fast as possible given only the hard constraint (1).
72 // This is achieved by turning the difficulty into a solution: the
73 // difficulty, that each value loaded from memory is used only in
74 // one kernel iteration, making this operation memory-intensive, hints at
75 // the solution, of shuffling the weights so that they are stored in the
76 // exact order as the kernel needs to load them, so that the memory
77 // accesses made by the kernel are trivial. This solves (2) because the
78 // trivial memory access pattern allows the CPU's automatic prefetching
79 // to perform very well (no need even for preload instructions), and this
80 // solves (3) because the values being loaded concurrently are now
81 // contiguous in the address space, thus don't alias each other in the cache.
82 //
83 // On ARM, we typically want our kernel to process a 4x16 block of weights
84 // at a time, because:
85 // - 16 is the number of bytes in a NEON register.
86 // - 4 is how many rows we need to handle concurrently in the kernel in
87 // order to have sufficient mutual independence of instructions to
88 // maximize arithmetic throughput.
89 //
90 // Finally, the 'Int8' part in the name refers to the fact that this
91 // weights format has each weights value encoded as a signed int8 value,
92 // even if the data type of the weights buffer is uint8. This is intended
93 // to save runtime kernels the effort to have to XOR the top bit of these
94 // bytes before using them in signed arithmetic, see this file for more
95 // explanations on the 'signed int8 trick' in matrix multiplication kernels:
96 //
97 // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
98 //
99 kShuffled4x16Int8,
100 };
101
102 // Quantization parameters, determining the mapping of quantized values
103 // to real values (i.e. determining how quantized values are mathematically
104 // interpreted).
105 //
106 // The correspondence is as follows:
107 //
108 // real_value = scale * (quantized_value - zero_point);
109 //
110 // In other words, zero_point designates which quantized value corresponds to
111 // the real 0 value, and scale designates the difference between the real values
112 // corresponding to consecutive quantized values differing by 1.
113 struct QuantizationParams {
114 int32 zero_point = 0;
115 double scale = 0.0;
116 };
117
118 inline bool operator==(const QuantizationParams& qp1,
119 const QuantizationParams& qp2) {
120 return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
121 }
122
123 template <int N>
124 struct Dims {
125 int sizes[N];
126 int strides[N];
127 };
128
129 class RuntimeShape {
130 public:
131 // Shapes with dimensions up to 4 are stored directly in the structure, while
132 // larger shapes are separately allocated.
133 static constexpr int kMaxSmallSize = 4;
134
135 RuntimeShape& operator=(RuntimeShape const&) = delete;
136
RuntimeShape()137 RuntimeShape() : size_(0) {}
138
RuntimeShape(int dimensions_count)139 explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
140 if (dimensions_count > kMaxSmallSize) {
141 #ifdef TF_LITE_STATIC_MEMORY
142 TFLITE_CHECK(false && "No shape resizing supported on this platform");
143 #else // TF_LITE_STATIC_MEMORY
144 dims_pointer_ = new int32[dimensions_count];
145 #endif // TF_LITE_STATIC_MEMORY
146 }
147 }
148
RuntimeShape(int shape_size,int32 value)149 RuntimeShape(int shape_size, int32 value) : size_(0) {
150 Resize(shape_size);
151 for (int i = 0; i < shape_size; ++i) {
152 SetDim(i, value);
153 }
154 }
155
RuntimeShape(int dimensions_count,const int32 * dims_data)156 RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
157 ReplaceWith(dimensions_count, dims_data);
158 }
159
RuntimeShape(const std::initializer_list<int> init_list)160 RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
161 BuildFrom(init_list);
162 }
163
164 // Avoid using this constructor. We should be able to delete it when C++17
165 // rolls out.
RuntimeShape(RuntimeShape const & other)166 RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
167 if (size_ > kMaxSmallSize) {
168 dims_pointer_ = new int32[size_];
169 }
170 std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
171 }
172
173 bool operator==(const RuntimeShape& comp) const {
174 return this->size_ == comp.size_ &&
175 std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
176 }
177
~RuntimeShape()178 ~RuntimeShape() {
179 if (size_ > kMaxSmallSize) {
180 #ifdef TF_LITE_STATIC_MEMORY
181 TFLITE_CHECK(false && "No shape resizing supported on this platform");
182 #else // TF_LITE_STATIC_MEMORY
183 delete[] dims_pointer_;
184 #endif // TF_LITE_STATIC_MEMORY
185 }
186 }
187
DimensionsCount()188 inline int32 DimensionsCount() const { return size_; }
Dims(int i)189 inline int32 Dims(int i) const {
190 TFLITE_DCHECK_GE(i, 0);
191 TFLITE_DCHECK_LT(i, size_);
192 return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
193 }
SetDim(int i,int32 val)194 inline void SetDim(int i, int32 val) {
195 TFLITE_DCHECK_GE(i, 0);
196 TFLITE_DCHECK_LT(i, size_);
197 if (size_ > kMaxSmallSize) {
198 dims_pointer_[i] = val;
199 } else {
200 dims_[i] = val;
201 }
202 }
203
DimsData()204 inline int32* DimsData() {
205 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
206 }
DimsData()207 inline const int32* DimsData() const {
208 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
209 }
210 // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()211 inline const int32* DimsDataUpTo4D() const { return dims_; }
212
Resize(int dimensions_count)213 inline void Resize(int dimensions_count) {
214 if (size_ > kMaxSmallSize) {
215 #ifdef TF_LITE_STATIC_MEMORY
216 TFLITE_CHECK(false && "No shape resizing supported on this platform");
217 #else // TF_LITE_STATIC_MEMORY
218 delete[] dims_pointer_;
219 #endif // TF_LITE_STATIC_MEMORY
220 }
221 size_ = dimensions_count;
222 if (dimensions_count > kMaxSmallSize) {
223 #ifdef TF_LITE_STATIC_MEMORY
224 TFLITE_CHECK(false && "No shape resizing supported on this platform");
225 #else // TF_LITE_STATIC_MEMORY
226 dims_pointer_ = new int32[dimensions_count];
227 #endif // TF_LITE_STATIC_MEMORY
228 }
229 }
230
ReplaceWith(int dimensions_count,const int32 * dims_data)231 inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
232 Resize(dimensions_count);
233 int32* dst_dims = DimsData();
234 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
235 }
236
237 template <typename T>
BuildFrom(const T & src_iterable)238 inline void BuildFrom(const T& src_iterable) {
239 const int dimensions_count =
240 std::distance(src_iterable.begin(), src_iterable.end());
241 Resize(dimensions_count);
242 int32* data = DimsData();
243 for (auto it : src_iterable) {
244 *data = it;
245 ++data;
246 }
247 }
248
249 // This will probably be factored out. Old code made substantial use of 4-D
250 // shapes, and so this function is used to extend smaller shapes. Note that
251 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
252 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
253 // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)254 inline static RuntimeShape ExtendedShape(int new_shape_size,
255 const RuntimeShape& shape) {
256 return RuntimeShape(new_shape_size, shape, 1);
257 }
258
BuildFrom(const std::initializer_list<int> init_list)259 inline void BuildFrom(const std::initializer_list<int> init_list) {
260 BuildFrom<const std::initializer_list<int>>(init_list);
261 }
262
263 // Returns the total count of elements, that is the size when flattened into a
264 // vector.
FlatSize()265 inline int FlatSize() const {
266 int buffer_size = 1;
267 const int* dims_data = reinterpret_cast<const int*>(DimsData());
268 for (int i = 0; i < size_; i++) {
269 buffer_size *= dims_data[i];
270 }
271 return buffer_size;
272 }
273
274 bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
275
276 private:
277 // For use only by ExtendedShape(), written to guarantee (return-value) copy
278 // elision in C++17.
279 // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)280 RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
281 : size_(0) {
282 // If the following check fails, it is likely because a 4D-only kernel is
283 // being used with an array of larger dimension count.
284 TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
285 Resize(new_shape_size);
286 const int size_increase = new_shape_size - shape.DimensionsCount();
287 for (int i = 0; i < size_increase; ++i) {
288 SetDim(i, pad_value);
289 }
290 std::memcpy(DimsData() + size_increase, shape.DimsData(),
291 sizeof(int32) * shape.DimensionsCount());
292 }
293
294 int32 size_;
295 union {
296 int32 dims_[kMaxSmallSize];
297 int32* dims_pointer_;
298 };
299 };
300
301 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)302 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
303 tflite::Dims<4> result;
304 const int dimensions_count = array_shape.DimensionsCount();
305 TFLITE_CHECK_LE(dimensions_count, 4);
306 int cum_prod = 1;
307 for (int i = 0; i < 4; i++) {
308 const int new_dim =
309 (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
310 result.sizes[i] = new_dim;
311 result.strides[i] = cum_prod;
312 cum_prod *= new_dim;
313 }
314 return result;
315 }
316
317 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)318 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
319 return RuntimeShape(
320 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
321 }
322
323 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)324 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
325 if (num_dims == 0) {
326 return false;
327 }
328 TFLITE_DCHECK(dims != nullptr);
329 TFLITE_DCHECK(current != nullptr);
330 int carry = 1;
331 for (int idx = num_dims - 1; idx >= 0; --idx) {
332 int current_val = current[idx] + carry;
333 TFLITE_DCHECK_GE(dims[idx], current_val);
334 if (dims[idx] == current_val) {
335 current[idx] = 0;
336 } else {
337 current[idx] = current_val;
338 carry = 0;
339 break;
340 }
341 }
342 return (carry == 0);
343 }
344
345 // Gets offset of index if reducing on axis. When reducing, the flattened offset
346 // will not change, if the input index changes on the given axis. For example,
347 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
348 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
349 // offset.
350 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)351 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
352 const int* index, const int num_axis,
353 const int* axis) {
354 if (num_dims == 0) {
355 return 0;
356 }
357 TFLITE_DCHECK(dims != nullptr);
358 TFLITE_DCHECK(index != nullptr);
359 size_t offset = 0;
360 for (int idx = 0; idx < num_dims; ++idx) {
361 // if we need to skip this axis
362 bool is_axis = false;
363 if (axis != nullptr) {
364 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
365 if (idx == axis[axis_idx]) {
366 is_axis = true;
367 break;
368 }
369 }
370 }
371 if (!is_axis) {
372 offset = offset * static_cast<size_t>(dims[idx]) +
373 static_cast<size_t>(index[idx]);
374 }
375 }
376 return offset;
377 }
378
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)379 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
380 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
381 const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo4D());
382 TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
383 TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
384 TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
385 TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
386 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
387 }
388
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)389 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
390 TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
391 TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
392 TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
393 TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
394 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
395 i3 * dims.strides[3];
396 }
397
Offset(const Dims<4> & dims,int * index)398 inline int Offset(const Dims<4>& dims, int* index) {
399 return Offset(dims, index[0], index[1], index[2], index[3]);
400 }
401
Offset(const RuntimeShape & shape,int * index)402 inline int Offset(const RuntimeShape& shape, int* index) {
403 return Offset(shape, index[0], index[1], index[2], index[3]);
404 }
405
406 // Get array size, DCHECKing that the dim index is in range.
407 //
408 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
409 // already performs this check.
410 template <int N>
ArraySize(const Dims<N> & array,int index)411 int ArraySize(const Dims<N>& array, int index) {
412 TFLITE_DCHECK(index >= 0 && index < N);
413 return array.sizes[index];
414 }
415
416 // Get common array size, DCHECKing that they all agree.
417 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)418 int MatchingArraySize(const ArrayType1& array1, int index1,
419 const ArrayType2& array2, int index2) {
420 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
421 return ArraySize(array1, index1);
422 }
423
424 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)425 int MatchingArraySize(const ArrayType1& array1, int index1,
426 const ArrayType2& array2, int index2, Args... args) {
427 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
428 return MatchingArraySize(array1, index1, args...);
429 }
430
431 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)432 inline int MatchingDim(const RuntimeShape& shape1, int index1,
433 const RuntimeShape& shape2, int index2) {
434 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
435 return shape1.Dims(index1);
436 }
437
438 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)439 int MatchingDim(const RuntimeShape& shape1, int index1,
440 const RuntimeShape& shape2, int index2, Args... args) {
441 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
442 return MatchingDim(shape1, index1, args...);
443 }
444
445 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
446 template <int N>
FlatSize(const Dims<N> & dims)447 inline int FlatSize(const Dims<N>& dims) {
448 int flat_size = 1;
449 for (int i = 0; i < N; ++i) {
450 flat_size *= dims.sizes[i];
451 }
452 return flat_size;
453 }
454
455 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)456 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
457 return FlatSize(dims);
458 }
459
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)460 inline int MatchingElementsSize(const RuntimeShape& shape,
461 const RuntimeShape& check_shape_0) {
462 const int size_1 = shape.FlatSize();
463 const int size_2 = check_shape_0.FlatSize();
464 TFLITE_CHECK_EQ(size_1, size_2);
465 return size_1;
466 }
467
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)468 inline int MatchingElementsSize(const RuntimeShape& shape,
469 const RuntimeShape& check_shape_0,
470 const RuntimeShape& check_shape_1) {
471 const int size_1 = shape.FlatSize();
472 const int size_2 = check_shape_0.FlatSize();
473 const int size_3 = check_shape_1.FlatSize();
474 TFLITE_CHECK_EQ(size_1, size_2);
475 TFLITE_CHECK_EQ(size_2, size_3);
476 return size_1;
477 }
478
479 // Flat size calculation, checking that dimensions match with one or more other
480 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)481 inline int MatchingFlatSize(const RuntimeShape& shape,
482 const RuntimeShape& check_shape_0) {
483 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
484 const int dims_count = shape.DimensionsCount();
485 for (int i = 0; i < dims_count; ++i) {
486 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
487 }
488 return shape.FlatSize();
489 }
490
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)491 inline int MatchingFlatSize(const RuntimeShape& shape,
492 const RuntimeShape& check_shape_0,
493 const RuntimeShape& check_shape_1) {
494 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
495 const int dims_count = shape.DimensionsCount();
496 for (int i = 0; i < dims_count; ++i) {
497 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
498 }
499 return MatchingFlatSize(shape, check_shape_1);
500 }
501
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)502 inline int MatchingFlatSize(const RuntimeShape& shape,
503 const RuntimeShape& check_shape_0,
504 const RuntimeShape& check_shape_1,
505 const RuntimeShape& check_shape_2) {
506 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
507 const int dims_count = shape.DimensionsCount();
508 for (int i = 0; i < dims_count; ++i) {
509 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
510 }
511 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
512 }
513
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)514 inline int MatchingFlatSize(const RuntimeShape& shape,
515 const RuntimeShape& check_shape_0,
516 const RuntimeShape& check_shape_1,
517 const RuntimeShape& check_shape_2,
518 const RuntimeShape& check_shape_3) {
519 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
520 const int dims_count = shape.DimensionsCount();
521 for (int i = 0; i < dims_count; ++i) {
522 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
523 }
524 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
525 }
526
527 // Flat size calculation, checking that dimensions match with one or more other
528 // arrays.
529 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)530 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
531 for (int i = 0; i < N; ++i) {
532 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
533 }
534 return FlatSize(dims);
535 }
536
537 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)538 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
539 const Dims<N>& check_dims_1) {
540 for (int i = 0; i < N; ++i) {
541 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
542 }
543 return MatchingFlatSize(dims, check_dims_1);
544 }
545
546 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)547 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
548 const Dims<N>& check_dims_1,
549 const Dims<N>& check_dims_2) {
550 for (int i = 0; i < N; ++i) {
551 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
552 }
553 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
554 }
555
556 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)557 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
558 const Dims<N>& check_dims_1,
559 const Dims<N>& check_dims_2,
560 const Dims<N>& check_dims_3) {
561 for (int i = 0; i < N; ++i) {
562 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
563 }
564 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
565 }
566
567 // Data is required to be contiguous, and so many operators can use either the
568 // full array flat size or the flat size with one dimension skipped (commonly
569 // the depth).
570 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)571 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
572 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
573 int flat_size = 1;
574 for (int i = 0; i < N; ++i) {
575 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
576 }
577 return flat_size;
578 }
579
580 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
581 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)582 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
583 const Dims<N>& check_dims_0) {
584 for (int i = 0; i < N; ++i) {
585 if (i != skip_dim) {
586 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
587 }
588 }
589 return FlatSizeSkipDim(dims, skip_dim);
590 }
591
592 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)593 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
594 const Dims<N>& check_dims_0,
595 const Dims<N>& check_dims_1) {
596 for (int i = 0; i < N; ++i) {
597 if (i != skip_dim) {
598 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
599 }
600 }
601 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
602 }
603
604 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)605 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
606 const Dims<N>& check_dims_0,
607 const Dims<N>& check_dims_1,
608 const Dims<N>& check_dims_2) {
609 for (int i = 0; i < N; ++i) {
610 if (i != skip_dim) {
611 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
612 }
613 }
614 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
615 }
616
617 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)618 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
619 const Dims<N>& check_dims_0,
620 const Dims<N>& check_dims_1,
621 const Dims<N>& check_dims_2,
622 const Dims<N>& check_dims_3) {
623 for (int i = 0; i < N; ++i) {
624 if (i != skip_dim) {
625 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
626 }
627 }
628 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
629 check_dims_3);
630 }
631
632 // Data is required to be contiguous, and so many operators can use either the
633 // full array flat size or the flat size with one dimension skipped (commonly
634 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)635 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
636 const int dims_count = shape.DimensionsCount();
637 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
638 const auto* dims_data = shape.DimsData();
639 int flat_size = 1;
640 for (int i = 0; i < dims_count; ++i) {
641 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
642 }
643 return flat_size;
644 }
645
646 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)647 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
648 const RuntimeShape& check_shape_0) {
649 const int dims_count = shape.DimensionsCount();
650 for (int i = 0; i < dims_count; ++i) {
651 if (i != skip_dim) {
652 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
653 }
654 }
655 return FlatSizeSkipDim(shape, skip_dim);
656 }
657
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)658 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
659 const RuntimeShape& check_shape_0,
660 const RuntimeShape& check_shape_1) {
661 const int dims_count = shape.DimensionsCount();
662 for (int i = 0; i < dims_count; ++i) {
663 if (i != skip_dim) {
664 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
665 }
666 }
667 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
668 }
669
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)670 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
671 const RuntimeShape& check_shape_0,
672 const RuntimeShape& check_shape_1,
673 const RuntimeShape& check_shape_2) {
674 const int dims_count = shape.DimensionsCount();
675 for (int i = 0; i < dims_count; ++i) {
676 if (i != skip_dim) {
677 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
678 }
679 }
680 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
681 }
682
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)683 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
684 const RuntimeShape& check_shape_0,
685 const RuntimeShape& check_shape_1,
686 const RuntimeShape& check_shape_2,
687 const RuntimeShape& check_shape_3) {
688 const int dims_count = shape.DimensionsCount();
689 for (int i = 0; i < dims_count; ++i) {
690 if (i != skip_dim) {
691 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
692 }
693 }
694 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
695 check_shape_3);
696 }
697
698 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)699 bool IsPackedWithoutStrides(const Dims<N>& dims) {
700 int expected_stride = 1;
701 for (int d = 0; d < N; d++) {
702 if (dims.strides[d] != expected_stride) return false;
703 expected_stride *= dims.sizes[d];
704 }
705 return true;
706 }
707
708 template <int N>
ComputeStrides(Dims<N> * dims)709 void ComputeStrides(Dims<N>* dims) {
710 dims->strides[0] = 1;
711 for (int d = 1; d < N; d++) {
712 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
713 }
714 }
715
716 enum class BroadcastableOpCategory : uint8 {
717 kNone,
718 kNonBroadcast, // Matching input shapes.
719 kFirstInputBroadcastsFast, // Fivefold nested loops.
720 kSecondInputBroadcastsFast, // Fivefold nested loops.
721 kGenericBroadcast, // Fall-back.
722 };
723
724 struct MinMax {
725 float min;
726 float max;
727 };
728 static_assert(sizeof(MinMax) == 8, "");
729
730 struct ActivationParams {
731 FusedActivationFunctionType activation_type;
732 // uint8, etc, activation params.
733 int32 quantized_activation_min;
734 int32 quantized_activation_max;
735 };
736
737 struct ReluParams : public ActivationParams {
738 int32 input_offset;
739 int32 output_offset;
740 int32 output_multiplier;
741 int32 output_shift;
742 };
743
744 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
745 // op for pattern-specific optimization.
746 enum class ResizingCategory : uint8 {
747 kNone,
748 kImageStyle, // 4D, operating on inner dimensions, say {0, a, b, 0}.
749 kGenericResize,
750 };
751
752 // For Add, Sub, Mul ops.
753 struct ArithmeticParams {
754 // Shape dependent / common to data / op types.
755 BroadcastableOpCategory broadcast_category;
756 // uint8 inference params.
757 int32 input1_offset;
758 int32 input2_offset;
759 int32 output_offset;
760 int32 output_multiplier;
761 int output_shift;
762 // Add / Sub, not Mul, uint8 inference params.
763 int left_shift;
764 int32 input1_multiplier;
765 int input1_shift;
766 int32 input2_multiplier;
767 int input2_shift;
768 // uint8, etc, activation params.
769 int32 quantized_activation_min;
770 int32 quantized_activation_max;
771 // float activation params.
772 float float_activation_min;
773 float float_activation_max;
774
775 // Processed output dimensions.
776 // Let input "a" be the one that broadcasts in the faster-changing dimension.
777 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
778 // {b0, b1, b2, b3, b4},
779 // broadcast_shape[4] = b0 = a0.
780 // broadcast_shape[3] = b1; a1 = 1.
781 // broadcast_shape[2] = b2 = a2.
782 // broadcast_shape[1] = a3; b3 = 1.
783 // broadcast_shape[0] = b4 = a4.
784 int broadcast_shape[5];
785 };
786
787 struct ConcatenationParams {
788 int8 axis;
789 const int32* input_zeropoint;
790 const float* input_scale;
791 uint16 inputs_count;
792 int32 output_zeropoint;
793 float output_scale;
794 };
795
796 struct ComparisonParams {
797 // uint8 inference params.
798 int left_shift;
799 int32 input1_offset;
800 int32 input1_multiplier;
801 int input1_shift;
802 int32 input2_offset;
803 int32 input2_multiplier;
804 int input2_shift;
805 // Shape dependent / common to inference types.
806 bool is_broadcast;
807 };
808
809 struct ConvParams {
810 PaddingType padding_type;
811 PaddingValues padding_values;
812 // TODO(starka): This was just "stride", so check that width+height is OK.
813 int16 stride_width;
814 int16 stride_height;
815 int16 dilation_width_factor;
816 int16 dilation_height_factor;
817 // uint8 inference params.
818 // TODO(b/65838351): Use smaller types if appropriate.
819 int32 input_offset;
820 int32 weights_offset;
821 int32 output_offset;
822 int32 output_multiplier;
823 int output_shift;
824 // uint8, etc, activation params.
825 int32 quantized_activation_min;
826 int32 quantized_activation_max;
827 // float activation params.
828 float float_activation_min;
829 float float_activation_max;
830 };
831
832 struct DepthToSpaceParams {
833 int32 block_size;
834 };
835
836 struct DepthwiseParams {
837 PaddingType padding_type;
838 PaddingValues padding_values;
839 int16 stride_width;
840 int16 stride_height;
841 int16 dilation_width_factor;
842 int16 dilation_height_factor;
843 int16 depth_multiplier;
844 // uint8 inference params.
845 // TODO(b/65838351): Use smaller types if appropriate.
846 int32 input_offset;
847 int32 weights_offset;
848 int32 output_offset;
849 int32 output_multiplier;
850 int output_shift;
851 // uint8, etc, activation params.
852 int32 quantized_activation_min;
853 int32 quantized_activation_max;
854 // float activation params.
855 float float_activation_min;
856 float float_activation_max;
857 const int32* output_multiplier_per_channel;
858 const int32* output_shift_per_channel;
859 };
860
861 struct DequantizationParams {
862 double scale;
863 int32 zero_point;
864 };
865
866 struct FakeQuantParams {
867 MinMax minmax;
868 int32 num_bits;
869 };
870
871 struct FullyConnectedParams {
872 // uint8 inference params.
873 // TODO(b/65838351): Use smaller types if appropriate.
874 int32 input_offset;
875 int32 weights_offset;
876 int32 output_offset;
877 int32 output_multiplier;
878 int output_shift;
879 // uint8, etc, activation params.
880 int32 quantized_activation_min;
881 int32 quantized_activation_max;
882 // float activation params.
883 float float_activation_min;
884 float float_activation_max;
885 // Mark the operands as cacheable if they are unchanging, e.g. weights.
886 bool lhs_cacheable;
887 bool rhs_cacheable;
888 FullyConnectedWeightsFormat weights_format;
889 };
890
891 struct GatherParams {
892 int16 axis;
893 };
894
895 struct L2NormalizationParams {
896 // uint8 inference params.
897 int32 input_zero_point;
898 };
899
900 struct LocalResponseNormalizationParams {
901 int32 range;
902 double bias;
903 double alpha;
904 double beta;
905 };
906
907 struct HardSwishParams {
908 // zero_point of the input activations.
909 int16_t input_zero_point;
910 // zero_point of the output activations.
911 int16_t output_zero_point;
912 // 16bit fixed-point component of the multiplier to apply to go from the
913 // "high-res input scale", which is the input scale multiplied by 2^7, to the
914 // "relu-ish scale", which 3.0/32768.
915 // See the implementation of HardSwishPrepare.
916 int16_t reluish_multiplier_fixedpoint_int16;
917 // exponent/bit-shift component of the aforementioned multiplier.
918 int reluish_multiplier_exponent;
919 // 16bit fixed-point component of the multiplier to apply to go from the
920 // "high-res input scale", which is the input scale multiplied by 2^7, to the
921 // output scale.
922 // See the implementation of HardSwishPrepare.
923 int16_t output_multiplier_fixedpoint_int16;
924 // exponent/bit-shift component of the aforementioned multiplier.
925 int output_multiplier_exponent;
926 };
927
928 struct LogisticParams {
929 // uint8 inference params.
930 int32 input_zero_point;
931 int32 input_range_radius;
932 int32 input_multiplier;
933 int input_left_shift;
934 };
935
936 struct LstmCellParams {
937 int32 weights_zero_point;
938 int32 accum_multiplier;
939 int accum_shift;
940 int state_integer_bits;
941 };
942
943 struct MeanParams {
944 int8 axis_count;
945 int16 axis[4];
946 };
947
948 struct PackParams {
949 int8 axis;
950 const int32* input_zeropoint;
951 const float* input_scale;
952 uint16 inputs_count;
953 int32 output_zeropoint;
954 float output_scale;
955 };
956
957 struct PadParams {
958 int8 left_padding_count;
959 int32 left_padding[4];
960 int8 right_padding_count;
961 int32 right_padding[4];
962 ResizingCategory resizing_category;
963 };
964
965 struct PreluParams {
966 int32 input_offset;
967 int32 alpha_offset;
968 int32 output_offset;
969 int32 output_multiplier;
970 int output_shift;
971 };
972
973 struct PoolParams {
974 FusedActivationFunctionType activation;
975 PaddingType padding_type;
976 PaddingValues padding_values;
977 int stride_height;
978 int stride_width;
979 int filter_height;
980 int filter_width;
981 // uint8, etc, activation params.
982 int32 quantized_activation_min;
983 int32 quantized_activation_max;
984 // float activation params.
985 float float_activation_min;
986 float float_activation_max;
987 };
988
989 struct ReshapeParams {
990 int8 shape_count;
991 int32 shape[4];
992 };
993
994 struct ResizeBilinearParams {
995 bool align_corners;
996 // half_pixel_centers assumes pixels are of half the actual dimensions, and
997 // yields more accurate resizes. Corresponds to the same argument for the
998 // original TensorFlow op in TF2.0.
999 bool half_pixel_centers;
1000 };
1001
1002 struct ResizeNearestNeighborParams {
1003 bool align_corners;
1004 };
1005
1006 struct SliceParams {
1007 int8 begin_count;
1008 int32 begin[4];
1009 int8 size_count;
1010 int32 size[4];
1011 };
1012
1013 struct SoftmaxParams {
1014 // beta is not really used (not a Tensorflow parameter) and not implemented
1015 // for LogSoftmax.
1016 double beta;
1017 // uint8 inference params. Used even when beta defaults to 1.0.
1018 int32 input_multiplier;
1019 int32 input_left_shift;
1020 // Reverse scaling is only used by LogSoftmax.
1021 int32 reverse_scaling_divisor;
1022 int32 reverse_scaling_right_shift;
1023 int diff_min;
1024 int32_t zero_point;
1025 float scale;
1026 float* table;
1027 };
1028
1029 struct SpaceToBatchParams {
1030 // "Zero" padding for uint8 means padding with the output offset.
1031 int32 output_offset;
1032 };
1033
1034 struct SpaceToDepthParams {
1035 int32 block_size;
1036 };
1037
1038 struct SplitParams {
1039 // Graphs that split into, say, 2000 nodes are encountered. The indices in
1040 // OperatorEdges are of type uint16.
1041 uint16 num_split;
1042 int16 axis;
1043 };
1044
1045 struct SqueezeParams {
1046 int8 squeeze_dims_count;
1047 int32 squeeze_dims[4];
1048 };
1049
1050 struct StridedSliceParams {
1051 int8 start_indices_count;
1052 int32 start_indices[4];
1053 int8 stop_indices_count;
1054 int32 stop_indices[4];
1055 int8 strides_count;
1056 int32 strides[4];
1057
1058 int16 begin_mask;
1059 int16 ellipsis_mask;
1060 int16 end_mask;
1061 int16 new_axis_mask;
1062 int16 shrink_axis_mask;
1063 };
1064
1065 struct TanhParams {
1066 int32 input_zero_point;
1067 int32 input_range_radius;
1068 int32 input_multiplier;
1069 int input_left_shift;
1070 };
1071
1072 struct TransposeParams {
1073 int8 perm_count;
1074 int32 perm[4];
1075 };
1076
1077 struct UnpackParams {
1078 uint16 num_split;
1079 int16 axis;
1080 };
1081
1082 struct LeakyReluParams {
1083 float alpha;
1084 int32 input_offset;
1085 int32 alpha_offset;
1086 int32 output_offset;
1087 int32 output_multiplier;
1088 int output_shift;
1089 };
1090
1091 template <typename P>
SetActivationParams(float min,float max,P * params)1092 inline void SetActivationParams(float min, float max, P* params) {
1093 params->float_activation_min = min;
1094 params->float_activation_max = max;
1095 }
1096
1097 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)1098 inline void SetActivationParams(int32 min, int32 max, P* params) {
1099 params->quantized_activation_min = min;
1100 params->quantized_activation_max = max;
1101 }
1102
1103 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)1104 inline void GetActivationParams(const P& params, int32* min, int32* max) {
1105 *min = params.quantized_activation_min;
1106 *max = params.quantized_activation_max;
1107 }
1108
1109 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1110 inline void GetActivationParams(const P& params, float* min, float* max) {
1111 *min = params.float_activation_min;
1112 *max = params.float_activation_max;
1113 }
1114
1115 } // namespace tflite
1116
1117 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1118