• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_MODEL_H_
16 #define TENSORFLOW_LITE_TOCO_MODEL_H_
17 
18 #include <complex>
19 #include <functional>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/lite/toco/model_flags.pb.h"
30 #include "tensorflow/lite/toco/runtime/types.h"
31 #include "tensorflow/lite/toco/toco_port.h"
32 #include "tensorflow/lite/toco/toco_types.h"
33 
34 namespace toco {
35 
36 using tflite::QuantizationParams;
37 
38 enum class OperatorType : uint8 {
39   kNone,
40   // General-purpose neural network operators.
41   kAdd,
42   kAddN,
43   kAveragePool,
44   kBatchMatMul,
45   kBatchNormalization,
46   kCeil,
47   kConv,
48   kConcatenation,
49   kCos,
50   kDepthwiseConv,
51   kDepthToSpace,
52   kSpaceToDepth,
53   kDequantize,
54   kDiv,
55   kExp,
56   kExpandDims,
57   kFill,
58   kFloorDiv,
59   kFloorMod,
60   kFullyConnected,
61   kL2Normalization,
62   kL2Pool,
63   kLstmCell,
64   kUnidirectionalSequenceLstm,
65   kLocalResponseNormalization,
66   kLog,
67   kLogistic,
68   kMaxPool,
69   kFakeQuant,
70   kMul,
71   kOneHot,
72   kRandomUniform,
73   kRange,
74   kRank,
75   kRelu,
76   kRelu1,
77   kRelu6,
78   kPRelu,
79   kHardSwish,
80   kSoftmax,
81   kLogSoftmax,
82   kSub,
83   kTanh,
84   kTransposeConv,
85   kCast,
86   kFloor,
87   kRound,
88   kGather,
89   kResizeBilinear,
90   kSin,
91   kSpaceToBatchND,
92   kPack,
93   kBatchToSpaceND,
94   kPad,
95   kPadV2,
96   kReduceProd,  // Reduction product
97   kStridedSlice,
98   kSlice,
99   kSqueeze,
100   kMean,
101   kArgMax,
102   // The SVDF Op is a decomposition of a densely connected Op into
103   // low rank filters. For details:
104   // https://research.google.com/pubs/pub43813.html
105   kSvdf,
106   // Special operators used for importing TensorFlow nodes.
107   // The general intent is to have some graph transformation either
108   // drop them or rewrite them as general-purpose operators.
109   kAll,
110   kAssert,
111   kConcat,
112   kConcatV2,
113   kGreater,
114   kGreaterEqual,
115   kIdentity,
116   kLess,
117   kLessEqual,
118   kReduceMax,  //  Reduction Max
119   kMaximum,    //  Element-wise Maximum
120   kReduceMin,  //  Reduction Min
121   kMinimum,    //  Element-wise Minimum
122   kMatMul,
123   kMerge,
124   kNeg,
125   kReshape,
126   kRsqrt,
127   kShape,
128   kSplit,
129   kSplitV,
130   kSqrt,
131   kSquare,
132   kSquaredDifference,
133   kSum,
134   kSwitch,
135   kTile,
136   kTranspose,
137   kTopK_V2,
138   kDynamicPartition,
139   kDynamicStitch,
140   // An unsupported TF operation. It's only needed to be able to represent TF
141   // graph internally and is expected to be dropped by graph transformations.
142   kUnsupported,
143   // Finally, TensorFlow uses different conventions for axes ordering,
144   // see AxesOrder, and this cannot always be resolved at the time of importing
145   // nodes, as TensorFlow parameters may be constant-expression subgraphs
146   // instead of being given as plain constant arrays. So we need to insert
147   // special nodes in the graph to shuffle axes.
148   kReorderAxes,
149   kSegmentSum,
150   kSelect,
151   kSelectV2,
152   kSparseToDense,
153   kEqual,
154   kNotEqual,
155   kPow,
156   kArgMin,
157   kAny,
158   kLogicalAnd,
159   kLogicalNot,
160   kLogicalOr,
161   kCTCBeamSearchDecoder,
162   kUnpack,
163   kZerosLike,
164   kResizeNearestNeighbor,
165   kLeakyRelu,
166   kAbs,
167   kMirrorPad,
168   kUnique,
169   kUnidirectionalSequenceRnn,
170   kBidirectionalSequenceLstm,
171   kReverseV2,
172   kBidirectionalSequenceRnn,
173   kGatherNd,
174   kWhere,
175   kElu,
176   kReverseSequence,
177   kMatrixDiag,
178   kMatrixSetDiag,
179   kMatrixDiagV2,
180   kMatrixSetDiagV2,
181   kMatrixDiagV3,
182   kMatrixSetDiagV3,
183   kScatterNd,
184   // Debugging operators.
185   kNumericVerify
186 };
187 
188 // Helper to deal with TensorFlow arrays using a different ordering of
189 // dimensions
190 // ("axes") than our own.
191 // TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes,
192 // we should have associative arrays mapping symbolic axes identifiers (like
193 // "output_depth") to dimensions. We would then not need this anymore.
194 enum class AxesOrder {
195   kOneAxis,  // one-dimensional array, one unique axis.
196   kCR,       // column-major matrix storage order. Our standard.
197   kRC,       // row-major matrix storage order. TensorFlow default.
198   kOHWI,     // Our standard for conv weights
199   kHWIO,     // TensorFlow conv weights
200   k1HWO,     // Our standard for DepthwiseConv weights
201   kHWIM,     // TensorFlow DepthwiseConv weights
202   kNHWC,     // TensorFlow activations
203   kHWOI,     // TensorFlow back-prop conv weights
204 };
205 
206 // The type of the scalars in an array.
207 // Note that the type does not by itself tell whether the values in the array
208 // are non-quantized (can be accessed directly) or quantized (must be
209 // interpreted in conjunction with QuantizationParams).
210 //
211 // In practice though:
212 //   float values are never quantized
213 //   uint8 values are always quantized
214 //   int32 values are sometimes quantized (depending on whether
215 //   QuantizationParams are present).
216 //   complex values are never quantized
217 //   other types are never quantized at the moment.
218 //
219 // kNone means that we don't know the data type yet, or that we don't care
220 // because we'll be dropping the array anyway (e.g. some exotic array types
221 // may be involved only in debug-only subgraphs that we may not be interested
222 // in actually supporting).
223 enum class ArrayDataType : uint8 {
224   kNone,  // 0
225   kBool,
226   kFloat,
227   kInt8,
228   kUint8,
229   kInt16,  // 5
230   kUint16,
231   kInt32,
232   kUint32,
233   kInt64,
234   kUint64,  // 10
235   kString,
236   kComplex64,
237   kFloat16,
238   kFloat64,
239   kComplex128,
240 };
241 
242 // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
243 template <ArrayDataType A>
244 struct DataTypeImpl {};
245 template <>
246 struct DataTypeImpl<ArrayDataType::kNone> {
247   typedef int Type;
248 };
249 template <>
250 struct DataTypeImpl<ArrayDataType::kBool> {
251   typedef bool Type;
252 };
253 template <>
254 struct DataTypeImpl<ArrayDataType::kFloat> {
255   typedef float Type;
256 };
257 template <>
258 struct DataTypeImpl<ArrayDataType::kInt8> {
259   typedef int8 Type;
260 };
261 template <>
262 struct DataTypeImpl<ArrayDataType::kUint8> {
263   typedef uint8 Type;
264 };
265 template <>
266 struct DataTypeImpl<ArrayDataType::kInt16> {
267   typedef int16 Type;
268 };
269 template <>
270 struct DataTypeImpl<ArrayDataType::kUint16> {
271   typedef uint16 Type;
272 };
273 template <>
274 struct DataTypeImpl<ArrayDataType::kInt32> {
275   typedef int32 Type;
276 };
277 template <>
278 struct DataTypeImpl<ArrayDataType::kUint32> {
279   typedef uint32 Type;
280 };
281 template <>
282 struct DataTypeImpl<ArrayDataType::kInt64> {
283   typedef int64 Type;
284 };
285 template <>
286 struct DataTypeImpl<ArrayDataType::kUint64> {
287   typedef uint64 Type;
288 };
289 template <>
290 struct DataTypeImpl<ArrayDataType::kString> {
291   typedef std::string Type;
292 };
293 template <>
294 struct DataTypeImpl<ArrayDataType::kComplex64> {
295   typedef std::complex<float> Type;
296 };
297 
298 template <ArrayDataType A>
299 using DataType = typename DataTypeImpl<A>::Type;
300 
301 // Base class for type-specific buffer types.
302 struct GenericBuffer {
303   // Non-default-constructible: only ArrayDataType-specific subclass
304   // objects may be constructed.
305   GenericBuffer() = delete;
306   // Non-copyable-or-movable: we should only store pointers-to-Buffer
307   // in containers, not Operators themselves, so there should be no
308   // copy or move.
309   GenericBuffer(const GenericBuffer&) = delete;
310   GenericBuffer(const GenericBuffer&&) = delete;
311 
312   // We need a virtual destructor so we can store pointers-to-Buffer
313   // in containers and have the containers call the right subclass destructor.
314   virtual ~GenericBuffer() {}
315 
316   virtual int Length() const = 0;
317 
318   const ArrayDataType type;
319 
320  protected:
321   // Constructor used by subclasses for specific ArrayDataType's.
322   explicit GenericBuffer(ArrayDataType t) : type(t) {}
323 };
324 
325 // Type-specific buffer, containing type-specific storage.
326 template <ArrayDataType A>
327 struct Buffer : GenericBuffer {
328   Buffer() : GenericBuffer(A) {}
329 
330   int Length() const override { return data.size(); }
331 
332   std::vector<DataType<A>> data;
333 };
334 
335 class Shape {
336  public:
337   // For Shape, we stick to half-way encapsulation for now:
338   // we hide the raw dims_ member, but expose it raw by accessors
339   // because from some brainstorming, it's not at all easy to
340   // anticipate which flavor of more hermetic encapsulation would
341   // actually buy us future-proof-ness without being needlessly
342   // cumbersome.
343   Shape() {}
344   Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
345 
346   void ReplaceDims(std::initializer_list<int> dim_list) {
347     dims_ = std::vector<int>(dim_list);
348   }
349 
350   const std::vector<int>& dims() const { return dims_; }
351   std::vector<int>* mutable_dims() { return &dims_; }
352   const int dimensions_count() const { return dims_.size(); }
353 
354   // We still have that one convenience accessor to avoid
355   // the awkward double bracket issue:  shape.dims()[i].
356   int dims(int i) const {
357     // Always check for out-of-bounds accesses, even in optimized builds where
358     // standard assertions are disabled. Out-of-bounds access here is a common
359     // occurrence.
360     CHECK_GE(i, 0);
361     CHECK_GT(dims_.size(), i);
362     return dims_[i];
363   }
364 
365   bool operator==(const Shape& comp) const {
366     return (this->dims_ == comp.dims());
367   }
368 
369   bool operator!=(const Shape& comp) const { return !((*this) == comp); }
370 
371  private:
372   std::vector<int> dims_;
373 };
374 
375 // Base class for all operator classes.
376 struct Operator {
377   // Non-default-constructible: only OperatorType-specific subclass
378   // objects may be constructed.
379   Operator() = delete;
380   // Non-copyable-or-movable: we should only store pointers-to-Operator
381   // in containers, not Operators themselves, so there should be no
382   // copy or move.
383   Operator(const Operator&) = delete;
384   Operator(const Operator&&) = delete;
385 
386   // We need a virtual destructor so we can store pointers-to-Operator
387   // in containers and have the containers call the right subclass destructor.
388   virtual ~Operator() {}
389 
390   // The specific type of operator. Corresponds 1:1 to subclasses.
391   const OperatorType type;
392 
393   // The activation function that may be fused into this operator,
394   // or None if no activation function is fused.
395   FusedActivationFunctionType fused_activation_function;
396 
397   // Input arrays: either activation arrays or constant array parameters.
398   // We refer to them by their name, not by their address; the mapping of
399   // names to addresses is given by the Model, which owns both Operator's and
400   // Array's. Thus, an Operator on its own doesn't contain much information,
401   // it is meant to be used in conjunction with the Model that owns it.
402   std::vector<std::string> inputs;
403 
404   // Output activation arrays. Same comments as for inputs apply here too.
405   std::vector<std::string> outputs;
406 
407   // If true, the operator has more outputs than are listed in the 'outputs'
408   // member. These need to be resolved by some graph transformation.
409   // This flag is only here to indicate that an operator should not be
410   // discarded as unused, even if from its 'outputs' member alone it
411   // looks unused.
412   bool unresolved_outputs = false;
413 
414   // A serialized tensorflow::NodeDef string.
415   // The field is filled only when importing from TensorFlow.
416   // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
417   // It's not guaranteed to be filled for other ops. Ops created by graph
418   // transformations won't have TensorFlow NodeDef.
419   std::string tensorflow_node_def;
420 
421  protected:
422   // Constructor used by subclasses for specific OperatorType's.
423   explicit Operator(OperatorType t)
424       : type(t),
425         fused_activation_function(FusedActivationFunctionType::kNone) {}
426 };
427 
428 // Padding types for Conv-like operators. This is how padding is typically
429 // specified in model files. But for inference, we will need to resolve this
430 // to a FixedPadding, see below.
431 enum class PaddingType { kNone, kSame, kValid };
432 
433 // Padding as resolved for a specific layer shape, as needed for inference.
434 // For a given layer shape, a given padding type will resolve to a choice of
435 // a number of padding rows and columns, which we call the padding height and
436 // width respectively.
437 struct FixedPadding {
438   int width = 0;
439   int height = 0;
440 };
441 
442 // "Universal" padding struct containing both a generic PaddingType (as
443 // represented in a model file), and a FixedPadding (as needed for inference).
444 // The latter is resolved during the PropagateFixedSizes pass.
445 struct Padding {
446   FixedPadding& GetOrCreateFixedPadding() {
447     if (!fixed) {
448       FixedPadding* ptr = new FixedPadding;
449       fixed = std::unique_ptr<FixedPadding>(ptr);
450     }
451     return *fixed;
452   }
453 
454   Padding() : type(PaddingType::kNone) {}
455   PaddingType type;
456   std::unique_ptr<FixedPadding> fixed;
457 };
458 
459 // "Convolutional" layer, as represented in model files.
460 //
461 // Inputs:
462 //   inputs[0]: required: the input activations array
463 //   inputs[1]: required: the Conv weights
464 //   inputs[2]: optional: the bias vector, specifying the biases for each output
465 //   channel.
466 //
467 // Outputs:
468 //   outputs[0]: required: the output activations array
469 //   outputs[1]: optional: the intermediate array of im2col-replicated input
470 //                         activations. Present when targeting implementations
471 //                         of Conv layers as Im2col+GEMM.
472 //
473 // TensorFlow equivalent: Conv2D
474 struct ConvOperator : Operator {
475   ConvOperator() : Operator(OperatorType::kConv) {}
476   Padding padding;
477   int stride_width = 0;
478   int stride_height = 0;
479   // A dilation_rate of 0 is invalid and this field is an optional attribute.
480   // Thus initializing it to 1 to allow default conv behavior when the
481   // attribute is not present.
482   int dilation_width_factor = 1;
483   int dilation_height_factor = 1;
484 };
485 
486 // CTCBeamSearchDecoder operator:
487 //
488 // Inputs:
489 //   inputs[0]: required: the logits.
490 //   inputs[1]: required: sequence length.
491 //   inputs[2]: optional: beam width.
492 //   inputs[3]: optional: top paths.
493 //   inputs[4]: optional: merge repeated.
494 //
495 //  Outputs:
496 //    outputs[0]: decoded.
497 //    outputs[1]: log probability.
498 //
499 // TensorFlow equivalent: CTCBeamSearchDecoder
500 struct CTCBeamSearchDecoderOperator : Operator {
501   CTCBeamSearchDecoderOperator()
502       : Operator(OperatorType::kCTCBeamSearchDecoder) {}
503   int beam_width;
504   int top_paths;
505   bool merge_repeated = true;
506 };
507 
508 // Depthwise-separable convolution operator.
509 //
510 // Inputs:
511 //   inputs[0]: required: the input activations array
512 //   inputs[1]: required: the DepthwiseConv weights
513 //   inputs[2]: optional: the bias vector, specifying the biases for each output
514 //   channel.
515 //
516 // TensorFlow equivalent: DepthwiseConv2dNative
517 struct DepthwiseConvOperator : Operator {
518   DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {}
519   Padding padding;
520   int stride_height = 0;
521   int stride_width = 0;
522   int depth_multiplier = 0;
523   // A dilation_rate of 0 is invalid and this field is an optional attribute.
524   // Thus initializing it to 1 to allow default conv behavior when the
525   // attribute is not present.
526   int dilation_width_factor = 1;
527   int dilation_height_factor = 1;
528 };
529 
530 // Depth-to-space transform operator.
531 //
532 // Inputs:
533 //   inputs[0]: required: the input activations array
534 //
535 // TensorFlow equivalent: DepthToSpace
536 struct DepthToSpaceOperator : Operator {
537   DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {}
538   int block_size = 0;
539 };
540 
541 // Space-to-depth transform operator.
542 //
543 // Inputs:
544 //   inputs[0]: required: the input activations array
545 //
546 // TensorFlow equivalent: SpaceToDepth
547 struct SpaceToDepthOperator : Operator {
548   SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {}
549   int block_size = 0;
550 };
551 
552 // Fully-connected operator.
553 //
554 // Inputs:
555 //   inputs[0]: required: the input activations array
556 //   inputs[1]: required: the FullyConnected weights
557 //   inputs[2]: optional: the bias vector, specifying the biases for each output
558 //   channel.
559 //
560 // TensorFlow equivalent: a pair consisting of a Reshape node reshaping the
561 // input activations as a matrix, followed by a MatMul node.
562 struct FullyConnectedOperator : Operator {
563   FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
564   FullyConnectedWeightsFormat weights_format =
565       FullyConnectedWeightsFormat::kDefault;
566 
567   // `keep_num_dims` is supported in the FullyConnected kernel version 5, but
568   // it's never supported by Toco.
569   bool keep_num_dims = false;
570 };
571 
572 // Dequantization operator, converting a quantized array of integers with
573 // quantization parameters specifying how these integers correspond to real
574 // numbers
575 // (see QuantizationParams) to an output activations array of floating-point
576 // values.
577 //
578 // In floating-point image models, there is typically a Dequantization operator
579 // at the very beginning, converting the input image RGB data, consisting of
580 // uint8 integer values, to floating-point input activations. That is where
581 // image model parameters such as "mean_value" and "std_value" are typically
582 // handled.
583 //
584 // This is the only operator type that converts from quantized to
585 // floating-point,
586 // and there is at the moment no operator type at all to convert from
587 // floating-point
588 // to quantized. Every other operator does either float->float or
589 // quantized->quantized.
590 //
591 // Inputs:
592 //   inputs[0]: required: the input quantized activations array
593 //
594 // TensorFlow equivalent: Dequantize
595 struct DequantizeOperator : Operator {
596   DequantizeOperator() : Operator(OperatorType::kDequantize) {}
597 };
598 
599 // Numeric verification operator, converting a quantized array of integers with
600 // quantization parameters specifying how these integers correspond to real
601 // numbers
602 // (see QuantizationParams) and verify them with an array of floating-point
603 // values.
604 
605 // Inputs:
606 //   inputs[0]: required: the input quantized activations array
607 //   inputs[1]: required: the input reference activations array
608 //
609 // TensorFlow equivalent: Dequantize
610 struct NumericVerifyOperator : Operator {
611   NumericVerifyOperator() : Operator(OperatorType::kNumericVerify) {}
612 };
613 
614 // Batch-normalization operator.
615 //
616 // We only support batch-normalization using pre-learned moments, so this is
617 // just
618 // computing (input - mean) * multiplier + offset. As such, this can be
619 // expressed as a combination of Add and Mul nodes, and indeed this is how
620 // we break it down during tooling for the purpose of fusing it into
621 // other operators.
622 //
623 // Inputs:
624 //   inputs[0]: required: the input activations array
625 //   inputs[1]: required: the learned mean array
626 //   inputs[2]: required: the learned multiplier array
627 //   inputs[3]: required: the learned offset array
628 //
629 // TensorFlow equivalent: a combination of Add and Mul nodes
630 struct BatchNormalizationOperator : Operator {
631   BatchNormalizationOperator()
632       : Operator(OperatorType::kBatchNormalization),
633         global_normalization(false) {}
634   bool global_normalization;
635 };
636 
637 // L2-normalization operator.
638 //
639 // Inputs:
640 //   inputs[0]: required: the input activations array
641 //
642 // TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented
643 // by a sub-graph of operators implementing L2-normalization
644 // from lower-level arithmetic nodes; during tooling, we identify such
645 // sub-graphs
646 // and replace them by L2NormalizationOperator's. See IdentifyL2Normalization.
647 struct L2NormalizationOperator : Operator {
648   L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {}
649 };
650 
651 // LSTM Cell operator.
652 //
653 // Inputs:
654 //   inputs[0]: required: the input data array
655 //   inputs[1]: required: the previous output activations array
656 //   inputs[2]: required: the learned weights array
657 //   inputs[3]: required: the learned biases array
658 //   inputs[4]: required: the previous output state
659 //   outputs[0]: required: the output activations array
660 //   outputs[1]: required: the new state array
661 //
662 // TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented
663 // with a sub-graph of lower-level arithmetic nodes; during tooling, we identify
664 // such sub-graphs and replace them with LstmCells. See IdentifyLstmCell().
665 struct LstmCellOperator : Operator {
666   enum Inputs {
667     DATA_INPUT = 0,
668     PREV_ACTIV_INPUT = 1,
669     WEIGHTS_INPUT = 2,
670     BIASES_INPUT = 3,
671     PREV_STATE_INPUT = 4,
672     NUM_INPUTS = 5
673   };
674   enum Outputs {
675     ACTIV_OUTPUT = 0,
676     STATE_OUTPUT = 1,
677     CONCAT_TEMP = 2,
678     ACTIV_TEMP = 3,
679     NUM_OUTPUTS = 4
680   };
681   enum KernelType {
682     KERNEL_BASIC = 0,
683     KERNEL_FULL = 1,
684   };
685 
686   LstmCellOperator()
687       : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {}
688 
689   KernelType kernel_type;
690 };
691 
692 struct UnidirectionalSequenceLstmOperator : Operator {
693   UnidirectionalSequenceLstmOperator()
694       : Operator(OperatorType::kUnidirectionalSequenceLstm) {}
695 };
696 
697 struct BidirectionalSequenceLstmOperator : Operator {
698   BidirectionalSequenceLstmOperator()
699       : Operator(OperatorType::kBidirectionalSequenceLstm) {}
700   bool merge_outputs;
701 };
702 
703 struct BidirectionalSequenceRnnOperator : Operator {
704   BidirectionalSequenceRnnOperator()
705       : Operator(OperatorType::kBidirectionalSequenceRnn) {}
706   bool merge_outputs;
707 };
708 
709 // Element-wise multiplication operator.
710 //
711 // Inputs:
712 //   inputs[0]: required: the left-hand side array
713 //   inputs[1]: required: the right-hand side array
714 //
715 // TensorFlow equivalent: Mul
716 struct MulOperator : Operator {
717   MulOperator() : Operator(OperatorType::kMul) {}
718 };
719 
720 // Element-wise Abs operator:
721 //   x -> abs(x)
722 //
723 // Inputs:
724 //   inputs[0]: required: the input array
725 //
726 // TensorFlow equivalent: abs
727 struct AbsOperator : Operator {
728   AbsOperator() : Operator(OperatorType::kAbs) {}
729 };
730 
731 // Element-wise HardSwish operator:
732 //   x -> x * relu6(x+3)/6
733 //
734 // Inputs:
735 //   inputs[0]: required: the input array
736 //
737 // TensorFlow equivalent: hard_swish
738 struct HardSwishOperator : Operator {
739   HardSwishOperator() : Operator(OperatorType::kHardSwish) {}
740 };
741 
742 // Elu
743 //   f(x) -> exp(x) - 1 for x < 0, x for x >= 0.
744 //
745 // Inputs:
746 //   inputs[0]: required: the input array
747 //
748 // TensorFlow equivalent: Elu
749 struct EluOperator : Operator {
750   EluOperator() : Operator(OperatorType::kElu) {}
751 };
752 
753 // Element-wise Relu operator:
754 //   x -> max(0, x)
755 //
756 // Inputs:
757 //   inputs[0]: required: the input array
758 //
759 // TensorFlow equivalent: Relu
760 struct ReluOperator : Operator {
761   ReluOperator() : Operator(OperatorType::kRelu) {}
762 };
763 
764 // Element-wise Relu1 operator:
765 //   x -> min(max(x, -1), 1)
766 //
767 // Inputs:
768 //   inputs[0]: required: the input array
769 //
770 // TensorFlow equivalent: none. We can construct the operator with Minimum
771 // and Maximum operations
772 struct Relu1Operator : Operator {
773   Relu1Operator() : Operator(OperatorType::kRelu1) {}
774 };
775 
776 // Element-wise Relu6 operator:
777 //   x -> max(0, min(6, x))
778 //
779 // Inputs:
780 //   inputs[0]: required: the input array
781 //
782 // TensorFlow equivalent: Relu6
783 struct Relu6Operator : Operator {
784   Relu6Operator() : Operator(OperatorType::kRelu6) {}
785 };
786 
787 // PRelu
788 //   f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
789 //
790 // Inputs:
791 //   inputs[0]: required: the input array
792 //   inputs[1]: required: the alpha array
793 //
794 // Equivalent to keras.layers.PReLU.
795 struct PReluOperator : Operator {
796   PReluOperator() : Operator(OperatorType::kPRelu) {}
797 };
798 
799 // LeakyRelu
800 //   x -> max(x, alpha * x)
801 //
802 // Inputs:
803 //   inputs[0]: required: the input array
804 //
805 // TensorFlow equivalent: LeakyRelu
806 struct LeakyReluOperator : Operator {
807   LeakyReluOperator() : Operator(OperatorType::kLeakyRelu) {}
808 
809   float alpha = 0.2f;  // 0.2 matches the default value for the TF op attribute.
810 };
811 
812 // Element-wise Logistic operator:
813 //   x -> Logistic(x) = 1 / (1 + exp(-x))
814 //
815 // Inputs:
816 //   inputs[0]: required: the input array
817 //
818 // TensorFlow equivalent: Sigmoid
819 struct LogisticOperator : Operator {
820   LogisticOperator() : Operator(OperatorType::kLogistic) {}
821 };
822 
823 // Element-wise natural log operator:
824 //   x -> ln(x)
825 //
826 // Inputs:
827 //   inputs[0]: required: the input array
828 //
829 // TensorFlow equivalent: Log
830 struct LogOperator : Operator {
831   LogOperator() : Operator(OperatorType::kLog) {}
832 };
833 
834 // Element-wise Tanh operator:
835 //   x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
836 //
837 // Inputs:
838 //   inputs[0]: required: the input array
839 //
840 // TensorFlow equivalent: Tanh
841 struct TanhOperator : Operator {
842   TanhOperator() : Operator(OperatorType::kTanh) {}
843 };
844 
845 // Element-wise Sin operator:
846 //   x -> Sin(x) = sin(x)
847 //
848 // Inputs:
849 //   inputs[0]: required: the input array
850 //
851 // TensorFlow equivalent: Sin
852 struct SinOperator : Operator {
853   SinOperator() : Operator(OperatorType::kSin) {}
854 };
855 
856 // Element-wise addition operator.
857 //
858 // Inputs:
859 //   inputs[0]: required: the left-hand side array
860 //   inputs[1]: required: the right-hand side array
861 //
862 // TensorFlow equivalent: Add
863 struct AddOperator : Operator {
864   AddOperator() : Operator(OperatorType::kAdd) {}
865 };
866 
867 // Element-wise addition operator for N inputs.
868 //
869 // Inputs:
870 //   inputs[i]: The i-th array to add together to form the output.
871 //
872 // TensorFlow equivalent: AddN
873 struct AddNOperator : Operator {
874   AddNOperator() : Operator(OperatorType::kAddN) {}
875 };
876 
877 // Concatenation operator: concatenates its inputs
878 // along the axis.
879 //
880 // Inputs: this operator accepts any number >= 1 of inputs.
881 //   inputs[i]: the i-th array to concatenate.
882 //
883 // TensorFlow equivalent: Concat.
884 struct ConcatenationOperator : Operator {
885   ConcatenationOperator() : Operator(OperatorType::kConcatenation) {}
886   int axis = 0;
887 };
888 
889 // Reordering dimensions. Used only during tooling to transform graphs from
890 // the TensorFlow format.
891 //
892 // Inputs:
893 //   inputs[0]: required: the input array
894 //
895 // TensorFlow equivalent: none. This is only useful to convert between formats.
896 struct ReorderAxesOperator : Operator {
897   ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {}
898   AxesOrder input_axes_order;
899   AxesOrder output_axes_order;
900 };
901 
902 // Average-pooling operator.
903 //
904 // Inputs:
905 //   inputs[0]: required: the input array
906 //
907 // TensorFlow equivalent: AveragePool
908 struct AveragePoolOperator : Operator {
909   AveragePoolOperator() : Operator(OperatorType::kAveragePool) {}
910   Padding padding;
911   int stride_height = 0;
912   int stride_width = 0;
913   int kheight = 0;
914   int kwidth = 0;
915 };
916 
917 // Local response normalization operator.
918 //
919 // Inputs:
920 //   inputs[0]: required: the input array
921 //
922 // TensorFlow equivalent: LRN
923 struct LocalResponseNormalizationOperator : Operator {
924   LocalResponseNormalizationOperator()
925       : Operator(OperatorType::kLocalResponseNormalization) {}
926 
927   int range = 0;
928   float bias = 0.f;
929   float alpha = 0.f;
930   float beta = 0.f;
931 };
932 
933 // Max-pooling operator.
934 //
935 // Inputs:
936 //   inputs[0]: required: the input array
937 //
938 // TensorFlow equivalent: MaxPool
939 struct MaxPoolOperator : Operator {
940   MaxPoolOperator() : Operator(OperatorType::kMaxPool) {}
941   Padding padding;
942   int stride_height = 0;
943   int stride_width = 0;
944   int kheight = 0;
945   int kwidth = 0;
946 };
947 
948 // L2-pooling operator.
949 //
950 // Inputs:
951 //   inputs[0]: required: the input array
952 //
953 // TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt.
954 struct L2PoolOperator : Operator {
955   L2PoolOperator() : Operator(OperatorType::kL2Pool) {}
956   Padding padding;
957   int stride_height = 0;
958   int stride_width = 0;
959   int kheight = 0;
960   int kwidth = 0;
961 };
962 
963 // The expected [min, max] range of values in a given array.
964 // Used for quantization only.
965 // This information typically comes from special nodes found in quantized
966 // models, see FakeQuantOperator, and is used during quantization to resolve
967 // actual quantization parameters (see QuantizationParams).
968 struct MinMax {
969   double min = 0.;
970   double max = 0.;
971 };
972 
973 inline bool operator==(const MinMax& m1, const MinMax& m2) {
974   return m1.min == m2.min && m1.max == m2.max;
975 }
976 
977 inline bool operator!=(const MinMax& m1, const MinMax& m2) {
978   return m1.min != m2.min || m1.max != m2.max;
979 }
980 
981 // Fake-quantization operator. This does two things:
982 //   - Annotate its input and output arrays with MinMax information,
983 //   - Arithmetic-wise, this operator rounds incoming activation values
984 //     to the nearest representable value on the scale of 256
985 //     values from the min to the max value dictated by its MinMax info.
986 //
987 // Inputs:
988 //   inputs[0]: required: the input array
989 //   inputs[1]: optional: the 'min' value, if it has not yet been resolved
990 //              to a constant.
991 //   inputs[2]: optional: the 'max' value, if it has not yet been resolved
992 //              to a constant.
993 //
994 // TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs.
995 struct FakeQuantOperator : Operator {
996   FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
997   std::unique_ptr<MinMax> minmax;
998   int num_bits = 8;
999   bool narrow_range = false;
1000 };
1001 
1002 // Element-wise division operator.
1003 //
1004 // Inputs:
1005 //   inputs[0]: required: the left-hand side array
1006 //   inputs[1]: required: the right-hand side array
1007 //
1008 // TensorFlow equivalent: Div
1009 struct DivOperator : Operator {
1010   DivOperator() : Operator(OperatorType::kDiv) {}
1011 };
1012 
1013 // Element-wise identity (x->x) operator.
1014 //
1015 // Inputs:
1016 //   inputs[0]: required: the input array
1017 //
1018 // TensorFlow equivalent: Identity
1019 struct TensorFlowIdentityOperator : Operator {
1020   TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
1021 };
1022 
1023 // Batch matrix multiplication operator. This comes from a tf.matmul where one
1024 // of the operands has rank 3 or more.
1025 //
1026 // Inputs:
1027 //   inputs[0]: required: the left-hand side matrix
1028 //   inputs[1]: required: the right-hand side matrix
1029 //
1030 // TensorFlow equivalent: MatMul
1031 struct BatchMatMulOperator : Operator {
1032   BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {}
1033   bool adj_x = false;
1034   bool adj_y = false;
1035 };
1036 
1037 // General matrix multiplication operator. We don't want to support general
1038 // matrix multiplication at inference time, so we resolve it during tooling
1039 // to more specific operator types, namely, FullyConnected.
1040 //
1041 // Inputs:
1042 //   inputs[0]: required: the left-hand side matrix
1043 //   inputs[1]: required: the right-hand side matrix
1044 //
1045 // TensorFlow equivalent: MatMul
1046 struct TensorFlowMatMulOperator : Operator {
1047   TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {}
1048   bool transpose_a = false;
1049   bool transpose_b = false;
1050 };
1051 
1052 // Padding operator. Pads a tensor with zeros.
1053 //
1054 // Inputs:
1055 //   inputs[0]: required: the input array
1056 //   inputs[1]: required: the padding array
1057 //
1058 // This operation pads a `input` with zeros according to the `paddings` you
1059 // specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
1060 // rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
1061 // how many zeros to add before the contents of `input` in that dimension, and
1062 // `paddings[D, 1]` indicates how many zeros to add after the contents of
1063 // `input` in that dimension.
1064 //
1065 // TensorFlow equivalent: Pad
1066 struct PadOperator : Operator {
1067   PadOperator() : Operator(OperatorType::kPad) {}
1068 
1069   std::vector<int> left_padding;
1070   std::vector<int> right_padding;
1071 };
1072 
1073 // PaddingV2 operator. Pads a tensor with the given constant value.
1074 //
1075 // Inputs:
1076 //   inputs[0]: required: the input array
1077 //   inputs[1]: required: the padding array
1078 //   inputs[2]: required: the scalar constant_values
1079 //
1080 // This operation pads input according to the paddings and constant_values you
1081 // specify. paddings is an integer tensor with shape [Dn, 2], where n is the
1082 // rank of input. For each dimension D of input, paddings[D, 0] indicates how
1083 // many padding values to add before the contents of input in that dimension,
1084 // and paddings[D, 1] indicates how many padding values to add after the
1085 // contents of input in that dimension. constant_values is a scalar tensor of
1086 // the same type as input that indicates the value to use for padding input.
1087 //
1088 // TensorFlow equivalent: PadV2
1089 struct PadV2Operator : Operator {
1090   PadV2Operator() : Operator(OperatorType::kPadV2) {}
1091 
1092   std::vector<int> left_padding;
1093   std::vector<int> right_padding;
1094 };
1095 
1096 // Strided slice operator.
1097 //
1098 // Inputs:
1099 //   inputs[0]: required: the input array
1100 //   inputs[1]: required: the begin array
1101 //   inputs[2]: required: the end array
1102 //   inputs[3]: optional: the strides array
1103 //
1104 // TensorFlow equivalent: StridedSlice
1105 struct StridedSliceOperator : Operator {
1106   StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {}
1107 
1108   std::vector<int> start_indices;
1109   std::vector<int> stop_indices;
1110   std::vector<int> strides;
1111 
1112   int begin_mask;
1113   int ellipsis_mask;
1114   int end_mask;
1115   int new_axis_mask;
1116   int shrink_axis_mask;
1117 
1118   StridedSliceOperator(const StridedSliceOperator& other)
1119       : Operator(OperatorType::kStridedSlice) {
1120     inputs = other.inputs;
1121     outputs = other.outputs;
1122 
1123     start_indices = other.start_indices;
1124     stop_indices = other.stop_indices;
1125     strides = other.strides;
1126 
1127     begin_mask = other.begin_mask;
1128     ellipsis_mask = other.ellipsis_mask;
1129     end_mask = other.end_mask;
1130     new_axis_mask = other.new_axis_mask;
1131     shrink_axis_mask = other.shrink_axis_mask;
1132   }
1133 
1134   void PadIndices(int dim_count) {
1135     // Add indices and mask bits to fully include extra dimensions
1136     CHECK_GE(dim_count, start_indices.size());
1137     CHECK_EQ(start_indices.size(), stop_indices.size());
1138     CHECK_EQ(stop_indices.size(), strides.size());
1139 
1140     for (int i = start_indices.size(); i < dim_count; i++) {
1141       start_indices.push_back(0);
1142       stop_indices.push_back(0);
1143       strides.push_back(1);
1144       begin_mask |= 1 << i;
1145       end_mask |= 1 << i;
1146     }
1147   }
1148 
1149   void ReverseIndices() {
1150     CHECK_EQ(start_indices.size(), stop_indices.size());
1151     CHECK_EQ(stop_indices.size(), strides.size());
1152 
1153     std::reverse(start_indices.begin(), start_indices.end());
1154     std::reverse(stop_indices.begin(), stop_indices.end());
1155     std::reverse(strides.begin(), strides.end());
1156 
1157     begin_mask = toco::port::ReverseBits32(static_cast<uint32>(begin_mask)) >>
1158                  (32 - start_indices.size());
1159     ellipsis_mask =
1160         toco::port::ReverseBits32(static_cast<uint32>(ellipsis_mask)) >>
1161         (32 - start_indices.size());
1162     end_mask = toco::port::ReverseBits32(static_cast<uint32>(end_mask)) >>
1163                (32 - start_indices.size());
1164     new_axis_mask =
1165         toco::port::ReverseBits32(static_cast<uint32>(new_axis_mask)) >>
1166         (32 - start_indices.size());
1167     shrink_axis_mask =
1168         toco::port::ReverseBits32(static_cast<uint32>(shrink_axis_mask)) >>
1169         (32 - start_indices.size());
1170   }
1171 };
1172 
1173 // Reshaping operator, reshaping its input array to a two-dimensional shape
1174 // (a "matrix"). This is used in the TensorFlow format, in conjunction with
1175 // MatMul nodes, to implement fully-connected layers.
1176 //
1177 // Inputs:
1178 //   inputs[0]: required: the input array
1179 //   inputs[1]: optional: the output tensor shape
1180 //
1181 // TensorFlow equivalent: Reshape --- except that we only support a special case
1182 // here, where the output shape is a matrix (2D) shape.
1183 struct TensorFlowReshapeOperator : Operator {
1184   TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {}
1185   std::vector<int> shape;
1186 };
1187 
1188 // Removes dimensions of size 1 from the shape of a tensor.
1189 // https://www.tensorflow.org/api_docs/python/tf/squeeze
1190 //
1191 // Inputs:
1192 //   inputs[0]: required: the input array
1193 //
1194 // TensorFlow equivalent: Squeeze
1195 struct SqueezeOperator : Operator {
1196   SqueezeOperator() : Operator(OperatorType::kSqueeze) {}
1197 
1198   std::vector<int> squeeze_dims;
1199 };
1200 
1201 // Inputs:
1202 //   inputs[0]: required: the output shape
1203 //   inputs[1]: required: the weights
1204 //   inputs[2]: required: the input activations array
1205 //   inputs[3]: optional: the bias vector, specifying the biases for each output
1206 //                        channel.
1207 //   NOTE: The input activations is NOT the first input.
1208 //
1209 //
1210 // Outputs:
1211 //   outputs[0]: required: the output activations array
1212 //
1213 // TensorFlow equivalent: Conv2DBackpropInput
1214 struct TransposeConvOperator : Operator {
1215   enum Inputs {
1216     OUTPUT_SHAPE = 0,
1217     WEIGHTS = 1,
1218     DATA_INPUT = 2,
1219     BIAS = 3,
1220   };
1221 
1222   TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {}
1223   Padding padding;
1224   int stride_width = 0;
1225   int stride_height = 0;
1226   // Dilation is possible with transpose convolution, but Tensorflow does not
1227   // currently support it, so we omit it.
1228 };
1229 
1230 // Given a tensor input, this operation calculates element-wise exponential
1231 // (y = e^x).
1232 //
1233 // Inputs:
1234 //   inputs[0]: required: input tensor
1235 //
1236 // TensorFlow equivalent: Exp
1237 struct ExpOperator : Operator {
1238   ExpOperator() : Operator(OperatorType::kExp) {}
1239 };
1240 
1241 // Given a tensor input, this operation calculates element-wise exponential
1242 // (y = cos(x)).
1243 //
1244 // Inputs:
1245 //   inputs[0]: required: input tensor
1246 //
1247 // TensorFlow equivalent: Cos
1248 struct CosOperator : Operator {
1249   CosOperator() : Operator(OperatorType::kCos) {}
1250 };
1251 
1252 // Given a tensor input, this operation inserts a dimension of 1 at the
1253 // dimension index axis of input's shape. The dimension index axis starts at
1254 // zero; if you specify a negative number for axis it is counted backward from
1255 // the end.
1256 //
1257 // Inputs:
1258 //   inputs[0]: required: input tensor
1259 //   inputs[1]: required: 0-D (scalar). Specifies the dimension index at which
1260 //   to expand the shape of input
1261 //
1262 // TensorFlow equivalent: ExpandDims
1263 struct ExpandDimsOperator : Operator {
1264   ExpandDimsOperator() : Operator(OperatorType::kExpandDims) {}
1265 };
1266 
1267 // Creates a tensor of shape dims and fills it with the given scalar value.
1268 // Output type will be the same as the given scalar value.
1269 //
1270 // Inputs:
1271 //   inputs[0]: required: 1-D (int32) - the shape of the output tensor
1272 //   inputs[1]: required: 0-D (scalar) - value to fill the tensor with
1273 //
1274 // TensorFlow equivalent: Fill
1275 struct FillOperator : Operator {
1276   FillOperator() : Operator(OperatorType::kFill) {}
1277 };
1278 
1279 // Element-wise floor division operator.
1280 //
1281 // Inputs:
1282 //   inputs[0]: required: the left-hand side array
1283 //   inputs[1]: required: the right-hand side array
1284 //
1285 // TensorFlow equivalent: FloorDiv
1286 struct FloorDivOperator : Operator {
1287   FloorDivOperator() : Operator(OperatorType::kFloorDiv) {}
1288 };
1289 
1290 // Element-wise floor mod operator.
1291 //
1292 // Inputs:
1293 //   inputs[0]: required: the left-hand side array
1294 //   inputs[1]: required: the right-hand side array
1295 //
1296 // TensorFlow equivalent: FloorMod
1297 struct FloorModOperator : Operator {
1298   FloorModOperator() : Operator(OperatorType::kFloorMod) {}
1299 };
1300 
1301 struct RandomUniformOperator : Operator {
1302   RandomUniformOperator() : Operator(OperatorType::kRandomUniform) {}
1303   ArrayDataType dtype = ArrayDataType::kNone;
1304   int64 seed;
1305   int64 seed2;
1306 };
1307 
1308 // Creates a sequence of numbers that begins at start and extends by increments
1309 // of delta up to but not including limit.
1310 //
1311 // The dtype of the resulting tensor is inferred from the inputs unless it is
1312 // provided explicitly.
1313 //
1314 // Inputs:
1315 //   inputs[0]: required: the start
1316 //   inputs[1]: required: the limit
1317 //   inputs[2]: required: the delta
1318 //
1319 // TensorFlow equivalent: Range
1320 struct RangeOperator : Operator {
1321   RangeOperator() : Operator(OperatorType::kRange) {}
1322   ArrayDataType dtype = ArrayDataType::kNone;
1323 };
1324 
1325 // Rank operator. Extracts the rank of the tensor.
1326 //
1327 // Inputs:
1328 //   inputs[0]: required: the input array
1329 //
1330 // This operation outputs a 0-D int32 Tensor representing the rank of input.
1331 //
1332 // TensorFlow equivalent: Rank.
1333 struct TensorFlowRankOperator : Operator {
1334   TensorFlowRankOperator() : Operator(OperatorType::kRank) {}
1335   ArrayDataType output_data_type = ArrayDataType::kInt32;
1336 };
1337 
1338 // Element-wise negation (-x) operator.
1339 //
1340 // Inputs:
1341 //   inputs[0]: required: the input array
1342 //
1343 // TensorFlow equivalent: Neg
1344 struct NegOperator : Operator {
1345   NegOperator() : Operator(OperatorType::kNeg) {}
1346 };
1347 
1348 // Element-wise select operator choosing elements from inputs[1] or input[2]
1349 //
1350 // Inputs:
1351 //  inputs[0]: required: boolean mask per index
1352 //  inputs[1]: required: tensor of values if true
1353 //  inputs[2]: required: tensor of values if false
1354 //
1355 //  TensorFlow equivalent: Select
1356 struct SelectOperator : Operator {
1357   SelectOperator() : Operator(OperatorType::kSelect) {}
1358 };
1359 
1360 // Element-wise reciprocal-square-root (x^-0.5) operator.
1361 //
1362 // Inputs:
1363 //   inputs[0]: required: the input array
1364 //
1365 // TensorFlow equivalent: Rsqrt
1366 struct TensorFlowRsqrtOperator : Operator {
1367   TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {}
1368 };
1369 
1370 // Stacks a list of rank-R tensors into one rank-(R+1) tensor.
1371 //
1372 // Packs the list of tensors in values into a tensor with rank one higher than
1373 // each tensor in values, by packing them along the axis dimension. Given a list
1374 // of length N of tensors of shape (A, B, C);.
1375 //
1376 // Inputs: this operator accepts any number >= 1 of inputs.
1377 //   inputs[i]: the i-th array to merge.
1378 //
1379 // TensorFlow equivalent: Pack
1380 struct PackOperator : Operator {
1381   PackOperator() : Operator(OperatorType::kPack) {}
1382   int values_count;
1383   int axis = 0;
1384   ArrayDataType dtype = ArrayDataType::kNone;
1385 };
1386 
1387 // Shape operator. Extracts the shape of the tensor.
1388 //
1389 // Inputs:
1390 //   inputs[0]: required: the input array
1391 //
1392 // This operation outputs a 1-D integer tensor representing the shape of
1393 // the input.
1394 //
1395 // TensorFlow equivalent: Shape.
1396 struct TensorFlowShapeOperator : Operator {
1397   TensorFlowShapeOperator() : Operator(OperatorType::kShape) {}
1398   ArrayDataType output_data_type = ArrayDataType::kInt32;
1399 };
1400 
1401 // Element-wise square-root (x^0.5) operator.
1402 //
1403 // Inputs:
1404 //   inputs[0]: required: the input array
1405 //
1406 // TensorFlow equivalent: Sqrt
1407 struct TensorFlowSqrtOperator : Operator {
1408   TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {}
1409 };
1410 
1411 // Element-wise square (x*x) operator.
1412 //
1413 // Inputs:
1414 //   inputs[0]: required: the input array
1415 //
1416 // TensorFlow equivalent: Square
1417 struct TensorFlowSquareOperator : Operator {
1418   TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {}
1419 };
1420 
1421 // Element-wise squared difference ((x-y)*(x-y)) operator.
1422 //
1423 // Inputs:
1424 //   inputs[0]: required: the left-hand side array
1425 //   inputs[1]: required: the right-hand side array
1426 //
1427 // TensorFlow equivalent: SquaredDifference
1428 struct SquaredDifferenceOperator : Operator {
1429   SquaredDifferenceOperator() : Operator(OperatorType::kSquaredDifference) {}
1430 };
1431 
1432 // Transposes a tensor.
1433 //
1434 // By default, this operation performs a regular matrix transpose on 2-D input
1435 // tensors.
1436 //
1437 // Inputs:
1438 //   inputs[0]: required: the input array
1439 //
1440 // TensorFlow equivalent: Transpose
1441 struct TransposeOperator : Operator {
1442   TransposeOperator() : Operator(OperatorType::kTranspose) {}
1443   std::vector<int> perm;
1444 };
1445 
1446 // Element-wise subtraction operator.
1447 //
1448 // Inputs:
1449 //   inputs[0]: required: the left-hand side array
1450 //   inputs[1]: required: the right-hand side array
1451 //
1452 // TensorFlow equivalent: Sub
1453 struct SubOperator : Operator {
1454   SubOperator() : Operator(OperatorType::kSub) {}
1455 };
1456 
1457 // Sum reduction: computes the sum of all of entries across the axes.
1458 //
1459 // Inputs:
1460 //   inputs[0]: required: the input array
1461 //
1462 // TensorFlow equivalent: Sum
1463 struct TensorFlowSumOperator : Operator {
1464   TensorFlowSumOperator() : Operator(OperatorType::kSum) {}
1465   std::vector<int> axis;
1466   bool keep_dims = false;
1467 };
1468 
1469 // Prod reduction: computes the product of all of entries across the axes.
1470 //
1471 // Inputs:
1472 //   inputs[0]: required: the input array
1473 //
1474 // TensorFlow equivalent: Prod
1475 struct TensorFlowProdOperator : Operator {
1476   TensorFlowProdOperator() : Operator(OperatorType::kReduceProd) {}
1477   std::vector<int> axis;
1478   bool keep_dims = false;
1479 };
1480 
1481 // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
1482 //
1483 // Inputs:
1484 //   inputs[0]: required: the input array
1485 //   inputs[1]: required: int array with length of rank(input[0])
1486 struct TensorFlowTileOperator : Operator {
1487   TensorFlowTileOperator() : Operator(OperatorType::kTile) {}
1488 };
1489 
1490 // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
1491 struct SliceOperator : Operator {
1492   SliceOperator() : Operator(OperatorType::kSlice) {}
1493 
1494   std::vector<int> begin;
1495   std::vector<int> size;
1496 };
1497 
1498 // TensorFlow Split equivalent. Refer to TensorFlow documentation for details.
1499 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1500 // support graph transformations to other operator types by matching sub-graphs.
1501 struct TensorFlowSplitOperator : Operator {
1502   TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {}
1503   int num_split = 0;
1504 };
1505 
1506 // TensorFlow SplitV equivalent. Refer to TensorFlow documentation for details.
1507 struct TensorFlowSplitVOperator : Operator {
1508   TensorFlowSplitVOperator() : Operator(OperatorType::kSplitV) {}
1509   int num_split = 0;
1510 };
1511 
1512 // TensorFlow Concat equivalent. Refer to TensorFlow documentation for details.
1513 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1514 // support graph transformations to other operator types by matching sub-graphs.
1515 // Concretely, once the concat dim becomes known, if it is the depth
1516 // dimension then we can change this op into a DepthConcatenation op.
1517 // Otherwise, we hope for some other graph transformation to drop this node.
1518 struct TensorFlowConcatOperator : Operator {
1519   TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {}
1520 };
1521 
1522 // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
1523 // details.
1524 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1525 // support graph transformations to other operator types by matching sub-graphs.
1526 // Concretely, once the concat dim becomes known, if it is the depth
1527 // dimension then we can change this op into a DepthConcatenation op.
1528 // Otherwise, we hope for some other graph transformation to drop this node.
1529 struct TensorFlowConcatV2Operator : Operator {
1530   TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {}
1531 };
1532 
1533 // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
1534 //
1535 // Inputs: this operator accepts any number >= 1 of inputs.
1536 //   inputs[i]: the i-th array to merge.
1537 //
1538 // It is expected that graph transformations will drop all but exactly one
1539 // of the inputs, at which point the Merge node will be equivalent to an
1540 // Identity node forwarding the remaining input.
1541 //
1542 // Note: We do not currently support runtime control flow: we only support
1543 // control flow that can be resolved at tooling time (independently of input
1544 // activations).
1545 struct TensorFlowMergeOperator : Operator {
1546   TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {}
1547 };
1548 
1549 // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
1550 //
1551 // Inputs:
1552 //   inputs[0]: required: the input array
1553 //   inputs[1]: required: the boolean predicate, given as an array of size 1
1554 //     and of type kBool, will determine which output gets selected.
1555 //
1556 // Outputs: a TensorFlow Switch node always has exactly two outputs. Depending
1557 // on the boolean value that the input predicate resolves to (see note below),
1558 // one or the other of the outputs will be 'selected': the input array will be
1559 // forwarded to the 'selected output' as if by a Identity node, while the other
1560 // output will be discarded, and any graph edge connecting that discarded output
1561 // will be dropped. The rule for selecting outputs is as follows:
1562 //   outputs[0] will be selected if the input predicate resolves to 'true'.
1563 //   outputs[1] will be selected if the input predicate resolves to 'false'.
1564 //
1565 // Note: We do not currently support runtime control flow: we only support
1566 // control flow that can be resolved at tooling time (independently of input
1567 // activations).
1568 struct TensorFlowSwitchOperator : Operator {
1569   TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {}
1570 };
1571 
1572 // TensorFlow All equivalent. Refer to TensorFlow documentation for details.
1573 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1574 // support graph transformations to other operator types by matching sub-graphs.
1575 // Typically, this is only used as an input to an Assert node, so can be
1576 // removed as an unused node as we drop Assert nodes.
1577 struct TensorFlowAllOperator : Operator {
1578   TensorFlowAllOperator() : Operator(OperatorType::kAll) {}
1579 };
1580 
1581 // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
1582 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1583 // support graph transformations to other operator types by matching sub-graphs.
1584 // Typically, we just drop Assert nodes.
1585 struct TensorFlowAssertOperator : Operator {
1586   TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {}
1587 };
1588 
1589 // TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
1590 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1591 // support graph transformations to other operator types by matching sub-graphs.
1592 // Typically, this is only used as an input to an Assert node, so can be
1593 // removed as an unused node as we drop Assert nodes.
1594 struct TensorFlowLessOperator : Operator {
1595   TensorFlowLessOperator() : Operator(OperatorType::kLess) {}
1596 };
1597 
1598 // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
1599 // details.
1600 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1601 // support graph transformations to other operator types by matching sub-graphs.
1602 // Typically, this is only used as an input to an Assert node, so can be
1603 // removed as an unused node as we drop Assert nodes.
1604 struct TensorFlowLessEqualOperator : Operator {
1605   TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {}
1606 };
1607 
1608 // TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
1609 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1610 // support graph transformations to other operator types by matching sub-graphs.
1611 // Typically, this is only used as an input to an Assert node, so can be
1612 // removed as an unused node as we drop Assert nodes.
1613 struct TensorFlowGreaterOperator : Operator {
1614   TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {}
1615 };
1616 
1617 // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
1618 // details.
1619 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1620 // support graph transformations to other operator types by matching sub-graphs.
1621 // Typically, this is only used as an input to an Assert node, so can be
1622 // removed as an unused node as we drop Assert nodes.
1623 struct TensorFlowGreaterEqualOperator : Operator {
1624   TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {}
1625 };
1626 
1627 // TensorFlow Equal equivalent. Refer to TensorFlow documentation for
1628 // details.
1629 // Not fully supported, just a placeholder to handle TensorFlow graphs and
1630 // support graph transformations to other operator types by matching sub-graphs.
1631 // Typically, this is only used as an input to an Assert node, so can be
1632 // removed as an unused node as we drop Assert nodes.
1633 struct TensorFlowEqualOperator : Operator {
1634   TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {}
1635 };
1636 
1637 // TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
1638 // details.
1639 struct TensorFlowNotEqualOperator : Operator {
1640   TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {}
1641 };
1642 
1643 // Max reduction: computes the max of all of entries across the axes.
1644 //
1645 // Inputs:
1646 //   inputs[0]: required: the input array
1647 //
1648 // TensorFlow equivalent: Max
1649 struct TensorFlowMaxOperator : Operator {
1650   TensorFlowMaxOperator() : Operator(OperatorType::kReduceMax) {}
1651   std::vector<int> axis;
1652   bool keep_dims = false;
1653 };
1654 
1655 // Min reduction: computes the min of all of entries across the axes.
1656 //
1657 // Inputs:
1658 //   inputs[0]: required: the input array
1659 //
1660 // TensorFlow equivalent: Min
1661 struct TensorFlowMinOperator : Operator {
1662   TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {}
1663   std::vector<int> axis;
1664   bool keep_dims = false;
1665 };
1666 
1667 // Element-wise maximum operator. Currently it only supports scalar as
1668 // the second operand.
1669 //
1670 // Inputs:
1671 //   inputs[0]: required: the left-hand side array
1672 //   inputs[1]: required: the right-hand side array
1673 //
1674 // TensorFlow equivalent: Maximum
1675 struct TensorFlowMaximumOperator : Operator {
1676   TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {}
1677 };
1678 
1679 // Element-wise minimum operator. Currently it only supports scalar as
1680 // the second operand.
1681 //
1682 // Inputs:
1683 //   inputs[0]: required: the left-hand side array
1684 //   inputs[1]: required: the right-hand side array
1685 //
1686 // TensorFlow equivalent: Minimum
1687 struct TensorFlowMinimumOperator : Operator {
1688   TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {}
1689 };
1690 
1691 // General TF operation, unsupported by tf.mini. Expected to be dropped by
1692 // graph transformations.
1693 struct TensorFlowUnsupportedOperator : Operator {
1694   TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {}
1695 
1696   // The original TF operation type. Used for diagnostic purposes.
1697   std::string tensorflow_op;
1698   // A boolean indicating if the unsupported op should be treated as quantized.
1699   bool quantized = false;
1700   // A boolean indicating if the unsupported op output should allow float values
1701   // in quantized mode.
1702   bool support_output_type_float_in_quantized_op = false;
1703   // Output data types
1704   std::vector<ArrayDataType> output_data_types;
1705   // Output shapes.
1706   std::vector<Shape> output_shapes;
1707 };
1708 
1709 // Softmax activation function.
1710 //
1711 // Inputs:
1712 //   inputs[0]: required: the input array
1713 //
1714 // TensorFlow equivalent: Softmax
1715 struct SoftmaxOperator : Operator {
1716   SoftmaxOperator() : Operator(OperatorType::kSoftmax) {}
1717   float beta = 0.f;
1718 };
1719 
1720 // LogSoftmax activation function.
1721 //
1722 // Inputs:
1723 //   inputs[0]: required: the logits input array
1724 //
1725 // TensorFlow equivalent: LogSoftmax
1726 struct LogSoftmaxOperator : Operator {
1727   LogSoftmaxOperator() : Operator(OperatorType::kLogSoftmax) {}
1728 
1729   // LogSoftmax can in principal have very large negative output, depending on
1730   // the input size.  However, input x_i that is less than x_max-10 is
1731   // accumulated as exp(x_i-x_max), which is truncated to zero.
1732   //
1733   // Since we effectively disregard smallish inputs in the normalizing factor,
1734   // we also drop them in the output (set to minimum output), and in doing so
1735   // make better use of the quantization range / resolution.
1736   static constexpr float kOutputRangeMin = -16.0;
1737 };
1738 
1739 // Cast operator.
1740 //
1741 // Inputs:
1742 //   inputs[0]: required: the input array
1743 //
1744 // TensorFlow equivalent: Cast
1745 struct CastOperator : Operator {
1746   CastOperator() : Operator(OperatorType::kCast) {}
1747   ArrayDataType src_data_type = ArrayDataType::kNone;
1748   ArrayDataType dst_data_type = ArrayDataType::kNone;
1749 };
1750 
1751 // Floor operator.
1752 //
1753 // Inputs:
1754 //   inputs[0]: required: the input array
1755 //
1756 // TensorFlow equivalent: Floor
1757 struct FloorOperator : Operator {
1758   FloorOperator() : Operator(OperatorType::kFloor) {}
1759 };
1760 
1761 // Ceil operator.
1762 //
1763 // Inputs:
1764 //   inputs[0]: required: the input array
1765 //
1766 // TensorFlow equivalent: Ceil
1767 struct CeilOperator : Operator {
1768   CeilOperator() : Operator(OperatorType::kCeil) {}
1769 };
1770 
1771 // Round operator.
1772 //
1773 // Inputs:
1774 //   inputs[0]: required: the input array
1775 //
1776 // TensorFlow equivalent: Round
1777 struct RoundOperator : Operator {
1778   RoundOperator() : Operator(OperatorType::kRound) {}
1779 };
1780 
1781 // Gather operator. It gathers slices from params according to indices.
1782 // Only 1-D indices are supported at the moment.
1783 //
1784 // Inputs:
1785 //   inputs[0]: required: the params array
1786 //   inputs[1]: required: the indices to gather
1787 //   inputs[2]: optional: axis
1788 //
1789 // TensorFlow equivalent: Gather
1790 struct GatherOperator : Operator {
1791   GatherOperator() : Operator(OperatorType::kGather) {}
1792   // Axis is populated explicitly or implicitly from the axis input by
1793   // ResolveGatherAttributes. An empty axis indicates that the axis has not yet
1794   // be resolved.
1795   absl::optional<int> axis;
1796 
1797   // This field is not used by the standard TF Lite export but it is still need
1798   // for legacy Gather implementations.
1799   int input_rank = 0;
1800 };
1801 
1802 // GatherNd operator. It gathers slices from params according to indices.
1803 //
1804 // Inputs:
1805 //   inputs[0]: required: the params array
1806 //   inputs[1]: required: the indices to gather
1807 //
1808 // TensorFlow equivalent: GatherNd
1809 struct GatherNdOperator : Operator {
1810   GatherNdOperator() : Operator(OperatorType::kGatherNd) {}
1811 };
1812 
1813 // ArgMax operator. It returns the index of the maximum value along axis.
1814 //
1815 // Inputs:
1816 //   inputs[0]: required: the input tensor
1817 //   inputs[1]: optional: 0-D (scalar) axis
1818 //
1819 // TensorFlow equivalent: ArgMax
1820 struct ArgMaxOperator : Operator {
1821   ArgMaxOperator() : Operator(OperatorType::kArgMax) {}
1822   ArrayDataType output_data_type = ArrayDataType::kInt64;
1823 };
1824 
1825 // ArgMin operator. It returns the index of the minimum value along axis.
1826 //
1827 // Inputs:
1828 //   inputs[0]: required: the input tensor
1829 //   inputs[1]: optional: 0-D (scalar) axis
1830 //
1831 // TensorFlow equivalent: ArgMin
1832 struct ArgMinOperator : Operator {
1833   ArgMinOperator() : Operator(OperatorType::kArgMin) {}
1834   ArrayDataType output_data_type = ArrayDataType::kInt64;
1835 };
1836 
1837 // ResizeBilinear operator. It resizes input images with bilinear interpolation.
1838 // It does not support align_corners at the moment.
1839 //
1840 // Inputs:
1841 //   inputs[0]: required: the input array
1842 //   inputs[1]: required: the new image size
1843 //
1844 // TensorFlow equivalent: ResizeBilinear
1845 struct ResizeBilinearOperator : Operator {
1846   ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {}
1847 
1848   bool align_corners = false;
1849   bool half_pixel_centers = false;
1850 };
1851 
1852 // ResizeNearestNeighborOperator operator. It resizes input images with nearest
1853 // neighbor interpolation. It does not support align_corners at the moment.
1854 //
1855 // Inputs:
1856 //   inputs[0]: required: the input array
1857 //   inputs[1]: required: the new image size
1858 //
1859 // TensorFlow equivalent: ResizeNearestNeighbor
1860 struct ResizeNearestNeighborOperator : Operator {
1861   ResizeNearestNeighborOperator()
1862       : Operator(OperatorType::kResizeNearestNeighbor) {}
1863 
1864   bool align_corners = false;
1865   bool half_pixel_centers = false;
1866 };
1867 
1868 // SpaceToBatchND operator. It divides spatial dimensions into a grid of
1869 // blocks and interleaves these blocks with the batch dimension. Currently,
1870 // only 2-d blocks are supported.
1871 //
1872 // Inputs:
1873 //   inputs[0]: required: the input array
1874 //   inputs[1]: required: the block shape
1875 //   inputs[2]: required: the paddings
1876 //
1877 // TensorFlow equivalent: SpaceToBatchND
1878 struct SpaceToBatchNDOperator : Operator {
1879   SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {}
1880 
1881   std::vector<int> block_shape;
1882   std::vector<int> before_paddings;
1883   std::vector<int> after_paddings;
1884 };
1885 
1886 // BatchToSpaceND operator. Rearranges data from batch into blocks of
1887 // spatial data. Currently, only 2-d blocks are supported.
1888 //
1889 // Inputs:
1890 //   inputs[0]: required: the input array
1891 //   inputs[1]: required: the block shape
1892 //   inputs[2]: required: the crops
1893 //
1894 // TensorFlow equivalent: BatchToSpaceND
1895 struct BatchToSpaceNDOperator : Operator {
1896   BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
1897 
1898   std::vector<int> block_shape;
1899   std::vector<int> before_crops;
1900   std::vector<int> after_crops;
1901 };
1902 
1903 // Mean operator.
1904 //
1905 // Inputs:
1906 //   inputs[0]: required: the input array
1907 //
1908 // TensorFlow equivalent: Mean
1909 struct MeanOperator : Operator {
1910   MeanOperator() : Operator(OperatorType::kMean) {}
1911 
1912   std::vector<int> axis;
1913   bool keep_dims = false;
1914 };
1915 
1916 // Svdf operator:
1917 //
1918 // Inputs:
1919 //   inputs[0]: required: the input array
1920 //   inputs[1]: required: weights_feature
1921 //   inputs[2]: required: weights_time
1922 //   inputs[3]: optional: bias
1923 struct SvdfOperator : Operator {
1924   SvdfOperator() : Operator(OperatorType::kSvdf) {}
1925   int rank;
1926 };
1927 
1928 // TopKV2 operator.
1929 //
1930 // Inputs:
1931 //    input tensor and top_k scalar.
1932 struct TopKV2Operator : Operator {
1933   TopKV2Operator() : Operator(OperatorType::kTopK_V2) {}
1934 };
1935 
1936 // DynamicPartition operator:
1937 //
1938 // Inputs:
1939 //  inputs[0]: required: data.
1940 //  inputs[1]: required: partitions.
1941 //
1942 // TensorFlow equivalent: DynamicPartition
1943 struct DynamicPartitionOperator : Operator {
1944   DynamicPartitionOperator() : Operator(OperatorType::kDynamicPartition) {}
1945   int num_partitions;
1946 };
1947 
1948 // DynamicStitch operator:
1949 //
1950 // Inputs:
1951 //  inputs[0,N): required: indices.
1952 //  inputs[N,2N): required: data.
1953 //
1954 // TensorFlow equivalent: DynamicStitch/ParallelDynamicStitch
1955 struct DynamicStitchOperator : Operator {
1956   DynamicStitchOperator() : Operator(OperatorType::kDynamicStitch) {}
1957   int num_partitions;
1958 };
1959 
1960 // SparseToDense operator:
1961 //
1962 // Inputs:
1963 // Inputs[0]: required: sparse_indices.
1964 // Inputs[1]: required: output_shape.
1965 // Inputs[2]: required: sparse_values.
1966 //
1967 // TensorFlow equivalent: SparseToDense.
1968 struct SparseToDenseOperator : Operator {
1969   SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {}
1970   bool validate_indices;
1971 };
1972 
1973 // Pow operator:
1974 //
1975 // Inputs:
1976 // Inputs[0]: required: A tensor.
1977 // Inputs[1]: required: A tensor.
1978 //
1979 // TensorFlow equivalent: Pow.
1980 struct PowOperator : Operator {
1981   PowOperator() : Operator(OperatorType::kPow) {}
1982 };
1983 
1984 // Any operator:
1985 //
1986 // Inputs:
1987 // Inputs[0]: required: A boolean input tensor.
1988 // Inputs[1]: required: reduction_indices.
1989 //
1990 // TensorFlow equivalent: tf.reduce_any.
1991 struct TensorFlowAnyOperator : Operator {
1992   TensorFlowAnyOperator() : Operator(OperatorType::kAny) {}
1993   std::vector<int> axis;
1994   bool keep_dims = false;
1995 };
1996 
1997 // LogicalAnd operator:
1998 //
1999 // Inputs:
2000 // Inputs[0]: required: A boolean tensor.
2001 // Inputs[1]: required: A boolean tensor.
2002 //
2003 // TensorFlow equivalent: tf.logical_and.
2004 struct LogicalAndOperator : Operator {
2005   LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {}
2006 };
2007 
2008 // LogicalNot operator:
2009 //
2010 // Inputs:
2011 // Inputs[0]: required: A boolean tensor.
2012 //
2013 // TensorFlow equivalent: tf.logical_not.
2014 struct LogicalNotOperator : Operator {
2015   LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
2016 };
2017 
2018 // OneHot operator:
2019 //
2020 // Inputs:
2021 // Inputs[0]: required: indices.
2022 // Inputs[1]: required: depth.
2023 // Inputs[2]: required: on_value.
2024 // Inputs[3]: required: off_value.
2025 //
2026 // TensorFlow equivalent: OneHot.
2027 struct OneHotOperator : Operator {
2028   enum Inputs {
2029     INDICES_INPUT = 0,
2030     DEPTH_INPUT = 1,
2031     ON_VALUE_INPUT = 2,
2032     OFF_VALUE_INPUT = 3,
2033   };
2034 
2035   OneHotOperator() : Operator(OperatorType::kOneHot) {}
2036   int axis = -1;
2037 };
2038 
2039 // LogicalOr operator:
2040 //
2041 // Inputs:
2042 // Inputs[0]: required: A Bool tensor.
2043 // Inputs[1]: required: A Bool tensor.
2044 //
2045 // TensorFlow equivalent: LogicalOr.
2046 struct LogicalOrOperator : Operator {
2047   LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
2048 };
2049 
2050 // Unpack operator:
2051 //
2052 // Inputs:
2053 // Inputs[0]: required: A boolean input tensor.
2054 // Inputs[1]: required: reduction_indices.
2055 //
2056 // TensorFlow equivalent: tf.unstack.
2057 struct UnpackOperator : Operator {
2058   UnpackOperator() : Operator(OperatorType::kUnpack) {}
2059   int num;
2060   int axis;
2061   ArrayDataType dtype = ArrayDataType::kNone;
2062 };
2063 
2064 // ZerosLike operator:
2065 //
2066 // Inputs:
2067 // inputs[0]: required: the input array
2068 //
2069 // TensorFlow equivalent: tf.zeros_like
2070 struct TensorFlowZerosLikeOperator : Operator {
2071   TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
2072 };
2073 
2074 // ReverseV2 operator:
2075 //
2076 // Inputs:
2077 // Inputs[0]: required: the input array.
2078 //
2079 // TensorFlow equivalent: ReverseV2.
2080 struct ReverseV2Operator : Operator {
2081   ReverseV2Operator() : Operator(OperatorType::kReverseV2) {}
2082 };
2083 
2084 enum class MirrorPadMode { kNone, kSymmetric, kReflect };
2085 
2086 // MirrorPad Operator:
2087 //
2088 // Inputs:
2089 // Inputs[0]: required: input tensor to be padded.
2090 // Inputs[1]: required: 2 Column matrix specifying padding sizes. The number of
2091 // rows must be the same as the rank of the input.
2092 // Inputs[2]: required: REFLECT or SYMMETRIC.
2093 //
2094 // TensorFlow equivalent: MirrorPad.
2095 struct MirrorPadOperator : Operator {
2096   MirrorPadOperator() : Operator(OperatorType::kMirrorPad) {}
2097   // mode is either SYMMETRIC or REFLECT.
2098   MirrorPadMode mode;
2099 };
2100 
2101 // ReverseSequence operator:
2102 //
2103 // Inputs:
2104 // Inputs[0]: required: the input array.
2105 // Inputs[1]: required: the lengths of the elements to be reversed.
2106 //
2107 // TensorFlow equivalent: tf.reverse_sequence.
2108 struct ReverseSequenceOperator : Operator {
2109   ReverseSequenceOperator() : Operator(OperatorType::kReverseSequence) {}
2110   int seq_dim;
2111   int batch_dim = 0;
2112 };
2113 
2114 // Unique Operator:
2115 //
2116 // Inputs:
2117 //   inputs[0]: required: the input array
2118 //
2119 // TensorFlow equivalent: Unique
2120 struct UniqueOperator : Operator {
2121   UniqueOperator() : Operator(OperatorType::kUnique) {}
2122   ArrayDataType idx_out_type = ArrayDataType::kInt32;
2123 };
2124 
2125 struct UnidirectionalSequenceRnnOperator : Operator {
2126   UnidirectionalSequenceRnnOperator()
2127       : Operator(OperatorType::kUnidirectionalSequenceRnn) {}
2128   bool time_major;
2129   FusedActivationFunctionType fused_activation_function;
2130 };
2131 
2132 // Where Operator:
2133 // Return the coordinates of the true values in condition tensor in row-major
2134 // order.
2135 //
2136 // Inputs:
2137 //  inputs[0]: required: boolean condition tensor
2138 //
2139 //  TensorFlow equivalent: Where
2140 struct WhereOperator : Operator {
2141   WhereOperator() : Operator(OperatorType::kWhere) {}
2142 };
2143 
2144 // Matrix Diag Operator:
2145 // Construct a batched diagonal tensor with given batched diagonal values.
2146 // Inputs: A tensor of values that will be on the diagonal of the returned
2147 //         tensor.
2148 struct MatrixDiagOperator : Operator {
2149   MatrixDiagOperator() : Operator(OperatorType::kMatrixDiag) {}
2150 };
2151 
2152 // Matrix Diag Operator V2:
2153 // Construct a batched diagonal tensor with given batched diagonal values.
2154 // Not fully supported, contains 4 extra inputs compared to MatrixDiag. Behave
2155 // like MatrixDiag when default parameters are used.
2156 struct MatrixDiagV2Operator : Operator {
2157   MatrixDiagV2Operator() : Operator(OperatorType::kMatrixDiagV2) {}
2158 };
2159 
2160 // Matrix Diag Operator V3:
2161 // Construct a batched diagonal tensor with given batched diagonal values.
2162 // Not fully supported, contains 5 extra inputs compared to MatrixDiag. Behave
2163 // like MatrixDiag when default parameters are used.
2164 // V3 is only different from V2 because it has an extra attribute (align) which
2165 // controls the alignment of diagonals in the band matrix (compact) format.
2166 // The alignment in V2 contradicts with the default alignment in V3 so V2 is
2167 // skipped. (It has never been, and should never be, exposed in the public API.)
2168 struct MatrixDiagV3Operator : Operator {
2169   MatrixDiagV3Operator() : Operator(OperatorType::kMatrixDiagV3) {}
2170 };
2171 
2172 // Matrix Set Diag Operator:
2173 // Construct a batched diagonal tensor with given input and diagonal values.
2174 // Input is a rank (k+1) tensor of values.
2175 // diagonal is a rank (k) tensor of values that will be on the diagonal
2176 // of the returned output. Output is rank k+1.
2177 //         tensor.
2178 struct MatrixSetDiagOperator : Operator {
2179   MatrixSetDiagOperator() : Operator(OperatorType::kMatrixSetDiag) {}
2180 };
2181 
2182 // Matrix Set Diag Operator V2:
2183 // Construct a batched diagonal tensor with given input and diagonal values.
2184 // Not fully supported, contains 1 extra inputs compared to MatrixSetDiag.
2185 // Behave like MatrixSetDiag when default parameters are used.
2186 struct MatrixSetDiagV2Operator : Operator {
2187   MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
2188 };
2189 
2190 // Matrix Set Diag Operator V3:
2191 // Construct a batched diagonal tensor with given input and diagonal values.
2192 // Not fully supported, contains 2 extra inputs compared to MatrixSetDiag.
2193 // Behave like MatrixSetDiag when default parameters are used.
2194 // V3 is only different from V2 because it has an extra attribute (align) which
2195 // controls the alignment of diagonals in the band matrix (compact) format.
2196 // The alignment in V2 contradicts with the default alignment in V3 so V2 is
2197 // skipped. (It has never been, and should never be, exposed in the public API.)
2198 struct MatrixSetDiagV3Operator : Operator {
2199   MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
2200 };
2201 
2202 struct ScatterNdOperator : Operator {
2203   ScatterNdOperator() : Operator(OperatorType::kScatterNd) {}
2204 };
2205 
2206 struct SegmentSumOperator : Operator {
2207   SegmentSumOperator() : Operator(OperatorType::kSegmentSum) {}
2208 };
2209 
2210 // Alloc's are used for transient arrays only. An Alloc specifies which interval
2211 // of the "transient_data" workspace buffer passed to inference functions, is to
2212 // be used for the transient array at hand. The 'start' and 'end' values are
2213 // offsets from the start of the workspace buffer, expressed in bytes.
2214 struct Alloc {
2215   int64 start = 0;
2216   int64 end = 0;
2217 };
2218 
2219 inline bool operator<(const Alloc& a, const Alloc& b) {
2220   return a.start < b.start;
2221 }
2222 
2223 // Array represents an array (either a constant parameter array or an
2224 // activations array) in a Model.
2225 struct Array {
2226   template <ArrayDataType A>
2227   const Buffer<A>& GetBuffer() const {
2228     DCHECK(buffer);
2229     DCHECK(buffer->type == A);
2230     return *static_cast<const Buffer<A>*>(buffer.get());
2231   }
2232   template <ArrayDataType A>
2233   Buffer<A>& GetMutableBuffer() {
2234     if (!buffer) {
2235       Buffer<A>* ptr = new Buffer<A>;
2236       buffer = std::unique_ptr<GenericBuffer>(ptr);
2237     }
2238     DCHECK(buffer);
2239     DCHECK(buffer->type == A);
2240     return *static_cast<Buffer<A>*>(buffer.get());
2241   }
2242   Alloc& GetOrCreateAlloc() {
2243     if (!alloc) {
2244       alloc = std::unique_ptr<Alloc>(new Alloc);
2245     }
2246     return *alloc;
2247   }
2248   MinMax& GetOrCreateMinMax() {
2249     if (!minmax) {
2250       minmax = std::unique_ptr<MinMax>(new MinMax);
2251     }
2252     return *minmax;
2253   }
2254   MinMax& GetMinMax() const {
2255     DCHECK(minmax);
2256     return *minmax;
2257   }
2258   QuantizationParams& GetOrCreateQuantizationParams() {
2259     if (!quantization_params) {
2260       quantization_params =
2261           std::unique_ptr<QuantizationParams>(new QuantizationParams);
2262     }
2263     return *quantization_params;
2264   }
2265   QuantizationParams& GetQuantizationParams() const {
2266     DCHECK(quantization_params);
2267     return *quantization_params;
2268   }
2269 
2270   // The data type of the actual elements of this array, that is:
2271   //  - If there is a buffer (see 'buffer' member), it must be of the same
2272   //    type.
2273   //  - If there is no buffer, meaning that this is a runtime (i.e. activations)
2274   //    array, then this specifies the type of elements that there will be
2275   //    at runtime.
2276   //
2277   // Note that this only specifies the storage type of elements; this does
2278   // not specify whether these are to be treated as 'real' or 'quantized'
2279   // values.
2280   // That is decided by whether the 'quantization_params' member is null.
2281   ArrayDataType data_type = ArrayDataType::kNone;
2282   // The final value that data_type should have at the end of graph
2283   // transformations
2284   ArrayDataType final_data_type = ArrayDataType::kNone;
2285   // The dimensions of this array --- this specifies both sizes and strides
2286   // (the storage layout).
2287   //
2288   // Issues with shape handling that remain include:
2289   //   - No way to distinguish between 0-dimensional dims and missing dims.
2290   //   - No way to describe dims that may be runtime-variable.
2291   //   - Addressing of dims by integer index differs in different graph formats
2292   //     (TensorFlow vs. other frameworks vs. what we have informally grown
2293   //     within toco).
2294   //     This is currently quite messy; see ReorderAxesOperator which is how we
2295   //     bridge some of these discrepancies at the moment. This is overdue for
2296   //     a redesign; I'm thinking that it would be nice to have more flexible
2297   //     dims that allow mapping 1:1, cleanly, dims as they are in various
2298   //     formats,
2299   //     then explicitly convert between different conventions.
2300 
2301   // Proto-style accessors
2302   bool has_shape() const { return array_shape != nullptr; }
2303   const Shape& shape() const {
2304     CHECK(has_shape());
2305     return *array_shape;
2306   }
2307   Shape* mutable_shape() {
2308     if (!array_shape) {
2309       array_shape.reset(new Shape);
2310     }
2311     return array_shape.get();
2312   }
2313   void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; }
2314   void clear_shape() { array_shape = nullptr; }
2315 
2316   // The constant buffer backing this array. This is non-null if and only if
2317   // this is a constant parameter array. Conversely, this is null for
2318   // activations arrays.
2319   //
2320   // Note that this buffer is pure storage. In the case of quantized values,
2321   // it only stores the quantized values, it does not know by itself about the
2322   // quantization parameters necessary to interprete these values, that is
2323   // in the separate 'quantization_params' field. In fact, this 'buffer' field
2324   // does no even know whether values are quantized. It only has a data_type,
2325   // which must equal the 'data_type' member here, and which only describes
2326   // the storage type of element, does not tell whether they are quantized i.e.
2327   // whether they are to be interpreted with quantization_params.
2328   std::unique_ptr<GenericBuffer> buffer;
2329   // Only for activation arrays (i.e. when 'buffer' is null).
2330   // Only for code generation.
2331   //
2332   // Describes the allocation of this array within the workspace buffer
2333   // allocated
2334   // for all transient arrays.
2335   std::unique_ptr<Alloc> alloc;
2336   // Describes the [min, max] range of values
2337   // to be assumed when determining quantization_params.
2338   //
2339   // Only used for quantization. In fact, only used for determining
2340   // quantization_params.
2341   //
2342   // Used for both constant arrays (those having a 'buffer') and non-constant
2343   // arrays (activations). Indeed, it is important to use the same min-max range
2344   // as was used during training, even if that min-max range is slightly wrong
2345   // w.r.t. actual buffer elements. Doing otherwise would defeat the point of
2346   // re-training for quantization.
2347   std::unique_ptr<MinMax> minmax;
2348   // Quantization parameters. The non-null-ness of this pointer is what
2349   // defines whether this array is quantized or not.
2350   //
2351   // If this is non-null, then these quantization parameters are to be used
2352   // to assign a meaning as real numbers to the elements of this array.
2353   std::unique_ptr<QuantizationParams> quantization_params;
2354   // narrow_range is a detail of how toco handles FakeQuant operators with
2355   // narrow_range, see
2356   // https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars
2357   //
2358   // For more context about what that is useful for, see the big comment in
2359   // graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
2360   //
2361   // The narrow_range flag applies only to quantized arrays, and changes
2362   // their quantization in the following way when it is set to 'true':
2363   // 1. The computation of {zero_point, scale} from {min, max} needs to be
2364   //    amended so that the real min value will get quantized to
2365   //    (min_quantized_value + 1) instead of just (min_quantized_value).
2366   //    E.g. for uint8 quantization, the real min value should get quantized to
2367   //    the uint8 value 1, not 0.
2368   // 2. Quantized values should get clamped to the interval
2369   //    [min_quantized_value + 1, max_value]. Equivalently, the
2370   //    min_quantized_value should get nudged to (min_quantized_value + 1).
2371   // The reason why 1. does not imply 2. is that real values may not belong to
2372   // the stated [min, max] interval. Concretely, weights recorded at the last
2373   // learning step may not fall in the [min, max] interval recorded over
2374   // previous learning steps, as the values evolve across learning steps.
2375   //
2376   // Rationale why this is directly a field on Array:
2377   // - This can't be just a field on FakeQuantOperator, because
2378   //   FakeQuantOperators are gone (DropFakeQuant) before we get to using that
2379   //   information (Quantize). We need a place to store that bit in the interim.
2380   // - This can't be in QuantizationParams because we need to record this
2381   //   ahead of quantization, and QuantizationParams are only created during
2382   //   quantization.
2383   // - This could be in MinMax, but that would be an abuse of what MinMax is
2384   //   about, and would break existing code that assumes that a MinMax is just
2385   //   a min and a max. Unlike MinMax which is agnostic as to the quantized
2386   //   data type, narrow_range refers to values in the quantized data type.
2387   bool narrow_range = false;
2388 
2389  private:
2390   std::unique_ptr<Shape> array_shape;
2391 };
2392 
2393 // Our Model struct, represents an entire model (our "top-level" struct).
2394 // Owns everything.
2395 class Model {
2396  public:
2397   using ArrayMap = std::unordered_map<std::string, std::unique_ptr<Array>>;
2398 
2399   bool HasArray(const std::string& name) const {
2400     return arrays.count(name) > 0;
2401   }
2402   Array& GetArray(const std::string& name) const {
2403     DCHECK(HasArray(name)) << "Array not found: " << name;
2404     return *arrays.at(name);
2405   }
2406   Array& GetOrCreateArray(const std::string& name) {
2407     // Make sure name is not used by an optional array
2408     DCHECK(!optional_arrays.count(name));
2409     if (!HasArray(name)) {
2410       Array* ptr = new Array;
2411       arrays[name] = std::unique_ptr<Array>(ptr);
2412     }
2413     Array& result = GetArray(name);
2414     return result;
2415   }
2416   void CreateOptionalArray(const std::string& name) {
2417     DCHECK(!arrays.count(name) && !optional_arrays.count(name));
2418     optional_arrays.insert(name);
2419   }
2420   bool IsOptionalArray(const std::string& name) const {
2421     return optional_arrays.count(name);
2422   }
2423 
2424   // Note that this invalidates all array iterators.
2425   void EraseArray(const std::string& name) { arrays.erase(name); }
2426   void EraseArrays(std::function<bool(const std::string&)> discardable) {
2427     for (auto it = arrays.begin(); it != arrays.end();) {
2428       if (discardable(it->first)) {
2429         it = arrays.erase(it);
2430       } else {
2431         ++it;
2432       }
2433     }
2434   }
2435   const ArrayMap& GetArrayMap() const { return arrays; }
2436   ArrayMap& GetMutableArrayMap() { return arrays; }
2437 
2438   int64 ArithmeticOpsCount() const { return ops_count; }
2439 
2440   void AddInvalidInputArray(std::string invalid_input_array) {
2441     invalid_input_arrays_.insert(invalid_input_array);
2442   }
2443 
2444   const std::unordered_set<std::string>& GetInvalidInputArrays() const {
2445     return invalid_input_arrays_;
2446   }
2447 
2448   // Optional arrays are used for optional tensors,
2449   // these tensors do not have data, but with reserved names as op inputs.
2450   std::set<std::string> optional_arrays;
2451 
2452   // The list of operators. Notice how it's a list of unique_ptr's, implying
2453   // that the Model is what owns Operator's and keeps them alive.
2454   std::vector<std::unique_ptr<Operator>> operators;
2455 
2456   // Generic flags, a place where we combine information passed to us via
2457   // command-line parameters (e.g. --input_width=N) with information that
2458   // we may or may not find in the input model file.
2459   ModelFlags flags;
2460   // For code-generation only: required size of the transient_data buffer
2461   std::size_t transient_data_size = 0;
2462   // For code-generation only: required alignment of the transient_data buffer
2463   std::size_t transient_data_alignment = 0;
2464   // Arithmetic operations performed in the model.
2465   int64 ops_count = 0;
2466 
2467  private:
2468   // The associative array mapping names to Array's.
2469   // Notice how it's a container of unique_ptr's, implying
2470   // that the Model is what owns Array's and keeps them alive.
2471   // The Operator's refer to these Array's by their name strings, not by their
2472   // addresses. See Operator::inputs, Operator::outputs.
2473   std::unordered_map<std::string, std::unique_ptr<Array>> arrays;
2474 
2475   // Invalid input arrays.
2476   std::unordered_set<std::string> invalid_input_arrays_;
2477 };
2478 
2479 // OperatorSignature contains the information required to making versioning
2480 // decisions.
2481 struct OperatorSignature {
2482   // The operator.
2483   const Operator* op;
2484 
2485   // The model in which the operator resides.
2486   const Model* model;
2487 };
2488 }  // namespace toco
2489 
2490 #endif  // TENSORFLOW_LITE_TOCO_MODEL_H_
2491