1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <cstring>
19 #include <iterator>
20
21 // TODO: Remove once AOSP has external/absl setup.
22 #if __ANDROID__
23 #define ABSL_DEPRECATED(x)
24 #else
25 #include "absl/base/macros.h"
26 #endif // __ANDROID__
27 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
28
29 namespace tflite {
30
31 enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
32 enum class PaddingType : uint8 { kNone, kSame, kValid };
33
34 struct PaddingValues {
35 int16 width;
36 int16 height;
37 };
38
39 // This enumeration allows for non-default formats for the weights array
40 // of a fully-connected operator, allowing the use of special optimized
41 // runtime paths.
42 enum class FullyConnectedWeightsFormat : uint8 {
43 // Default format (flat 2D layout, the inner contiguous dimension
44 // is input_depth, the outer non-contiguous dimension is output_depth)
45 kDefault,
46 // Summary: optimized layout for fast CPU runtime implementation,
47 // aimed specifically at ARM CPUs at the moment, and specialized for
48 // 8-bit quantized layers.
49 //
50 // The use case we're concerned with here is: 8-bit quantization,
51 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
52 // a key application that drove this), very small batch size (e.g. 1 -- 4).
53 //
54 // Even with 8-bit quantization of weights, the performance of memory
55 // accesses to the weights can become the dominant issue when
56 // the batch size is small, so each weight value is used in only a few
57 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
58 // intensity. The specific issues that arise are of three kinds:
59 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
60 // bound. That's the "good" issue to run into.
61 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
62 // prefetched into the cache by the time we need it.
63 // (3) One may run into cache aliasing: multiple values that are
64 // pre-fetched, alias each other in the L1 cache (which typically
65 // has only 4-way set associativity in ARM CPUs) and thus evict
66 // each other before we get to using them.
67 //
68 // The point of this shuffling is to avoid issues (2) and (3) so that
69 // we get as fast as possible given only the hard constraint (1).
70 // This is achieved by turning the difficulty into a solution: the
71 // difficulty, that each value loaded from memory is used only in
72 // one kernel iteration, making this operation memory-intensive, hints at
73 // the solution, of shuffling the weights so that they are stored in the
74 // exact order as the kernel needs to load them, so that the memory
75 // accesses made by the kernel are trivial. This solves (2) because the
76 // trivial memory access pattern allows the CPU's automatic prefetching
77 // to perform very well (no need even for preload instructions), and this
78 // solves (3) because the values being loaded concurrently are now
79 // contiguous in the address space, thus don't alias each other in the cache.
80 //
81 // On ARM, we typically want our kernel to process a 4x16 block of weights
82 // at a time, because:
83 // - 16 is the number of bytes in a NEON register.
84 // - 4 is how many rows we need to handle concurrently in the kernel in
85 // order to have sufficient mutual independence of instructions to
86 // maximize arithmetic throughput.
87 //
88 // Finally, the 'Int8' part in the name refers to the fact that this
89 // weights format has each weights value encoded as a signed int8 value,
90 // even if the data type of the weights buffer is uint8. This is intended
91 // to save runtime kernels the effort to have to XOR the top bit of these
92 // bytes before using them in signed arithmetic, see this file for more
93 // explanations on the 'signed int8 trick' in matrix multiplication kernels:
94 //
95 // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
96 //
97 kShuffled4x16Int8,
98 };
99
100 // Quantization parameters, determining the mapping of quantized values
101 // to real values (i.e. determining how quantized values are mathematically
102 // interpreted).
103 //
104 // The correspondence is as follows:
105 //
106 // real_value = scale * (quantized_value - zero_point);
107 //
108 // In other words, zero_point designates which quantized value corresponds to
109 // the real 0 value, and scale designates the difference between the real values
110 // corresponding to consecutive quantized values differing by 1.
111 struct QuantizationParams {
112 int32 zero_point = 0;
113 double scale = 0.0;
114 };
115
116 template <int N>
117 struct Dims {
118 int sizes[N];
119 int strides[N];
120 };
121
122 class RuntimeShape {
123 public:
124 // Shapes with dimensions up to 4 are stored directly in the structure, while
125 // larger shapes are separately allocated.
126 static constexpr int kMaxSmallSize = 4;
127
128 RuntimeShape& operator=(RuntimeShape const&) = delete;
129
RuntimeShape()130 RuntimeShape() : size_(0) {}
131
RuntimeShape(int dimensions_count)132 explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
133 if (dimensions_count > kMaxSmallSize) {
134 dims_pointer_ = new int32[dimensions_count];
135 }
136 }
137
RuntimeShape(int shape_size,int32 value)138 RuntimeShape(int shape_size, int32 value) : size_(0) {
139 Resize(shape_size);
140 for (int i = 0; i < shape_size; ++i) {
141 SetDim(i, value);
142 }
143 }
144
RuntimeShape(int dimensions_count,const int32 * dims_data)145 RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
146 ReplaceWith(dimensions_count, dims_data);
147 }
148
RuntimeShape(const std::initializer_list<int> init_list)149 RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
150 BuildFrom(init_list);
151 }
152
153 // Avoid using this constructor. We should be able to delete it when C++17
154 // rolls out.
RuntimeShape(RuntimeShape const & other)155 RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
156 if (size_ > kMaxSmallSize) {
157 dims_pointer_ = new int32[size_];
158 }
159 std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
160 }
161
162 bool operator==(const RuntimeShape& comp) const {
163 return this->size_ == comp.size_ &&
164 std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
165 }
166
~RuntimeShape()167 ~RuntimeShape() {
168 if (size_ > kMaxSmallSize) {
169 delete[] dims_pointer_;
170 }
171 }
172
DimensionsCount()173 inline int32 DimensionsCount() const { return size_; }
Dims(int i)174 inline int32 Dims(int i) const {
175 TFLITE_DCHECK_GE(i, 0);
176 TFLITE_DCHECK_LT(i, size_);
177 return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
178 }
SetDim(int i,int32 val)179 inline void SetDim(int i, int32 val) {
180 TFLITE_DCHECK_GE(i, 0);
181 TFLITE_DCHECK_LT(i, size_);
182 if (size_ > kMaxSmallSize) {
183 dims_pointer_[i] = val;
184 } else {
185 dims_[i] = val;
186 }
187 }
188
DimsData()189 inline int32* DimsData() {
190 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
191 }
DimsData()192 inline const int32* DimsData() const {
193 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
194 }
195 // The caller must ensure that the shape is no bigger than 4-D.
DimsDataUpTo4D()196 inline const int32* DimsDataUpTo4D() const { return dims_; }
197
Resize(int dimensions_count)198 inline void Resize(int dimensions_count) {
199 if (size_ > kMaxSmallSize) {
200 delete[] dims_pointer_;
201 }
202 size_ = dimensions_count;
203 if (dimensions_count > kMaxSmallSize) {
204 dims_pointer_ = new int32[dimensions_count];
205 }
206 }
207
ReplaceWith(int dimensions_count,const int32 * dims_data)208 inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
209 Resize(dimensions_count);
210 int32* dst_dims = DimsData();
211 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
212 }
213
214 template <typename T>
BuildFrom(const T & src_iterable)215 inline void BuildFrom(const T& src_iterable) {
216 const int dimensions_count =
217 std::distance(src_iterable.begin(), src_iterable.end());
218 Resize(dimensions_count);
219 int32* data = DimsData();
220 for (auto it : src_iterable) {
221 *data = it;
222 ++data;
223 }
224 }
225
226 // This will probably be factored out. Old code made substantial use of 4-D
227 // shapes, and so this function is used to extend smaller shapes. Note that
228 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
229 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
230 // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)231 inline static RuntimeShape ExtendedShape(int new_shape_size,
232 const RuntimeShape& shape) {
233 return RuntimeShape(new_shape_size, shape, 1);
234 }
235
BuildFrom(const std::initializer_list<int> init_list)236 inline void BuildFrom(const std::initializer_list<int> init_list) {
237 BuildFrom<const std::initializer_list<int>>(init_list);
238 }
239
240 // Returns the total count of elements, that is the size when flattened into a
241 // vector.
FlatSize()242 inline int FlatSize() const {
243 int buffer_size = 1;
244 const int* dims_data = DimsData();
245 for (int i = 0; i < size_; i++) {
246 const int dim = dims_data[i];
247 TFLITE_DCHECK_GE(dim, 1);
248 buffer_size *= dim;
249 }
250 return buffer_size;
251 }
252
253 bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
254
255 private:
256 // For use only by ExtendedShape(), written to guarantee (return-value) copy
257 // elision in C++17.
258 // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)259 RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
260 : size_(0) {
261 TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
262 TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
263 Resize(new_shape_size);
264 const int size_increase = new_shape_size - shape.DimensionsCount();
265 for (int i = 0; i < size_increase; ++i) {
266 SetDim(i, pad_value);
267 }
268 std::memcpy(DimsData() + size_increase, shape.DimsData(),
269 sizeof(int32) * shape.DimensionsCount());
270 }
271
272 int32 size_;
273 union {
274 int32 dims_[kMaxSmallSize];
275 int32* dims_pointer_;
276 };
277 };
278
279 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)280 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
281 tflite::Dims<4> result;
282 const int dimensions_count = array_shape.DimensionsCount();
283 TFLITE_CHECK_LE(dimensions_count, 4);
284 int cum_prod = 1;
285 for (int i = 0; i < 4; i++) {
286 const int new_dim =
287 (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
288 result.sizes[i] = new_dim;
289 result.strides[i] = cum_prod;
290 cum_prod *= new_dim;
291 }
292 return result;
293 }
294
295 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)296 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
297 return RuntimeShape(
298 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
299 }
300
301 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)302 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
303 if (num_dims == 0) {
304 return false;
305 }
306 TFLITE_DCHECK(dims != nullptr);
307 TFLITE_DCHECK(current != nullptr);
308 int carry = 1;
309 for (int idx = num_dims - 1; idx >= 0; --idx) {
310 int current_val = current[idx] + carry;
311 TFLITE_DCHECK_GE(dims[idx], current_val);
312 if (dims[idx] == current_val) {
313 current[idx] = 0;
314 } else {
315 current[idx] = current_val;
316 carry = 0;
317 break;
318 }
319 }
320 return (carry == 0);
321 }
322
323 // Gets offset of index if reducing on axis. When reducing, the flattened offset
324 // will not change, if the input index changes on the given axis. For example,
325 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
326 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
327 // offset.
328 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)329 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
330 const int* index, const int num_axis,
331 const int* axis) {
332 if (num_dims == 0) {
333 return 0;
334 }
335 TFLITE_DCHECK(dims != nullptr);
336 TFLITE_DCHECK(index != nullptr);
337 size_t offset = 0;
338 for (int idx = 0; idx < num_dims; ++idx) {
339 // if we need to skip this axis
340 bool is_axis = false;
341 if (axis != nullptr) {
342 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
343 if (idx == axis[axis_idx]) {
344 is_axis = true;
345 break;
346 }
347 }
348 }
349 if (!is_axis) {
350 offset = offset * static_cast<size_t>(dims[idx]) +
351 static_cast<size_t>(index[idx]);
352 }
353 }
354 return offset;
355 }
356
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)357 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
358 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
359 const int* dims_data = shape.DimsDataUpTo4D();
360 TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
361 TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
362 TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
363 TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
364 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
365 }
366
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)367 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
368 TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
369 TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
370 TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
371 TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
372 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
373 i3 * dims.strides[3];
374 }
375
Offset(const Dims<4> & dims,int * index)376 inline int Offset(const Dims<4>& dims, int* index) {
377 return Offset(dims, index[0], index[1], index[2], index[3]);
378 }
379
Offset(const RuntimeShape & shape,int * index)380 inline int Offset(const RuntimeShape& shape, int* index) {
381 return Offset(shape, index[0], index[1], index[2], index[3]);
382 }
383
384 // Get array size, DCHECKing that the dim index is in range.
385 //
386 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
387 // already performs this check.
388 template <int N>
ArraySize(const Dims<N> & array,int index)389 int ArraySize(const Dims<N>& array, int index) {
390 TFLITE_DCHECK(index >= 0 && index < N);
391 return array.sizes[index];
392 }
393
394 // Get common array size, DCHECKing that they all agree.
395 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)396 int MatchingArraySize(const ArrayType1& array1, int index1,
397 const ArrayType2& array2, int index2) {
398 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
399 return ArraySize(array1, index1);
400 }
401
402 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)403 int MatchingArraySize(const ArrayType1& array1, int index1,
404 const ArrayType2& array2, int index2, Args... args) {
405 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
406 return MatchingArraySize(array1, index1, args...);
407 }
408
409 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)410 inline int MatchingDim(const RuntimeShape& shape1, int index1,
411 const RuntimeShape& shape2, int index2) {
412 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
413 return shape1.Dims(index1);
414 }
415
416 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)417 int MatchingDim(const RuntimeShape& shape1, int index1,
418 const RuntimeShape& shape2, int index2, Args... args) {
419 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
420 return MatchingDim(shape1, index1, args...);
421 }
422
423 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
424 template <int N>
FlatSize(const Dims<N> & dims)425 inline int FlatSize(const Dims<N>& dims) {
426 int flat_size = 1;
427 for (int i = 0; i < N; ++i) {
428 flat_size *= dims.sizes[i];
429 }
430 return flat_size;
431 }
432
433 ABSL_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)434 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
435 return FlatSize(dims);
436 }
437
438 // Flat size calculation, checking that dimensions match with one or more other
439 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)440 inline int MatchingFlatSize(const RuntimeShape& shape,
441 const RuntimeShape& check_shape_0) {
442 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
443 const int dims_count = shape.DimensionsCount();
444 for (int i = 0; i < dims_count; ++i) {
445 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
446 }
447 return shape.FlatSize();
448 }
449
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)450 inline int MatchingFlatSize(const RuntimeShape& shape,
451 const RuntimeShape& check_shape_0,
452 const RuntimeShape& check_shape_1) {
453 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
454 const int dims_count = shape.DimensionsCount();
455 for (int i = 0; i < dims_count; ++i) {
456 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
457 }
458 return MatchingFlatSize(shape, check_shape_1);
459 }
460
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)461 inline int MatchingFlatSize(const RuntimeShape& shape,
462 const RuntimeShape& check_shape_0,
463 const RuntimeShape& check_shape_1,
464 const RuntimeShape& check_shape_2) {
465 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
466 const int dims_count = shape.DimensionsCount();
467 for (int i = 0; i < dims_count; ++i) {
468 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
469 }
470 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
471 }
472
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)473 inline int MatchingFlatSize(const RuntimeShape& shape,
474 const RuntimeShape& check_shape_0,
475 const RuntimeShape& check_shape_1,
476 const RuntimeShape& check_shape_2,
477 const RuntimeShape& check_shape_3) {
478 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
479 const int dims_count = shape.DimensionsCount();
480 for (int i = 0; i < dims_count; ++i) {
481 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
482 }
483 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
484 }
485
486 // Flat size calculation, checking that dimensions match with one or more other
487 // arrays.
488 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)489 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
490 for (int i = 0; i < N; ++i) {
491 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
492 }
493 return FlatSize(dims);
494 }
495
496 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)497 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
498 const Dims<N>& check_dims_1) {
499 for (int i = 0; i < N; ++i) {
500 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
501 }
502 return MatchingFlatSize(dims, check_dims_1);
503 }
504
505 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)506 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
507 const Dims<N>& check_dims_1,
508 const Dims<N>& check_dims_2) {
509 for (int i = 0; i < N; ++i) {
510 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
511 }
512 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
513 }
514
515 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)516 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
517 const Dims<N>& check_dims_1,
518 const Dims<N>& check_dims_2,
519 const Dims<N>& check_dims_3) {
520 for (int i = 0; i < N; ++i) {
521 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
522 }
523 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
524 }
525
526 // Data is required to be contiguous, and so many operators can use either the
527 // full array flat size or the flat size with one dimension skipped (commonly
528 // the depth).
529 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)530 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
531 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
532 int flat_size = 1;
533 for (int i = 0; i < N; ++i) {
534 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
535 }
536 return flat_size;
537 }
538
539 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
540 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)541 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
542 const Dims<N>& check_dims_0) {
543 for (int i = 0; i < N; ++i) {
544 if (i != skip_dim) {
545 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
546 }
547 }
548 return FlatSizeSkipDim(dims, skip_dim);
549 }
550
551 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)552 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
553 const Dims<N>& check_dims_0,
554 const Dims<N>& check_dims_1) {
555 for (int i = 0; i < N; ++i) {
556 if (i != skip_dim) {
557 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
558 }
559 }
560 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
561 }
562
563 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)564 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
565 const Dims<N>& check_dims_0,
566 const Dims<N>& check_dims_1,
567 const Dims<N>& check_dims_2) {
568 for (int i = 0; i < N; ++i) {
569 if (i != skip_dim) {
570 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
571 }
572 }
573 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
574 }
575
576 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)577 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
578 const Dims<N>& check_dims_0,
579 const Dims<N>& check_dims_1,
580 const Dims<N>& check_dims_2,
581 const Dims<N>& check_dims_3) {
582 for (int i = 0; i < N; ++i) {
583 if (i != skip_dim) {
584 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
585 }
586 }
587 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
588 check_dims_3);
589 }
590
591 // Data is required to be contiguous, and so many operators can use either the
592 // full array flat size or the flat size with one dimension skipped (commonly
593 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)594 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
595 const int dims_count = shape.DimensionsCount();
596 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
597 const auto* dims_data = shape.DimsData();
598 int flat_size = 1;
599 for (int i = 0; i < dims_count; ++i) {
600 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
601 }
602 return flat_size;
603 }
604
605 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)606 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
607 const RuntimeShape& check_shape_0) {
608 const int dims_count = shape.DimensionsCount();
609 for (int i = 0; i < dims_count; ++i) {
610 if (i != skip_dim) {
611 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
612 }
613 }
614 return FlatSizeSkipDim(shape, skip_dim);
615 }
616
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)617 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
618 const RuntimeShape& check_shape_0,
619 const RuntimeShape& check_shape_1) {
620 const int dims_count = shape.DimensionsCount();
621 for (int i = 0; i < dims_count; ++i) {
622 if (i != skip_dim) {
623 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
624 }
625 }
626 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
627 }
628
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)629 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
630 const RuntimeShape& check_shape_0,
631 const RuntimeShape& check_shape_1,
632 const RuntimeShape& check_shape_2) {
633 const int dims_count = shape.DimensionsCount();
634 for (int i = 0; i < dims_count; ++i) {
635 if (i != skip_dim) {
636 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
637 }
638 }
639 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
640 }
641
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)642 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
643 const RuntimeShape& check_shape_0,
644 const RuntimeShape& check_shape_1,
645 const RuntimeShape& check_shape_2,
646 const RuntimeShape& check_shape_3) {
647 const int dims_count = shape.DimensionsCount();
648 for (int i = 0; i < dims_count; ++i) {
649 if (i != skip_dim) {
650 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
651 }
652 }
653 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
654 check_shape_3);
655 }
656
657 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)658 bool IsPackedWithoutStrides(const Dims<N>& dims) {
659 int expected_stride = 1;
660 for (int d = 0; d < N; d++) {
661 if (dims.strides[d] != expected_stride) return false;
662 expected_stride *= dims.sizes[d];
663 }
664 return true;
665 }
666
667 template <int N>
ComputeStrides(Dims<N> * dims)668 void ComputeStrides(Dims<N>* dims) {
669 dims->strides[0] = 1;
670 for (int d = 1; d < N; d++) {
671 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
672 }
673 }
674
675 enum class BroadcastableOpCategory : uint8 {
676 kNone,
677 kNonBroadcast, // Matching input shapes.
678 kFirstInputBroadcastsFast, // Fivefold nested loops.
679 kSecondInputBroadcastsFast, // Fivefold nested loops.
680 kGenericBroadcast, // Fall-back.
681 };
682
683 struct MinMax {
684 float min;
685 float max;
686 };
687 static_assert(sizeof(MinMax) == 8, "");
688
689 struct ActivationParams {
690 FusedActivationFunctionType activation_type;
691 // uint8, etc, activation params.
692 int32 quantized_activation_min;
693 int32 quantized_activation_max;
694 };
695
696 // For Add, Sub, Mul ops.
697 struct ArithmeticParams {
698 // Shape dependent / common to data / op types.
699 BroadcastableOpCategory broadcast_category;
700 // uint8 inference params.
701 int32 input1_offset;
702 int32 input2_offset;
703 int32 output_offset;
704 int32 output_multiplier;
705 int output_shift;
706 // Add / Sub, not Mul, uint8 inference params.
707 int left_shift;
708 int32 input1_multiplier;
709 int input1_shift;
710 int32 input2_multiplier;
711 int input2_shift;
712 // uint8, etc, activation params.
713 int32 quantized_activation_min;
714 int32 quantized_activation_max;
715 // float activation params.
716 float float_activation_min;
717 float float_activation_max;
718
719 // Processed output dimensions.
720 // Let input "a" be the one that broadcasts in the faster-changing dimension.
721 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
722 // {b0, b1, b2, b3, b4},
723 // broadcast_shape[4] = b0 = a0.
724 // broadcast_shape[3] = b1; a1 = 1.
725 // broadcast_shape[2] = b2 = a2.
726 // broadcast_shape[1] = a3; b3 = 1.
727 // broadcast_shape[0] = b4 = a4.
728 int broadcast_shape[5];
729 };
730
731 struct ConcatenationParams {
732 int8 axis;
733 const int32* input_zeropoint;
734 const float* input_scale;
735 uint16 inputs_count;
736 int32 output_zeropoint;
737 float output_scale;
738 };
739
740 struct ComparisonParams {
741 // uint8 inference params.
742 int left_shift;
743 int32 input1_offset;
744 int32 input1_multiplier;
745 int input1_shift;
746 int32 input2_offset;
747 int32 input2_multiplier;
748 int input2_shift;
749 // Shape dependent / common to inference types.
750 bool is_broadcast;
751 };
752
753 struct ConvParams {
754 PaddingType padding_type;
755 PaddingValues padding_values;
756 // TODO(starka): This was just "stride", so check that width+height is OK.
757 int16 stride_width;
758 int16 stride_height;
759 int16 dilation_width_factor;
760 int16 dilation_height_factor;
761 // uint8 inference params.
762 // TODO(b/65838351): Use smaller types if appropriate.
763 int32 input_offset;
764 int32 weights_offset;
765 int32 output_offset;
766 int32 output_multiplier;
767 int output_shift;
768 // uint8, etc, activation params.
769 int32 quantized_activation_min;
770 int32 quantized_activation_max;
771 // float activation params.
772 float float_activation_min;
773 float float_activation_max;
774 };
775
776 struct DepthToSpaceParams {
777 int32 block_size;
778 };
779
780 struct DepthwiseParams {
781 PaddingType padding_type;
782 PaddingValues padding_values;
783 int16 stride_width;
784 int16 stride_height;
785 int16 dilation_width_factor;
786 int16 dilation_height_factor;
787 int16 depth_multiplier;
788 // uint8 inference params.
789 // TODO(b/65838351): Use smaller types if appropriate.
790 int32 input_offset;
791 int32 weights_offset;
792 int32 output_offset;
793 int32 output_multiplier;
794 int output_shift;
795 // uint8, etc, activation params.
796 int32 quantized_activation_min;
797 int32 quantized_activation_max;
798 // float activation params.
799 float float_activation_min;
800 float float_activation_max;
801 };
802
803 struct DequantizationParams {
804 double scale;
805 int32 zero_point;
806 };
807
808 struct FakeQuantParams {
809 MinMax minmax;
810 int32 num_bits;
811 };
812
813 struct FullyConnectedParams {
814 // uint8 inference params.
815 // TODO(b/65838351): Use smaller types if appropriate.
816 int32 input_offset;
817 int32 weights_offset;
818 int32 output_offset;
819 int32 output_multiplier;
820 int output_shift;
821 // uint8, etc, activation params.
822 int32 quantized_activation_min;
823 int32 quantized_activation_max;
824 // float activation params.
825 float float_activation_min;
826 float float_activation_max;
827 FullyConnectedWeightsFormat weights_format;
828 };
829
830 struct GatherParams {
831 int16 input_rank;
832 int16 axis;
833 };
834
835 struct L2NormalizationParams {
836 // uint8 inference params.
837 int32 input_zero_point;
838 };
839
840 struct LocalResponseNormalizationParams {
841 int32 range;
842 double bias;
843 double alpha;
844 double beta;
845 };
846
847 struct LogisticParams {
848 // uint8 inference params.
849 int32 input_zero_point;
850 int32 input_range_radius;
851 int32 input_multiplier;
852 int input_left_shift;
853 };
854
855 struct LstmCellParams {
856 int32 weights_zero_point;
857 int32 accum_multiplier;
858 int accum_shift;
859 int state_integer_bits;
860 };
861
862 struct MeanParams {
863 int8 axis_count;
864 int16 axis[4];
865 };
866
867 struct PadParams {
868 int8 left_padding_count;
869 int32 left_padding[4];
870 int8 right_padding_count;
871 int32 right_padding[4];
872 };
873
874 struct PoolParams {
875 FusedActivationFunctionType activation;
876 PaddingType padding_type;
877 PaddingValues padding_values;
878 int stride_height;
879 int stride_width;
880 int filter_height;
881 int filter_width;
882 // uint8, etc, activation params.
883 int32 quantized_activation_min;
884 int32 quantized_activation_max;
885 // float activation params.
886 float float_activation_min;
887 float float_activation_max;
888 };
889
890 struct ReshapeParams {
891 int8 shape_count;
892 int32 shape[4];
893 };
894
895 struct ResizeBilinearParams {
896 bool align_corners;
897 };
898
899 struct SliceParams {
900 int8 begin_count;
901 int32 begin[4];
902 int8 size_count;
903 int32 size[4];
904 };
905
906 struct SoftmaxParams {
907 // beta is not really used (not a Tensorflow parameter) and not implemented
908 // for LogSoftmax.
909 double beta;
910 // uint8 inference params. Used even when beta defaults to 1.0.
911 int32 input_multiplier;
912 int32 input_left_shift;
913 // Reverse scaling is only used by LogSoftmax.
914 int32 reverse_scaling_divisor;
915 int32 reverse_scaling_right_shift;
916 int diff_min;
917 };
918
919 struct SpaceToBatchParams {
920 // "Zero" padding for uint8 means padding with the output offset.
921 int32 output_offset;
922 };
923
924 struct SpaceToDepthParams {
925 int32 block_size;
926 };
927
928 struct SplitParams {
929 // Graphs that split into, say, 2000 nodes are encountered. The indices in
930 // OperatorEdges are of type uint16.
931 uint16 num_split;
932 int16 axis;
933 };
934
935 struct SqueezeParams {
936 int8 squeeze_dims_count;
937 int32 squeeze_dims[4];
938 };
939
940 struct StridedSliceParams {
941 int8 start_indices_count;
942 int16 start_indices[4];
943 int8 stop_indices_count;
944 int16 stop_indices[4];
945 int8 strides_count;
946 int16 strides[4];
947
948 int16 begin_mask;
949 int16 ellipsis_mask;
950 int16 end_mask;
951 int16 new_axis_mask;
952 int16 shrink_axis_mask;
953 };
954
955 struct TanhParams {
956 int32 input_zero_point;
957 int32 input_range_radius;
958 int32 input_multiplier;
959 int input_left_shift;
960 };
961
962 struct TransposeParams {
963 int8 perm_count;
964 int32 perm[4];
965 };
966
967 template <typename P>
SetActivationParams(float min,float max,P * params)968 inline void SetActivationParams(float min, float max, P* params) {
969 params->float_activation_min = min;
970 params->float_activation_max = max;
971 }
972
973 template <typename P>
SetActivationParams(int32 min,int32 max,P * params)974 inline void SetActivationParams(int32 min, int32 max, P* params) {
975 params->quantized_activation_min = min;
976 params->quantized_activation_max = max;
977 }
978
979 template <typename P>
GetActivationParams(const P & params,int32 * min,int32 * max)980 inline void GetActivationParams(const P& params, int32* min, int32* max) {
981 *min = params.quantized_activation_min;
982 *max = params.quantized_activation_max;
983 }
984
985 template <typename P>
GetActivationParams(const P & params,float * min,float * max)986 inline void GetActivationParams(const P& params, float* min, float* max) {
987 *min = params.float_activation_min;
988 *max = params.float_activation_max;
989 }
990
991 } // namespace tflite
992
993 #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
994