1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <initializer_list>
22
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24
25 namespace tflite {
26
27 enum class FusedActivationFunctionType : uint8_t {
28 kNone,
29 kRelu6,
30 kRelu1,
31 kRelu
32 };
33 enum class PaddingType : uint8_t { kNone, kSame, kValid };
34
35 struct PaddingValues {
36 int16_t width;
37 int16_t height;
38 // offset is used for calculating "remaining" padding, for example, `width`
39 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
40 // 1 + 1 = 2.
41 int16_t width_offset;
42 // Same as width_offset except it's over the height dimension.
43 int16_t height_offset;
44 };
45
46 struct Padding3DValues {
47 int16_t width;
48 int16_t height;
49 int16_t depth;
50 // offset is used for calculating "remaining" padding, for example, `width`
51 // is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
52 // 1 + 1 = 2.
53 int16_t width_offset;
54 // Same as width_offset except it's over the height dimension.
55 int16_t height_offset;
56 // Same as width_offset except it's over the depth dimension.
57 int16_t depth_offset;
58 };
59
60 // This enumeration allows for non-default formats for the weights array
61 // of a fully-connected operator, allowing the use of special optimized
62 // runtime paths.
63 enum class FullyConnectedWeightsFormat : uint8_t {
64 // Default format (flat 2D layout, the inner contiguous dimension
65 // is input_depth, the outer non-contiguous dimension is output_depth)
66 kDefault,
67 // Summary: optimized layout for fast CPU runtime implementation,
68 // aimed specifically at ARM CPUs at the moment, and specialized for
69 // 8-bit quantized layers.
70 //
71 // The use case we're concerned with here is: 8-bit quantization,
72 // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
73 // a key application that drove this), very small batch size (e.g. 1 -- 4).
74 //
75 // Even with 8-bit quantization of weights, the performance of memory
76 // accesses to the weights can become the dominant issue when
77 // the batch size is small, so each weight value is used in only a few
78 // arithmetic ops, i.e. the fully-connected node has a low arithmetic
79 // intensity. The specific issues that arise are of three kinds:
80 // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
81 // bound. That's the "good" issue to run into.
82 // (2) One may run into sub-optimal pre-fetching: the data hasn't been
83 // prefetched into the cache by the time we need it.
84 // (3) One may run into cache aliasing: multiple values that are
85 // pre-fetched, alias each other in the L1 cache (which typically
86 // has only 4-way set associativity in ARM CPUs) and thus evict
87 // each other before we get to using them.
88 //
89 // The point of this shuffling is to avoid issues (2) and (3) so that
90 // we get as fast as possible given only the hard constraint (1).
91 // This is achieved by turning the difficulty into a solution: the
92 // difficulty, that each value loaded from memory is used only in
93 // one kernel iteration, making this operation memory-intensive, hints at
94 // the solution, of shuffling the weights so that they are stored in the
95 // exact order as the kernel needs to load them, so that the memory
96 // accesses made by the kernel are trivial. This solves (2) because the
97 // trivial memory access pattern allows the CPU's automatic prefetching
98 // to perform very well (no need even for preload instructions), and this
99 // solves (3) because the values being loaded concurrently are now
100 // contiguous in the address space, thus don't alias each other in the cache.
101 //
102 // On ARM, we typically want our kernel to process a 4x16 block of weights
103 // at a time, because:
104 // - 16 is the number of bytes in a NEON register.
105 // - 4 is how many rows we need to handle concurrently in the kernel in
106 // order to have sufficient mutual independence of instructions to
107 // maximize arithmetic throughput.
108 //
109 // Finally, the 'Int8' part in the name refers to the fact that this
110 // weights format has each weights value encoded as a signed int8_t value,
111 // even if the data type of the weights buffer is uint8_t. This is intended
112 // to save runtime kernels the effort to have to XOR the top bit of these
113 // bytes before using them in signed arithmetic, see this file for more
114 // explanations on the 'signed int8_t trick' in matrix multiplication kernels:
115 //
116 // tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
117 //
118 kShuffled4x16Int8,
119 };
120
121 // Quantization parameters, determining the mapping of quantized values
122 // to real values (i.e. determining how quantized values are mathematically
123 // interpreted).
124 //
125 // The correspondence is as follows:
126 //
127 // real_value = scale * (quantized_value - zero_point);
128 //
129 // In other words, zero_point designates which quantized value corresponds to
130 // the real 0 value, and scale designates the difference between the real values
131 // corresponding to consecutive quantized values differing by 1.
132 struct QuantizationParams {
133 int32_t zero_point = 0;
134 double scale = 0.0;
135 };
136
137 inline bool operator==(const QuantizationParams& qp1,
138 const QuantizationParams& qp2) {
139 return qp1.zero_point == qp2.zero_point && qp1.scale == qp2.scale;
140 }
141
142 template <int N>
143 struct Dims {
144 int sizes[N];
145 int strides[N];
146 };
147
148 class RuntimeShape {
149 public:
150 // Shapes with dimensions up to 5 are stored directly in the structure, while
151 // larger shapes are separately allocated.
152 static constexpr int kMaxSmallSize = 5;
153
154 RuntimeShape& operator=(RuntimeShape const&) = delete;
155
RuntimeShape()156 RuntimeShape() : size_(0) {}
157
RuntimeShape(int dimensions_count)158 explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
159 if (dimensions_count > kMaxSmallSize) {
160 #ifdef TF_LITE_STATIC_MEMORY
161 TFLITE_CHECK(false && "No shape resizing supported on this platform");
162 #else // TF_LITE_STATIC_MEMORY
163 dims_pointer_ = new int32_t[dimensions_count];
164 #endif // TF_LITE_STATIC_MEMORY
165 }
166 }
167
RuntimeShape(int shape_size,int32_t value)168 RuntimeShape(int shape_size, int32_t value) : size_(0) {
169 Resize(shape_size);
170 for (int i = 0; i < shape_size; ++i) {
171 SetDim(i, value);
172 }
173 }
174
RuntimeShape(int dimensions_count,const int32_t * dims_data)175 RuntimeShape(int dimensions_count, const int32_t* dims_data) : size_(0) {
176 ReplaceWith(dimensions_count, dims_data);
177 }
178
RuntimeShape(const std::initializer_list<int> init_list)179 RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
180 BuildFrom(init_list);
181 }
182
183 // Avoid using this constructor. We should be able to delete it when C++17
184 // rolls out.
RuntimeShape(RuntimeShape const & other)185 RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
186 if (size_ > kMaxSmallSize) {
187 #ifdef TF_LITE_STATIC_MEMORY
188 TFLITE_CHECK(false && "No shape resizing supported on this platform");
189 #else
190 dims_pointer_ = new int32_t[size_];
191 #endif
192 }
193 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_);
194 }
195
196 bool operator==(const RuntimeShape& comp) const {
197 return this->size_ == comp.size_ &&
198 std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) ==
199 0;
200 }
201
~RuntimeShape()202 ~RuntimeShape() {
203 if (size_ > kMaxSmallSize) {
204 #ifdef TF_LITE_STATIC_MEMORY
205 TFLITE_CHECK(false && "No shape resizing supported on this platform");
206 #else // TF_LITE_STATIC_MEMORY
207 delete[] dims_pointer_;
208 #endif // TF_LITE_STATIC_MEMORY
209 }
210 }
211
DimensionsCount()212 inline int32_t DimensionsCount() const { return size_; }
Dims(int i)213 inline int32_t Dims(int i) const {
214 TFLITE_DCHECK_GE(i, 0);
215 TFLITE_DCHECK_LT(i, size_);
216 return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
217 }
SetDim(int i,int32_t val)218 inline void SetDim(int i, int32_t val) {
219 TFLITE_DCHECK_GE(i, 0);
220 TFLITE_DCHECK_LT(i, size_);
221 if (size_ > kMaxSmallSize) {
222 dims_pointer_[i] = val;
223 } else {
224 dims_[i] = val;
225 }
226 }
227
DimsData()228 inline int32_t* DimsData() {
229 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
230 }
DimsData()231 inline const int32_t* DimsData() const {
232 return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
233 }
234 // The caller must ensure that the shape is no bigger than 5-D.
DimsDataUpTo5D()235 inline const int32_t* DimsDataUpTo5D() const { return dims_; }
236
Resize(int dimensions_count)237 inline void Resize(int dimensions_count) {
238 if (size_ > kMaxSmallSize) {
239 #ifdef TF_LITE_STATIC_MEMORY
240 TFLITE_CHECK(false && "No shape resizing supported on this platform");
241 #else // TF_LITE_STATIC_MEMORY
242 delete[] dims_pointer_;
243 #endif // TF_LITE_STATIC_MEMORY
244 }
245 size_ = dimensions_count;
246 if (dimensions_count > kMaxSmallSize) {
247 #ifdef TF_LITE_STATIC_MEMORY
248 TFLITE_CHECK(false && "No shape resizing supported on this platform");
249 #else // TF_LITE_STATIC_MEMORY
250 dims_pointer_ = new int32_t[dimensions_count];
251 #endif // TF_LITE_STATIC_MEMORY
252 }
253 }
254
ReplaceWith(int dimensions_count,const int32_t * dims_data)255 inline void ReplaceWith(int dimensions_count, const int32_t* dims_data) {
256 Resize(dimensions_count);
257 int32_t* dst_dims = DimsData();
258 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
259 }
260
261 template <typename T>
BuildFrom(const T & src_iterable)262 inline void BuildFrom(const T& src_iterable) {
263 const int dimensions_count =
264 std::distance(src_iterable.begin(), src_iterable.end());
265 Resize(dimensions_count);
266 int32_t* data = DimsData();
267 for (auto it : src_iterable) {
268 *data = it;
269 ++data;
270 }
271 }
272
273 // This will probably be factored out. Old code made substantial use of 4-D
274 // shapes, and so this function is used to extend smaller shapes. Note that
275 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
276 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
277 // inputs should already be 4-D, so this function should not be needed.
ExtendedShape(int new_shape_size,const RuntimeShape & shape)278 inline static RuntimeShape ExtendedShape(int new_shape_size,
279 const RuntimeShape& shape) {
280 return RuntimeShape(new_shape_size, shape, 1);
281 }
282
BuildFrom(const std::initializer_list<int> init_list)283 inline void BuildFrom(const std::initializer_list<int> init_list) {
284 BuildFrom<const std::initializer_list<int>>(init_list);
285 }
286
287 // Returns the total count of elements, that is the size when flattened into a
288 // vector.
FlatSize()289 inline int FlatSize() const {
290 int buffer_size = 1;
291 const int* dims_data = reinterpret_cast<const int*>(DimsData());
292 for (int i = 0; i < size_; i++) {
293 buffer_size *= dims_data[i];
294 }
295 return buffer_size;
296 }
297
298 bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
299
300 private:
301 // For use only by ExtendedShape(), written to guarantee (return-value) copy
302 // elision in C++17.
303 // This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size,const RuntimeShape & shape,int pad_value)304 RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
305 : size_(0) {
306 // If the following check fails, it is likely because a 4D-only kernel is
307 // being used with an array of larger dimension count.
308 TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
309 Resize(new_shape_size);
310 const int size_increase = new_shape_size - shape.DimensionsCount();
311 for (int i = 0; i < size_increase; ++i) {
312 SetDim(i, pad_value);
313 }
314 std::memcpy(DimsData() + size_increase, shape.DimsData(),
315 sizeof(int32_t) * shape.DimensionsCount());
316 }
317
318 int32_t size_;
319 union {
320 int32_t dims_[kMaxSmallSize];
321 int32_t* dims_pointer_;
322 };
323 };
324
325 // Converts inference-style shape to legacy tflite::Dims<4>.
ToRuntimeDims(const tflite::RuntimeShape & array_shape)326 inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
327 tflite::Dims<4> result;
328 const int dimensions_count = array_shape.DimensionsCount();
329 TFLITE_CHECK_LE(dimensions_count, 4);
330 int cum_prod = 1;
331 for (int i = 0; i < 4; i++) {
332 const int new_dim =
333 (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
334 result.sizes[i] = new_dim;
335 result.strides[i] = cum_prod;
336 cum_prod *= new_dim;
337 }
338 return result;
339 }
340
341 // TODO(b/80418076): Move to legacy ops file, update invocations.
DimsToShape(const tflite::Dims<4> & dims)342 inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
343 return RuntimeShape(
344 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
345 }
346
347 // Gets next index to iterate through a multidimensional array.
NextIndex(const int num_dims,const int * dims,int * current)348 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
349 if (num_dims == 0) {
350 return false;
351 }
352 TFLITE_DCHECK(dims != nullptr);
353 TFLITE_DCHECK(current != nullptr);
354 int carry = 1;
355 for (int idx = num_dims - 1; idx >= 0; --idx) {
356 int current_val = current[idx] + carry;
357 TFLITE_DCHECK_GE(dims[idx], current_val);
358 if (dims[idx] == current_val) {
359 current[idx] = 0;
360 } else {
361 current[idx] = current_val;
362 carry = 0;
363 break;
364 }
365 }
366 return (carry == 0);
367 }
368
369 // Gets offset of index if reducing on axis. When reducing, the flattened offset
370 // will not change, if the input index changes on the given axis. For example,
371 // if you have a 3D tensor and you are reducing to 2D by eliminating axis 0,
372 // then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened
373 // offset.
374 // TODO(kanlig): uses Dims to represent dimensions.
ReducedOutputOffset(const int num_dims,const int * dims,const int * index,const int num_axis,const int * axis)375 inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
376 const int* index, const int num_axis,
377 const int* axis) {
378 if (num_dims == 0) {
379 return 0;
380 }
381 TFLITE_DCHECK(dims != nullptr);
382 TFLITE_DCHECK(index != nullptr);
383 size_t offset = 0;
384 for (int idx = 0; idx < num_dims; ++idx) {
385 // if we need to skip this axis
386 bool is_axis = false;
387 if (axis != nullptr) {
388 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
389 if (idx == axis[axis_idx]) {
390 is_axis = true;
391 break;
392 }
393 }
394 }
395 if (!is_axis) {
396 offset = offset * static_cast<size_t>(dims[idx]) +
397 static_cast<size_t>(index[idx]);
398 }
399 }
400 return offset;
401 }
402
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3)403 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
404 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
405 const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
406 TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
407 TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
408 TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
409 TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
410 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
411 }
412
Offset(const RuntimeShape & shape,int i0,int i1,int i2,int i3,int i4)413 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
414 int i4) {
415 TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
416 const int* dims_data = reinterpret_cast<const int*>(shape.DimsDataUpTo5D());
417 TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
418 TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
419 TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
420 TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
421 TFLITE_DCHECK(i4 >= 0 && i4 < dims_data[4]);
422 return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
423 dims_data[4] +
424 i4;
425 }
426
Offset(const Dims<4> & dims,int i0,int i1,int i2,int i3)427 inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
428 TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
429 TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
430 TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]);
431 TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]);
432 return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] +
433 i3 * dims.strides[3];
434 }
435
Offset(const Dims<4> & dims,int * index)436 inline int Offset(const Dims<4>& dims, int* index) {
437 return Offset(dims, index[0], index[1], index[2], index[3]);
438 }
439
Offset(const RuntimeShape & shape,int * index)440 inline int Offset(const RuntimeShape& shape, int* index) {
441 return Offset(shape, index[0], index[1], index[2], index[3]);
442 }
443
444 // Get array size, DCHECKing that the dim index is in range.
445 //
446 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
447 // already performs this check.
448 template <int N>
ArraySize(const Dims<N> & array,int index)449 int ArraySize(const Dims<N>& array, int index) {
450 TFLITE_DCHECK(index >= 0 && index < N);
451 return array.sizes[index];
452 }
453
454 // Get common array size, DCHECKing that they all agree.
455 template <typename ArrayType1, typename ArrayType2>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2)456 int MatchingArraySize(const ArrayType1& array1, int index1,
457 const ArrayType2& array2, int index2) {
458 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
459 return ArraySize(array1, index1);
460 }
461
462 template <typename ArrayType1, typename ArrayType2, typename... Args>
MatchingArraySize(const ArrayType1 & array1,int index1,const ArrayType2 & array2,int index2,Args...args)463 int MatchingArraySize(const ArrayType1& array1, int index1,
464 const ArrayType2& array2, int index2, Args... args) {
465 TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2));
466 return MatchingArraySize(array1, index1, args...);
467 }
468
469 // Get common shape dim, DCHECKing that they all agree.
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2)470 inline int MatchingDim(const RuntimeShape& shape1, int index1,
471 const RuntimeShape& shape2, int index2) {
472 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
473 return std::min(shape1.Dims(index1), shape2.Dims(index2));
474 }
475
476 template <typename... Args>
MatchingDim(const RuntimeShape & shape1,int index1,const RuntimeShape & shape2,int index2,Args...args)477 int MatchingDim(const RuntimeShape& shape1, int index1,
478 const RuntimeShape& shape2, int index2, Args... args) {
479 TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
480 return MatchingDim(shape1, index1, args...);
481 }
482
483 // Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
484 template <int N>
FlatSize(const Dims<N> & dims)485 inline int FlatSize(const Dims<N>& dims) {
486 int flat_size = 1;
487 for (int i = 0; i < N; ++i) {
488 flat_size *= dims.sizes[i];
489 }
490 return flat_size;
491 }
492
493 TFLITE_DEPRECATED("Prefer FlatSize.")
RequiredBufferSizeForDims(const Dims<4> & dims)494 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
495 return FlatSize(dims);
496 }
497
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)498 inline int MatchingElementsSize(const RuntimeShape& shape,
499 const RuntimeShape& check_shape_0) {
500 const int size_1 = shape.FlatSize();
501 const int size_2 = check_shape_0.FlatSize();
502 TFLITE_CHECK_EQ(size_1, size_2);
503 return size_1;
504 }
505
MatchingElementsSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)506 inline int MatchingElementsSize(const RuntimeShape& shape,
507 const RuntimeShape& check_shape_0,
508 const RuntimeShape& check_shape_1) {
509 const int size_1 = shape.FlatSize();
510 const int size_2 = check_shape_0.FlatSize();
511 const int size_3 = check_shape_1.FlatSize();
512 TFLITE_CHECK_EQ(size_1, size_2);
513 TFLITE_CHECK_EQ(size_2, size_3);
514 return size_1;
515 }
516
517 // Flat size calculation, checking that dimensions match with one or more other
518 // arrays.
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0)519 inline int MatchingFlatSize(const RuntimeShape& shape,
520 const RuntimeShape& check_shape_0) {
521 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
522 const int dims_count = shape.DimensionsCount();
523 for (int i = 0; i < dims_count; ++i) {
524 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
525 }
526 return shape.FlatSize();
527 }
528
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)529 inline int MatchingFlatSize(const RuntimeShape& shape,
530 const RuntimeShape& check_shape_0,
531 const RuntimeShape& check_shape_1) {
532 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
533 const int dims_count = shape.DimensionsCount();
534 for (int i = 0; i < dims_count; ++i) {
535 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
536 }
537 return MatchingFlatSize(shape, check_shape_1);
538 }
539
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)540 inline int MatchingFlatSize(const RuntimeShape& shape,
541 const RuntimeShape& check_shape_0,
542 const RuntimeShape& check_shape_1,
543 const RuntimeShape& check_shape_2) {
544 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
545 const int dims_count = shape.DimensionsCount();
546 for (int i = 0; i < dims_count; ++i) {
547 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
548 }
549 return MatchingFlatSize(shape, check_shape_1, check_shape_2);
550 }
551
MatchingFlatSize(const RuntimeShape & shape,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)552 inline int MatchingFlatSize(const RuntimeShape& shape,
553 const RuntimeShape& check_shape_0,
554 const RuntimeShape& check_shape_1,
555 const RuntimeShape& check_shape_2,
556 const RuntimeShape& check_shape_3) {
557 TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
558 const int dims_count = shape.DimensionsCount();
559 for (int i = 0; i < dims_count; ++i) {
560 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
561 }
562 return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
563 }
564
565 // Flat size calculation, checking that dimensions match with one or more other
566 // arrays.
567 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0)568 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
569 for (int i = 0; i < N; ++i) {
570 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
571 }
572 return FlatSize(dims);
573 }
574
575 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)576 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
577 const Dims<N>& check_dims_1) {
578 for (int i = 0; i < N; ++i) {
579 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
580 }
581 return MatchingFlatSize(dims, check_dims_1);
582 }
583
584 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)585 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
586 const Dims<N>& check_dims_1,
587 const Dims<N>& check_dims_2) {
588 for (int i = 0; i < N; ++i) {
589 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
590 }
591 return MatchingFlatSize(dims, check_dims_1, check_dims_2);
592 }
593
594 template <int N>
MatchingFlatSize(const Dims<N> & dims,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)595 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
596 const Dims<N>& check_dims_1,
597 const Dims<N>& check_dims_2,
598 const Dims<N>& check_dims_3) {
599 for (int i = 0; i < N; ++i) {
600 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
601 }
602 return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
603 }
604
605 // Data is required to be contiguous, and so many operators can use either the
606 // full array flat size or the flat size with one dimension skipped (commonly
607 // the depth).
608 template <int N>
FlatSizeSkipDim(const Dims<N> & dims,int skip_dim)609 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
610 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
611 int flat_size = 1;
612 for (int i = 0; i < N; ++i) {
613 flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
614 }
615 return flat_size;
616 }
617
618 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
619 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0)620 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
621 const Dims<N>& check_dims_0) {
622 for (int i = 0; i < N; ++i) {
623 if (i != skip_dim) {
624 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
625 }
626 }
627 return FlatSizeSkipDim(dims, skip_dim);
628 }
629
630 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1)631 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
632 const Dims<N>& check_dims_0,
633 const Dims<N>& check_dims_1) {
634 for (int i = 0; i < N; ++i) {
635 if (i != skip_dim) {
636 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
637 }
638 }
639 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1);
640 }
641
642 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2)643 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
644 const Dims<N>& check_dims_0,
645 const Dims<N>& check_dims_1,
646 const Dims<N>& check_dims_2) {
647 for (int i = 0; i < N; ++i) {
648 if (i != skip_dim) {
649 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
650 }
651 }
652 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2);
653 }
654
655 template <int N>
MatchingFlatSizeSkipDim(const Dims<N> & dims,int skip_dim,const Dims<N> & check_dims_0,const Dims<N> & check_dims_1,const Dims<N> & check_dims_2,const Dims<N> & check_dims_3)656 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
657 const Dims<N>& check_dims_0,
658 const Dims<N>& check_dims_1,
659 const Dims<N>& check_dims_2,
660 const Dims<N>& check_dims_3) {
661 for (int i = 0; i < N; ++i) {
662 if (i != skip_dim) {
663 TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
664 }
665 }
666 return MatchingFlatSizeSkipDim(dims, skip_dim, check_dims_1, check_dims_2,
667 check_dims_3);
668 }
669
670 // Data is required to be contiguous, and so many operators can use either the
671 // full array flat size or the flat size with one dimension skipped (commonly
672 // the depth).
FlatSizeSkipDim(const RuntimeShape & shape,int skip_dim)673 inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
674 const int dims_count = shape.DimensionsCount();
675 TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
676 const auto* dims_data = shape.DimsData();
677 int flat_size = 1;
678 for (int i = 0; i < dims_count; ++i) {
679 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
680 }
681 return flat_size;
682 }
683
684 // A combination of MatchingFlatSize() and FlatSizeSkipDim().
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0)685 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
686 const RuntimeShape& check_shape_0) {
687 const int dims_count = shape.DimensionsCount();
688 for (int i = 0; i < dims_count; ++i) {
689 if (i != skip_dim) {
690 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
691 }
692 }
693 return FlatSizeSkipDim(shape, skip_dim);
694 }
695
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1)696 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
697 const RuntimeShape& check_shape_0,
698 const RuntimeShape& check_shape_1) {
699 const int dims_count = shape.DimensionsCount();
700 for (int i = 0; i < dims_count; ++i) {
701 if (i != skip_dim) {
702 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
703 }
704 }
705 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
706 }
707
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2)708 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
709 const RuntimeShape& check_shape_0,
710 const RuntimeShape& check_shape_1,
711 const RuntimeShape& check_shape_2) {
712 const int dims_count = shape.DimensionsCount();
713 for (int i = 0; i < dims_count; ++i) {
714 if (i != skip_dim) {
715 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
716 }
717 }
718 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
719 }
720
MatchingFlatSizeSkipDim(const RuntimeShape & shape,int skip_dim,const RuntimeShape & check_shape_0,const RuntimeShape & check_shape_1,const RuntimeShape & check_shape_2,const RuntimeShape & check_shape_3)721 inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
722 const RuntimeShape& check_shape_0,
723 const RuntimeShape& check_shape_1,
724 const RuntimeShape& check_shape_2,
725 const RuntimeShape& check_shape_3) {
726 const int dims_count = shape.DimensionsCount();
727 for (int i = 0; i < dims_count; ++i) {
728 if (i != skip_dim) {
729 TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
730 }
731 }
732 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
733 check_shape_3);
734 }
735
736 template <int N>
IsPackedWithoutStrides(const Dims<N> & dims)737 bool IsPackedWithoutStrides(const Dims<N>& dims) {
738 int expected_stride = 1;
739 for (int d = 0; d < N; d++) {
740 if (dims.strides[d] != expected_stride) return false;
741 expected_stride *= dims.sizes[d];
742 }
743 return true;
744 }
745
746 template <int N>
ComputeStrides(Dims<N> * dims)747 void ComputeStrides(Dims<N>* dims) {
748 dims->strides[0] = 1;
749 for (int d = 1; d < N; d++) {
750 dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
751 }
752 }
753
754 enum class BroadcastableOpCategory : uint8_t {
755 kNone,
756 kNonBroadcast, // Matching input shapes.
757 kFirstInputBroadcastsFast, // Fivefold nested loops.
758 kSecondInputBroadcastsFast, // Fivefold nested loops.
759 kGenericBroadcast, // Fall-back.
760 };
761
762 struct MinMax {
763 float min;
764 float max;
765 };
766 static_assert(sizeof(MinMax) == 8, "");
767
768 struct ActivationParams {
769 FusedActivationFunctionType activation_type;
770 // uint8_t, etc, activation params.
771 int32_t quantized_activation_min;
772 int32_t quantized_activation_max;
773 };
774
775 struct ReluParams : public ActivationParams {
776 int32_t input_offset;
777 int32_t output_offset;
778 int32_t output_multiplier;
779 int output_shift;
780 };
781
782 // Styles of resizing op usages. For example, kImageStyle can be used with a Pad
783 // op for pattern-specific optimization.
784 enum class ResizingCategory : uint8_t {
785 kNone,
786 kImageStyle, // 4D, operating on inner dimensions, say {0, a, b, 0}.
787 kGenericResize,
788 };
789
790 // For Add, Sub, Mul ops.
791 struct ArithmeticParams {
792 // Shape dependent / common to data / op types.
793 BroadcastableOpCategory broadcast_category;
794 // uint8_t inference params.
795 int32_t input1_offset;
796 int32_t input2_offset;
797 int32_t output_offset;
798 int32_t output_multiplier;
799 int output_shift;
800 // Add / Sub, not Mul, uint8_t inference params.
801 int left_shift;
802 int32_t input1_multiplier;
803 int input1_shift;
804 int32_t input2_multiplier;
805 int input2_shift;
806
807 // TODO(b/158622529): Union the following activation params.
808 // uint8_t, etc, activation params.
809 int32_t quantized_activation_min;
810 int32_t quantized_activation_max;
811 // float activation params.
812 float float_activation_min;
813 float float_activation_max;
814 // int64_t activation params.
815 int64_t int64_activation_min;
816 int64_t int64_activation_max;
817
818 // Processed output dimensions.
819 // Let input "a" be the one that broadcasts in the faster-changing dimension.
820 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
821 // {b0, b1, b2, b3, b4},
822 // broadcast_shape[4] = b0 = a0.
823 // broadcast_shape[3] = b1; a1 = 1.
824 // broadcast_shape[2] = b2 = a2.
825 // broadcast_shape[1] = a3; b3 = 1.
826 // broadcast_shape[0] = b4 = a4.
827 int broadcast_shape[5];
828 };
829
830 struct ConcatenationParams {
831 int8_t axis;
832 const int32_t* input_zeropoint;
833 const float* input_scale;
834 uint16_t inputs_count;
835 int32_t output_zeropoint;
836 float output_scale;
837 };
838
839 struct ComparisonParams {
840 // uint8_t inference params.
841 int left_shift;
842 int32_t input1_offset;
843 int32_t input1_multiplier;
844 int input1_shift;
845 int32_t input2_offset;
846 int32_t input2_multiplier;
847 int input2_shift;
848 // Shape dependent / common to inference types.
849 bool is_broadcast;
850 };
851
852 struct ConvParams {
853 PaddingType padding_type;
854 PaddingValues padding_values;
855 // TODO(starka): This was just "stride", so check that width+height is OK.
856 int16_t stride_width;
857 int16_t stride_height;
858 int16_t dilation_width_factor;
859 int16_t dilation_height_factor;
860 // uint8_t inference params.
861 // TODO(b/65838351): Use smaller types if appropriate.
862 int32_t input_offset;
863 int32_t weights_offset;
864 int32_t output_offset;
865 int32_t output_multiplier;
866 int output_shift;
867 // uint8_t, etc, activation params.
868 int32_t quantized_activation_min;
869 int32_t quantized_activation_max;
870 // float activation params.
871 float float_activation_min;
872 float float_activation_max;
873 };
874
875 struct Conv3DParams {
876 Padding3DValues padding_values;
877 int stride_width;
878 int stride_height;
879 int stride_depth;
880 int dilation_width;
881 int dilation_height;
882 int dilation_depth;
883 // float activation params.
884 float float_activation_min;
885 float float_activation_max;
886 };
887
888 struct DepthToSpaceParams {
889 int32_t block_size;
890 };
891
892 struct DepthwiseParams {
893 PaddingType padding_type;
894 PaddingValues padding_values;
895 int16_t stride_width;
896 int16_t stride_height;
897 int16_t dilation_width_factor;
898 int16_t dilation_height_factor;
899 int16_t depth_multiplier;
900 // uint8_t inference params.
901 // TODO(b/65838351): Use smaller types if appropriate.
902 int32_t input_offset;
903 int32_t weights_offset;
904 int32_t output_offset;
905 int32_t output_multiplier;
906 int output_shift;
907 // uint8_t, etc, activation params.
908 int32_t quantized_activation_min;
909 int32_t quantized_activation_max;
910 // float activation params.
911 float float_activation_min;
912 float float_activation_max;
913 const int32_t* output_multiplier_per_channel;
914 const int32_t* output_shift_per_channel;
915 };
916
917 struct DequantizationParams {
918 double scale;
919 int32_t zero_point;
920 };
921
922 struct PerChannelDequantizationParams {
923 const float* scale;
924 const int32_t* zero_point;
925 int32_t quantized_dimension;
926 };
927
928 struct FakeQuantParams {
929 MinMax minmax;
930 int32_t num_bits;
931 };
932
933 struct FullyConnectedParams {
934 // uint8_t inference params.
935 // TODO(b/65838351): Use smaller types if appropriate.
936 int32_t input_offset;
937 int32_t weights_offset;
938 int32_t output_offset;
939 int32_t output_multiplier;
940 int output_shift;
941 // uint8_t, etc, activation params.
942 int32_t quantized_activation_min;
943 int32_t quantized_activation_max;
944 // float activation params.
945 float float_activation_min;
946 float float_activation_max;
947 // Mark the operands as cacheable if they are unchanging, e.g. weights.
948 bool lhs_cacheable;
949 bool rhs_cacheable;
950 FullyConnectedWeightsFormat weights_format;
951 };
952
953 struct GatherParams {
954 int16_t axis;
955 };
956
957 struct L2NormalizationParams {
958 // uint8_t inference params.
959 int32_t input_zero_point;
960 };
961
962 struct LocalResponseNormalizationParams {
963 int32_t range;
964 double bias;
965 double alpha;
966 double beta;
967 };
968
969 struct HardSwishParams {
970 // zero_point of the input activations.
971 int16_t input_zero_point;
972 // zero_point of the output activations.
973 int16_t output_zero_point;
974 // 16bit fixed-point component of the multiplier to apply to go from the
975 // "high-res input scale", which is the input scale multiplied by 2^7, to the
976 // "relu-ish scale", which 3.0/32768.
977 // See the implementation of HardSwishPrepare.
978 int16_t reluish_multiplier_fixedpoint_int16;
979 // exponent/bit-shift component of the aforementioned multiplier.
980 int reluish_multiplier_exponent;
981 // 16bit fixed-point component of the multiplier to apply to go from the
982 // "high-res input scale", which is the input scale multiplied by 2^7, to the
983 // output scale.
984 // See the implementation of HardSwishPrepare.
985 int16_t output_multiplier_fixedpoint_int16;
986 // exponent/bit-shift component of the aforementioned multiplier.
987 int output_multiplier_exponent;
988 };
989
990 struct LogisticParams {
991 // uint8_t inference params.
992 int32_t input_zero_point;
993 int32_t input_range_radius;
994 int32_t input_multiplier;
995 int input_left_shift;
996 };
997
998 struct LstmCellParams {
999 int32_t weights_zero_point;
1000 int32_t accum_multiplier;
1001 int accum_shift;
1002 int state_integer_bits;
1003 };
1004
1005 struct MeanParams {
1006 int8_t axis_count;
1007 int16_t axis[4];
1008 };
1009
1010 struct PackParams {
1011 int8_t axis;
1012 const int32_t* input_zeropoint;
1013 const float* input_scale;
1014 uint16_t inputs_count;
1015 int32_t output_zeropoint;
1016 float output_scale;
1017 };
1018
1019 struct PadParams {
1020 int8_t left_padding_count;
1021 int32_t left_padding[4];
1022 int8_t right_padding_count;
1023 int32_t right_padding[4];
1024 ResizingCategory resizing_category;
1025 };
1026
1027 struct PreluParams {
1028 int32_t input_offset;
1029 int32_t alpha_offset;
1030 int32_t output_offset;
1031 int32_t output_multiplier_1;
1032 int output_shift_1;
1033 int32_t output_multiplier_2;
1034 int output_shift_2;
1035 };
1036
1037 struct PoolParams {
1038 FusedActivationFunctionType activation;
1039 PaddingType padding_type;
1040 PaddingValues padding_values;
1041 int stride_height;
1042 int stride_width;
1043 int filter_height;
1044 int filter_width;
1045 // uint8_t, etc, activation params.
1046 int32_t quantized_activation_min;
1047 int32_t quantized_activation_max;
1048 // float activation params.
1049 float float_activation_min;
1050 float float_activation_max;
1051 };
1052
1053 struct ReshapeParams {
1054 int8_t shape_count;
1055 int32_t shape[4];
1056 };
1057
1058 struct ResizeBilinearParams {
1059 bool align_corners;
1060 // half_pixel_centers assumes pixels are of half the actual dimensions, and
1061 // yields more accurate resizes. Corresponds to the same argument for the
1062 // original TensorFlow op in TF2.0.
1063 bool half_pixel_centers;
1064 };
1065
1066 struct ResizeNearestNeighborParams {
1067 bool align_corners;
1068 bool half_pixel_centers;
1069 };
1070
1071 struct SliceParams {
1072 int8_t begin_count;
1073 int32_t begin[5];
1074 int8_t size_count;
1075 int32_t size[5];
1076 };
1077
1078 struct SoftmaxParams {
1079 // beta is not really used (not a Tensorflow parameter) and not implemented
1080 // for LogSoftmax.
1081 double beta;
1082 // uint8_t inference params. Used even when beta defaults to 1.0.
1083 int32_t input_multiplier;
1084 int32_t input_left_shift;
1085 // Reverse scaling is only used by LogSoftmax.
1086 int32_t reverse_scaling_divisor;
1087 int32_t reverse_scaling_right_shift;
1088 int diff_min;
1089 int32_t zero_point;
1090 float scale;
1091 float* table;
1092 // int16 LUT for exp(x), where x uniform distributed between [-10.0 , 0.0]
1093 int16_t* exp_lut;
1094 // int16 LUT for 1 / (1 + x), where x uniform distributed between [0.0 , 1.0]
1095 int16_t* one_over_one_plus_x_lut;
1096 uint8_t* uint8_table1;
1097 uint8_t* uint8_table2;
1098 };
1099
1100 struct SpaceToBatchParams {
1101 // "Zero" padding for uint8_t means padding with the output offset.
1102 int32_t output_offset;
1103 };
1104
1105 struct SpaceToDepthParams {
1106 int32_t block_size;
1107 };
1108
1109 struct SplitParams {
1110 // Graphs that split into, say, 2000 nodes are encountered. The indices in
1111 // OperatorEdges are of type uint16_t.
1112 uint16_t num_split;
1113 int16_t axis;
1114 };
1115
1116 struct SqueezeParams {
1117 int8_t squeeze_dims_count;
1118 int32_t squeeze_dims[4];
1119 };
1120
1121 struct StridedSliceParams {
1122 int8_t start_indices_count;
1123 int32_t start_indices[5];
1124 int8_t stop_indices_count;
1125 int32_t stop_indices[5];
1126 int8_t strides_count;
1127 int32_t strides[5];
1128
1129 int16_t begin_mask;
1130 int16_t ellipsis_mask;
1131 int16_t end_mask;
1132 int16_t new_axis_mask;
1133 int16_t shrink_axis_mask;
1134 };
1135
1136 struct TanhParams {
1137 int32_t input_zero_point;
1138 int32_t input_range_radius;
1139 int32_t input_multiplier;
1140 int input_left_shift;
1141 };
1142
1143 struct TransposeParams {
1144 int8_t perm_count;
1145 int32_t perm[5];
1146 };
1147
1148 struct UnpackParams {
1149 uint16_t num_split;
1150 int16_t axis;
1151 };
1152
1153 struct LeakyReluParams {
1154 float alpha;
1155 int32_t input_offset;
1156 int32_t output_offset;
1157 int32_t output_multiplier_alpha;
1158 int32_t output_shift_alpha;
1159 int32_t output_multiplier_identity;
1160 int32_t output_shift_identity;
1161 };
1162
1163 template <typename P>
SetActivationParams(float min,float max,P * params)1164 inline void SetActivationParams(float min, float max, P* params) {
1165 params->float_activation_min = min;
1166 params->float_activation_max = max;
1167 }
1168
1169 template <typename P>
SetActivationParams(int32_t min,int32_t max,P * params)1170 inline void SetActivationParams(int32_t min, int32_t max, P* params) {
1171 params->quantized_activation_min = min;
1172 params->quantized_activation_max = max;
1173 }
1174
1175 template <typename P>
SetActivationParams(int64_t min,int64_t max,P * params)1176 inline void SetActivationParams(int64_t min, int64_t max, P* params) {
1177 params->int64_activation_min = min;
1178 params->int64_activation_max = max;
1179 }
1180
1181 template <typename P>
GetActivationParams(const P & params,int32_t * min,int32_t * max)1182 inline void GetActivationParams(const P& params, int32_t* min, int32_t* max) {
1183 *min = params.quantized_activation_min;
1184 *max = params.quantized_activation_max;
1185 }
1186
1187 template <typename P>
GetActivationParams(const P & params,float * min,float * max)1188 inline void GetActivationParams(const P& params, float* min, float* max) {
1189 *min = params.float_activation_min;
1190 *max = params.float_activation_max;
1191 }
1192
1193 template <typename P>
GetActivationParams(const P & params,int64_t * min,int64_t * max)1194 inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
1195 *min = params.int64_activation_min;
1196 *max = params.int64_activation_max;
1197 }
1198 } // namespace tflite
1199
1200 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
1201