• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16 
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/util/ptr_util.h"
22 
23 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
24 // graph_transformation module.
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
29 #include "tensorflow/lite/toco/tflite/custom_operator.h"
30 #include "tensorflow/lite/toco/tflite/simple_operator.h"
31 #include "tensorflow/lite/toco/tflite/types.h"
32 #include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
33 
34 namespace toco {
35 
36 namespace tflite {
37 
38 class AveragePool
39     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
40                              ::tflite::BuiltinOptions_Pool2DOptions> {
41  public:
42   using BuiltinOperator::BuiltinOperator;
43 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const44   flatbuffers::Offset<TfLiteOptions> WriteOptions(
45       const TocoOperator& op,
46       flatbuffers::FlatBufferBuilder* builder) const override {
47     auto padding = Padding::Serialize(op.padding.type);
48     auto activation_function =
49         ActivationFunction::Serialize(op.fused_activation_function);
50     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
51                                          op.stride_height, op.kwidth,
52                                          op.kheight, activation_function);
53   }
54 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const55   void ReadOptions(const TfLiteOptions& options,
56                    TocoOperator* op) const override {
57     op->padding.type = Padding::Deserialize(options.padding());
58     op->stride_width = options.stride_w();
59     op->stride_height = options.stride_h();
60     op->kwidth = options.filter_width();
61     op->kheight = options.filter_height();
62     op->fused_activation_function =
63         ActivationFunction::Deserialize(options.fused_activation_function());
64   }
65 
GetVersion(const OperatorSignature & op_signature) const66   int GetVersion(const OperatorSignature& op_signature) const override {
67     const string& input_name = op_signature.op->inputs[0];
68     const Array& input_array = op_signature.model->GetArray(input_name);
69     if (input_array.data_type == ArrayDataType::kInt8) {
70       return 2;
71     }
72     return 1;
73   }
74 };
75 
76 class Convolution
77     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
78                              ::tflite::BuiltinOptions_Conv2DOptions> {
79  public:
80   using BuiltinOperator::BuiltinOperator;
81 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const82   flatbuffers::Offset<TfLiteOptions> WriteOptions(
83       const TocoOperator& op,
84       flatbuffers::FlatBufferBuilder* builder) const override {
85     auto padding = Padding::Serialize(op.padding.type);
86     auto activation_function =
87         ActivationFunction::Serialize(op.fused_activation_function);
88     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
89                                          op.stride_height, activation_function,
90                                          op.dilation_width_factor,
91                                          op.dilation_height_factor);
92   }
93 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const94   void ReadOptions(const TfLiteOptions& options,
95                    TocoOperator* op) const override {
96     op->padding.type = Padding::Deserialize(options.padding());
97     op->stride_width = options.stride_w();
98     op->stride_height = options.stride_h();
99     op->dilation_width_factor = options.dilation_w_factor();
100     op->dilation_height_factor = options.dilation_h_factor();
101     op->fused_activation_function =
102         ActivationFunction::Deserialize(options.fused_activation_function());
103   }
104 
GetVersion(const OperatorSignature & op_signature) const105   int GetVersion(const OperatorSignature& op_signature) const override {
106     const string& input_name = op_signature.op->inputs[0];
107     const string& filter_name = op_signature.op->inputs[1];
108     const string& output_name = op_signature.op->outputs[0];
109     const Array& input_array = op_signature.model->GetArray(input_name);
110     const Array& filter_array = op_signature.model->GetArray(filter_name);
111     const Array& output_array = op_signature.model->GetArray(output_name);
112     // If the op has signed int8 inputs and outputs, its version 3.
113     if (input_array.data_type == ArrayDataType::kInt8 &&
114         filter_array.data_type == ArrayDataType::kInt8 &&
115         output_array.data_type == ArrayDataType::kInt8) {
116       return 3;
117     }
118     // If the op is a signed int8 hybrid operation, we need to return
119     // version 2.
120     if (input_array.data_type == ArrayDataType::kFloat &&
121         filter_array.data_type == ArrayDataType::kInt8 &&
122         output_array.data_type == ArrayDataType::kFloat) {
123       return 2;
124     }
125     return 1;
126   }
127 };
128 
129 class DepthwiseConvolution
130     : public BuiltinOperator<DepthwiseConvOperator,
131                              ::tflite::DepthwiseConv2DOptions,
132                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
133  public:
134   using BuiltinOperator::BuiltinOperator;
135 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const136   flatbuffers::Offset<TfLiteOptions> WriteOptions(
137       const TocoOperator& op,
138       flatbuffers::FlatBufferBuilder* builder) const override {
139     auto padding = Padding::Serialize(op.padding.type);
140     auto activation_function =
141         ActivationFunction::Serialize(op.fused_activation_function);
142     return ::tflite::CreateDepthwiseConv2DOptions(
143         *builder, padding, op.stride_width, op.stride_height,
144         op.depth_multiplier, activation_function, op.dilation_width_factor,
145         op.dilation_height_factor);
146   }
147 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const148   void ReadOptions(const TfLiteOptions& options,
149                    TocoOperator* op) const override {
150     op->padding.type = Padding::Deserialize(options.padding());
151     op->stride_width = options.stride_w();
152     op->stride_height = options.stride_h();
153     op->depth_multiplier = options.depth_multiplier();
154     op->fused_activation_function =
155         ActivationFunction::Deserialize(options.fused_activation_function());
156     op->dilation_width_factor = options.dilation_w_factor();
157     op->dilation_height_factor = options.dilation_h_factor();
158   }
159 
GetVersion(const OperatorSignature & op_signature) const160   int GetVersion(const OperatorSignature& op_signature) const override {
161     const auto& conv_op =
162         static_cast<const DepthwiseConvOperator&>(*op_signature.op);
163     const string& input_name = op_signature.op->inputs[0];
164     const string& filter_name = op_signature.op->inputs[1];
165     const string& output_name = op_signature.op->outputs[0];
166     const Array& input_array = op_signature.model->GetArray(input_name);
167     const Array& filter_array = op_signature.model->GetArray(filter_name);
168     const Array& output_array = op_signature.model->GetArray(output_name);
169     // If the op has signed int8 inputs and outputs, its version 3.
170     if (input_array.data_type == ArrayDataType::kInt8 &&
171         filter_array.data_type == ArrayDataType::kInt8 &&
172         output_array.data_type == ArrayDataType::kInt8) {
173       return 3;
174     }
175     if (conv_op.dilation_width_factor != 1 ||
176         conv_op.dilation_height_factor != 1) {
177       return 2;
178     }
179     return 1;
180   }
181 };
182 
183 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
184                                    ::tflite::BuiltinOptions_AddOptions> {
185  public:
186   using BuiltinOperator::BuiltinOperator;
187 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const188   flatbuffers::Offset<TfLiteOptions> WriteOptions(
189       const TocoOperator& op,
190       flatbuffers::FlatBufferBuilder* builder) const override {
191     auto activation_function =
192         ActivationFunction::Serialize(op.fused_activation_function);
193     return ::tflite::CreateAddOptions(*builder, activation_function);
194   }
195 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const196   void ReadOptions(const TfLiteOptions& options,
197                    TocoOperator* op) const override {
198     op->fused_activation_function =
199         ActivationFunction::Deserialize(options.fused_activation_function());
200   }
201 
GetVersion(const OperatorSignature & op_signature) const202   int GetVersion(const OperatorSignature& op_signature) const override {
203     const string& input_name = op_signature.op->inputs[0];
204     const Array& input_array = op_signature.model->GetArray(input_name);
205     // Version 2 supports signed int8 input types.
206     if (input_array.data_type == ArrayDataType::kInt8) {
207       return 2;
208     }
209     return 1;
210   }
211 };
212 
213 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
214                                     ::tflite::BuiltinOptions_AddNOptions> {
215  public:
216   using BuiltinOperator::BuiltinOperator;
217 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const218   flatbuffers::Offset<TfLiteOptions> WriteOptions(
219       const TocoOperator& op,
220       flatbuffers::FlatBufferBuilder* builder) const override {
221     return ::tflite::CreateAddNOptions(*builder);
222   }
223 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const224   void ReadOptions(const TfLiteOptions& options,
225                    TocoOperator* op) const override {}
226 
GetVersion(const OperatorSignature & op_signature) const227   int GetVersion(const OperatorSignature& op_signature) const override {
228     return 1;
229   }
230 };
231 
232 class SpaceToBatchND
233     : public BuiltinOperator<SpaceToBatchNDOperator,
234                              ::tflite::SpaceToBatchNDOptions,
235                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
236  public:
237   using BuiltinOperator::BuiltinOperator;
238 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const239   flatbuffers::Offset<TfLiteOptions> WriteOptions(
240       const TocoOperator& op,
241       flatbuffers::FlatBufferBuilder* builder) const override {
242     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
243   }
244 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const245   void ReadOptions(const TfLiteOptions& options,
246                    TocoOperator* op) const override {}
247 
GetVersion(const OperatorSignature & op_signature) const248   int GetVersion(const OperatorSignature& op_signature) const override {
249     const string& input_name = op_signature.op->inputs[0];
250     const Array& input_array = op_signature.model->GetArray(input_name);
251     // If the op take int8 input, it is version 2.
252     if (input_array.data_type == ArrayDataType::kInt8) {
253       return 2;
254     }
255     return 1;
256   }
257 };
258 
259 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
260                                    ::tflite::BuiltinOptions_SubOptions> {
261  public:
262   using BuiltinOperator::BuiltinOperator;
263 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const264   flatbuffers::Offset<TfLiteOptions> WriteOptions(
265       const TocoOperator& op,
266       flatbuffers::FlatBufferBuilder* builder) const override {
267     auto activation_function =
268         ActivationFunction::Serialize(op.fused_activation_function);
269     return ::tflite::CreateSubOptions(*builder, activation_function);
270   }
271 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const272   void ReadOptions(const TfLiteOptions& options,
273                    TocoOperator* op) const override {
274     op->fused_activation_function =
275         ActivationFunction::Deserialize(options.fused_activation_function());
276   }
277 
GetVersion(const OperatorSignature & op_signature) const278   int GetVersion(const OperatorSignature& op_signature) const override {
279     const string& input_name = op_signature.op->inputs[0];
280     const Array& input_array = op_signature.model->GetArray(input_name);
281     // If the op take int8 input, it is version 2.
282     if (input_array.data_type == ArrayDataType::kInt8) {
283       return 2;
284     }
285     return 1;
286   }
287 };
288 
289 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
290                                    ::tflite::BuiltinOptions_DivOptions> {
291  public:
292   using BuiltinOperator::BuiltinOperator;
293 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const294   flatbuffers::Offset<TfLiteOptions> WriteOptions(
295       const TocoOperator& op,
296       flatbuffers::FlatBufferBuilder* builder) const override {
297     auto activation_function =
298         ActivationFunction::Serialize(op.fused_activation_function);
299     return ::tflite::CreateDivOptions(*builder, activation_function);
300   }
301 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const302   void ReadOptions(const TfLiteOptions& options,
303                    TocoOperator* op) const override {
304     op->fused_activation_function =
305         ActivationFunction::Deserialize(options.fused_activation_function());
306   }
307 
GetVersion(const OperatorSignature & op_signature) const308   int GetVersion(const OperatorSignature& op_signature) const override {
309     return 1;
310   }
311 };
312 
313 class BatchToSpaceND
314     : public BuiltinOperator<BatchToSpaceNDOperator,
315                              ::tflite::BatchToSpaceNDOptions,
316                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
317  public:
318   using BuiltinOperator::BuiltinOperator;
319 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const320   flatbuffers::Offset<TfLiteOptions> WriteOptions(
321       const TocoOperator& op,
322       flatbuffers::FlatBufferBuilder* builder) const override {
323     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
324   }
325 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const326   void ReadOptions(const TfLiteOptions& options,
327                    TocoOperator* op) const override {}
328 
GetVersion(const OperatorSignature & op_signature) const329   int GetVersion(const OperatorSignature& op_signature) const override {
330     const string& input_name = op_signature.op->inputs[0];
331     const Array& input_array = op_signature.model->GetArray(input_name);
332     // If the op take int8 input, it is version 2.
333     if (input_array.data_type == ArrayDataType::kInt8) {
334       return 2;
335     }
336     return 1;
337   }
338 };
339 
340 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
341                                     ::tflite::BuiltinOptions_CastOptions> {
342  public:
343   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const344   flatbuffers::Offset<TfLiteOptions> WriteOptions(
345       const TocoOperator& op,
346       flatbuffers::FlatBufferBuilder* builder) const override {
347     return ::tflite::CreateCastOptions(*builder,
348                                        DataType::Serialize(op.src_data_type),
349                                        DataType::Serialize(op.dst_data_type));
350   }
351 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const352   void ReadOptions(const TfLiteOptions& options,
353                    TocoOperator* op) const override {
354     op->src_data_type = DataType::Deserialize(options.in_data_type());
355     op->dst_data_type = DataType::Deserialize(options.out_data_type());
356   }
357 
GetVersion(const OperatorSignature & op_signature) const358   int GetVersion(const OperatorSignature& op_signature) const override {
359     return 1;
360   }
361 };
362 
363 class Concatenation
364     : public BuiltinOperator<ConcatenationOperator,
365                              ::tflite::ConcatenationOptions,
366                              ::tflite::BuiltinOptions_ConcatenationOptions> {
367  public:
368   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const369   flatbuffers::Offset<TfLiteOptions> WriteOptions(
370       const TocoOperator& op,
371       flatbuffers::FlatBufferBuilder* builder) const override {
372     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
373   }
374 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const375   void ReadOptions(const TfLiteOptions& options,
376                    TocoOperator* op) const override {
377     op->axis = options.axis();
378   }
379 
GetVersion(const OperatorSignature & op_signature) const380   int GetVersion(const OperatorSignature& op_signature) const override {
381     const string& input_name = op_signature.op->inputs[0];
382     const Array& input_array = op_signature.model->GetArray(input_name);
383     // If the op take int8 input, it is version 2.
384     if (input_array.data_type == ArrayDataType::kInt8) {
385       return 2;
386     }
387     return 1;
388   }
389 };
390 
391 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
392  public:
393   using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const394   void WriteOptions(const TocoOperator& op,
395                     flexbuffers::Builder* fbb) const override {
396     fbb->Int("block_size", op.block_size);
397   }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const398   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
399     op->block_size = m["block_size"].AsInt64();
400   }
401 
GetVersion(const OperatorSignature & op_signature) const402   int GetVersion(const OperatorSignature& op_signature) const override {
403     return 1;
404   }
405 };
406 
407 class FakeQuant
408     : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
409                              ::tflite::BuiltinOptions_FakeQuantOptions> {
410  public:
411   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const412   flatbuffers::Offset<TfLiteOptions> WriteOptions(
413       const TocoOperator& op,
414       flatbuffers::FlatBufferBuilder* builder) const override {
415     return ::tflite::CreateFakeQuantOptions(
416         *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
417   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const418   void ReadOptions(const TfLiteOptions& options,
419                    TocoOperator* op) const override {
420     auto* minmax = new MinMax;
421     minmax->min = options.min();
422     minmax->max = options.max();
423     op->minmax.reset(minmax);
424     op->num_bits = options.num_bits();
425     op->narrow_range = options.narrow_range();
426   }
GetVersion(const OperatorSignature & op_signature) const427   int GetVersion(const OperatorSignature& op_signature) const override {
428     const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
429     return fq_op.narrow_range ? 2 : 1;
430   }
431 };
432 
433 class FullyConnected
434     : public BuiltinOperator<FullyConnectedOperator,
435                              ::tflite::FullyConnectedOptions,
436                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
437  public:
438   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const439   flatbuffers::Offset<TfLiteOptions> WriteOptions(
440       const TocoOperator& op,
441       flatbuffers::FlatBufferBuilder* builder) const override {
442     auto activation_function =
443         ActivationFunction::Serialize(op.fused_activation_function);
444     ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
445     switch (op.weights_format) {
446       case FullyConnectedWeightsFormat::kDefault:
447         tflite_weights_format =
448             ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
449         break;
450       case FullyConnectedWeightsFormat::kShuffled4x16Int8:
451         tflite_weights_format =
452             ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
453         break;
454       default:
455         LOG(ERROR) << "Unhandled FC weights format";
456         tflite_weights_format =
457             ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
458     }
459     return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
460                                                  tflite_weights_format);
461   }
462 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const463   void ReadOptions(const TfLiteOptions& options,
464                    TocoOperator* op) const override {
465     op->fused_activation_function =
466         ActivationFunction::Deserialize(options.fused_activation_function());
467     switch (options.weights_format()) {
468       case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
469         op->weights_format = FullyConnectedWeightsFormat::kDefault;
470         break;
471       case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
472         op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
473         break;
474       default:
475         LOG(ERROR) << "Unhandled FC weights format";
476         op->weights_format = FullyConnectedWeightsFormat::kDefault;
477     }
478   }
479 
480   // +-----------------+--------------------+--------------------------+
481   // |                 |    Weight::Default | Weight::Shuffled4x16Int8 |
482   // +-----------------+--------------------+--------------------------+
483   // | Float           |                  1 |                        2 |
484   // | Quantized Uint8 |                  1 |                        2 |
485   // | Hybrid          |                  3 |                        3 |
486   // | Quantized Int8  |                  4 |                        4 |
487   // +-----------------+--------------------+--------------------------+
GetVersion(const OperatorSignature & op_signature) const488   int GetVersion(const OperatorSignature& op_signature) const override {
489     const auto& fc_op =
490         static_cast<const FullyConnectedOperator&>(*op_signature.op);
491     const string& input_name = op_signature.op->inputs[0];
492     const string& weights_name = op_signature.op->inputs[1];
493     const string& output_name = op_signature.op->outputs[0];
494     const Array& input_array = op_signature.model->GetArray(input_name);
495     const Array& weights_array = op_signature.model->GetArray(weights_name);
496     const Array& output_array = op_signature.model->GetArray(output_name);
497     // Int8 fully fixed point kernel is at version 4.
498     if (input_array.data_type == ArrayDataType::kInt8 &&
499         weights_array.data_type == ArrayDataType::kInt8 &&
500         output_array.data_type == ArrayDataType::kInt8) {
501       return 4;
502     }
503     // If the op is a signed int8 hybrid operation, we need to return
504     // version 3.
505     if (input_array.data_type == ArrayDataType::kFloat &&
506         weights_array.data_type == ArrayDataType::kInt8 &&
507         output_array.data_type == ArrayDataType::kFloat) {
508       return 3;
509     }
510     // For float and uint8 fixed point kernels, if the weight is
511     // Shuffled4x16Int8, is is version 2.
512     if (fc_op.weights_format ==
513         FullyConnectedWeightsFormat::kShuffled4x16Int8) {
514       return 2;
515     }
516 
517     // Otherwise (weight is default), the version is 1.
518     return 1;
519   }
520 };
521 
522 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
523                                       ::tflite::BuiltinOptions_GatherOptions> {
524  public:
525   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const526   flatbuffers::Offset<TfLiteOptions> WriteOptions(
527       const TocoOperator& op,
528       flatbuffers::FlatBufferBuilder* builder) const override {
529     int axis = op.axis ? op.axis.value() : 0;
530     return ::tflite::CreateGatherOptions(*builder, axis);
531   }
532 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const533   void ReadOptions(const TfLiteOptions& options,
534                    TocoOperator* op) const override {
535     op->axis = {options.axis()};
536   }
537 
GetVersion(const OperatorSignature & op_signature) const538   int GetVersion(const OperatorSignature& op_signature) const override {
539     const string& input_name = op_signature.op->inputs[0];
540     const Array& input_array = op_signature.model->GetArray(input_name);
541     // If the op take int8 input, it is version 2.
542     if (input_array.data_type == ArrayDataType::kInt8) {
543       return 2;
544     }
545     return 1;
546   }
547 };
548 
549 class GatherNd
550     : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
551                              ::tflite::BuiltinOptions_GatherNdOptions> {
552  public:
553   using BuiltinOperator::BuiltinOperator;
554 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const555   flatbuffers::Offset<TfLiteOptions> WriteOptions(
556       const TocoOperator& op,
557       flatbuffers::FlatBufferBuilder* builder) const override {
558     return ::tflite::CreateGatherNdOptions(*builder);
559   }
560 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const561   void ReadOptions(const TfLiteOptions& options,
562                    TocoOperator* op) const override {}
563 
GetVersion(const OperatorSignature & op_signature) const564   int GetVersion(const OperatorSignature& op_signature) const override {
565     return 1;
566   }
567 };
568 
569 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
570                                     ::tflite::BuiltinOptions_SVDFOptions> {
571  public:
572   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const573   flatbuffers::Offset<TfLiteOptions> WriteOptions(
574       const TocoOperator& op,
575       flatbuffers::FlatBufferBuilder* builder) const override {
576     auto activation_function =
577         ActivationFunction::Serialize(op.fused_activation_function);
578     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
579   }
580 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const581   void ReadOptions(const TfLiteOptions& options,
582                    TocoOperator* op) const override {
583     op->fused_activation_function =
584         ActivationFunction::Deserialize(options.fused_activation_function());
585     op->rank = options.rank();
586   }
587 
GetVersion(const OperatorSignature & op_signature) const588   int GetVersion(const OperatorSignature& op_signature) const override {
589     const string& input_name = op_signature.op->inputs[0];
590     const string& weights_feature_name = op_signature.op->inputs[1];
591     const string& output_name = op_signature.op->outputs[0];
592     const Array& input_array = op_signature.model->GetArray(input_name);
593     const Array& weights_feature_array =
594         op_signature.model->GetArray(weights_feature_name);
595     const Array& output_array = op_signature.model->GetArray(output_name);
596     // If the op is a signed int8 hybrid operation, we need to return
597     // version 2.
598     if (input_array.data_type == ArrayDataType::kFloat &&
599         weights_feature_array.data_type == ArrayDataType::kInt8 &&
600         output_array.data_type == ArrayDataType::kFloat) {
601       return 2;
602     }
603     return 1;
604   }
605 };
606 
607 class L2Normalization
608     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
609                              ::tflite::BuiltinOptions_L2NormOptions> {
610  public:
611   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const612   flatbuffers::Offset<TfLiteOptions> WriteOptions(
613       const TocoOperator& op,
614       flatbuffers::FlatBufferBuilder* builder) const override {
615     auto activation_function =
616         ActivationFunction::Serialize(op.fused_activation_function);
617     return ::tflite::CreateL2NormOptions(*builder, activation_function);
618   }
619 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const620   void ReadOptions(const TfLiteOptions& options,
621                    TocoOperator* op) const override {
622     op->fused_activation_function =
623         ActivationFunction::Deserialize(options.fused_activation_function());
624   }
625 
GetVersion(const OperatorSignature & op_signature) const626   int GetVersion(const OperatorSignature& op_signature) const override {
627     const string& output_name = op_signature.op->outputs[0];
628     const Array& output_array = op_signature.model->GetArray(output_name);
629     // Version 2 supports signed int8 input types.
630     if (output_array.data_type == ArrayDataType::kInt8) {
631       return 2;
632     }
633     return 1;
634   }
635 };
636 
637 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
638                                       ::tflite::BuiltinOptions_Pool2DOptions> {
639  public:
640   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const641   flatbuffers::Offset<TfLiteOptions> WriteOptions(
642       const TocoOperator& op,
643       flatbuffers::FlatBufferBuilder* builder) const override {
644     auto padding = Padding::Serialize(op.padding.type);
645     auto activation_function =
646         ActivationFunction::Serialize(op.fused_activation_function);
647     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
648                                          op.stride_height, op.kwidth,
649                                          op.kheight, activation_function);
650   }
651 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const652   void ReadOptions(const TfLiteOptions& options,
653                    TocoOperator* op) const override {
654     op->padding.type = Padding::Deserialize(options.padding());
655     op->stride_width = options.stride_w();
656     op->stride_height = options.stride_h();
657     op->kwidth = options.filter_width();
658     op->kheight = options.filter_height();
659     op->fused_activation_function =
660         ActivationFunction::Deserialize(options.fused_activation_function());
661   }
662 
GetVersion(const OperatorSignature & op_signature) const663   int GetVersion(const OperatorSignature& op_signature) const override {
664     return 1;
665   }
666 };
667 
668 class LocalResponseNormalization
669     : public BuiltinOperator<
670           LocalResponseNormalizationOperator,
671           ::tflite::LocalResponseNormalizationOptions,
672           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
673  public:
674   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const675   flatbuffers::Offset<TfLiteOptions> WriteOptions(
676       const TocoOperator& op,
677       flatbuffers::FlatBufferBuilder* builder) const override {
678     return ::tflite::CreateLocalResponseNormalizationOptions(
679         *builder, op.range, op.bias, op.alpha, op.beta);
680   }
681 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const682   void ReadOptions(const TfLiteOptions& options,
683                    TocoOperator* op) const override {
684     op->range = options.radius();
685     op->bias = options.bias();
686     op->alpha = options.alpha();
687     op->beta = options.beta();
688   }
689 
GetVersion(const OperatorSignature & op_signature) const690   int GetVersion(const OperatorSignature& op_signature) const override {
691     return 1;
692   }
693 };
694 
695 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
696                                        ::tflite::BuiltinOptions_Pool2DOptions> {
697  public:
698   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const699   flatbuffers::Offset<TfLiteOptions> WriteOptions(
700       const TocoOperator& op,
701       flatbuffers::FlatBufferBuilder* builder) const override {
702     auto padding = Padding::Serialize(op.padding.type);
703     auto activation_function =
704         ActivationFunction::Serialize(op.fused_activation_function);
705     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
706                                          op.stride_height, op.kwidth,
707                                          op.kheight, activation_function);
708   }
709 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const710   void ReadOptions(const TfLiteOptions& options,
711                    TocoOperator* op) const override {
712     op->padding.type = Padding::Deserialize(options.padding());
713     op->stride_width = options.stride_w();
714     op->stride_height = options.stride_h();
715     op->kwidth = options.filter_width();
716     op->kheight = options.filter_height();
717     op->fused_activation_function =
718         ActivationFunction::Deserialize(options.fused_activation_function());
719   }
720 
GetVersion(const OperatorSignature & op_signature) const721   int GetVersion(const OperatorSignature& op_signature) const override {
722     const string& input_name = op_signature.op->inputs[0];
723     const Array& input_array = op_signature.model->GetArray(input_name);
724     if (input_array.data_type == ArrayDataType::kInt8) {
725       return 2;
726     }
727     return 1;
728   }
729 };
730 
731 class Maximum : public SimpleOperator<TensorFlowMaximumOperator> {
732  public:
Maximum()733   explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {}
GetVersion(const OperatorSignature & op_signature) const734   int GetVersion(const OperatorSignature& op_signature) const override {
735     const string& input_name = op_signature.op->inputs[0];
736     const Array& input_array = op_signature.model->GetArray(input_name);
737     // Version 2 supports signed int8 input types.
738     if (input_array.data_type == ArrayDataType::kInt8) {
739       return 2;
740     }
741     return 1;
742   }
743 };
744 
745 class Minimum : public SimpleOperator<TensorFlowMinimumOperator> {
746  public:
Minimum()747   explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {}
GetVersion(const OperatorSignature & op_signature) const748   int GetVersion(const OperatorSignature& op_signature) const override {
749     const string& input_name = op_signature.op->inputs[0];
750     const Array& input_array = op_signature.model->GetArray(input_name);
751     // Version 2 supports signed int8 input types.
752     if (input_array.data_type == ArrayDataType::kInt8) {
753       return 2;
754     }
755     return 1;
756   }
757 };
758 
759 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
760                                    ::tflite::BuiltinOptions_MulOptions> {
761  public:
762   using BuiltinOperator::BuiltinOperator;
763 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const764   flatbuffers::Offset<TfLiteOptions> WriteOptions(
765       const TocoOperator& op,
766       flatbuffers::FlatBufferBuilder* builder) const override {
767     auto activation_function =
768         ActivationFunction::Serialize(op.fused_activation_function);
769     return ::tflite::CreateMulOptions(*builder, activation_function);
770   }
771 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const772   void ReadOptions(const TfLiteOptions& options,
773                    TocoOperator* op) const override {
774     op->fused_activation_function =
775         ActivationFunction::Deserialize(options.fused_activation_function());
776   }
777 
GetVersion(const OperatorSignature & op_signature) const778   int GetVersion(const OperatorSignature& op_signature) const override {
779     const string& input_name = op_signature.op->inputs[0];
780     const Array& input_array = op_signature.model->GetArray(input_name);
781     // Version 2 supports signed int8 input types.
782     if (input_array.data_type == ArrayDataType::kInt8) {
783       return 2;
784     }
785     return 1;
786   }
787 };
788 
789 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
790                                    ::tflite::BuiltinOptions_PadOptions> {
791  public:
792   using BuiltinOperator::BuiltinOperator;
793 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const794   flatbuffers::Offset<TfLiteOptions> WriteOptions(
795       const TocoOperator& op,
796       flatbuffers::FlatBufferBuilder* builder) const override {
797     return ::tflite::CreatePadOptions(*builder);
798   }
799 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const800   void ReadOptions(const TfLiteOptions& options,
801                    TocoOperator* op) const override {}
802 
GetVersion(const OperatorSignature & op_signature) const803   int GetVersion(const OperatorSignature& op_signature) const override {
804     const string& input_name = op_signature.op->inputs[0];
805     const Array& input_array = op_signature.model->GetArray(input_name);
806     // If the op take int8 input, it is version 2.
807     if (input_array.data_type == ArrayDataType::kInt8) {
808       return 2;
809     }
810     return 1;
811   }
812 };
813 
814 class Tile
815     : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
816                              ::tflite::BuiltinOptions_TileOptions> {
817   using BuiltinOperator::BuiltinOperator;
818 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const819   flatbuffers::Offset<TfLiteOptions> WriteOptions(
820       const TocoOperator& op,
821       flatbuffers::FlatBufferBuilder* builder) const override {
822     return ::tflite::CreateTileOptions(*builder);
823   }
824 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const825   void ReadOptions(const TfLiteOptions& options,
826                    TocoOperator* op) const override {}
GetVersion(const OperatorSignature & op_signature) const827   int GetVersion(const OperatorSignature& op_signature) const override {
828     return 1;
829   }
830 };
831 
832 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
833                                      ::tflite::BuiltinOptions_PadV2Options> {
834  public:
835   using BuiltinOperator::BuiltinOperator;
836 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const837   flatbuffers::Offset<TfLiteOptions> WriteOptions(
838       const TocoOperator& op,
839       flatbuffers::FlatBufferBuilder* builder) const override {
840     return ::tflite::CreatePadV2Options(*builder);
841   }
842 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const843   void ReadOptions(const TfLiteOptions& options,
844                    TocoOperator* op) const override {}
845 
GetVersion(const OperatorSignature & op_signature) const846   int GetVersion(const OperatorSignature& op_signature) const override {
847     const string& input_name = op_signature.op->inputs[0];
848     const Array& input_array = op_signature.model->GetArray(input_name);
849     // If the op take int8 input, it is version 2.
850     if (input_array.data_type == ArrayDataType::kInt8) {
851       return 2;
852     }
853     return 1;
854   }
855 };
856 
857 class Reshape
858     : public BuiltinOperator<TensorFlowReshapeOperator,
859                              ::tflite::ReshapeOptions,
860                              ::tflite::BuiltinOptions_ReshapeOptions> {
861  public:
862   using BuiltinOperator::BuiltinOperator;
863 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const864   flatbuffers::Offset<TfLiteOptions> WriteOptions(
865       const TocoOperator& op,
866       flatbuffers::FlatBufferBuilder* builder) const override {
867     return ::tflite::CreateReshapeOptions(*builder,
868                                           builder->CreateVector(op.shape));
869   }
870 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const871   void ReadOptions(const TfLiteOptions& options,
872                    TocoOperator* op) const override {
873     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
874                      options.new_shape()->end());
875   }
876 
GetVersion(const OperatorSignature & op_signature) const877   int GetVersion(const OperatorSignature& op_signature) const override {
878     return 1;
879   }
880 };
881 
882 class Softmax
883     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
884                              ::tflite::BuiltinOptions_SoftmaxOptions> {
885  public:
886   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const887   flatbuffers::Offset<TfLiteOptions> WriteOptions(
888       const TocoOperator& op,
889       flatbuffers::FlatBufferBuilder* builder) const override {
890     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
891   }
892 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const893   void ReadOptions(const TfLiteOptions& options,
894                    TocoOperator* op) const override {
895     op->beta = options.beta();
896   }
897 
GetVersion(const OperatorSignature & op_signature) const898   int GetVersion(const OperatorSignature& op_signature) const override {
899     const string& input_name = op_signature.op->inputs[0];
900     const Array& input_array = op_signature.model->GetArray(input_name);
901     if (input_array.data_type == ArrayDataType::kInt8) {
902       return 2;
903     }
904     return 1;
905   }
906 };
907 
908 class SpaceToDepth
909     : public BuiltinOperator<SpaceToDepthOperator,
910                              ::tflite::SpaceToDepthOptions,
911                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
912  public:
913   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const914   flatbuffers::Offset<TfLiteOptions> WriteOptions(
915       const TocoOperator& op,
916       flatbuffers::FlatBufferBuilder* builder) const override {
917     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
918   }
919 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const920   void ReadOptions(const TfLiteOptions& options,
921                    TocoOperator* op) const override {
922     op->block_size = options.block_size();
923   }
924 
GetVersion(const OperatorSignature & op_signature) const925   int GetVersion(const OperatorSignature& op_signature) const override {
926     const string& input_name = op_signature.op->inputs[0];
927     const Array& input_array = op_signature.model->GetArray(input_name);
928     // If the op take int8 input, it is version 2.
929     if (input_array.data_type == ArrayDataType::kInt8) {
930       return 2;
931     }
932     return 1;
933   }
934 };
935 
936 class Transpose
937     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
938                              ::tflite::BuiltinOptions_TransposeOptions> {
939  public:
940   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const941   flatbuffers::Offset<TfLiteOptions> WriteOptions(
942       const TocoOperator& op,
943       flatbuffers::FlatBufferBuilder* builder) const override {
944     return ::tflite::CreateTransposeOptions(*builder);
945   }
946 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const947   void ReadOptions(const TfLiteOptions& options,
948                    TocoOperator* op) const override {}
949 
GetVersion(const OperatorSignature & op_signature) const950   int GetVersion(const OperatorSignature& op_signature) const override {
951     const string& input_name = op_signature.op->inputs[0];
952     const Array& input_array = op_signature.model->GetArray(input_name);
953     // If the op take int8 input, it is version 2.
954     if (input_array.data_type == ArrayDataType::kInt8) {
955       return 2;
956     }
957     return 1;
958   }
959 };
960 
961 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
962                                     ::tflite::BuiltinOptions_LSTMOptions> {
963  public:
964   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const965   flatbuffers::Offset<TfLiteOptions> WriteOptions(
966       const TocoOperator& op,
967       flatbuffers::FlatBufferBuilder* builder) const override {
968     ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL;
969     switch (op.kernel_type) {
970       case LstmCellOperator::KERNEL_BASIC:
971         kernel_type = ::tflite::LSTMKernelType_BASIC;
972         break;
973       case LstmCellOperator::KERNEL_FULL:
974         kernel_type = ::tflite::LSTMKernelType_FULL;
975         break;
976       default:
977         return -1;
978     }
979 
980     // Current toco converter only supports tanh, no clip.
981     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
982                                        ::tflite::ActivationFunctionType_TANH,
983                                        /*cell_clip=*/0.0,
984                                        /*proj_clip=*/0.0, kernel_type);
985   }
986 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const987   void ReadOptions(const TfLiteOptions& options,
988                    TocoOperator* op) const override {
989     // Only support tanh activation, so check that tflite type is tanh.
990     CHECK(options.fused_activation_function() ==
991           ::tflite::ActivationFunctionType_TANH);
992 
993     switch (options.kernel_type()) {
994       case ::tflite::LSTMKernelType_BASIC:
995         op->kernel_type = LstmCellOperator::KERNEL_BASIC;
996         break;
997       case ::tflite::LSTMKernelType_FULL:
998         op->kernel_type = LstmCellOperator::KERNEL_FULL;
999         break;
1000     }
1001   }
1002 
GetVersion(const OperatorSignature & op_signature) const1003   int GetVersion(const OperatorSignature& op_signature) const override {
1004     const auto& lstm_op =
1005         static_cast<const LstmCellOperator&>(*op_signature.op);
1006     switch (lstm_op.kernel_type) {
1007       case LstmCellOperator::KERNEL_FULL: {
1008         // If the input tensor is float and a weight is int8, this is a version
1009         // 3 hybrid operation.
1010         const string& input_name = op_signature.op->inputs[0];
1011         const string& weights_name = op_signature.op->inputs[2];
1012         const string& output_name = op_signature.op->outputs[0];
1013         const Array& input_array = op_signature.model->GetArray(input_name);
1014         const Array& weights_array = op_signature.model->GetArray(weights_name);
1015         const Array& output_array = op_signature.model->GetArray(output_name);
1016         if (input_array.data_type == ArrayDataType::kFloat &&
1017             weights_array.data_type == ArrayDataType::kInt8 &&
1018             output_array.data_type == ArrayDataType::kFloat) {
1019           return 3;
1020         }
1021         return 1;
1022       }
1023       case LstmCellOperator::KERNEL_BASIC:
1024         // KERNEL_BASIC was added in version 2.
1025         return 2;
1026     }
1027   }
1028 
GetMutatingInputVariables(const Operator & op) const1029   std::vector<bool> GetMutatingInputVariables(
1030       const Operator& op) const override {
1031     const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
1032 
1033     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1034     switch (lstm_op.kernel_type) {
1035       case LstmCellOperator::KERNEL_FULL: {
1036         mutating_input_variables[kInputActivationStateTensor] = true;
1037         mutating_input_variables[kInputCellStateTensor] = true;
1038         break;
1039       }
1040       case LstmCellOperator::KERNEL_BASIC: {
1041         mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
1042         mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
1043         break;
1044       }
1045     }
1046     return mutating_input_variables;
1047   }
1048 };
1049 
1050 class UnidirectionalSequenceLstm
1051     : public BuiltinOperator<
1052           UnidirectionalSequenceLstmOperator,
1053           ::tflite::UnidirectionalSequenceLSTMOptions,
1054           ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
1055  public:
1056   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1057   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1058       const TocoOperator& op,
1059       flatbuffers::FlatBufferBuilder* builder) const override {
1060     // Current toco converter only supports tanh, no clip.
1061     return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
1062         *builder, /*fused_activation_function=*/
1063         ::tflite::ActivationFunctionType_TANH,
1064         /*cell_clip=*/0.0,
1065         /*proj_clip=*/0.0,
1066         /*time_major=*/true);
1067   }
1068 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1069   void ReadOptions(const TfLiteOptions& options,
1070                    TocoOperator* op) const override {
1071     // Only support tanh activation, so check that tflite type is tanh.
1072     DCHECK(options.fused_activation_function() ==
1073            ::tflite::ActivationFunctionType_TANH);
1074   }
1075 
GetVersion(const OperatorSignature & op_signature) const1076   int GetVersion(const OperatorSignature& op_signature) const override {
1077     // If the input tensor is float and a weight is int8, this is a version
1078     // 2 hybrid operation.
1079     const string& input_name = op_signature.op->inputs[0];
1080     const string& weights_name = op_signature.op->inputs[2];
1081     const string& output_name = op_signature.op->outputs[0];
1082     const Array& input_array = op_signature.model->GetArray(input_name);
1083     const Array& weights_array = op_signature.model->GetArray(weights_name);
1084     const Array& output_array = op_signature.model->GetArray(output_name);
1085     if (input_array.data_type == ArrayDataType::kFloat &&
1086         weights_array.data_type == ArrayDataType::kInt8 &&
1087         output_array.data_type == ArrayDataType::kFloat) {
1088       return 2;
1089     }
1090     return 1;
1091   }
1092 
GetMutatingInputVariables(const Operator & op) const1093   std::vector<bool> GetMutatingInputVariables(
1094       const Operator& op) const override {
1095     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1096     mutating_input_variables[kInputActivationStateTensor] = true;
1097     mutating_input_variables[kInputCellStateTensor] = true;
1098     return mutating_input_variables;
1099   }
1100 };
1101 
1102 class BidirectionalSequenceLstm
1103     : public BuiltinOperator<
1104           BidirectionalSequenceLstmOperator,
1105           ::tflite::BidirectionalSequenceLSTMOptions,
1106           ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
1107  public:
1108   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1109   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1110       const TocoOperator& op,
1111       flatbuffers::FlatBufferBuilder* builder) const override {
1112     // Current toco converter only supports tanh, no clip.
1113     return ::tflite::CreateBidirectionalSequenceLSTMOptions(
1114         *builder, /*fused_activation_function=*/
1115         ::tflite::ActivationFunctionType_TANH,
1116         /*cell_clip=*/0.0,
1117         /*proj_clip=*/0.0,
1118         /*merge_outputs=*/op.merge_outputs,
1119         /*time_major=*/true);
1120   }
1121 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1122   void ReadOptions(const TfLiteOptions& options,
1123                    TocoOperator* op) const override {
1124     // Only support tanh activation, so check that tflite type is tanh.
1125     DCHECK(options.fused_activation_function() ==
1126            ::tflite::ActivationFunctionType_TANH);
1127     op->merge_outputs = options.merge_outputs();
1128   }
1129 
GetVersion(const OperatorSignature & op_signature) const1130   int GetVersion(const OperatorSignature& op_signature) const override {
1131     return 1;
1132   }
1133 
GetMutatingInputVariables(const Operator & op) const1134   std::vector<bool> GetMutatingInputVariables(
1135       const Operator& op) const override {
1136     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1137     // Forward input activation state.
1138     mutating_input_variables[35] = true;
1139     // Forward input cell state.
1140     mutating_input_variables[36] = true;
1141     // Backward input activation state.
1142     mutating_input_variables[37] = true;
1143     // Backward input cell state.
1144     mutating_input_variables[38] = true;
1145     return mutating_input_variables;
1146   }
1147 };
1148 
1149 class BidirectionalSequenceRnn
1150     : public BuiltinOperator<
1151           BidirectionalSequenceRnnOperator,
1152           ::tflite::BidirectionalSequenceRNNOptions,
1153           ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
1154  public:
1155   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1156   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1157       const TocoOperator& op,
1158       flatbuffers::FlatBufferBuilder* builder) const override {
1159     // Current toco converter only supports tanh, no clip.
1160     return ::tflite::CreateBidirectionalSequenceRNNOptions(
1161         *builder, /*time_major=*/true,
1162         /*fused_activation_function=*/
1163         ::tflite::ActivationFunctionType_TANH,
1164         /*merge_outputs=*/op.merge_outputs);
1165   }
1166 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1167   void ReadOptions(const TfLiteOptions& options,
1168                    TocoOperator* op) const override {
1169     // Only support tanh activation, so check that tflite type is tanh.
1170     DCHECK(options.fused_activation_function() ==
1171            ::tflite::ActivationFunctionType_TANH);
1172     op->merge_outputs = options.merge_outputs();
1173   }
1174 
GetVersion(const OperatorSignature & op_signature) const1175   int GetVersion(const OperatorSignature& op_signature) const override {
1176     return 1;
1177   }
1178 
GetMutatingInputVariables(const Operator & op) const1179   std::vector<bool> GetMutatingInputVariables(
1180       const Operator& op) const override {
1181     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1182     // Forward hidden state.
1183     mutating_input_variables[4] = true;
1184     // Backward hidden state.
1185     mutating_input_variables[8] = true;
1186     return mutating_input_variables;
1187   }
1188 };
1189 
1190 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
1191                                     ::tflite::BuiltinOptions_ReducerOptions> {
1192  public:
1193   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1194   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1195       const TocoOperator& op,
1196       flatbuffers::FlatBufferBuilder* builder) const override {
1197     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1198   }
1199 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1200   void ReadOptions(const TfLiteOptions& options,
1201                    TocoOperator* op) const override {
1202     op->keep_dims = options.keep_dims();
1203   }
1204 
GetVersion(const OperatorSignature & op_signature) const1205   int GetVersion(const OperatorSignature& op_signature) const override {
1206     return 1;
1207   }
1208 };
1209 
1210 class Sum
1211     : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1212                              ::tflite::BuiltinOptions_ReducerOptions> {
1213  public:
1214   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1215   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1216       const TocoOperator& op,
1217       flatbuffers::FlatBufferBuilder* builder) const override {
1218     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1219   }
1220 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1221   void ReadOptions(const TfLiteOptions& options,
1222                    TocoOperator* op) const override {
1223     op->keep_dims = options.keep_dims();
1224   }
1225 
GetVersion(const OperatorSignature & op_signature) const1226   int GetVersion(const OperatorSignature& op_signature) const override {
1227     return 1;
1228   }
1229 };
1230 
1231 class ReduceMax
1232     : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1233                              ::tflite::BuiltinOptions_ReducerOptions> {
1234  public:
1235   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1236   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1237       const TocoOperator& op,
1238       flatbuffers::FlatBufferBuilder* builder) const override {
1239     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1240   }
1241 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1242   void ReadOptions(const TfLiteOptions& options,
1243                    TocoOperator* op) const override {
1244     op->keep_dims = options.keep_dims();
1245   }
1246 
GetVersion(const OperatorSignature & op_signature) const1247   int GetVersion(const OperatorSignature& op_signature) const override {
1248     const string& input_name = op_signature.op->inputs[0];
1249     const Array& input_array = op_signature.model->GetArray(input_name);
1250     // If the op take int8 input, it is version 2.
1251     if (input_array.data_type == ArrayDataType::kInt8) {
1252       return 2;
1253     }
1254     return 1;
1255   }
1256 };
1257 
1258 class ReduceMin
1259     : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1260                              ::tflite::BuiltinOptions_ReducerOptions> {
1261  public:
1262   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1263   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1264       const TocoOperator& op,
1265       flatbuffers::FlatBufferBuilder* builder) const override {
1266     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1267   }
1268 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1269   void ReadOptions(const TfLiteOptions& options,
1270                    TocoOperator* op) const override {
1271     op->keep_dims = options.keep_dims();
1272   }
1273 
GetVersion(const OperatorSignature & op_signature) const1274   int GetVersion(const OperatorSignature& op_signature) const override {
1275     const string& input_name = op_signature.op->inputs[0];
1276     const Array& input_array = op_signature.model->GetArray(input_name);
1277     // If the op take int8 input, it is version 2.
1278     if (input_array.data_type == ArrayDataType::kInt8) {
1279       return 2;
1280     }
1281     return 1;
1282   }
1283 };
1284 
1285 class ReduceProd
1286     : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1287                              ::tflite::BuiltinOptions_ReducerOptions> {
1288  public:
1289   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1290   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1291       const TocoOperator& op,
1292       flatbuffers::FlatBufferBuilder* builder) const override {
1293     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1294   }
1295 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1296   void ReadOptions(const TfLiteOptions& options,
1297                    TocoOperator* op) const override {
1298     op->keep_dims = options.keep_dims();
1299   }
1300 
GetVersion(const OperatorSignature & op_signature) const1301   int GetVersion(const OperatorSignature& op_signature) const override {
1302     return 1;
1303   }
1304 };
1305 
1306 class ReduceAny
1307     : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1308                              ::tflite::BuiltinOptions_ReducerOptions> {
1309  public:
1310   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1311   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1312       const TocoOperator& op,
1313       flatbuffers::FlatBufferBuilder* builder) const override {
1314     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1315   }
1316 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1317   void ReadOptions(const TfLiteOptions& options,
1318                    TocoOperator* op) const override {
1319     op->keep_dims = options.keep_dims();
1320   }
1321 
GetVersion(const OperatorSignature & op_signature) const1322   int GetVersion(const OperatorSignature& op_signature) const override {
1323     return 1;
1324   }
1325 };
1326 
1327 class Relu6 : public SimpleOperator<Relu6Operator> {
1328  public:
Relu6()1329   explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {}
GetVersion(const OperatorSignature & op_signature) const1330   int GetVersion(const OperatorSignature& op_signature) const override {
1331     const string& input_name = op_signature.op->inputs[0];
1332     const Array& input_array = op_signature.model->GetArray(input_name);
1333     // Version 2 supports signed int8 input types.
1334     if (input_array.data_type == ArrayDataType::kInt8) {
1335       return 2;
1336     }
1337     return 1;
1338   }
1339 };
1340 
1341 class ResizeBilinear
1342     : public BuiltinOperator<ResizeBilinearOperator,
1343                              ::tflite::ResizeBilinearOptions,
1344                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1345  public:
1346   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1347   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1348       const TocoOperator& op,
1349       flatbuffers::FlatBufferBuilder* builder) const override {
1350     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
1351   }
1352 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1353   void ReadOptions(const TfLiteOptions& options,
1354                    TocoOperator* op) const override {
1355     op->align_corners = options.align_corners();
1356   }
1357 
GetVersion(const OperatorSignature & op_signature) const1358   int GetVersion(const OperatorSignature& op_signature) const override {
1359     const string& input_name = op_signature.op->inputs[0];
1360     const Array& input_array = op_signature.model->GetArray(input_name);
1361     // If the op takes int8 input, it is version 2.
1362     if (input_array.data_type == ArrayDataType::kInt8) {
1363       return 2;
1364     }
1365     return 1;
1366   }
1367 };
1368 
1369 class ResizeNearestNeighbor
1370     : public BuiltinOperator<
1371           ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1372           ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1373  public:
1374   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1375   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1376       const TocoOperator& op,
1377       flatbuffers::FlatBufferBuilder* builder) const override {
1378     return ::tflite::CreateResizeNearestNeighborOptions(*builder,
1379                                                         op.align_corners);
1380   }
1381 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1382   void ReadOptions(const TfLiteOptions& options,
1383                    TocoOperator* op) const override {
1384     op->align_corners = options.align_corners();
1385   }
1386 
GetVersion(const OperatorSignature & op_signature) const1387   int GetVersion(const OperatorSignature& op_signature) const override {
1388     const string& input_name = op_signature.op->inputs[0];
1389     const Array& input_array = op_signature.model->GetArray(input_name);
1390     // Version 2 supports signed int8 input types.
1391     if (input_array.data_type == ArrayDataType::kInt8) {
1392       return 2;
1393     }
1394     return 1;
1395   }
1396 };
1397 
1398 class Squeeze
1399     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1400                              ::tflite::BuiltinOptions_SqueezeOptions> {
1401  public:
1402   using BuiltinOperator::BuiltinOperator;
1403 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1404   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1405       const TocoOperator& op,
1406       flatbuffers::FlatBufferBuilder* builder) const override {
1407     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1408     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1409   }
1410 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1411   void ReadOptions(const TfLiteOptions& options,
1412                    TocoOperator* op) const override {
1413     op->squeeze_dims.insert(op->squeeze_dims.end(),
1414                             options.squeeze_dims()->begin(),
1415                             options.squeeze_dims()->end());
1416   }
1417 
GetVersion(const OperatorSignature & op_signature) const1418   int GetVersion(const OperatorSignature& op_signature) const override {
1419     return 1;
1420   }
1421 };
1422 
1423 class Split
1424     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1425                              ::tflite::BuiltinOptions_SplitOptions> {
1426  public:
1427   using BuiltinOperator::BuiltinOperator;
1428 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1429   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1430       const TocoOperator& op,
1431       flatbuffers::FlatBufferBuilder* builder) const override {
1432     return ::tflite::CreateSplitOptions(*builder, op.num_split);
1433   }
1434 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1435   void ReadOptions(const TfLiteOptions& options,
1436                    TocoOperator* op) const override {
1437     op->num_split = options.num_splits();
1438   }
1439 
GetVersion(const OperatorSignature & op_signature) const1440   int GetVersion(const OperatorSignature& op_signature) const override {
1441     const string& input_name = op_signature.op->inputs[0];
1442     const Array& input_array = op_signature.model->GetArray(input_name);
1443     // If the op take int8 input, it is version 2, for int32 it's version 3.
1444     if (input_array.data_type == ArrayDataType::kInt8) {
1445       return 2;
1446     } else if (input_array.data_type == ArrayDataType::kInt32) {
1447       return 3;
1448     }
1449     return 1;
1450   }
1451 };
1452 
1453 class SplitV
1454     : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1455                              ::tflite::BuiltinOptions_SplitVOptions> {
1456  public:
1457   using BuiltinOperator::BuiltinOperator;
1458 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1459   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1460       const TocoOperator& op,
1461       flatbuffers::FlatBufferBuilder* builder) const override {
1462     return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1463   }
1464 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1465   void ReadOptions(const TfLiteOptions& options,
1466                    TocoOperator* op) const override {
1467     op->num_split = options.num_splits();
1468   }
1469 
GetVersion(const OperatorSignature & op_signature) const1470   int GetVersion(const OperatorSignature& op_signature) const override {
1471     return 1;
1472   }
1473 };
1474 
1475 class StridedSlice
1476     : public BuiltinOperator<StridedSliceOperator,
1477                              ::tflite::StridedSliceOptions,
1478                              ::tflite::BuiltinOptions_StridedSliceOptions> {
1479  public:
1480   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1481   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1482       const TocoOperator& op,
1483       flatbuffers::FlatBufferBuilder* builder) const override {
1484     return ::tflite::CreateStridedSliceOptions(
1485         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1486         op.new_axis_mask, op.shrink_axis_mask);
1487   }
1488 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1489   void ReadOptions(const TfLiteOptions& options,
1490                    TocoOperator* op) const override {
1491     op->begin_mask = options.begin_mask();
1492     op->end_mask = options.end_mask();
1493     op->ellipsis_mask = options.ellipsis_mask();
1494     op->new_axis_mask = options.new_axis_mask();
1495     op->shrink_axis_mask = options.shrink_axis_mask();
1496   }
1497 
GetVersion(const OperatorSignature & op_signature) const1498   int GetVersion(const OperatorSignature& op_signature) const override {
1499     const string& input_name = op_signature.op->inputs[0];
1500     const Array& input_array = op_signature.model->GetArray(input_name);
1501     // If the op take int8 input, it is version 2.
1502     if (input_array.data_type == ArrayDataType::kInt8) {
1503       return 2;
1504     }
1505     return 1;
1506   }
1507 };
1508 
1509 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1510                                        ::tflite::BuiltinOptions_TopKV2Options> {
1511  public:
1512   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1513   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1514       const TocoOperator& op,
1515       flatbuffers::FlatBufferBuilder* builder) const override {
1516     return ::tflite::CreateTopKV2Options(*builder);
1517   }
1518 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1519   void ReadOptions(const TfLiteOptions& options,
1520                    TocoOperator* op) const override {}
1521 
GetVersion(const OperatorSignature & op_signature) const1522   int GetVersion(const OperatorSignature& op_signature) const override {
1523     const string& input_name = op_signature.op->inputs[0];
1524     const Array& input_array = op_signature.model->GetArray(input_name);
1525     if (input_array.data_type == ArrayDataType::kInt8) {
1526       return 2;
1527     }
1528     return 1;
1529   }
1530 };
1531 
1532 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1533                                       ::tflite::BuiltinOptions_ArgMaxOptions> {
1534  public:
1535   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1536   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1537       const TocoOperator& op,
1538       flatbuffers::FlatBufferBuilder* builder) const override {
1539     return ::tflite::CreateArgMaxOptions(
1540         *builder, DataType::Serialize(op.output_data_type));
1541   }
1542 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1543   void ReadOptions(const TfLiteOptions& options,
1544                    TocoOperator* op) const override {
1545     op->output_data_type = DataType::Deserialize(options.output_type());
1546   }
1547 
GetVersion(const OperatorSignature & op_signature) const1548   int GetVersion(const OperatorSignature& op_signature) const override {
1549     const string& input_name = op_signature.op->inputs[0];
1550     const Array& input_array = op_signature.model->GetArray(input_name);
1551     if (input_array.data_type == ArrayDataType::kInt8) {
1552       return 2;
1553     }
1554 
1555     return 1;
1556   }
1557 };
1558 
1559 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1560                                       ::tflite::BuiltinOptions_ArgMinOptions> {
1561  public:
1562   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1563   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1564       const TocoOperator& op,
1565       flatbuffers::FlatBufferBuilder* builder) const override {
1566     return ::tflite::CreateArgMinOptions(
1567         *builder, DataType::Serialize(op.output_data_type));
1568   }
1569 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1570   void ReadOptions(const TfLiteOptions& options,
1571                    TocoOperator* op) const override {
1572     op->output_data_type = DataType::Deserialize(options.output_type());
1573   }
1574 
GetVersion(const OperatorSignature & op_signature) const1575   int GetVersion(const OperatorSignature& op_signature) const override {
1576     const string& input_name = op_signature.op->inputs[0];
1577     const Array& input_array = op_signature.model->GetArray(input_name);
1578     if (input_array.data_type == ArrayDataType::kInt8) {
1579       return 2;
1580     }
1581 
1582     return 1;
1583   }
1584 };
1585 
1586 class TransposeConv
1587     : public BuiltinOperator<TransposeConvOperator,
1588                              ::tflite::TransposeConvOptions,
1589                              ::tflite::BuiltinOptions_TransposeConvOptions> {
1590  public:
1591   using BuiltinOperator::BuiltinOperator;
1592 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1593   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1594       const TocoOperator& op,
1595       flatbuffers::FlatBufferBuilder* builder) const override {
1596     auto padding = Padding::Serialize(op.padding.type);
1597     return ::tflite::CreateTransposeConvOptions(
1598         *builder, padding, op.stride_width, op.stride_height);
1599   }
1600 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1601   void ReadOptions(const TfLiteOptions& options,
1602                    TocoOperator* op) const override {
1603     op->padding.type = Padding::Deserialize(options.padding());
1604     op->stride_width = options.stride_w();
1605     op->stride_height = options.stride_h();
1606   }
1607 
GetVersion(const OperatorSignature & op_signature) const1608   int GetVersion(const OperatorSignature& op_signature) const override {
1609     return 1;
1610   }
1611 };
1612 
1613 class SparseToDense
1614     : public BuiltinOperator<SparseToDenseOperator,
1615                              ::tflite::SparseToDenseOptions,
1616                              ::tflite::BuiltinOptions_SparseToDenseOptions> {
1617  public:
1618   using BuiltinOperator::BuiltinOperator;
1619 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1620   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1621       const TocoOperator& op,
1622       flatbuffers::FlatBufferBuilder* builder) const override {
1623     return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1624   }
1625 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1626   void ReadOptions(const TfLiteOptions& options,
1627                    TocoOperator* op) const override {
1628     op->validate_indices = options.validate_indices();
1629   }
1630 
GetVersion(const OperatorSignature & op_signature) const1631   int GetVersion(const OperatorSignature& op_signature) const override {
1632     return 1;
1633   }
1634 };
1635 
1636 class ExpandDims
1637     : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1638                              ::tflite::BuiltinOptions_ExpandDimsOptions> {
1639  public:
1640   using BuiltinOperator::BuiltinOperator;
1641 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1642   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1643       const TocoOperator& op,
1644       flatbuffers::FlatBufferBuilder* builder) const override {
1645     return ::tflite::CreateExpandDimsOptions(*builder);
1646   }
1647 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1648   void ReadOptions(const TfLiteOptions& options,
1649                    TocoOperator* op) const override {}
1650 
GetVersion(const OperatorSignature & op_signature) const1651   int GetVersion(const OperatorSignature& op_signature) const override {
1652     return 1;
1653   }
1654 };
1655 
1656 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1657                                     ::tflite::BuiltinOptions_PackOptions> {
1658  public:
1659   using BuiltinOperator::BuiltinOperator;
1660 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1661   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1662       const TocoOperator& op,
1663       flatbuffers::FlatBufferBuilder* builder) const override {
1664     return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1665   }
1666 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1667   void ReadOptions(const TfLiteOptions& options,
1668                    TocoOperator* op) const override {
1669     op->values_count = options.values_count();
1670     op->axis = options.axis();
1671   }
1672 
GetVersion(const OperatorSignature & op_signature) const1673   int GetVersion(const OperatorSignature& op_signature) const override {
1674     const string& input_name = op_signature.op->inputs[0];
1675     const Array& input_array = op_signature.model->GetArray(input_name);
1676     // If the op take int8 input, it is version 2.
1677     if (input_array.data_type == ArrayDataType::kInt8) {
1678       return 2;
1679     }
1680     return 1;
1681   }
1682 };
1683 
1684 class Shape
1685     : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1686                              ::tflite::BuiltinOptions_ShapeOptions> {
1687  public:
1688   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1689   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1690       const TocoOperator& op,
1691       flatbuffers::FlatBufferBuilder* builder) const override {
1692     return ::tflite::CreateShapeOptions(
1693         *builder, DataType::Serialize(op.output_data_type));
1694   }
1695 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1696   void ReadOptions(const TfLiteOptions& options,
1697                    TocoOperator* op) const override {
1698     op->output_data_type = DataType::Deserialize(options.out_type());
1699   }
1700 
GetVersion(const OperatorSignature & op_signature) const1701   int GetVersion(const OperatorSignature& op_signature) const override {
1702     return 1;
1703   }
1704 };
1705 
1706 class Slice : public SimpleOperator<SliceOperator> {
1707  public:
Slice()1708   explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {}
GetVersion(const OperatorSignature & op_signature) const1709   int GetVersion(const OperatorSignature& op_signature) const override {
1710     const string& input_name = op_signature.op->inputs[0];
1711     const Array& input_array = op_signature.model->GetArray(input_name);
1712     // Version 2 supports signed int8 input types.
1713     if (input_array.data_type == ArrayDataType::kInt8) {
1714       return 2;
1715     }
1716     return 1;
1717   }
1718 };
1719 
1720 class Tanh : public SimpleOperator<TanhOperator> {
1721  public:
Tanh()1722   explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {}
GetVersion(const OperatorSignature & op_signature) const1723   int GetVersion(const OperatorSignature& op_signature) const override {
1724     const string& input_name = op_signature.op->inputs[0];
1725     const Array& input_array = op_signature.model->GetArray(input_name);
1726     // Version 2 supports signed int8 input types.
1727     if (input_array.data_type == ArrayDataType::kInt8) {
1728       return 2;
1729     }
1730     return 1;
1731   }
1732 };
1733 
1734 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1735                                       ::tflite::BuiltinOptions_OneHotOptions> {
1736  public:
1737   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1738   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1739       const TocoOperator& op,
1740       flatbuffers::FlatBufferBuilder* builder) const override {
1741     return ::tflite::CreateOneHotOptions(*builder, op.axis);
1742   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1743   void ReadOptions(const TfLiteOptions& options,
1744                    TocoOperator* op) const override {
1745     op->axis = options.axis();
1746   }
1747 
GetVersion(const OperatorSignature & op_signature) const1748   int GetVersion(const OperatorSignature& op_signature) const override {
1749     return 1;
1750   }
1751 };
1752 
1753 class CTCBeamSearchDecoder
1754     : public CustomOperator<CTCBeamSearchDecoderOperator> {
1755  public:
1756   using CustomOperator::CustomOperator;
1757 
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1758   void WriteOptions(const TocoOperator& op,
1759                     flexbuffers::Builder* fbb) const override {
1760     fbb->Int("beam_width", op.beam_width);
1761     fbb->Int("top_paths", op.top_paths);
1762     fbb->Bool("merge_repeated", op.merge_repeated);
1763   }
1764 
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1765   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1766     op->beam_width = m["beam_width"].AsInt32();
1767     op->top_paths = m["top_paths"].AsInt32();
1768     op->merge_repeated = m["merge_repeated"].AsBool();
1769   }
1770 
GetVersion(const OperatorSignature & op_signature) const1771   int GetVersion(const OperatorSignature& op_signature) const override {
1772     return 1;
1773   }
1774 };
1775 
1776 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1777                                       ::tflite::BuiltinOptions_UnpackOptions> {
1778  public:
1779   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1780   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1781       const TocoOperator& op,
1782       flatbuffers::FlatBufferBuilder* builder) const override {
1783     return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1784   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1785   void ReadOptions(const TfLiteOptions& options,
1786                    TocoOperator* op) const override {
1787     op->num = options.num();
1788     op->axis = options.axis();
1789   }
1790 
GetVersion(const OperatorSignature & op_signature) const1791   int GetVersion(const OperatorSignature& op_signature) const override {
1792     return 1;
1793   }
1794 };
1795 
1796 class LeakyRelu
1797     : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1798                              ::tflite::BuiltinOptions_LeakyReluOptions> {
1799  public:
1800   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1801   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1802       const TocoOperator& op,
1803       flatbuffers::FlatBufferBuilder* builder) const override {
1804     return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1805   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1806   void ReadOptions(const TfLiteOptions& options,
1807                    TocoOperator* op) const override {
1808     op->alpha = options.alpha();
1809   }
1810 
GetVersion(const OperatorSignature & op_signature) const1811   int GetVersion(const OperatorSignature& op_signature) const override {
1812     return 1;
1813   }
1814 };
1815 
1816 class Logistic : public SimpleOperator<LogisticOperator> {
1817  public:
Logistic()1818   explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {}
GetVersion(const OperatorSignature & op_signature) const1819   int GetVersion(const OperatorSignature& op_signature) const override {
1820     const string& input_name = op_signature.op->inputs[0];
1821     const Array& input_array = op_signature.model->GetArray(input_name);
1822     // Version 2 supports signed int8 input types.
1823     if (input_array.data_type == ArrayDataType::kInt8) {
1824       return 2;
1825     }
1826     return 1;
1827   }
1828 };
1829 
1830 class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
1831  public:
LogSoftmax()1832   explicit LogSoftmax()
1833       : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
GetVersion(const OperatorSignature & op_signature) const1834   int GetVersion(const OperatorSignature& op_signature) const override {
1835     const string& input_name = op_signature.op->inputs[0];
1836     const Array& input_array = op_signature.model->GetArray(input_name);
1837     // Version 2 supports signed int8 input types.
1838     if (input_array.data_type == ArrayDataType::kInt8) {
1839       return 2;
1840     }
1841     return 1;
1842   }
1843 };
1844 
1845 class SquaredDifference
1846     : public BuiltinOperator<
1847           SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1848           ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1849  public:
1850   using BuiltinOperator::BuiltinOperator;
1851 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1852   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1853       const TocoOperator& op,
1854       flatbuffers::FlatBufferBuilder* builder) const override {
1855     return ::tflite::CreateSquaredDifferenceOptions(*builder);
1856   }
1857 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1858   void ReadOptions(const TfLiteOptions& options,
1859                    TocoOperator* op) const override {}
1860 
GetVersion(const OperatorSignature & op_signature) const1861   int GetVersion(const OperatorSignature& op_signature) const override {
1862     return 1;
1863   }
1864 };
1865 
1866 class MirrorPad
1867     : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1868                              ::tflite::BuiltinOptions_MirrorPadOptions> {
1869  public:
1870   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1871   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1872       const TocoOperator& op,
1873       flatbuffers::FlatBufferBuilder* builder) const override {
1874     return ::tflite::CreateMirrorPadOptions(
1875         *builder, op.mode == MirrorPadMode::kReflect
1876                       ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1877                       : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1878   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1879   void ReadOptions(const TfLiteOptions& options,
1880                    TocoOperator* op) const override {
1881     op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1882                    ? MirrorPadMode::kReflect
1883                    : MirrorPadMode::kSymmetric;
1884   }
1885 
GetVersion(const OperatorSignature & op) const1886   int GetVersion(const OperatorSignature& op) const override { return 1; }
1887 };
1888 
1889 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1890                                       ::tflite::BuiltinOptions_UniqueOptions> {
1891  public:
1892   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1893   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1894       const TocoOperator& op,
1895       flatbuffers::FlatBufferBuilder* builder) const override {
1896     const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1897     return ::tflite::CreateUniqueOptions(
1898         *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1899                       ? ::tflite::TensorType::TensorType_INT64
1900                       : ::tflite::TensorType_INT32);
1901   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1902   void ReadOptions(const TfLiteOptions& options,
1903                    TocoOperator* op) const override {
1904     UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1905     unique_op->idx_out_type =
1906         options.idx_out_type() == ::tflite::TensorType_INT64
1907             ? toco::ArrayDataType::kInt64
1908             : toco::ArrayDataType::kInt32;
1909   }
1910 
GetVersion(const OperatorSignature & op_signature) const1911   int GetVersion(const OperatorSignature& op_signature) const override {
1912     return 1;
1913   }
1914 };
1915 
1916 class UnidirectionalSequenceRnn
1917     : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1918                              ::tflite::SequenceRNNOptions,
1919                              ::tflite::BuiltinOptions_SequenceRNNOptions> {
1920  public:
1921   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1922   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1923       const TocoOperator& op,
1924       flatbuffers::FlatBufferBuilder* builder) const override {
1925     return ::tflite::CreateSequenceRNNOptions(
1926         *builder, /*time_major=*/true,
1927         /*fused_activation_function=*/
1928         ::tflite::ActivationFunctionType_TANH);
1929   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1930   void ReadOptions(const TfLiteOptions& options,
1931                    TocoOperator* op) const override {
1932     // Only support tanh activation, so check that tflite type is tanh.
1933     DCHECK(options.fused_activation_function() ==
1934            ::tflite::ActivationFunctionType_TANH);
1935   }
1936 
GetVersion(const OperatorSignature & op_signature) const1937   int GetVersion(const OperatorSignature& op_signature) const override {
1938     return 1;
1939   }
1940 
GetMutatingInputVariables(const Operator & op) const1941   std::vector<bool> GetMutatingInputVariables(
1942       const Operator& op) const override {
1943     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1944     mutating_input_variables[4] = true;
1945     return mutating_input_variables;
1946   }
1947 };
1948 
1949 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1950                                      ::tflite::BuiltinOptions_WhereOptions> {
1951  public:
1952   using BuiltinOperator::BuiltinOperator;
1953 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1954   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1955       const TocoOperator& op,
1956       flatbuffers::FlatBufferBuilder* builder) const override {
1957     return ::tflite::CreateWhereOptions(*builder);
1958   }
1959 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1960   void ReadOptions(const TfLiteOptions& options,
1961                    TocoOperator* op) const override {}
1962 
GetVersion(const OperatorSignature & op_signature) const1963   int GetVersion(const OperatorSignature& op_signature) const override {
1964     return 1;
1965   }
1966 };
1967 
WriteFlexOpOptions(const string & tensorflow_node_def)1968 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1969     const string& tensorflow_node_def) {
1970   auto fbb = absl::make_unique<flexbuffers::Builder>();
1971 
1972   ::tensorflow::NodeDef node_def;
1973   if (!node_def.ParseFromString(tensorflow_node_def)) {
1974     LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1975     return {};
1976   }
1977 
1978   fbb->Vector([&]() {
1979     fbb->String(node_def.op());
1980     fbb->String(tensorflow_node_def);
1981   });
1982   fbb->Finish();
1983   LOG(INFO) << "Writing flex op: " << node_def.op();
1984   return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1985 }
1986 
1987 class TensorFlowUnsupported : public BaseOperator {
1988  public:
TensorFlowUnsupported(const string & name,OperatorType type,bool enable_select_tf_ops)1989   TensorFlowUnsupported(const string& name, OperatorType type,
1990                         bool enable_select_tf_ops)
1991       : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1992 
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1993   Options Serialize(const Operator& op,
1994                     flatbuffers::FlatBufferBuilder* builder) const override {
1995     auto fbb =
1996         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1997     if (fbb) {
1998       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1999     } else {
2000       return Options::Custom(0);
2001     }
2002   }
2003 
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const2004   std::unique_ptr<Operator> Deserialize(
2005       const BuiltinOptions* builtin_options,
2006       const CustomOptions* custom_options) const override {
2007     // Deserializing Flex ops doesn't work now.
2008     // TODO(ycling): Revisit and decide if we should fix the flow for importing
2009     // TFLite models with Flex ops.
2010     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
2011     if (custom_options) {
2012       auto flexbuffer_map =
2013           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
2014               .AsMap();
2015       ReadOptions(flexbuffer_map, op.get());
2016     }
2017     return std::unique_ptr<Operator>(op.release());
2018   }
2019 
WriteOptions(const TensorFlowUnsupportedOperator & op) const2020   std::unique_ptr<flexbuffers::Builder> WriteOptions(
2021       const TensorFlowUnsupportedOperator& op) const {
2022     if (enable_select_tf_ops_) {
2023       return WriteFlexOpOptions(op.tensorflow_node_def);
2024     }
2025     auto fbb = absl::make_unique<flexbuffers::Builder>();
2026 
2027     ::tensorflow::NodeDef node_def;
2028     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
2029       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
2030       return std::unique_ptr<flexbuffers::Builder>();
2031     }
2032 
2033     if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
2034       fbb->Vector([&]() {
2035         fbb->String(node_def.op());
2036         fbb->String(op.tensorflow_node_def);
2037       });
2038       fbb->Finish();
2039       LOG(INFO) << "Writing flex op: " << node_def.op();
2040       return std::unique_ptr<flexbuffers::Builder>(fbb.release());
2041     }
2042 
2043     bool has_valid_attr = false;
2044     size_t map_start = fbb->StartMap();
2045     for (const auto& pair : node_def.attr()) {
2046       const char* key = pair.first.c_str();
2047       const auto& attr = pair.second;
2048       switch (attr.value_case()) {
2049         case ::tensorflow::AttrValue::kS:
2050           fbb->String(key, attr.s());
2051           has_valid_attr = true;
2052           break;
2053         case ::tensorflow::AttrValue::kI:
2054           fbb->Int(key, attr.i());
2055           has_valid_attr = true;
2056           break;
2057         case ::tensorflow::AttrValue::kF:
2058           fbb->Float(key, attr.f());
2059           has_valid_attr = true;
2060           break;
2061         case ::tensorflow::AttrValue::kB:
2062           fbb->Bool(key, attr.b());
2063           has_valid_attr = true;
2064           break;
2065         case tensorflow::AttrValue::kList:
2066           if (attr.list().s_size() > 0) {
2067             auto start = fbb->StartVector(key);
2068             for (const string& v : attr.list().s()) {
2069               fbb->Add(v);
2070             }
2071             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2072             has_valid_attr = true;
2073           } else if (attr.list().i_size() > 0) {
2074             auto start = fbb->StartVector(key);
2075             for (const int64_t v : attr.list().i()) {
2076               fbb->Add(v);
2077             }
2078             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2079             has_valid_attr = true;
2080           } else if (attr.list().f_size() > 0) {
2081             auto start = fbb->StartVector(key);
2082             for (const float v : attr.list().f()) {
2083               fbb->Add(v);
2084             }
2085             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2086             has_valid_attr = true;
2087           } else {
2088             LOG(WARNING)
2089                 << "Ignoring unsupported type in list attribute with key '"
2090                 << key << "'";
2091           }
2092           break;
2093         default:
2094           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
2095                        << key << "'";
2096           break;
2097       }
2098     }
2099     if (!has_valid_attr) {
2100       return std::unique_ptr<flexbuffers::Builder>();
2101     }
2102     fbb->EndMap(map_start);
2103     fbb->Finish();
2104     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
2105   }
2106 
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const2107   void ReadOptions(const flexbuffers::Map& m,
2108                    TensorFlowUnsupportedOperator* op) const {
2109     ::tensorflow::NodeDef node_def;
2110     auto attr = node_def.mutable_attr();
2111 
2112     const auto& keys = m.Keys();
2113     for (size_t i = 0; i < keys.size(); ++i) {
2114       const auto key = keys[i].AsKey();
2115       const auto& value = m[key];
2116       // TODO(wvo): hack to make this code compile with 2 different API
2117       // versions.
2118       // Please remove once OS/internal versions are in sync.
2119       // See hardcoded values in the switch below.
2120       switch (value.GetType()) {
2121         case 5:  // flexbuffers::FBT_STRING:
2122           (*attr)[key].set_s(value.AsString().c_str());
2123           break;
2124         case 1:  // flexbuffers::FBT_INT:
2125           (*attr)[key].set_i(value.AsInt64());
2126           break;
2127         case 3:  // flexbuffers::FBT_FLOAT:
2128           (*attr)[key].set_f(value.AsFloat());
2129           break;
2130         case 26:  // flexbuffers::FBT_BOOL:
2131           (*attr)[key].set_b(value.AsBool());
2132           if (string(key) == "_output_quantized") {
2133             op->quantized = value.AsBool();
2134           }
2135           if (string(key) == "_support_output_type_float_in_quantized_op") {
2136             op->support_output_type_float_in_quantized_op = value.AsBool();
2137           }
2138           break;
2139         case 11: {  // flexbuffers::FBT_VECTOR_INT: {
2140           auto* list = (*attr)[key].mutable_list();
2141           const auto& vector = value.AsTypedVector();
2142           for (size_t i = 0; i < vector.size(); i++) {
2143             list->add_i(vector[i].AsInt64());
2144           }
2145           break;
2146         }
2147         case 13: {  // flexbuffers::FBT_VECTOR_FLOAT: {
2148           auto* list = (*attr)[key].mutable_list();
2149           const auto& vector = value.AsTypedVector();
2150           for (size_t i = 0; i < vector.size(); i++) {
2151             list->add_f(vector[i].AsFloat());
2152           }
2153           break;
2154         }
2155         case 15: {  // flexbuffers::FBT_VECTOR_STRING: {
2156           auto* list = (*attr)[key].mutable_list();
2157           const auto& vector = value.AsTypedVector();
2158           for (size_t i = 0; i < vector.size(); i++) {
2159             list->add_s(vector[i].AsString().str());
2160           }
2161           break;
2162         }
2163         default:
2164           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
2165                        << key << "'";
2166           break;
2167       }
2168     }
2169     node_def.SerializeToString(&op->tensorflow_node_def);
2170   }
2171 
GetVersion(const OperatorSignature & op_signature) const2172   int GetVersion(const OperatorSignature& op_signature) const override {
2173     // TODO(ycling): Design and implement a way to plumb the version of
2174     // custom ops.
2175     return 1;
2176   }
2177 
2178  private:
2179   const bool enable_select_tf_ops_;
2180 };
2181 
2182 class Dequantize
2183     : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
2184                              ::tflite::BuiltinOptions_DequantizeOptions> {
2185  public:
2186   using BuiltinOperator::BuiltinOperator;
2187 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const2188   flatbuffers::Offset<TfLiteOptions> WriteOptions(
2189       const TocoOperator& op,
2190       flatbuffers::FlatBufferBuilder* builder) const override {
2191     return ::tflite::CreateDequantizeOptions(*builder);
2192   }
2193 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const2194   void ReadOptions(const TfLiteOptions& options,
2195                    TocoOperator* op) const override {}
2196 
GetVersion(const OperatorSignature & op_signature) const2197   int GetVersion(const OperatorSignature& op_signature) const override {
2198     const string& input_name = op_signature.op->inputs[0];
2199     const Array& input_array = op_signature.model->GetArray(input_name);
2200     // Version 2 supports signed int8 input types.
2201     if (input_array.data_type == ArrayDataType::kInt8) {
2202       return 2;
2203     }
2204     return 1;
2205   }
2206 };
2207 
2208 class ReverseSequence
2209     : public BuiltinOperator<ReverseSequenceOperator,
2210                              ::tflite::ReverseSequenceOptions,
2211                              ::tflite::BuiltinOptions_ReverseSequenceOptions> {
2212  public:
2213   using BuiltinOperator::BuiltinOperator;
2214 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const2215   flatbuffers::Offset<TfLiteOptions> WriteOptions(
2216       const TocoOperator& op,
2217       flatbuffers::FlatBufferBuilder* builder) const override {
2218     return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
2219                                                   op.batch_dim);
2220   }
2221 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const2222   void ReadOptions(const TfLiteOptions& options,
2223                    TocoOperator* op) const override {
2224     op->seq_dim = options.seq_dim();
2225     op->batch_dim = options.batch_dim();
2226   }
2227 
GetVersion(const OperatorSignature & op_signature) const2228   int GetVersion(const OperatorSignature& op_signature) const override {
2229     return 1;
2230   }
2231 };
2232 
2233 class Equal : public SimpleOperator<TensorFlowEqualOperator> {
2234  public:
Equal()2235   explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {}
GetVersion(const OperatorSignature & op_signature) const2236   int GetVersion(const OperatorSignature& op_signature) const override {
2237     const string& input_name = op_signature.op->inputs[0];
2238     const Array& input_array = op_signature.model->GetArray(input_name);
2239     // Version 2 supports signed int8 input types.
2240     if (input_array.data_type == ArrayDataType::kInt8) {
2241       return 2;
2242     }
2243     return 1;
2244   }
2245 };
2246 
2247 class NotEqual : public SimpleOperator<TensorFlowNotEqualOperator> {
2248  public:
NotEqual()2249   explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {}
GetVersion(const OperatorSignature & op_signature) const2250   int GetVersion(const OperatorSignature& op_signature) const override {
2251     const string& input_name = op_signature.op->inputs[0];
2252     const Array& input_array = op_signature.model->GetArray(input_name);
2253     // Version 2 supports signed int8 input types.
2254     if (input_array.data_type == ArrayDataType::kInt8) {
2255       return 2;
2256     }
2257     return 1;
2258   }
2259 };
2260 
2261 class Greater : public SimpleOperator<TensorFlowGreaterOperator> {
2262  public:
Greater()2263   explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {}
GetVersion(const OperatorSignature & op_signature) const2264   int GetVersion(const OperatorSignature& op_signature) const override {
2265     const string& input_name = op_signature.op->inputs[0];
2266     const Array& input_array = op_signature.model->GetArray(input_name);
2267     // Version 2 supports signed int8 input types.
2268     if (input_array.data_type == ArrayDataType::kInt8) {
2269       return 2;
2270     }
2271     return 1;
2272   }
2273 };
2274 
2275 class GreaterEqual : public SimpleOperator<TensorFlowGreaterEqualOperator> {
2276  public:
GreaterEqual()2277   explicit GreaterEqual()
2278       : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {}
GetVersion(const OperatorSignature & op_signature) const2279   int GetVersion(const OperatorSignature& op_signature) const override {
2280     const string& input_name = op_signature.op->inputs[0];
2281     const Array& input_array = op_signature.model->GetArray(input_name);
2282     // Version 2 supports signed int8 input types.
2283     if (input_array.data_type == ArrayDataType::kInt8) {
2284       return 2;
2285     }
2286     return 1;
2287   }
2288 };
2289 
2290 class Less : public SimpleOperator<TensorFlowLessOperator> {
2291  public:
Less()2292   explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {}
GetVersion(const OperatorSignature & op_signature) const2293   int GetVersion(const OperatorSignature& op_signature) const override {
2294     const string& input_name = op_signature.op->inputs[0];
2295     const Array& input_array = op_signature.model->GetArray(input_name);
2296     // Version 2 supports signed int8 input types.
2297     if (input_array.data_type == ArrayDataType::kInt8) {
2298       return 2;
2299     }
2300     return 1;
2301   }
2302 };
2303 
2304 class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
2305  public:
LessEqual()2306   explicit LessEqual()
2307       : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {}
GetVersion(const OperatorSignature & op_signature) const2308   int GetVersion(const OperatorSignature& op_signature) const override {
2309     const string& input_name = op_signature.op->inputs[0];
2310     const Array& input_array = op_signature.model->GetArray(input_name);
2311     // Version 2 supports signed int8 input types.
2312     if (input_array.data_type == ArrayDataType::kInt8) {
2313       return 2;
2314     }
2315     return 1;
2316   }
2317 };
2318 
2319 class Select : public SimpleOperator<SelectOperator> {
2320  public:
Select()2321   explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
GetVersion(const OperatorSignature & op_signature) const2322   int GetVersion(const OperatorSignature& op_signature) const override {
2323     const string& input_name = op_signature.op->inputs[0];
2324     const Array& input_array = op_signature.model->GetArray(input_name);
2325     // Version 2 supports signed int8 input types.
2326     if (input_array.data_type == ArrayDataType::kInt8) {
2327       return 2;
2328     }
2329     return 1;
2330   }
2331 };
2332 
2333 namespace {
2334 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)2335 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
2336     bool enable_select_tf_ops = false) {
2337   std::vector<std::unique_ptr<BaseOperator>> ops;
2338   using tensorflow::MakeUnique;
2339   // Builtin Operators.
2340   ops.push_back(
2341       MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
2342   ops.push_back(
2343       MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
2344   ops.push_back(
2345       MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
2346   ops.push_back(
2347       MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
2348   ops.push_back(MakeUnique<AveragePool>(
2349       ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
2350   ops.push_back(
2351       MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
2352                                  OperatorType::kSpaceToBatchND));
2353   ops.push_back(
2354       MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
2355                                  OperatorType::kBatchToSpaceND));
2356   ops.push_back(MakeUnique<Concatenation>(
2357       ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
2358   ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
2359                                         OperatorType::kConv));
2360   ops.push_back(MakeUnique<DepthwiseConvolution>(
2361       ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
2362       OperatorType::kDepthwiseConv));
2363   ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
2364                                        OperatorType::kDequantize));
2365   ops.push_back(
2366       MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
2367                                  OperatorType::kFullyConnected));
2368   ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
2369                                    OperatorType::kGather));
2370   ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
2371                                      OperatorType::kGatherNd));
2372   ops.push_back(
2373       MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
2374                                   OperatorType::kL2Normalization));
2375   ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
2376                                    OperatorType::kL2Pool));
2377   ops.push_back(MakeUnique<LocalResponseNormalization>(
2378       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
2379       OperatorType::kLocalResponseNormalization));
2380   ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
2381                                     OperatorType::kMaxPool));
2382   ops.push_back(
2383       MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
2384 
2385   ops.push_back(
2386       MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
2387   ops.push_back(
2388       MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
2389   ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
2390                                     OperatorType::kReshape));
2391   ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
2392                                     OperatorType::kSoftmax));
2393   ops.push_back(MakeUnique<SpaceToDepth>(
2394       ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
2395   ops.push_back(
2396       MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
2397   ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
2398                                       OperatorType::kTranspose));
2399   ops.push_back(
2400       MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
2401   ops.push_back(
2402       MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
2403   ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
2404                                        OperatorType::kReduceProd));
2405   ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
2406                                       OperatorType::kReduceMax));
2407   ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
2408                                       OperatorType::kReduceMin));
2409   ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
2410                                       OperatorType::kAny));
2411   ops.push_back(
2412       MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
2413                                  OperatorType::kResizeBilinear));
2414   ops.push_back(MakeUnique<ResizeNearestNeighbor>(
2415       ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
2416       OperatorType::kResizeNearestNeighbor));
2417   ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
2418                                     OperatorType::kSqueeze));
2419   ops.push_back(
2420       MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
2421   ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
2422                                    OperatorType::kSplitV));
2423   ops.push_back(MakeUnique<StridedSlice>(
2424       ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
2425   ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
2426                                     OperatorType::kTopK_V2));
2427   ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
2428                                  OperatorType::kLstmCell));
2429   ops.push_back(
2430       MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
2431   ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
2432                                    OperatorType::kArgMax));
2433   ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
2434                                    OperatorType::kArgMin));
2435   ops.push_back(
2436       MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
2437   ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
2438                                        OperatorType::kExpandDims));
2439   ops.push_back(MakeUnique<TransposeConv>(
2440       ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
2441   ops.push_back(MakeUnique<SparseToDense>(
2442       ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
2443   ops.push_back(
2444       MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
2445   ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
2446                                       OperatorType::kFakeQuant));
2447   ops.push_back(
2448       MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
2449   ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
2450       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
2451       OperatorType::kUnidirectionalSequenceLstm));
2452   ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
2453       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
2454       OperatorType::kBidirectionalSequenceLstm));
2455   ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
2456       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
2457       OperatorType::kBidirectionalSequenceRnn));
2458   ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
2459                                    OperatorType::kOneHot));
2460   ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
2461                                    OperatorType::kUnpack));
2462   ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
2463                                       OperatorType::kLeakyRelu));
2464   ops.push_back(MakeUnique<SquaredDifference>(
2465       ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
2466       OperatorType::kSquaredDifference));
2467   ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
2468                                       OperatorType::kMirrorPad));
2469   ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
2470                                    OperatorType::kUnique));
2471   ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
2472       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
2473       OperatorType::kUnidirectionalSequenceRnn));
2474   ops.push_back(
2475       MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
2476   ops.push_back(
2477       MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
2478                                   OperatorType::kReverseSequence));
2479 
2480   // Custom Operators.
2481   ops.push_back(
2482       MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
2483   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
2484       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
2485   ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
2486                                                   OperatorType::kUnsupported,
2487                                                   enable_select_tf_ops));
2488 
2489   // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
2490   // been modified to also export builtins. As TOCO evolved we added warnings
2491   // when custom ops are exported but SimpleOperator bypasses thoses. To
2492   // prevent user confusion we are settling on using SimpleOperator only for
2493   // builtins.
2494   ops.push_back(
2495       MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
2496   ops.push_back(
2497       MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
2498   ops.push_back(
2499       MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
2500   ops.push_back(
2501       MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
2502   ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2503       "RELU_N1_TO_1", OperatorType::kRelu1));
2504   ops.push_back(MakeUnique<Relu6>());
2505   ops.push_back(
2506       MakeUnique<SimpleOperator<PReluOperator>>("PRELU", OperatorType::kPRelu));
2507   ops.push_back(MakeUnique<Logistic>());
2508   ops.push_back(MakeUnique<Tanh>());
2509   ops.push_back(
2510       MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
2511   ops.push_back(
2512       MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
2513   ops.push_back(MakeUnique<LogSoftmax>());
2514   ops.push_back(MakeUnique<Maximum>());  //  Element-wise Maximum
2515   ops.push_back(MakeUnique<Minimum>());  //  Element-wise Minimum
2516   ops.push_back(MakeUnique<Greater>());
2517   ops.push_back(MakeUnique<GreaterEqual>());
2518   ops.push_back(MakeUnique<Less>());
2519   ops.push_back(MakeUnique<LessEqual>());
2520   ops.push_back(MakeUnique<Equal>());
2521   ops.push_back(MakeUnique<NotEqual>());
2522   ops.push_back(
2523       MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
2524   ops.push_back(MakeUnique<Select>());
2525   ops.push_back(MakeUnique<Slice>());
2526   ops.push_back(
2527       MakeUnique<SimpleOperator<PowOperator>>("POW", OperatorType::kPow));
2528   ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2529       "LOGICAL_OR", OperatorType::kLogicalOr));
2530   ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2531       "LOGICAL_AND", OperatorType::kLogicalAnd));
2532   ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2533       "LOGICAL_NOT", OperatorType::kLogicalNot));
2534   ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2535       "FLOOR_DIV", OperatorType::kFloorDiv));
2536   ops.emplace_back(new SimpleOperator<FloorModOperator>(
2537       "FLOOR_MOD", OperatorType::kFloorMod));
2538   ops.emplace_back(
2539       new SimpleOperator<RangeOperator>("RANGE", OperatorType::kRange));
2540   // Element-wise operator
2541   ops.push_back(
2542       MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
2543   ops.push_back(
2544       MakeUnique<SimpleOperator<LogOperator>>("LOG", OperatorType::kLog));
2545   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2546       "SQRT", OperatorType::kSqrt));
2547   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2548       "RSQRT", OperatorType::kRsqrt));
2549   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2550       "SQUARE", OperatorType::kSquare));
2551   ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2552       "ZEROS_LIKE", OperatorType::kZerosLike));
2553   ops.push_back(
2554       MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
2555   ops.push_back(
2556       MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
2557   ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2558       "REVERSE_V2", OperatorType::kReverseV2));
2559   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2560       "RANK", OperatorType::kRank));
2561   return ops;
2562 }
2563 }  // namespace
2564 
BuildOperatorByTypeMap(bool enable_select_tf_ops)2565 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2566     bool enable_select_tf_ops) {
2567   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2568 
2569   std::vector<std::unique_ptr<BaseOperator>> ops =
2570       BuildOperatorList(enable_select_tf_ops);
2571   for (auto& op : ops) {
2572     result[op->type()] = std::move(op);
2573   }
2574 
2575   return result;
2576 }
2577 
BuildOperatorByNameMap(bool enable_select_tf_ops)2578 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2579     bool enable_select_tf_ops) {
2580   std::map<string, std::unique_ptr<BaseOperator>> result;
2581 
2582   std::vector<std::unique_ptr<BaseOperator>> ops =
2583       BuildOperatorList(enable_select_tf_ops);
2584   for (auto& op : ops) {
2585     result[op->name()] = std::move(op);
2586   }
2587 
2588   return result;
2589 }
2590 
ShouldExportAsFlexOp(bool enable_select_tf_ops,const string & tensorflow_op_name)2591 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2592                           const string& tensorflow_op_name) {
2593   // If Flex ops aren't allow at all, simply return false.
2594   if (!enable_select_tf_ops) {
2595     return false;
2596   }
2597   // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2598   // it and it has been whitelisted, export the op as an Flex op. Otherwise,
2599   // export it as a regular custom op.
2600   const tensorflow::OpDef* op_def = nullptr;
2601   if (!tensorflow::OpRegistry::Global()
2602            ->LookUpOpDef(tensorflow_op_name, &op_def)
2603            .ok()) {
2604     return false;
2605   }
2606 
2607   if (!IsWhitelistedFlexOp(tensorflow_op_name)) {
2608     LOG(WARNING) << "Op " << tensorflow_op_name
2609                  << " is a valid TensorFlow op but has not been whitelisted for"
2610                     " the TensorFlow Lite flex op set.";
2611     return false;
2612   }
2613 
2614   return true;
2615 }
2616 
2617 }  // namespace tflite
2618 
2619 }  // namespace toco
2620