1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/runtime_shape.h"
25
26 namespace tflite {
27
28 enum class FusedActivationFunctionType : uint8_t {
29 kNone,
30 kRelu6,
31 kRelu1,
32 kRelu
33 };
34 enum class PaddingType : uint8_t { kNone, kSame, kValid };
35
36 struct PaddingValues {
37 int16_t width;
38 int16_t height;
39 // offset is used for calculating "remaining" padding, for example, `width`
40 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
41 // 1 + 1 = 2.
42 int16_t width_offset;
43 // Same as width_offset except it's over the height dimension.
44 int16_t height_offset;
45 };
46
47 struct Padding3DValues {
48 int16_t width;
49 int16_t height;
50 int16_t depth;
51 // offset is used for calculating "remaining" padding, for example, `width`
52 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
53 // 1 + 1 = 2.
54 int16_t width_offset;
55 // Same as width_offset except it's over the height dimension.
56 int16_t height_offset;
57 // Same as width_offset except it's over the depth dimension.
58 int16_t depth_offset;
59 };
60
61 // This enumeration allows for non-default formats for the weights array
62 // of a fully-connected operator, allowing the use of special optimized
63 // runtime paths.
64 enum class FullyConnectedWeightsFormat : uint8_t {
65 // Default format (flat 2D layout, the inner contiguous dimension
66 // is input_depth, the outer non-contiguous dimension is output_depth)
67 kDefault,
68 // Summary: optimized layout for fast CPU runtime implementation,
69 // aimed specifically at ARM CPUs at the moment, and specialized for
70 // 8-bit quantized layers.
71 //
72 // The use case we're concerned with here is: 8-bit quantization,
73 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
74 // a key application that drove this), very small batch size (e.g. 1 -- 4).
75 //
76 // Even with 8-bit quantization of weights, the performance of memory
77 // accesses to the weights can become the dominant issue when
78 // the batch size is small, so each weight value is used in only a few
79 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
80 // intensity. The specific issues that arise are of three kinds:
81 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
82 // bound. That's the "good" issue to run into.
83 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
84 // prefetched into the cache by the time we need it.
85 // (3) One may run into cache aliasing: multiple values that are
86 // pre-fetched, alias each other in the L1 cache (which typically
87 // has only 4-way set associativity in ARM CPUs) and thus evict
88 // each other before we get to using them.
89 //
90 // The point of this shuffling is to avoid issues (2) and (3) so that
91 // we get as fast as possible given only the hard constraint (1).
92 // This is achieved by turning the difficulty into a solution: the
93 // difficulty, that each value loaded from memory is used only in
94 // one kernel iteration, making this operation memory-intensive, hints at
95 // the solution, of shuffling the weights so that they are stored in the
96 // exact order as the kernel needs to load them, so that the memory
97 // accesses made by the kernel are trivial. This solves (2) because the
98 // trivial memory access pattern allows the CPU's automatic prefetching
99 // to perform very well (no need even for preload instructions), and this
100 // solves (3) because the values being loaded concurrently are now
101 // contiguous in the address space, thus don't alias each other in the cache.
102 //
103 // On ARM, we typically want our kernel to process a 4x16 block of weights
104 // at a time, because:
105 // - 16 is the number of bytes in a NEON register.
106 // - 4 is how many rows we need to handle concurrently in the kernel in
107 // order to have sufficient mutual independence of instructions to
108 // maximize arithmetic throughput.
109 //
110 // Finally, the 'Int8' part in the name refers to the fact that this
111 // weights format has each weights value encoded as a signed int8_t value,
112 // even if the data type of the weights buffer is uint8_t. This is intended
113 // to save runtime kernels the effort to have to XOR the top bit of these
114 // bytes before using them in signed arithmetic, see this file for more
115 // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
116 //
117 // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
118 //
119 kShuffled4x16Int8,
120 };
121
122 // Quantization parameters, determining the mapping of quantized values
123 // to real values (i.e. determining how quantized values are mathematically
124 // interpreted).
125 //
126 // The correspondence is as follows:
127 //
128 // real_value = scale * (quantized_value - zero_point);
129 //
130 // In other words, zero_point designates which quantized value corresponds to
131 // the real 0 value, and scale designates the difference between the real values
132 // corresponding to consecutive quantized values differing by 1.
133 struct QuantizationParams {
134 int32_t zero_point = 0;
135 double scale = 0.0;
136 };
137
138 inline bool operator==(const QuantizationParams& qp1,
139 const QuantizationParams& qp2) {
140 return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
141 }
142
143 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)144 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
145 if (num_dims == 0) {
146 return false;
147 }
148 TFLITE_DCHECK(dims != nullptr);
149 TFLITE_DCHECK(current != nullptr);
150 int carry = 1;
151 for (int idx = num_dims - 1; idx >= 0; --idx) {
152 int current_val = current[idx] + carry;
153 TFLITE_DCHECK_GE(dims[idx], current_val);
154 if (dims[idx] == current_val) {
155 current[idx] = 0;
156 } else {
157 current[idx] = current_val;
158 carry = 0;
159 break;
160 }
161 }
162 return (carry == 0);
163 }
164
165 // Gets offset of index if reducing on axis. When reducing, the flattened offset
166 // will not change, if the input index changes on the given axis. For example,
167 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
168 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
169 // offset.
170 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)171 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
172 const int* index, const int num_axis,
173 const int* axis) {
174 if (num_dims == 0) {
175 return 0;
176 }
177 TFLITE_DCHECK(dims != nullptr);
178 TFLITE_DCHECK(index != nullptr);
179 size_t offset = 0;
180 for (int idx = 0; idx < num_dims; ++idx) {
181 // if we need to skip this axis
182 bool is_axis = false;
183 if (axis != nullptr) {
184 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
185 if (idx == axis[axis_idx]) {
186 is_axis = true;
187 break;
188 }
189 }
190 }
191 if (!is_axis) {
192 offset = offset * static_cast<size_t>(dims[idx]) +
193 static_cast<size_t>(index[idx]);
194 }
195 }
196 return offset;
197 }
198
199 // Since tensors with '0' in their shape are valid in TF, these offset functions
200 // allow that as long as the corresponding index is also 0. It is upto the
201 // calling ops to ensure that they perform verification checks on tensor shapes
202 // if they don't support a particular behavior.
203
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)204 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
205 TFLITE_DCHECK((i0 == 0 && dims.sizes[0] == 0) ||
206 (i0 >= 0 && i0 < dims.sizes[0]));
207 TFLITE_DCHECK((i1 == 0 && dims.sizes[1] == 0) ||
208 (i1 >= 0 && i1 < dims.sizes[1]));
209 TFLITE_DCHECK((i2 == 0 && dims.sizes[2] == 0) ||
210 (i2 >= 0 && i2 < dims.sizes[2]));
211 TFLITE_DCHECK((i3 == 0 && dims.sizes[3] == 0) ||
212 (i3 >= 0 && i3 < dims.sizes[3]));
213 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
214 i3 * dims.strides[3];
215 }
216
Offset(const Dims<4> & dims,int * index)217 inline int Offset(const Dims<4>& dims, int* index) {
218 return Offset(dims, index[0], index[1], index[2], index[3]);
219 }
220
221 // Get array size, DCHECKing that the dim index is in range.
222 //
223 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
224 // already performs this check.
225 template <int N>
ArraySize(const Dims<N> & array,int index)226 int ArraySize(const Dims<N>& array, int index) {
227 TFLITE_DCHECK(index >= 0 && index < N);
228 return array.sizes[index];
229 }
230
231 // Get common array size, DCHECKing that they all agree.
232 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)233 int MatchingArraySize(const ArrayType1& array1, int index1,
234 const ArrayType2& array2, int index2) {
235 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
236 return ArraySize(array1, index1);
237 }
238
239 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)240 int MatchingArraySize(const ArrayType1& array1, int index1,
241 const ArrayType2& array2, int index2, Args... args) {
242 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
243 return MatchingArraySize(array1, index1, args...);
244 }
245
246 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)247 inline int MatchingDim(const RuntimeShape& shape1, int index1,
248 const RuntimeShape& shape2, int index2) {
249 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
250 return std::min(shape1.Dims(index1), shape2.Dims(index2));
251 }
252
253 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)254 int MatchingDim(const RuntimeShape& shape1, int index1,
255 const RuntimeShape& shape2, int index2, Args... args) {
256 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
257 return MatchingDim(shape1, index1, args...);
258 }
259
260 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
261 template <int N>
FlatSize(const Dims<N> & dims)262 inline int FlatSize(const Dims<N>& dims) {
263 int flat_size = 1;
264 for (int i = 0; i < N; ++i) {
265 flat_size *= dims.sizes[i];
266 }
267 return flat_size;
268 }
269
270 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)271 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
272 return FlatSize(dims);
273 }
274
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)275 inline int MatchingElementsSize(const RuntimeShape& shape,
276 const RuntimeShape& check_shape_0) {
277 const int size_1 = shape.FlatSize();
278 const int size_2 = check_shape_0.FlatSize();
279 TFLITE_CHECK_EQ(size_1, size_2);
280 return size_1;
281 }
282
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)283 inline int MatchingElementsSize(const RuntimeShape& shape,
284 const RuntimeShape& check_shape_0,
285 const RuntimeShape& check_shape_1) {
286 const int size_1 = shape.FlatSize();
287 const int size_2 = check_shape_0.FlatSize();
288 const int size_3 = check_shape_1.FlatSize();
289 TFLITE_CHECK_EQ(size_1, size_2);
290 TFLITE_CHECK_EQ(size_2, size_3);
291 return size_1;
292 }
293
294 // Flat size calculation, checking that dimensions match with one or more other
295 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)296 inline int MatchingFlatSize(const RuntimeShape& shape,
297 const RuntimeShape& check_shape_0) {
298 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
299 const int dims_count = shape.DimensionsCount();
300 for (int i = 0; i < dims_count; ++i) {
301 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
302 }
303 return shape.FlatSize();
304 }
305
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)306 inline int MatchingFlatSize(const RuntimeShape& shape,
307 const RuntimeShape& check_shape_0,
308 const RuntimeShape& check_shape_1) {
309 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
310 const int dims_count = shape.DimensionsCount();
311 for (int i = 0; i < dims_count; ++i) {
312 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
313 }
314 return MatchingFlatSize(shape, check_shape_1);
315 }
316
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)317 inline int MatchingFlatSize(const RuntimeShape& shape,
318 const RuntimeShape& check_shape_0,
319 const RuntimeShape& check_shape_1,
320 const RuntimeShape& check_shape_2) {
321 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
322 const int dims_count = shape.DimensionsCount();
323 for (int i = 0; i < dims_count; ++i) {
324 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
325 }
326 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
327 }
328
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)329 inline int MatchingFlatSize(const RuntimeShape& shape,
330 const RuntimeShape& check_shape_0,
331 const RuntimeShape& check_shape_1,
332 const RuntimeShape& check_shape_2,
333 const RuntimeShape& check_shape_3) {
334 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
335 const int dims_count = shape.DimensionsCount();
336 for (int i = 0; i < dims_count; ++i) {
337 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
338 }
339 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
340 }
341
342 // Flat size calculation, checking that dimensions match with one or more other
343 // arrays.
344 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)345 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
346 for (int i = 0; i < N; ++i) {
347 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
348 }
349 return FlatSize(dims);
350 }
351
352 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)353 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
354 const Dims<N>& check_dims_1) {
355 for (int i = 0; i < N; ++i) {
356 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
357 }
358 return MatchingFlatSize(dims, check_dims_1);
359 }
360
361 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)362 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
363 const Dims<N>& check_dims_1,
364 const Dims<N>& check_dims_2) {
365 for (int i = 0; i < N; ++i) {
366 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
367 }
368 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
369 }
370
371 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)372 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
373 const Dims<N>& check_dims_1,
374 const Dims<N>& check_dims_2,
375 const Dims<N>& check_dims_3) {
376 for (int i = 0; i < N; ++i) {
377 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
378 }
379 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
380 }
381
382 // Flat size calculation, checking if their extended shapes match.
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)383 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
384 const RuntimeShape& check_shape_0) {
385 const int shape_dims = shape.DimensionsCount();
386 const int check_shape_0_dims = check_shape_0.DimensionsCount();
387 const int min_dims = std::min(shape_dims, check_shape_0_dims);
388
389 for (int i = 0; i < min_dims; ++i) {
390 TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i),
391 check_shape_0.Dims(check_shape_0_dims - 1 - i));
392 }
393 for (int i = min_dims; i < shape_dims; ++i) {
394 TFLITE_DCHECK_EQ(shape.Dims(shape_dims - 1 - i), 1);
395 }
396 for (int i = min_dims; i < check_shape_0_dims; ++i) {
397 TFLITE_DCHECK_EQ(check_shape_0.Dims(check_shape_0_dims - 1 - i), 1);
398 }
399 return shape.FlatSize();
400 }
401
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)402 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
403 const RuntimeShape& check_shape_0,
404 const RuntimeShape& check_shape_1) {
405 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
406 TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1),
407 flat_size);
408 return flat_size;
409 }
410
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)411 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
412 const RuntimeShape& check_shape_0,
413 const RuntimeShape& check_shape_1,
414 const RuntimeShape& check_shape_2) {
415 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
416 TFLITE_DCHECK_EQ(
417 MatchingExtendedShapeFlatSize(shape, check_shape_1, check_shape_2),
418 flat_size);
419 return flat_size;
420 }
421
MatchingExtendedShapeFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)422 inline int MatchingExtendedShapeFlatSize(const RuntimeShape& shape,
423 const RuntimeShape& check_shape_0,
424 const RuntimeShape& check_shape_1,
425 const RuntimeShape& check_shape_2,
426 const RuntimeShape& check_shape_3) {
427 const int flat_size = MatchingExtendedShapeFlatSize(shape, check_shape_0);
428 TFLITE_DCHECK_EQ(MatchingExtendedShapeFlatSize(shape, check_shape_1,
429 check_shape_2, check_shape_3),
430 flat_size);
431 return flat_size;
432 }
433
434 // Data is required to be contiguous, and so many operators can use either the
435 // full array flat size or the flat size with one dimension skipped (commonly
436 // the depth).
437 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)438 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
439 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
440 int flat_size = 1;
441 for (int i = 0; i < N; ++i) {
442 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
443 }
444 return flat_size;
445 }
446
447 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
448 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)449 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
450 const Dims<N>& check_dims_0) {
451 for (int i = 0; i < N; ++i) {
452 if (i != skip_dim) {
453 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
454 }
455 }
456 return FlatSizeSkipDim(dims, skip_dim);
457 }
458
459 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)460 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
461 const Dims<N>& check_dims_0,
462 const Dims<N>& check_dims_1) {
463 for (int i = 0; i < N; ++i) {
464 if (i != skip_dim) {
465 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
466 }
467 }
468 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
469 }
470
471 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)472 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
473 const Dims<N>& check_dims_0,
474 const Dims<N>& check_dims_1,
475 const Dims<N>& check_dims_2) {
476 for (int i = 0; i < N; ++i) {
477 if (i != skip_dim) {
478 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
479 }
480 }
481 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
482 }
483
484 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)485 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
486 const Dims<N>& check_dims_0,
487 const Dims<N>& check_dims_1,
488 const Dims<N>& check_dims_2,
489 const Dims<N>& check_dims_3) {
490 for (int i = 0; i < N; ++i) {
491 if (i != skip_dim) {
492 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
493 }
494 }
495 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
496 check_dims_3);
497 }
498
499 // Data is required to be contiguous, and so many operators can use either the
500 // full array flat size or the flat size with one dimension skipped (commonly
501 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)502 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
503 const int dims_count = shape.DimensionsCount();
504 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
505 const auto* dims_data = shape.DimsData();
506 int flat_size = 1;
507 for (int i = 0; i < dims_count; ++i) {
508 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
509 }
510 return flat_size;
511 }
512
513 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)514 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
515 const RuntimeShape& check_shape_0) {
516 const int dims_count = shape.DimensionsCount();
517 for (int i = 0; i < dims_count; ++i) {
518 if (i != skip_dim) {
519 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
520 }
521 }
522 return FlatSizeSkipDim(shape, skip_dim);
523 }
524
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)525 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
526 const RuntimeShape& check_shape_0,
527 const RuntimeShape& check_shape_1) {
528 const int dims_count = shape.DimensionsCount();
529 for (int i = 0; i < dims_count; ++i) {
530 if (i != skip_dim) {
531 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
532 }
533 }
534 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
535 }
536
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)537 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
538 const RuntimeShape& check_shape_0,
539 const RuntimeShape& check_shape_1,
540 const RuntimeShape& check_shape_2) {
541 const int dims_count = shape.DimensionsCount();
542 for (int i = 0; i < dims_count; ++i) {
543 if (i != skip_dim) {
544 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
545 }
546 }
547 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
548 }
549
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)550 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
551 const RuntimeShape& check_shape_0,
552 const RuntimeShape& check_shape_1,
553 const RuntimeShape& check_shape_2,
554 const RuntimeShape& check_shape_3) {
555 const int dims_count = shape.DimensionsCount();
556 for (int i = 0; i < dims_count; ++i) {
557 if (i != skip_dim) {
558 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
559 }
560 }
561 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
562 check_shape_3);
563 }
564
565 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)566 bool IsPackedWithoutStrides(const Dims<N>& dims) {
567 int expected_stride = 1;
568 for (int d = 0; d < N; d++) {
569 if (dims.strides[d] != expected_stride) return false;
570 expected_stride *= dims.sizes[d];
571 }
572 return true;
573 }
574
575 template <int N>
ComputeStrides(Dims<N> * dims)576 void ComputeStrides(Dims<N>* dims) {
577 dims->strides[0] = 1;
578 for (int d = 1; d < N; d++) {
579 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
580 }
581 }
582
583 enum class BroadcastableOpCategory : uint8_t {
584 kNone,
585 kNonBroadcast, // Matching input shapes.
586 kFirstInputBroadcastsFast, // Fivefold nested loops.
587 kSecondInputBroadcastsFast, // Fivefold nested loops.
588 kGenericBroadcast, // Fall-back.
589 };
590
591 struct MinMax {
592 float min;
593 float max;
594 };
595 static_assert(sizeof(MinMax) == 8, "");
596
597 struct ActivationParams {
598 FusedActivationFunctionType activation_type;
599 // uint8_t, etc, activation params.
600 int32_t quantized_activation_min;
601 int32_t quantized_activation_max;
602 };
603
604 struct ReluParams : public ActivationParams {
605 int32_t input_offset;
606 int32_t output_offset;
607 int32_t output_multiplier;
608 int output_shift;
609 };
610
611 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
612 // op for pattern-specific optimization.
613 enum class ResizingCategory : uint8_t {
614 kNone,
615 kImageStyle, // 4D, operating on inner dimensions, say {0, a, b, 0}.
616 kGenericResize,
617 };
618
619 // For Add, Sub, Mul ops.
620 struct ArithmeticParams {
621 // Shape dependent / common to data / op types.
622 BroadcastableOpCategory broadcast_category;
623 // uint8_t inference params.
624 int32_t input1_offset;
625 int32_t input2_offset;
626 int32_t output_offset;
627 int32_t output_multiplier;
628 int output_shift;
629 // Add / Sub, not Mul, uint8_t inference params.
630 int left_shift;
631 int32_t input1_multiplier;
632 int input1_shift;
633 int32_t input2_multiplier;
634 int input2_shift;
635
636 // TODO(b/158622529): Union the following activation params.
637 // uint8_t, etc, activation params.
638 int32_t quantized_activation_min;
639 int32_t quantized_activation_max;
640 // float activation params.
641 float float_activation_min;
642 float float_activation_max;
643 // int64_t activation params.
644 int64_t int64_activation_min;
645 int64_t int64_activation_max;
646
647 // Processed output dimensions.
648 // Let input "a" be the one that broadcasts in the faster-changing dimension.
649 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
650 // {b0, b1, b2, b3, b4},
651 // broadcast_shape[4] = b0 = a0.
652 // broadcast_shape[3] = b1; a1 = 1.
653 // broadcast_shape[2] = b2 = a2.
654 // broadcast_shape[1] = a3; b3 = 1.
655 // broadcast_shape[0] = b4 = a4.
656 int broadcast_shape[5];
657 };
658
659 struct ConcatenationParams {
660 int8_t axis;
661 const int32_t* input_zeropoint;
662 const float* input_scale;
663 uint16_t inputs_count;
664 int32_t output_zeropoint;
665 float output_scale;
666 };
667
668 struct ComparisonParams {
669 // uint8_t inference params.
670 int left_shift;
671 int32_t input1_offset;
672 int32_t input1_multiplier;
673 int input1_shift;
674 int32_t input2_offset;
675 int32_t input2_multiplier;
676 int input2_shift;
677 // Shape dependent / common to inference types.
678 bool is_broadcast;
679 };
680
681 struct ConvParams {
682 PaddingType padding_type;
683 PaddingValues padding_values;
684 // TODO(starka): This was just "stride", so check that width+height is OK.
685 int16_t stride_width;
686 int16_t stride_height;
687 int16_t dilation_width_factor;
688 int16_t dilation_height_factor;
689 // uint8_t inference params.
690 // TODO(b/65838351): Use smaller types if appropriate.
691 int32_t input_offset;
692 int32_t weights_offset;
693 int32_t output_offset;
694 int32_t output_multiplier;
695 int output_shift;
696 // uint8_t, etc, activation params.
697 int32_t quantized_activation_min;
698 int32_t quantized_activation_max;
699 // float activation params.
700 float float_activation_min;
701 float float_activation_max;
702 };
703
704 struct Conv3DParams {
705 Padding3DValues padding_values;
706 int stride_width;
707 int stride_height;
708 int stride_depth;
709 int dilation_width;
710 int dilation_height;
711 int dilation_depth;
712 // float activation params.
713 float float_activation_min;
714 float float_activation_max;
715 };
716
717 typedef Conv3DParams Conv3DTransposeParams;
718
719 struct DepthToSpaceParams {
720 int32_t block_size;
721 };
722
723 struct DepthwiseParams {
724 PaddingType padding_type;
725 PaddingValues padding_values;
726 int16_t stride_width;
727 int16_t stride_height;
728 int16_t dilation_width_factor;
729 int16_t dilation_height_factor;
730 int16_t depth_multiplier;
731 // uint8_t inference params.
732 // TODO(b/65838351): Use smaller types if appropriate.
733 int32_t input_offset;
734 int32_t weights_offset;
735 int32_t output_offset;
736 int32_t output_multiplier;
737 int output_shift;
738 // uint8_t, etc, activation params.
739 int32_t quantized_activation_min;
740 int32_t quantized_activation_max;
741 // float activation params.
742 float float_activation_min;
743 float float_activation_max;
744 const int32_t* output_multiplier_per_channel;
745 const int32_t* output_shift_per_channel;
746 };
747
748 struct DequantizationParams {
749 double scale;
750 int32_t zero_point;
751 };
752
753 struct PerChannelDequantizationParams {
754 const float* scale;
755 const int32_t* zero_point;
756 int32_t quantized_dimension;
757 };
758
759 struct FakeQuantParams {
760 MinMax minmax;
761 int32_t num_bits;
762 };
763
764 struct FullyConnectedParams {
765 // uint8_t inference params.
766 // TODO(b/65838351): Use smaller types if appropriate.
767 int32_t input_offset;
768 int32_t weights_offset;
769 int32_t output_offset;
770 int32_t output_multiplier;
771 int output_shift;
772 // uint8_t, etc, activation params.
773 int32_t quantized_activation_min;
774 int32_t quantized_activation_max;
775 // float activation params.
776 float float_activation_min;
777 float float_activation_max;
778 // Mark the operands as cacheable if they are unchanging, e.g. weights.
779 bool lhs_cacheable;
780 bool rhs_cacheable;
781 FullyConnectedWeightsFormat weights_format;
782 };
783
784 struct GatherParams {
785 int16_t axis;
786 int16_t batch_dims;
787 };
788
789 struct L2NormalizationParams {
790 // uint8_t inference params.
791 int32_t input_zero_point;
792 };
793
794 struct LocalResponseNormalizationParams {
795 int32_t range;
796 double bias;
797 double alpha;
798 double beta;
799 };
800
801 struct HardSwishParams {
802 // zero_point of the input activations.
803 int16_t input_zero_point;
804 // zero_point of the output activations.
805 int16_t output_zero_point;
806 // 16bit fixed-point component of the multiplier to apply to go from the
807 // "high-res input scale", which is the input scale multiplied by 2^7, to the
808 // "relu-ish scale", which 3.0/32768.
809 // See the implementation of HardSwishPrepare.
810 int16_t reluish_multiplier_fixedpoint_int16;
811 // exponent/bit-shift component of the aforementioned multiplier.
812 int reluish_multiplier_exponent;
813 // 16bit fixed-point component of the multiplier to apply to go from the
814 // "high-res input scale", which is the input scale multiplied by 2^7, to the
815 // output scale.
816 // See the implementation of HardSwishPrepare.
817 int16_t output_multiplier_fixedpoint_int16;
818 // exponent/bit-shift component of the aforementioned multiplier.
819 int output_multiplier_exponent;
820 };
821
822 struct LogisticParams {
823 // uint8_t inference params.
824 int32_t input_zero_point;
825 int32_t input_range_radius;
826 int32_t input_multiplier;
827 int input_left_shift;
828 };
829
830 struct LstmCellParams {
831 int32_t weights_zero_point;
832 int32_t accum_multiplier;
833 int accum_shift;
834 int state_integer_bits;
835 };
836
837 struct MeanParams {
838 int8_t axis_count;
839 int16_t axis[4];
840 };
841
842 struct PackParams {
843 int8_t axis;
844 const int32_t* input_zeropoint;
845 const float* input_scale;
846 uint16_t inputs_count;
847 int32_t output_zeropoint;
848 float output_scale;
849 };
850
851 struct PadParams {
852 int8_t left_padding_count;
853 int32_t left_padding[5];
854 int8_t right_padding_count;
855 int32_t right_padding[5];
856 ResizingCategory resizing_category;
857 };
858
859 struct PreluParams {
860 int32_t input_offset;
861 int32_t alpha_offset;
862 int32_t output_offset;
863 int32_t output_multiplier_1;
864 int output_shift_1;
865 int32_t output_multiplier_2;
866 int output_shift_2;
867 };
868
869 struct PoolParams {
870 FusedActivationFunctionType activation;
871 PaddingType padding_type;
872 PaddingValues padding_values;
873 int stride_height;
874 int stride_width;
875 int filter_height;
876 int filter_width;
877 // uint8_t, etc, activation params.
878 int32_t quantized_activation_min;
879 int32_t quantized_activation_max;
880 // float activation params.
881 float float_activation_min;
882 float float_activation_max;
883 };
884
885 struct ReshapeParams {
886 int8_t shape_count;
887 int32_t shape[4];
888 };
889
890 struct ResizeBilinearParams {
891 bool align_corners;
892 // half_pixel_centers assumes pixels are of half the actual dimensions, and
893 // yields more accurate resizes. Corresponds to the same argument for the
894 // original TensorFlow op in TF2.0.
895 bool half_pixel_centers;
896 };
897
898 struct ResizeNearestNeighborParams {
899 bool align_corners;
900 bool half_pixel_centers;
901 };
902
903 struct SliceParams {
904 int8_t begin_count;
905 int32_t begin[5];
906 int8_t size_count;
907 int32_t size[5];
908 };
909
910 struct SoftmaxParams {
911 // beta is not really used (not a Tensorflow parameter) and not implemented
912 // for LogSoftmax.
913 double beta;
914 // uint8_t inference params. Used even when beta defaults to 1.0.
915 int32_t input_multiplier;
916 int32_t input_left_shift;
917 // Reverse scaling is only used by LogSoftmax.
918 int32_t reverse_scaling_divisor;
919 int32_t reverse_scaling_right_shift;
920 int diff_min;
921 int32_t zero_point;
922 float scale;
923 float* table;
924 // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
925 int16_t* exp_lut;
926 // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
927 int16_t* one_over_one_plus_x_lut;
928 uint8_t* uint8_table1;
929 uint8_t* uint8_table2;
930 };
931
932 struct SpaceToBatchParams {
933 // "Zero" padding for uint8_t means padding with the output offset.
934 int32_t output_offset;
935 };
936
937 struct SpaceToDepthParams {
938 int32_t block_size;
939 };
940
941 struct SplitParams {
942 // Graphs that split into, say, 2000 nodes are encountered. The indices in
943 // OperatorEdges are of type uint16_t.
944 uint16_t num_split;
945 int16_t axis;
946 };
947
948 struct SqueezeParams {
949 int8_t squeeze_dims_count;
950 int32_t squeeze_dims[4];
951 };
952
953 struct StridedSliceParams {
954 int8_t start_indices_count;
955 int32_t start_indices[5];
956 int8_t stop_indices_count;
957 int32_t stop_indices[5];
958 int8_t strides_count;
959 int32_t strides[5];
960
961 int16_t begin_mask;
962 int16_t ellipsis_mask;
963 int16_t end_mask;
964 int16_t new_axis_mask;
965 int16_t shrink_axis_mask;
966 };
967
968 struct TanhParams {
969 int32_t input_zero_point;
970 int32_t input_range_radius;
971 int32_t input_multiplier;
972 int input_left_shift;
973 };
974
975 struct TransposeParams {
976 int8_t perm_count;
977 int32_t perm[5];
978 };
979
980 struct UnpackParams {
981 uint16_t num_split;
982 int16_t axis;
983 };
984
985 struct LeakyReluParams {
986 float alpha;
987 int32_t input_offset;
988 int32_t output_offset;
989 int32_t output_multiplier_alpha;
990 int32_t output_shift_alpha;
991 int32_t output_multiplier_identity;
992 int32_t output_shift_identity;
993 };
994
995 template <typename P>
SetActivationParams(float min,float max,P * params)996 inline void SetActivationParams(float min, float max, P* params) {
997 params->float_activation_min = min;
998 params->float_activation_max = max;
999 }
1000
1001 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1002 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1003 params->quantized_activation_min = min;
1004 params->quantized_activation_max = max;
1005 }
1006
1007 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1008 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1009 params->int64_activation_min = min;
1010 params->int64_activation_max = max;
1011 }
1012
1013 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1014 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1015 *min = params.quantized_activation_min;
1016 *max = params.quantized_activation_max;
1017 }
1018
1019 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1020 inline void GetActivationParams(const P& params, float* min, float* max) {
1021 *min = params.float_activation_min;
1022 *max = params.float_activation_max;
1023 }
1024
1025 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1026 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1027 *min = params.int64_activation_min;
1028 *max = params.int64_activation_max;
1029 }
1030
1031 // Type trait to check of given type has size smaller than 4 bytes.
1032 template <typename T>
1033 struct is_small_integer
1034 : public std::integral_constant<bool,
1035 std::is_same<T, int8_t>::value ||
1036 std::is_same<T, uint8_t>::value ||
1037 std::is_same<T, int16_t>::value ||
1038 std::is_same<T, uint16_t>::value> {};
1039
1040 // Type trait to check of given type is int32 or int64.
1041 template <typename T>
1042 struct is_int32_or_int64
1043 : public std::integral_constant<bool, std::is_same<T, int32_t>::value ||
1044 std::is_same<T, int64_t>::value> {
1045 };
1046
1047 } // namespace tflite
1048
1049 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1050