• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16 
17 #include <map>
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/util/ptr_util.h"
24 
25 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
26 // graph_transformation module.
27 #include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
32 #include "tensorflow/lite/toco/tflite/custom_operator.h"
33 #include "tensorflow/lite/toco/tflite/simple_operator.h"
34 #include "tensorflow/lite/toco/tflite/types.h"
35 #include "tensorflow/lite/tools/versioning/op_version.h"
36 
37 namespace toco {
38 
39 namespace tflite {
40 
41 // LINT.IfChange
42 
GetTensorType(const ArrayDataType type)43 ::tflite::TensorType GetTensorType(const ArrayDataType type) {
44   const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
45       {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
46       {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
47       {ArrayDataType::kInt8, ::tflite::TensorType_INT8},
48       {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
49       {ArrayDataType::kInt16, ::tflite::TensorType_INT16},
50       {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
51       {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
52       {ArrayDataType::kString, ::tflite::TensorType_STRING},
53       {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
54       {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}};
55 
56   auto it = tensor_type_map.find(type);
57   if (it != tensor_type_map.end()) {
58     return it->second;
59   }
60   return static_cast<::tflite::TensorType>(-1);
61 }
62 
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)63 ::tflite::OpSignature GetVersioningOpSig(
64     const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
65   std::vector<::tflite::TensorType> input_types, output_types;
66   for (auto input_name : op_signature.op->inputs) {
67     ::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
68     if (op_signature.model->HasArray(input_name)) {
69       const Array& input_array = op_signature.model->GetArray(input_name);
70       input_type = GetTensorType(input_array.data_type);
71     }
72     input_types.push_back(input_type);
73   }
74   for (auto output_name : op_signature.op->outputs) {
75     ::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
76     if (op_signature.model->HasArray(output_name)) {
77       const Array& output_array = op_signature.model->GetArray(output_name);
78       output_type = GetTensorType(output_array.data_type);
79     }
80     output_types.push_back(output_type);
81   }
82   return ::tflite::OpSignature{op, input_types, output_types};
83 }
84 
85 class AveragePool
86     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
87                              ::tflite::BuiltinOptions_Pool2DOptions> {
88  public:
89   using BuiltinOperator::BuiltinOperator;
90 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const91   flatbuffers::Offset<TfLiteOptions> WriteOptions(
92       const TocoOperator& op,
93       flatbuffers::FlatBufferBuilder* builder) const override {
94     auto padding = Padding::Serialize(op.padding.type);
95     auto activation_function =
96         ActivationFunction::Serialize(op.fused_activation_function);
97     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
98                                          op.stride_height, op.kwidth,
99                                          op.kheight, activation_function);
100   }
101 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const102   void ReadOptions(const TfLiteOptions& options,
103                    TocoOperator* op) const override {
104     op->padding.type = Padding::Deserialize(options.padding());
105     op->stride_width = options.stride_w();
106     op->stride_height = options.stride_h();
107     op->kwidth = options.filter_width();
108     op->kheight = options.filter_height();
109     op->fused_activation_function =
110         ActivationFunction::Deserialize(options.fused_activation_function());
111   }
112 };
113 
114 class Convolution
115     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
116                              ::tflite::BuiltinOptions_Conv2DOptions> {
117  public:
118   using BuiltinOperator::BuiltinOperator;
119 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const120   flatbuffers::Offset<TfLiteOptions> WriteOptions(
121       const TocoOperator& op,
122       flatbuffers::FlatBufferBuilder* builder) const override {
123     auto padding = Padding::Serialize(op.padding.type);
124     auto activation_function =
125         ActivationFunction::Serialize(op.fused_activation_function);
126     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
127                                          op.stride_height, activation_function,
128                                          op.dilation_width_factor,
129                                          op.dilation_height_factor);
130   }
131 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const132   void ReadOptions(const TfLiteOptions& options,
133                    TocoOperator* op) const override {
134     op->padding.type = Padding::Deserialize(options.padding());
135     op->stride_width = options.stride_w();
136     op->stride_height = options.stride_h();
137     op->dilation_width_factor = options.dilation_w_factor();
138     op->dilation_height_factor = options.dilation_h_factor();
139     op->fused_activation_function =
140         ActivationFunction::Deserialize(options.fused_activation_function());
141   }
142 };
143 
144 class DepthwiseConvolution
145     : public BuiltinOperator<DepthwiseConvOperator,
146                              ::tflite::DepthwiseConv2DOptions,
147                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
148  public:
149   using BuiltinOperator::BuiltinOperator;
150 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const151   flatbuffers::Offset<TfLiteOptions> WriteOptions(
152       const TocoOperator& op,
153       flatbuffers::FlatBufferBuilder* builder) const override {
154     auto padding = Padding::Serialize(op.padding.type);
155     auto activation_function =
156         ActivationFunction::Serialize(op.fused_activation_function);
157     return ::tflite::CreateDepthwiseConv2DOptions(
158         *builder, padding, op.stride_width, op.stride_height,
159         op.depth_multiplier, activation_function, op.dilation_width_factor,
160         op.dilation_height_factor);
161   }
162 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const163   void ReadOptions(const TfLiteOptions& options,
164                    TocoOperator* op) const override {
165     op->padding.type = Padding::Deserialize(options.padding());
166     op->stride_width = options.stride_w();
167     op->stride_height = options.stride_h();
168     op->depth_multiplier = options.depth_multiplier();
169     op->fused_activation_function =
170         ActivationFunction::Deserialize(options.fused_activation_function());
171     op->dilation_width_factor = options.dilation_w_factor();
172     op->dilation_height_factor = options.dilation_h_factor();
173   }
174 
GetVersion(const OperatorSignature & op_signature) const175   int GetVersion(const OperatorSignature& op_signature) const override {
176     const auto& conv_op =
177         static_cast<const DepthwiseConvOperator&>(*op_signature.op);
178     ::tflite::OpSignature op_sig =
179         GetVersioningOpSig(builtin_op(), op_signature);
180     op_sig.options.depthwise_conv_2d.dilation_w_factor =
181         conv_op.dilation_width_factor;
182     op_sig.options.depthwise_conv_2d.dilation_h_factor =
183         conv_op.dilation_height_factor;
184     return ::tflite::GetBuiltinOperatorVersion(op_sig);
185   }
186 };
187 
188 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
189                                    ::tflite::BuiltinOptions_AddOptions> {
190  public:
191   using BuiltinOperator::BuiltinOperator;
192 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const193   flatbuffers::Offset<TfLiteOptions> WriteOptions(
194       const TocoOperator& op,
195       flatbuffers::FlatBufferBuilder* builder) const override {
196     auto activation_function =
197         ActivationFunction::Serialize(op.fused_activation_function);
198     return ::tflite::CreateAddOptions(*builder, activation_function);
199   }
200 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const201   void ReadOptions(const TfLiteOptions& options,
202                    TocoOperator* op) const override {
203     op->fused_activation_function =
204         ActivationFunction::Deserialize(options.fused_activation_function());
205   }
206 };
207 
208 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
209                                     ::tflite::BuiltinOptions_AddNOptions> {
210  public:
211   using BuiltinOperator::BuiltinOperator;
212 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const213   flatbuffers::Offset<TfLiteOptions> WriteOptions(
214       const TocoOperator& op,
215       flatbuffers::FlatBufferBuilder* builder) const override {
216     return ::tflite::CreateAddNOptions(*builder);
217   }
218 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const219   void ReadOptions(const TfLiteOptions& options,
220                    TocoOperator* op) const override {}
221 };
222 
223 class SpaceToBatchND
224     : public BuiltinOperator<SpaceToBatchNDOperator,
225                              ::tflite::SpaceToBatchNDOptions,
226                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
227  public:
228   using BuiltinOperator::BuiltinOperator;
229 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const230   flatbuffers::Offset<TfLiteOptions> WriteOptions(
231       const TocoOperator& op,
232       flatbuffers::FlatBufferBuilder* builder) const override {
233     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
234   }
235 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const236   void ReadOptions(const TfLiteOptions& options,
237                    TocoOperator* op) const override {}
238 };
239 
240 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
241                                    ::tflite::BuiltinOptions_SubOptions> {
242  public:
243   using BuiltinOperator::BuiltinOperator;
244 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const245   flatbuffers::Offset<TfLiteOptions> WriteOptions(
246       const TocoOperator& op,
247       flatbuffers::FlatBufferBuilder* builder) const override {
248     auto activation_function =
249         ActivationFunction::Serialize(op.fused_activation_function);
250     return ::tflite::CreateSubOptions(*builder, activation_function);
251   }
252 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const253   void ReadOptions(const TfLiteOptions& options,
254                    TocoOperator* op) const override {
255     op->fused_activation_function =
256         ActivationFunction::Deserialize(options.fused_activation_function());
257   }
258 };
259 
260 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
261                                    ::tflite::BuiltinOptions_DivOptions> {
262  public:
263   using BuiltinOperator::BuiltinOperator;
264 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const265   flatbuffers::Offset<TfLiteOptions> WriteOptions(
266       const TocoOperator& op,
267       flatbuffers::FlatBufferBuilder* builder) const override {
268     auto activation_function =
269         ActivationFunction::Serialize(op.fused_activation_function);
270     return ::tflite::CreateDivOptions(*builder, activation_function);
271   }
272 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const273   void ReadOptions(const TfLiteOptions& options,
274                    TocoOperator* op) const override {
275     op->fused_activation_function =
276         ActivationFunction::Deserialize(options.fused_activation_function());
277   }
278 };
279 
280 class BatchToSpaceND
281     : public BuiltinOperator<BatchToSpaceNDOperator,
282                              ::tflite::BatchToSpaceNDOptions,
283                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
284  public:
285   using BuiltinOperator::BuiltinOperator;
286 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const287   flatbuffers::Offset<TfLiteOptions> WriteOptions(
288       const TocoOperator& op,
289       flatbuffers::FlatBufferBuilder* builder) const override {
290     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
291   }
292 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const293   void ReadOptions(const TfLiteOptions& options,
294                    TocoOperator* op) const override {}
295 };
296 
297 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
298                                     ::tflite::BuiltinOptions_CastOptions> {
299  public:
300   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const301   flatbuffers::Offset<TfLiteOptions> WriteOptions(
302       const TocoOperator& op,
303       flatbuffers::FlatBufferBuilder* builder) const override {
304     return ::tflite::CreateCastOptions(*builder,
305                                        DataType::Serialize(op.src_data_type),
306                                        DataType::Serialize(op.dst_data_type));
307   }
308 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const309   void ReadOptions(const TfLiteOptions& options,
310                    TocoOperator* op) const override {
311     op->src_data_type = DataType::Deserialize(options.in_data_type());
312     op->dst_data_type = DataType::Deserialize(options.out_data_type());
313   }
314 };
315 
316 class Concatenation
317     : public BuiltinOperator<ConcatenationOperator,
318                              ::tflite::ConcatenationOptions,
319                              ::tflite::BuiltinOptions_ConcatenationOptions> {
320  public:
321   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const322   flatbuffers::Offset<TfLiteOptions> WriteOptions(
323       const TocoOperator& op,
324       flatbuffers::FlatBufferBuilder* builder) const override {
325     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
326   }
327 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const328   void ReadOptions(const TfLiteOptions& options,
329                    TocoOperator* op) const override {
330     op->axis = options.axis();
331   }
332 };
333 
334 class DepthToSpace
335     : public BuiltinOperator<DepthToSpaceOperator,
336                              ::tflite::DepthToSpaceOptions,
337                              ::tflite::BuiltinOptions_DepthToSpaceOptions> {
338  public:
339   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const340   flatbuffers::Offset<TfLiteOptions> WriteOptions(
341       const TocoOperator& op,
342       flatbuffers::FlatBufferBuilder* builder) const override {
343     return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
344   }
345 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const346   void ReadOptions(const TfLiteOptions& options,
347                    TocoOperator* op) const override {
348     op->block_size = options.block_size();
349   }
350 };
351 
352 class FakeQuant
353     : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
354                              ::tflite::BuiltinOptions_FakeQuantOptions> {
355  public:
356   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const357   flatbuffers::Offset<TfLiteOptions> WriteOptions(
358       const TocoOperator& op,
359       flatbuffers::FlatBufferBuilder* builder) const override {
360     return ::tflite::CreateFakeQuantOptions(
361         *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
362   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const363   void ReadOptions(const TfLiteOptions& options,
364                    TocoOperator* op) const override {
365     auto* minmax = new MinMax;
366     minmax->min = options.min();
367     minmax->max = options.max();
368     op->minmax.reset(minmax);
369     op->num_bits = options.num_bits();
370     op->narrow_range = options.narrow_range();
371   }
GetVersion(const OperatorSignature & op_signature) const372   int GetVersion(const OperatorSignature& op_signature) const override {
373     const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
374     ::tflite::OpSignature op_sig =
375         GetVersioningOpSig(builtin_op(), op_signature);
376     op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
377     return ::tflite::GetBuiltinOperatorVersion(op_sig);
378   }
379 };
380 
381 class FullyConnected
382     : public BuiltinOperator<FullyConnectedOperator,
383                              ::tflite::FullyConnectedOptions,
384                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
385  public:
386   using BuiltinOperator::BuiltinOperator;
387 
GetWeightFormat(FullyConnectedWeightsFormat fmt) const388   ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
389       FullyConnectedWeightsFormat fmt) const {
390     switch (fmt) {
391       case FullyConnectedWeightsFormat::kDefault:
392         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
393       case FullyConnectedWeightsFormat::kShuffled4x16Int8:
394         return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
395       default:
396         LOG(ERROR) << "Unhandled FC weights format";
397         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
398     }
399   }
400 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const401   flatbuffers::Offset<TfLiteOptions> WriteOptions(
402       const TocoOperator& op,
403       flatbuffers::FlatBufferBuilder* builder) const override {
404     auto activation_function =
405         ActivationFunction::Serialize(op.fused_activation_function);
406     return ::tflite::CreateFullyConnectedOptions(
407         *builder, activation_function, GetWeightFormat(op.weights_format));
408   }
409 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const410   void ReadOptions(const TfLiteOptions& options,
411                    TocoOperator* op) const override {
412     op->fused_activation_function =
413         ActivationFunction::Deserialize(options.fused_activation_function());
414     switch (options.weights_format()) {
415       case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
416         op->weights_format = FullyConnectedWeightsFormat::kDefault;
417         break;
418       case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
419         op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
420         break;
421       default:
422         LOG(ERROR) << "Unhandled FC weights format";
423         op->weights_format = FullyConnectedWeightsFormat::kDefault;
424     }
425   }
426 
GetVersion(const OperatorSignature & op_signature) const427   int GetVersion(const OperatorSignature& op_signature) const override {
428     const auto& fc_op =
429         static_cast<const FullyConnectedOperator&>(*op_signature.op);
430     ::tflite::OpSignature op_sig =
431         GetVersioningOpSig(builtin_op(), op_signature);
432     op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
433     op_sig.options.fully_connected.weights_format =
434         GetWeightFormat(fc_op.weights_format);
435     return ::tflite::GetBuiltinOperatorVersion(op_sig);
436   }
437 };
438 
439 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
440                                       ::tflite::BuiltinOptions_GatherOptions> {
441  public:
442   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const443   flatbuffers::Offset<TfLiteOptions> WriteOptions(
444       const TocoOperator& op,
445       flatbuffers::FlatBufferBuilder* builder) const override {
446     int axis = op.axis ? op.axis.value() : 0;
447     return ::tflite::CreateGatherOptions(*builder, axis);
448   }
449 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const450   void ReadOptions(const TfLiteOptions& options,
451                    TocoOperator* op) const override {
452     op->axis = {options.axis()};
453   }
454 };
455 
456 class GatherNd
457     : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
458                              ::tflite::BuiltinOptions_GatherNdOptions> {
459  public:
460   using BuiltinOperator::BuiltinOperator;
461 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const462   flatbuffers::Offset<TfLiteOptions> WriteOptions(
463       const TocoOperator& op,
464       flatbuffers::FlatBufferBuilder* builder) const override {
465     return ::tflite::CreateGatherNdOptions(*builder);
466   }
467 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const468   void ReadOptions(const TfLiteOptions& options,
469                    TocoOperator* op) const override {}
470 };
471 
472 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
473                                     ::tflite::BuiltinOptions_SVDFOptions> {
474  public:
475   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const476   flatbuffers::Offset<TfLiteOptions> WriteOptions(
477       const TocoOperator& op,
478       flatbuffers::FlatBufferBuilder* builder) const override {
479     auto activation_function =
480         ActivationFunction::Serialize(op.fused_activation_function);
481     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
482   }
483 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const484   void ReadOptions(const TfLiteOptions& options,
485                    TocoOperator* op) const override {
486     op->fused_activation_function =
487         ActivationFunction::Deserialize(options.fused_activation_function());
488     op->rank = options.rank();
489   }
490 };
491 
492 class L2Normalization
493     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
494                              ::tflite::BuiltinOptions_L2NormOptions> {
495  public:
496   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const497   flatbuffers::Offset<TfLiteOptions> WriteOptions(
498       const TocoOperator& op,
499       flatbuffers::FlatBufferBuilder* builder) const override {
500     auto activation_function =
501         ActivationFunction::Serialize(op.fused_activation_function);
502     return ::tflite::CreateL2NormOptions(*builder, activation_function);
503   }
504 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const505   void ReadOptions(const TfLiteOptions& options,
506                    TocoOperator* op) const override {
507     op->fused_activation_function =
508         ActivationFunction::Deserialize(options.fused_activation_function());
509   }
510 };
511 
512 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
513                                       ::tflite::BuiltinOptions_Pool2DOptions> {
514  public:
515   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const516   flatbuffers::Offset<TfLiteOptions> WriteOptions(
517       const TocoOperator& op,
518       flatbuffers::FlatBufferBuilder* builder) const override {
519     auto padding = Padding::Serialize(op.padding.type);
520     auto activation_function =
521         ActivationFunction::Serialize(op.fused_activation_function);
522     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
523                                          op.stride_height, op.kwidth,
524                                          op.kheight, activation_function);
525   }
526 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const527   void ReadOptions(const TfLiteOptions& options,
528                    TocoOperator* op) const override {
529     op->padding.type = Padding::Deserialize(options.padding());
530     op->stride_width = options.stride_w();
531     op->stride_height = options.stride_h();
532     op->kwidth = options.filter_width();
533     op->kheight = options.filter_height();
534     op->fused_activation_function =
535         ActivationFunction::Deserialize(options.fused_activation_function());
536   }
537 };
538 
539 class LocalResponseNormalization
540     : public BuiltinOperator<
541           LocalResponseNormalizationOperator,
542           ::tflite::LocalResponseNormalizationOptions,
543           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
544  public:
545   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const546   flatbuffers::Offset<TfLiteOptions> WriteOptions(
547       const TocoOperator& op,
548       flatbuffers::FlatBufferBuilder* builder) const override {
549     return ::tflite::CreateLocalResponseNormalizationOptions(
550         *builder, op.range, op.bias, op.alpha, op.beta);
551   }
552 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const553   void ReadOptions(const TfLiteOptions& options,
554                    TocoOperator* op) const override {
555     op->range = options.radius();
556     op->bias = options.bias();
557     op->alpha = options.alpha();
558     op->beta = options.beta();
559   }
560 };
561 
562 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
563                                        ::tflite::BuiltinOptions_Pool2DOptions> {
564  public:
565   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const566   flatbuffers::Offset<TfLiteOptions> WriteOptions(
567       const TocoOperator& op,
568       flatbuffers::FlatBufferBuilder* builder) const override {
569     auto padding = Padding::Serialize(op.padding.type);
570     auto activation_function =
571         ActivationFunction::Serialize(op.fused_activation_function);
572     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
573                                          op.stride_height, op.kwidth,
574                                          op.kheight, activation_function);
575   }
576 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const577   void ReadOptions(const TfLiteOptions& options,
578                    TocoOperator* op) const override {
579     op->padding.type = Padding::Deserialize(options.padding());
580     op->stride_width = options.stride_w();
581     op->stride_height = options.stride_h();
582     op->kwidth = options.filter_width();
583     op->kheight = options.filter_height();
584     op->fused_activation_function =
585         ActivationFunction::Deserialize(options.fused_activation_function());
586   }
587 };
588 
589 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
590                                    ::tflite::BuiltinOptions_MulOptions> {
591  public:
592   using BuiltinOperator::BuiltinOperator;
593 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const594   flatbuffers::Offset<TfLiteOptions> WriteOptions(
595       const TocoOperator& op,
596       flatbuffers::FlatBufferBuilder* builder) const override {
597     auto activation_function =
598         ActivationFunction::Serialize(op.fused_activation_function);
599     return ::tflite::CreateMulOptions(*builder, activation_function);
600   }
601 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const602   void ReadOptions(const TfLiteOptions& options,
603                    TocoOperator* op) const override {
604     op->fused_activation_function =
605         ActivationFunction::Deserialize(options.fused_activation_function());
606   }
607 
GetVersion(const OperatorSignature & op_signature) const608   int GetVersion(const OperatorSignature& op_signature) const override {
609     const string& input1_name = op_signature.op->inputs[0];
610     const string& input2_name = op_signature.op->inputs[1];
611     const string& output_name = op_signature.op->outputs[0];
612     const Array& input1_array = op_signature.model->GetArray(input1_name);
613     const Array& input2_array = op_signature.model->GetArray(input2_name);
614     const Array& output_array = op_signature.model->GetArray(output_name);
615     const auto& input1_quant = input1_array.quantization_params;
616     const auto& input2_quant = input2_array.quantization_params;
617     const auto& output_quant = output_array.quantization_params;
618     const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
619     const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
620     const float output_scale = output_quant ? output_quant->scale : 0.0f;
621     ::tflite::OpSignature op_sig =
622         GetVersioningOpSig(builtin_op(), op_signature);
623     op_sig.options.mul.input1_scale = input1_scale;
624     op_sig.options.mul.input2_scale = input2_scale;
625     op_sig.options.mul.output_scale = output_scale;
626     return ::tflite::GetBuiltinOperatorVersion(op_sig);
627   }
628 };
629 
630 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
631                                    ::tflite::BuiltinOptions_PadOptions> {
632  public:
633   using BuiltinOperator::BuiltinOperator;
634 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const635   flatbuffers::Offset<TfLiteOptions> WriteOptions(
636       const TocoOperator& op,
637       flatbuffers::FlatBufferBuilder* builder) const override {
638     return ::tflite::CreatePadOptions(*builder);
639   }
640 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const641   void ReadOptions(const TfLiteOptions& options,
642                    TocoOperator* op) const override {}
643 };
644 
645 class Tile
646     : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
647                              ::tflite::BuiltinOptions_TileOptions> {
648   using BuiltinOperator::BuiltinOperator;
649 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const650   flatbuffers::Offset<TfLiteOptions> WriteOptions(
651       const TocoOperator& op,
652       flatbuffers::FlatBufferBuilder* builder) const override {
653     return ::tflite::CreateTileOptions(*builder);
654   }
655 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const656   void ReadOptions(const TfLiteOptions& options,
657                    TocoOperator* op) const override {}
658 };
659 
660 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
661                                      ::tflite::BuiltinOptions_PadV2Options> {
662  public:
663   using BuiltinOperator::BuiltinOperator;
664 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const665   flatbuffers::Offset<TfLiteOptions> WriteOptions(
666       const TocoOperator& op,
667       flatbuffers::FlatBufferBuilder* builder) const override {
668     return ::tflite::CreatePadV2Options(*builder);
669   }
670 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const671   void ReadOptions(const TfLiteOptions& options,
672                    TocoOperator* op) const override {}
673 };
674 
675 class Reshape
676     : public BuiltinOperator<TensorFlowReshapeOperator,
677                              ::tflite::ReshapeOptions,
678                              ::tflite::BuiltinOptions_ReshapeOptions> {
679  public:
680   using BuiltinOperator::BuiltinOperator;
681 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const682   flatbuffers::Offset<TfLiteOptions> WriteOptions(
683       const TocoOperator& op,
684       flatbuffers::FlatBufferBuilder* builder) const override {
685     return ::tflite::CreateReshapeOptions(*builder,
686                                           builder->CreateVector(op.shape));
687   }
688 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const689   void ReadOptions(const TfLiteOptions& options,
690                    TocoOperator* op) const override {
691     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
692                      options.new_shape()->end());
693   }
694 };
695 
696 class Softmax
697     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
698                              ::tflite::BuiltinOptions_SoftmaxOptions> {
699  public:
700   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const701   flatbuffers::Offset<TfLiteOptions> WriteOptions(
702       const TocoOperator& op,
703       flatbuffers::FlatBufferBuilder* builder) const override {
704     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
705   }
706 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const707   void ReadOptions(const TfLiteOptions& options,
708                    TocoOperator* op) const override {
709     op->beta = options.beta();
710   }
711 };
712 
713 class SpaceToDepth
714     : public BuiltinOperator<SpaceToDepthOperator,
715                              ::tflite::SpaceToDepthOptions,
716                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
717  public:
718   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const719   flatbuffers::Offset<TfLiteOptions> WriteOptions(
720       const TocoOperator& op,
721       flatbuffers::FlatBufferBuilder* builder) const override {
722     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
723   }
724 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const725   void ReadOptions(const TfLiteOptions& options,
726                    TocoOperator* op) const override {
727     op->block_size = options.block_size();
728   }
729 };
730 
731 class Transpose
732     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
733                              ::tflite::BuiltinOptions_TransposeOptions> {
734  public:
735   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const736   flatbuffers::Offset<TfLiteOptions> WriteOptions(
737       const TocoOperator& op,
738       flatbuffers::FlatBufferBuilder* builder) const override {
739     return ::tflite::CreateTransposeOptions(*builder);
740   }
741 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const742   void ReadOptions(const TfLiteOptions& options,
743                    TocoOperator* op) const override {}
744 };
745 
746 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
747                                     ::tflite::BuiltinOptions_LSTMOptions> {
748  public:
749   using BuiltinOperator::BuiltinOperator;
750 
GetKernelType(LstmCellOperator::KernelType type) const751   ::tflite::LSTMKernelType GetKernelType(
752       LstmCellOperator::KernelType type) const {
753     switch (type) {
754       case LstmCellOperator::KERNEL_BASIC:
755         return ::tflite::LSTMKernelType_BASIC;
756         break;
757       case LstmCellOperator::KERNEL_FULL:
758         return ::tflite::LSTMKernelType_FULL;
759         break;
760       default:
761         LOG(ERROR) << "Unhandled Kernel Type";
762         return static_cast<::tflite::LSTMKernelType>(-1);
763     }
764   }
765 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const766   flatbuffers::Offset<TfLiteOptions> WriteOptions(
767       const TocoOperator& op,
768       flatbuffers::FlatBufferBuilder* builder) const override {
769     ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
770 
771     // Current toco converter only supports tanh, no clip.
772     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
773                                        ::tflite::ActivationFunctionType_TANH,
774                                        /*cell_clip=*/0.0,
775                                        /*proj_clip=*/0.0, kernel_type);
776   }
777 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const778   void ReadOptions(const TfLiteOptions& options,
779                    TocoOperator* op) const override {
780     // Only support tanh activation, so check that tflite type is tanh.
781     CHECK(options.fused_activation_function() ==
782           ::tflite::ActivationFunctionType_TANH);
783 
784     switch (options.kernel_type()) {
785       case ::tflite::LSTMKernelType_BASIC:
786         op->kernel_type = LstmCellOperator::KERNEL_BASIC;
787         break;
788       case ::tflite::LSTMKernelType_FULL:
789         op->kernel_type = LstmCellOperator::KERNEL_FULL;
790         break;
791     }
792   }
793 
GetVersion(const OperatorSignature & op_signature) const794   int GetVersion(const OperatorSignature& op_signature) const override {
795     const auto& lstm_op =
796         static_cast<const LstmCellOperator&>(*op_signature.op);
797     ::tflite::OpSignature op_sig =
798         GetVersioningOpSig(builtin_op(), op_signature);
799     op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
800     return ::tflite::GetBuiltinOperatorVersion(op_sig);
801   }
802 
GetMutatingInputVariables(const Operator & op) const803   std::vector<bool> GetMutatingInputVariables(
804       const Operator& op) const override {
805     const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
806 
807     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
808     switch (lstm_op.kernel_type) {
809       case LstmCellOperator::KERNEL_FULL: {
810         mutating_input_variables[kInputActivationStateTensor] = true;
811         mutating_input_variables[kInputCellStateTensor] = true;
812         break;
813       }
814       case LstmCellOperator::KERNEL_BASIC: {
815         mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
816         mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
817         break;
818       }
819     }
820     return mutating_input_variables;
821   }
822 };
823 
824 class UnidirectionalSequenceLstm
825     : public BuiltinOperator<
826           UnidirectionalSequenceLstmOperator,
827           ::tflite::UnidirectionalSequenceLSTMOptions,
828           ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
829  public:
830   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const831   flatbuffers::Offset<TfLiteOptions> WriteOptions(
832       const TocoOperator& op,
833       flatbuffers::FlatBufferBuilder* builder) const override {
834     // Current toco converter only supports tanh, no clip.
835     return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
836         *builder, /*fused_activation_function=*/
837         ::tflite::ActivationFunctionType_TANH,
838         /*cell_clip=*/0.0,
839         /*proj_clip=*/0.0,
840         /*time_major=*/true);
841   }
842 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const843   void ReadOptions(const TfLiteOptions& options,
844                    TocoOperator* op) const override {
845     // Only support tanh activation, so check that tflite type is tanh.
846     DCHECK(options.fused_activation_function() ==
847            ::tflite::ActivationFunctionType_TANH);
848   }
849 
GetMutatingInputVariables(const Operator & op) const850   std::vector<bool> GetMutatingInputVariables(
851       const Operator& op) const override {
852     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
853     mutating_input_variables[kInputActivationStateTensor] = true;
854     mutating_input_variables[kInputCellStateTensor] = true;
855     return mutating_input_variables;
856   }
857 };
858 
859 class BidirectionalSequenceLstm
860     : public BuiltinOperator<
861           BidirectionalSequenceLstmOperator,
862           ::tflite::BidirectionalSequenceLSTMOptions,
863           ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
864  public:
865   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const866   flatbuffers::Offset<TfLiteOptions> WriteOptions(
867       const TocoOperator& op,
868       flatbuffers::FlatBufferBuilder* builder) const override {
869     // Current toco converter only supports tanh, no clip.
870     return ::tflite::CreateBidirectionalSequenceLSTMOptions(
871         *builder, /*fused_activation_function=*/
872         ::tflite::ActivationFunctionType_TANH,
873         /*cell_clip=*/0.0,
874         /*proj_clip=*/0.0,
875         /*merge_outputs=*/op.merge_outputs,
876         /*time_major=*/true);
877   }
878 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const879   void ReadOptions(const TfLiteOptions& options,
880                    TocoOperator* op) const override {
881     // Only support tanh activation, so check that tflite type is tanh.
882     DCHECK(options.fused_activation_function() ==
883            ::tflite::ActivationFunctionType_TANH);
884     op->merge_outputs = options.merge_outputs();
885   }
886 
GetMutatingInputVariables(const Operator & op) const887   std::vector<bool> GetMutatingInputVariables(
888       const Operator& op) const override {
889     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
890     // Forward input activation state.
891     mutating_input_variables[35] = true;
892     // Forward input cell state.
893     mutating_input_variables[36] = true;
894     // Backward input activation state.
895     mutating_input_variables[37] = true;
896     // Backward input cell state.
897     mutating_input_variables[38] = true;
898     return mutating_input_variables;
899   }
900 };
901 
902 class BidirectionalSequenceRnn
903     : public BuiltinOperator<
904           BidirectionalSequenceRnnOperator,
905           ::tflite::BidirectionalSequenceRNNOptions,
906           ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
907  public:
908   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const909   flatbuffers::Offset<TfLiteOptions> WriteOptions(
910       const TocoOperator& op,
911       flatbuffers::FlatBufferBuilder* builder) const override {
912     // Current toco converter only supports tanh, no clip.
913     return ::tflite::CreateBidirectionalSequenceRNNOptions(
914         *builder, /*time_major=*/true,
915         /*fused_activation_function=*/
916         ::tflite::ActivationFunctionType_TANH,
917         /*merge_outputs=*/op.merge_outputs);
918   }
919 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const920   void ReadOptions(const TfLiteOptions& options,
921                    TocoOperator* op) const override {
922     // Only support tanh activation, so check that tflite type is tanh.
923     DCHECK(options.fused_activation_function() ==
924            ::tflite::ActivationFunctionType_TANH);
925     op->merge_outputs = options.merge_outputs();
926   }
927 
GetMutatingInputVariables(const Operator & op) const928   std::vector<bool> GetMutatingInputVariables(
929       const Operator& op) const override {
930     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
931     // Forward hidden state.
932     mutating_input_variables[4] = true;
933     // Backward hidden state.
934     mutating_input_variables[8] = true;
935     return mutating_input_variables;
936   }
937 };
938 
939 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
940                                     ::tflite::BuiltinOptions_ReducerOptions> {
941  public:
942   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const943   flatbuffers::Offset<TfLiteOptions> WriteOptions(
944       const TocoOperator& op,
945       flatbuffers::FlatBufferBuilder* builder) const override {
946     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
947   }
948 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const949   void ReadOptions(const TfLiteOptions& options,
950                    TocoOperator* op) const override {
951     op->keep_dims = options.keep_dims();
952   }
953 };
954 
955 class Sum
956     : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
957                              ::tflite::BuiltinOptions_ReducerOptions> {
958  public:
959   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const960   flatbuffers::Offset<TfLiteOptions> WriteOptions(
961       const TocoOperator& op,
962       flatbuffers::FlatBufferBuilder* builder) const override {
963     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
964   }
965 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const966   void ReadOptions(const TfLiteOptions& options,
967                    TocoOperator* op) const override {
968     op->keep_dims = options.keep_dims();
969   }
970 };
971 
972 class ReduceMax
973     : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
974                              ::tflite::BuiltinOptions_ReducerOptions> {
975  public:
976   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const977   flatbuffers::Offset<TfLiteOptions> WriteOptions(
978       const TocoOperator& op,
979       flatbuffers::FlatBufferBuilder* builder) const override {
980     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
981   }
982 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const983   void ReadOptions(const TfLiteOptions& options,
984                    TocoOperator* op) const override {
985     op->keep_dims = options.keep_dims();
986   }
987 };
988 
989 class ReduceMin
990     : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
991                              ::tflite::BuiltinOptions_ReducerOptions> {
992  public:
993   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const994   flatbuffers::Offset<TfLiteOptions> WriteOptions(
995       const TocoOperator& op,
996       flatbuffers::FlatBufferBuilder* builder) const override {
997     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
998   }
999 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1000   void ReadOptions(const TfLiteOptions& options,
1001                    TocoOperator* op) const override {
1002     op->keep_dims = options.keep_dims();
1003   }
1004 };
1005 
1006 class ReduceProd
1007     : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1008                              ::tflite::BuiltinOptions_ReducerOptions> {
1009  public:
1010   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1011   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1012       const TocoOperator& op,
1013       flatbuffers::FlatBufferBuilder* builder) const override {
1014     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1015   }
1016 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1017   void ReadOptions(const TfLiteOptions& options,
1018                    TocoOperator* op) const override {
1019     op->keep_dims = options.keep_dims();
1020   }
1021 };
1022 
1023 class ReduceAny
1024     : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1025                              ::tflite::BuiltinOptions_ReducerOptions> {
1026  public:
1027   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1028   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1029       const TocoOperator& op,
1030       flatbuffers::FlatBufferBuilder* builder) const override {
1031     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1032   }
1033 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1034   void ReadOptions(const TfLiteOptions& options,
1035                    TocoOperator* op) const override {
1036     op->keep_dims = options.keep_dims();
1037   }
1038 };
1039 
1040 class ResizeBilinear
1041     : public BuiltinOperator<ResizeBilinearOperator,
1042                              ::tflite::ResizeBilinearOptions,
1043                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1044  public:
1045   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1046   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1047       const TocoOperator& op,
1048       flatbuffers::FlatBufferBuilder* builder) const override {
1049     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1050                                                  op.half_pixel_centers);
1051   }
1052 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1053   void ReadOptions(const TfLiteOptions& options,
1054                    TocoOperator* op) const override {
1055     op->align_corners = options.align_corners();
1056     op->half_pixel_centers = options.half_pixel_centers();
1057   }
1058 
GetVersion(const OperatorSignature & op_signature) const1059   int GetVersion(const OperatorSignature& op_signature) const override {
1060     const auto& resize_bilinear_op =
1061         static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1062     ::tflite::OpSignature op_sig =
1063         GetVersioningOpSig(builtin_op(), op_signature);
1064     op_sig.options.resize_bilinear.half_pixel_centers =
1065         resize_bilinear_op.half_pixel_centers;
1066     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1067   }
1068 };
1069 
1070 class ResizeNearestNeighbor
1071     : public BuiltinOperator<
1072           ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1073           ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1074  public:
1075   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1076   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1077       const TocoOperator& op,
1078       flatbuffers::FlatBufferBuilder* builder) const override {
1079     return ::tflite::CreateResizeNearestNeighborOptions(*builder,
1080                                                         op.align_corners);
1081   }
1082 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1083   void ReadOptions(const TfLiteOptions& options,
1084                    TocoOperator* op) const override {
1085     op->align_corners = options.align_corners();
1086   }
1087 };
1088 
1089 class Squeeze
1090     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1091                              ::tflite::BuiltinOptions_SqueezeOptions> {
1092  public:
1093   using BuiltinOperator::BuiltinOperator;
1094 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1095   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1096       const TocoOperator& op,
1097       flatbuffers::FlatBufferBuilder* builder) const override {
1098     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1099     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1100   }
1101 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1102   void ReadOptions(const TfLiteOptions& options,
1103                    TocoOperator* op) const override {
1104     op->squeeze_dims.insert(op->squeeze_dims.end(),
1105                             options.squeeze_dims()->begin(),
1106                             options.squeeze_dims()->end());
1107   }
1108 };
1109 
1110 class Split
1111     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1112                              ::tflite::BuiltinOptions_SplitOptions> {
1113  public:
1114   using BuiltinOperator::BuiltinOperator;
1115 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1116   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1117       const TocoOperator& op,
1118       flatbuffers::FlatBufferBuilder* builder) const override {
1119     return ::tflite::CreateSplitOptions(*builder, op.num_split);
1120   }
1121 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1122   void ReadOptions(const TfLiteOptions& options,
1123                    TocoOperator* op) const override {
1124     op->num_split = options.num_splits();
1125   }
1126 };
1127 
1128 class SplitV
1129     : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1130                              ::tflite::BuiltinOptions_SplitVOptions> {
1131  public:
1132   using BuiltinOperator::BuiltinOperator;
1133 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1134   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1135       const TocoOperator& op,
1136       flatbuffers::FlatBufferBuilder* builder) const override {
1137     return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1138   }
1139 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1140   void ReadOptions(const TfLiteOptions& options,
1141                    TocoOperator* op) const override {
1142     op->num_split = options.num_splits();
1143   }
1144 };
1145 
1146 class StridedSlice
1147     : public BuiltinOperator<StridedSliceOperator,
1148                              ::tflite::StridedSliceOptions,
1149                              ::tflite::BuiltinOptions_StridedSliceOptions> {
1150  public:
1151   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1152   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1153       const TocoOperator& op,
1154       flatbuffers::FlatBufferBuilder* builder) const override {
1155     return ::tflite::CreateStridedSliceOptions(
1156         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1157         op.new_axis_mask, op.shrink_axis_mask);
1158   }
1159 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1160   void ReadOptions(const TfLiteOptions& options,
1161                    TocoOperator* op) const override {
1162     op->begin_mask = options.begin_mask();
1163     op->end_mask = options.end_mask();
1164     op->ellipsis_mask = options.ellipsis_mask();
1165     op->new_axis_mask = options.new_axis_mask();
1166     op->shrink_axis_mask = options.shrink_axis_mask();
1167   }
1168 };
1169 
1170 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1171                                        ::tflite::BuiltinOptions_TopKV2Options> {
1172  public:
1173   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1174   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1175       const TocoOperator& op,
1176       flatbuffers::FlatBufferBuilder* builder) const override {
1177     return ::tflite::CreateTopKV2Options(*builder);
1178   }
1179 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1180   void ReadOptions(const TfLiteOptions& options,
1181                    TocoOperator* op) const override {}
1182 };
1183 
1184 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1185                                       ::tflite::BuiltinOptions_ArgMaxOptions> {
1186  public:
1187   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1188   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1189       const TocoOperator& op,
1190       flatbuffers::FlatBufferBuilder* builder) const override {
1191     return ::tflite::CreateArgMaxOptions(
1192         *builder, DataType::Serialize(op.output_data_type));
1193   }
1194 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1195   void ReadOptions(const TfLiteOptions& options,
1196                    TocoOperator* op) const override {
1197     op->output_data_type = DataType::Deserialize(options.output_type());
1198   }
1199 };
1200 
1201 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1202                                       ::tflite::BuiltinOptions_ArgMinOptions> {
1203  public:
1204   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1205   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1206       const TocoOperator& op,
1207       flatbuffers::FlatBufferBuilder* builder) const override {
1208     return ::tflite::CreateArgMinOptions(
1209         *builder, DataType::Serialize(op.output_data_type));
1210   }
1211 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1212   void ReadOptions(const TfLiteOptions& options,
1213                    TocoOperator* op) const override {
1214     op->output_data_type = DataType::Deserialize(options.output_type());
1215   }
1216 };
1217 
1218 class TransposeConv
1219     : public BuiltinOperator<TransposeConvOperator,
1220                              ::tflite::TransposeConvOptions,
1221                              ::tflite::BuiltinOptions_TransposeConvOptions> {
1222  public:
1223   using BuiltinOperator::BuiltinOperator;
1224 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1225   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1226       const TocoOperator& op,
1227       flatbuffers::FlatBufferBuilder* builder) const override {
1228     auto padding = Padding::Serialize(op.padding.type);
1229     return ::tflite::CreateTransposeConvOptions(
1230         *builder, padding, op.stride_width, op.stride_height);
1231   }
1232 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1233   void ReadOptions(const TfLiteOptions& options,
1234                    TocoOperator* op) const override {
1235     op->padding.type = Padding::Deserialize(options.padding());
1236     op->stride_width = options.stride_w();
1237     op->stride_height = options.stride_h();
1238   }
1239 };
1240 
1241 class SparseToDense
1242     : public BuiltinOperator<SparseToDenseOperator,
1243                              ::tflite::SparseToDenseOptions,
1244                              ::tflite::BuiltinOptions_SparseToDenseOptions> {
1245  public:
1246   using BuiltinOperator::BuiltinOperator;
1247 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1248   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1249       const TocoOperator& op,
1250       flatbuffers::FlatBufferBuilder* builder) const override {
1251     return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1252   }
1253 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1254   void ReadOptions(const TfLiteOptions& options,
1255                    TocoOperator* op) const override {
1256     op->validate_indices = options.validate_indices();
1257   }
1258 };
1259 
1260 class ExpandDims
1261     : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1262                              ::tflite::BuiltinOptions_ExpandDimsOptions> {
1263  public:
1264   using BuiltinOperator::BuiltinOperator;
1265 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1266   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1267       const TocoOperator& op,
1268       flatbuffers::FlatBufferBuilder* builder) const override {
1269     return ::tflite::CreateExpandDimsOptions(*builder);
1270   }
1271 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1272   void ReadOptions(const TfLiteOptions& options,
1273                    TocoOperator* op) const override {}
1274 };
1275 
1276 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1277                                     ::tflite::BuiltinOptions_PackOptions> {
1278  public:
1279   using BuiltinOperator::BuiltinOperator;
1280 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1281   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1282       const TocoOperator& op,
1283       flatbuffers::FlatBufferBuilder* builder) const override {
1284     return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1285   }
1286 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1287   void ReadOptions(const TfLiteOptions& options,
1288                    TocoOperator* op) const override {
1289     op->values_count = options.values_count();
1290     op->axis = options.axis();
1291   }
1292 };
1293 
1294 class Shape
1295     : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1296                              ::tflite::BuiltinOptions_ShapeOptions> {
1297  public:
1298   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1299   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1300       const TocoOperator& op,
1301       flatbuffers::FlatBufferBuilder* builder) const override {
1302     return ::tflite::CreateShapeOptions(
1303         *builder, DataType::Serialize(op.output_data_type));
1304   }
1305 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1306   void ReadOptions(const TfLiteOptions& options,
1307                    TocoOperator* op) const override {
1308     op->output_data_type = DataType::Deserialize(options.out_type());
1309   }
1310 };
1311 
1312 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1313                                       ::tflite::BuiltinOptions_OneHotOptions> {
1314  public:
1315   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1316   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1317       const TocoOperator& op,
1318       flatbuffers::FlatBufferBuilder* builder) const override {
1319     return ::tflite::CreateOneHotOptions(*builder, op.axis);
1320   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1321   void ReadOptions(const TfLiteOptions& options,
1322                    TocoOperator* op) const override {
1323     op->axis = options.axis();
1324   }
1325 };
1326 
1327 class CTCBeamSearchDecoder
1328     : public CustomOperator<CTCBeamSearchDecoderOperator> {
1329  public:
1330   using CustomOperator::CustomOperator;
1331 
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1332   void WriteOptions(const TocoOperator& op,
1333                     flexbuffers::Builder* fbb) const override {
1334     fbb->Int("beam_width", op.beam_width);
1335     fbb->Int("top_paths", op.top_paths);
1336     fbb->Bool("merge_repeated", op.merge_repeated);
1337   }
1338 
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1339   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1340     op->beam_width = m["beam_width"].AsInt32();
1341     op->top_paths = m["top_paths"].AsInt32();
1342     op->merge_repeated = m["merge_repeated"].AsBool();
1343   }
1344 
GetVersion(const OperatorSignature & op_signature) const1345   int GetVersion(const OperatorSignature& op_signature) const override {
1346     return 1;
1347   }
1348 };
1349 
1350 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1351                                       ::tflite::BuiltinOptions_UnpackOptions> {
1352  public:
1353   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1354   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1355       const TocoOperator& op,
1356       flatbuffers::FlatBufferBuilder* builder) const override {
1357     return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1358   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1359   void ReadOptions(const TfLiteOptions& options,
1360                    TocoOperator* op) const override {
1361     op->num = options.num();
1362     op->axis = options.axis();
1363   }
1364 
GetVersion(const OperatorSignature & op_signature) const1365   int GetVersion(const OperatorSignature& op_signature) const override {
1366     const string& input_name = op_signature.op->inputs[0];
1367     const Array& input_array = op_signature.model->GetArray(input_name);
1368     // If the op take int8/uint8 input, it is version 2.
1369     if (input_array.data_type == ArrayDataType::kInt8 ||
1370         input_array.data_type == ArrayDataType::kUint8) {
1371       return 2;
1372     }
1373     // If the op take bool input, it is version 3.
1374     if (input_array.data_type == ArrayDataType::kBool) {
1375       return 3;
1376     }
1377     return 1;
1378   }
1379 };
1380 
1381 class LeakyRelu
1382     : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1383                              ::tflite::BuiltinOptions_LeakyReluOptions> {
1384  public:
1385   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1386   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1387       const TocoOperator& op,
1388       flatbuffers::FlatBufferBuilder* builder) const override {
1389     return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1390   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1391   void ReadOptions(const TfLiteOptions& options,
1392                    TocoOperator* op) const override {
1393     op->alpha = options.alpha();
1394   }
1395 };
1396 
1397 class SquaredDifference
1398     : public BuiltinOperator<
1399           SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1400           ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1401  public:
1402   using BuiltinOperator::BuiltinOperator;
1403 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1404   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1405       const TocoOperator& op,
1406       flatbuffers::FlatBufferBuilder* builder) const override {
1407     return ::tflite::CreateSquaredDifferenceOptions(*builder);
1408   }
1409 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1410   void ReadOptions(const TfLiteOptions& options,
1411                    TocoOperator* op) const override {}
1412 };
1413 
1414 class MirrorPad
1415     : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1416                              ::tflite::BuiltinOptions_MirrorPadOptions> {
1417  public:
1418   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1419   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1420       const TocoOperator& op,
1421       flatbuffers::FlatBufferBuilder* builder) const override {
1422     return ::tflite::CreateMirrorPadOptions(
1423         *builder, op.mode == MirrorPadMode::kReflect
1424                       ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1425                       : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1426   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1427   void ReadOptions(const TfLiteOptions& options,
1428                    TocoOperator* op) const override {
1429     op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1430                    ? MirrorPadMode::kReflect
1431                    : MirrorPadMode::kSymmetric;
1432   }
1433 };
1434 
1435 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1436                                       ::tflite::BuiltinOptions_UniqueOptions> {
1437  public:
1438   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1439   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1440       const TocoOperator& op,
1441       flatbuffers::FlatBufferBuilder* builder) const override {
1442     const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1443     return ::tflite::CreateUniqueOptions(
1444         *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1445                       ? ::tflite::TensorType::TensorType_INT64
1446                       : ::tflite::TensorType_INT32);
1447   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1448   void ReadOptions(const TfLiteOptions& options,
1449                    TocoOperator* op) const override {
1450     UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1451     unique_op->idx_out_type =
1452         options.idx_out_type() == ::tflite::TensorType_INT64
1453             ? toco::ArrayDataType::kInt64
1454             : toco::ArrayDataType::kInt32;
1455   }
1456 };
1457 
1458 class UnidirectionalSequenceRnn
1459     : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1460                              ::tflite::SequenceRNNOptions,
1461                              ::tflite::BuiltinOptions_SequenceRNNOptions> {
1462  public:
1463   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1464   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1465       const TocoOperator& op,
1466       flatbuffers::FlatBufferBuilder* builder) const override {
1467     return ::tflite::CreateSequenceRNNOptions(
1468         *builder, /*time_major=*/true,
1469         /*fused_activation_function=*/
1470         ::tflite::ActivationFunctionType_TANH);
1471   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1472   void ReadOptions(const TfLiteOptions& options,
1473                    TocoOperator* op) const override {
1474     // Only support tanh activation, so check that tflite type is tanh.
1475     DCHECK(options.fused_activation_function() ==
1476            ::tflite::ActivationFunctionType_TANH);
1477   }
1478 
GetMutatingInputVariables(const Operator & op) const1479   std::vector<bool> GetMutatingInputVariables(
1480       const Operator& op) const override {
1481     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1482     mutating_input_variables[4] = true;
1483     return mutating_input_variables;
1484   }
1485 };
1486 
1487 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1488                                      ::tflite::BuiltinOptions_WhereOptions> {
1489  public:
1490   using BuiltinOperator::BuiltinOperator;
1491 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1492   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1493       const TocoOperator& op,
1494       flatbuffers::FlatBufferBuilder* builder) const override {
1495     return ::tflite::CreateWhereOptions(*builder);
1496   }
1497 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1498   void ReadOptions(const TfLiteOptions& options,
1499                    TocoOperator* op) const override {}
1500 };
1501 
WriteFlexOpOptions(const string & tensorflow_node_def)1502 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1503     const string& tensorflow_node_def) {
1504   auto fbb = absl::make_unique<flexbuffers::Builder>();
1505 
1506   ::tensorflow::NodeDef node_def;
1507   if (!node_def.ParseFromString(tensorflow_node_def)) {
1508     LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1509     return {};
1510   }
1511 
1512   fbb->Vector([&]() {
1513     fbb->String(node_def.op());
1514     fbb->String(tensorflow_node_def);
1515   });
1516   fbb->Finish();
1517   LOG(INFO) << "Writing flex op: " << node_def.op();
1518   return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1519 }
1520 
1521 class TensorFlowUnsupported : public BaseOperator {
1522  public:
TensorFlowUnsupported(const string & name,OperatorType type,bool enable_select_tf_ops)1523   TensorFlowUnsupported(const string& name, OperatorType type,
1524                         bool enable_select_tf_ops)
1525       : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1526 
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1527   Options Serialize(const Operator& op,
1528                     flatbuffers::FlatBufferBuilder* builder) const override {
1529     auto fbb =
1530         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1531     if (fbb) {
1532       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1533     } else {
1534       return Options::Custom(0);
1535     }
1536   }
1537 
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1538   std::unique_ptr<Operator> Deserialize(
1539       const BuiltinOptions* builtin_options,
1540       const CustomOptions* custom_options) const override {
1541     // Deserializing Flex ops doesn't work now.
1542     // TODO(ycling): Revisit and decide if we should fix the flow for importing
1543     // TFLite models with Flex ops.
1544     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
1545     if (custom_options) {
1546       auto flexbuffer_map =
1547           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1548               .AsMap();
1549       ReadOptions(flexbuffer_map, op.get());
1550     }
1551     return std::unique_ptr<Operator>(op.release());
1552   }
1553 
WriteOptions(const TensorFlowUnsupportedOperator & op) const1554   std::unique_ptr<flexbuffers::Builder> WriteOptions(
1555       const TensorFlowUnsupportedOperator& op) const {
1556     if (enable_select_tf_ops_) {
1557       return WriteFlexOpOptions(op.tensorflow_node_def);
1558     }
1559     auto fbb = absl::make_unique<flexbuffers::Builder>();
1560 
1561     ::tensorflow::NodeDef node_def;
1562     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1563       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1564       return std::unique_ptr<flexbuffers::Builder>();
1565     }
1566 
1567     if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1568       fbb->Vector([&]() {
1569         fbb->String(node_def.op());
1570         fbb->String(op.tensorflow_node_def);
1571       });
1572       fbb->Finish();
1573       LOG(INFO) << "Writing flex op: " << node_def.op();
1574       return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1575     }
1576 
1577     bool has_valid_attr = false;
1578     size_t map_start = fbb->StartMap();
1579     for (const auto& pair : node_def.attr()) {
1580       const char* key = pair.first.c_str();
1581       const auto& attr = pair.second;
1582       switch (attr.value_case()) {
1583         case ::tensorflow::AttrValue::kS:
1584           fbb->String(key, attr.s());
1585           has_valid_attr = true;
1586           break;
1587         case ::tensorflow::AttrValue::kI:
1588           fbb->Int(key, attr.i());
1589           has_valid_attr = true;
1590           break;
1591         case ::tensorflow::AttrValue::kF:
1592           fbb->Float(key, attr.f());
1593           has_valid_attr = true;
1594           break;
1595         case ::tensorflow::AttrValue::kB:
1596           fbb->Bool(key, attr.b());
1597           has_valid_attr = true;
1598           break;
1599         case tensorflow::AttrValue::kList:
1600           if (attr.list().s_size() > 0) {
1601             auto start = fbb->StartVector(key);
1602             for (const string& v : attr.list().s()) {
1603               fbb->Add(v);
1604             }
1605             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1606             has_valid_attr = true;
1607           } else if (attr.list().i_size() > 0) {
1608             auto start = fbb->StartVector(key);
1609             for (const int64_t v : attr.list().i()) {
1610               fbb->Add(v);
1611             }
1612             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1613             has_valid_attr = true;
1614           } else if (attr.list().f_size() > 0) {
1615             auto start = fbb->StartVector(key);
1616             for (const float v : attr.list().f()) {
1617               fbb->Add(v);
1618             }
1619             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1620             has_valid_attr = true;
1621           } else {
1622             LOG(WARNING)
1623                 << "Ignoring unsupported type in list attribute with key '"
1624                 << key << "'";
1625           }
1626           break;
1627         default:
1628           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1629                        << key << "'";
1630           break;
1631       }
1632     }
1633     if (!has_valid_attr) {
1634       return std::unique_ptr<flexbuffers::Builder>();
1635     }
1636     fbb->EndMap(map_start);
1637     fbb->Finish();
1638     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1639   }
1640 
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1641   void ReadOptions(const flexbuffers::Map& m,
1642                    TensorFlowUnsupportedOperator* op) const {
1643     ::tensorflow::NodeDef node_def;
1644     auto attr = node_def.mutable_attr();
1645 
1646     const auto& keys = m.Keys();
1647     for (size_t i = 0; i < keys.size(); ++i) {
1648       const auto key = keys[i].AsKey();
1649       const auto& value = m[key];
1650       // TODO(wvo): hack to make this code compile with 2 different API
1651       // versions.
1652       // Please remove once OS/internal versions are in sync.
1653       // See hardcoded values in the switch below.
1654       switch (value.GetType()) {
1655         case 5:  // flexbuffers::FBT_STRING:
1656           (*attr)[key].set_s(value.AsString().c_str());
1657           break;
1658         case 1:  // flexbuffers::FBT_INT:
1659           (*attr)[key].set_i(value.AsInt64());
1660           break;
1661         case 3:  // flexbuffers::FBT_FLOAT:
1662           (*attr)[key].set_f(value.AsFloat());
1663           break;
1664         case 26:  // flexbuffers::FBT_BOOL:
1665           (*attr)[key].set_b(value.AsBool());
1666           if (string(key) == "_output_quantized") {
1667             op->quantized = value.AsBool();
1668           }
1669           if (string(key) == "_support_output_type_float_in_quantized_op") {
1670             op->support_output_type_float_in_quantized_op = value.AsBool();
1671           }
1672           break;
1673         case 11: {  // flexbuffers::FBT_VECTOR_INT: {
1674           auto* list = (*attr)[key].mutable_list();
1675           const auto& vector = value.AsTypedVector();
1676           for (size_t i = 0; i < vector.size(); i++) {
1677             list->add_i(vector[i].AsInt64());
1678           }
1679           break;
1680         }
1681         case 13: {  // flexbuffers::FBT_VECTOR_FLOAT: {
1682           auto* list = (*attr)[key].mutable_list();
1683           const auto& vector = value.AsTypedVector();
1684           for (size_t i = 0; i < vector.size(); i++) {
1685             list->add_f(vector[i].AsFloat());
1686           }
1687           break;
1688         }
1689         case 15: {  // flexbuffers::FBT_VECTOR_STRING: {
1690           auto* list = (*attr)[key].mutable_list();
1691           const auto& vector = value.AsTypedVector();
1692           for (size_t i = 0; i < vector.size(); i++) {
1693             list->add_s(vector[i].AsString().str());
1694           }
1695           break;
1696         }
1697         default:
1698           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1699                        << key << "'";
1700           break;
1701       }
1702     }
1703     node_def.SerializeToString(&op->tensorflow_node_def);
1704   }
1705 
GetVersion(const OperatorSignature & op_signature) const1706   int GetVersion(const OperatorSignature& op_signature) const override {
1707     // TODO(ycling): Design and implement a way to plumb the version of
1708     // custom ops.
1709     return 1;
1710   }
1711 
1712  private:
1713   const bool enable_select_tf_ops_;
1714 };
1715 
1716 class Dequantize
1717     : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1718                              ::tflite::BuiltinOptions_DequantizeOptions> {
1719  public:
1720   using BuiltinOperator::BuiltinOperator;
1721 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1722   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1723       const TocoOperator& op,
1724       flatbuffers::FlatBufferBuilder* builder) const override {
1725     return ::tflite::CreateDequantizeOptions(*builder);
1726   }
1727 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1728   void ReadOptions(const TfLiteOptions& options,
1729                    TocoOperator* op) const override {}
1730 };
1731 
1732 class ReverseSequence
1733     : public BuiltinOperator<ReverseSequenceOperator,
1734                              ::tflite::ReverseSequenceOptions,
1735                              ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1736  public:
1737   using BuiltinOperator::BuiltinOperator;
1738 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1739   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1740       const TocoOperator& op,
1741       flatbuffers::FlatBufferBuilder* builder) const override {
1742     return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1743                                                   op.batch_dim);
1744   }
1745 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1746   void ReadOptions(const TfLiteOptions& options,
1747                    TocoOperator* op) const override {
1748     op->seq_dim = options.seq_dim();
1749     op->batch_dim = options.batch_dim();
1750   }
1751 };
1752 
1753 namespace {
1754 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1755 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1756     bool enable_select_tf_ops = false) {
1757   std::vector<std::unique_ptr<BaseOperator>> ops;
1758   using tensorflow::MakeUnique;
1759   // Builtin Operators.
1760   ops.push_back(
1761       MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1762   ops.push_back(
1763       MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1764   ops.push_back(
1765       MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1766   ops.push_back(
1767       MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1768   ops.push_back(MakeUnique<AveragePool>(
1769       ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1770   ops.push_back(
1771       MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1772                                  OperatorType::kSpaceToBatchND));
1773   ops.push_back(
1774       MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1775                                  OperatorType::kBatchToSpaceND));
1776   ops.push_back(MakeUnique<Concatenation>(
1777       ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1778   ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1779                                         OperatorType::kConv));
1780   ops.push_back(MakeUnique<DepthwiseConvolution>(
1781       ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1782       OperatorType::kDepthwiseConv));
1783   ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1784                                        OperatorType::kDequantize));
1785   ops.push_back(
1786       MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1787                                  OperatorType::kFullyConnected));
1788   ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1789                                    OperatorType::kGather));
1790   ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1791                                      OperatorType::kGatherNd));
1792   ops.push_back(
1793       MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1794                                   OperatorType::kL2Normalization));
1795   ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1796                                    OperatorType::kL2Pool));
1797   ops.push_back(MakeUnique<LocalResponseNormalization>(
1798       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1799       OperatorType::kLocalResponseNormalization));
1800   ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1801                                     OperatorType::kMaxPool));
1802   ops.push_back(
1803       MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1804 
1805   ops.push_back(
1806       MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1807   ops.push_back(
1808       MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1809   ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1810                                     OperatorType::kReshape));
1811   ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1812                                     OperatorType::kSoftmax));
1813   ops.push_back(MakeUnique<SpaceToDepth>(
1814       ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1815   ops.push_back(MakeUnique<DepthToSpace>(
1816       ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1817   ops.push_back(
1818       MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1819   ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1820                                       OperatorType::kTranspose));
1821   ops.push_back(
1822       MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1823   ops.push_back(
1824       MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1825   ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1826                                        OperatorType::kReduceProd));
1827   ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1828                                       OperatorType::kReduceMax));
1829   ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1830                                       OperatorType::kReduceMin));
1831   ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1832                                       OperatorType::kAny));
1833   ops.push_back(
1834       MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1835                                  OperatorType::kResizeBilinear));
1836   ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1837       ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1838       OperatorType::kResizeNearestNeighbor));
1839   ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1840                                     OperatorType::kSqueeze));
1841   ops.push_back(
1842       MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1843   ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1844                                    OperatorType::kSplitV));
1845   ops.push_back(MakeUnique<StridedSlice>(
1846       ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1847   ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1848                                     OperatorType::kTopK_V2));
1849   ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1850                                  OperatorType::kLstmCell));
1851   ops.push_back(
1852       MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1853   ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1854                                    OperatorType::kArgMax));
1855   ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1856                                    OperatorType::kArgMin));
1857   ops.push_back(
1858       MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1859   ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1860                                        OperatorType::kExpandDims));
1861   ops.push_back(MakeUnique<TransposeConv>(
1862       ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1863   ops.push_back(MakeUnique<SparseToDense>(
1864       ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1865   ops.push_back(
1866       MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1867   ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1868                                       OperatorType::kFakeQuant));
1869   ops.push_back(
1870       MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1871   ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1872       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1873       OperatorType::kUnidirectionalSequenceLstm));
1874   ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1875       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1876       OperatorType::kBidirectionalSequenceLstm));
1877   ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1878       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1879       OperatorType::kBidirectionalSequenceRnn));
1880   ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1881                                    OperatorType::kOneHot));
1882   ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1883                                    OperatorType::kUnpack));
1884   ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1885                                       OperatorType::kLeakyRelu));
1886   ops.push_back(MakeUnique<SquaredDifference>(
1887       ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1888       OperatorType::kSquaredDifference));
1889   ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1890                                       OperatorType::kMirrorPad));
1891   ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1892                                    OperatorType::kUnique));
1893   ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1894       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1895       OperatorType::kUnidirectionalSequenceRnn));
1896   ops.push_back(
1897       MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1898   ops.push_back(
1899       MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1900                                   OperatorType::kReverseSequence));
1901   ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1902       ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1903   ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1904       ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1905   // Custom Operators.
1906   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1907       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1908   ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1909                                                   OperatorType::kUnsupported,
1910                                                   enable_select_tf_ops));
1911 
1912   // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1913   // been modified to also export builtins. As TOCO evolved we added warnings
1914   // when custom ops are exported but SimpleOperator bypasses thoses. To
1915   // prevent user confusion we are settling on using SimpleOperator only for
1916   // builtins.
1917   ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1918       ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1919   ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1920       ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1921   ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1922       ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
1923   ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
1924       ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
1925   ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
1926       ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
1927   ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
1928       ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
1929   ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
1930       ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
1931   ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
1932       ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
1933   ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
1934       ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
1935   ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
1936       ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
1937   ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
1938       ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
1939   ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
1940       ::tflite::BuiltinOperator_COS, OperatorType::kCos));
1941   ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
1942       ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
1943   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
1944       ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
1945   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
1946       ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
1947   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
1948       ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
1949   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
1950       ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
1951   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
1952       ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
1953   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
1954       ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
1955   ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
1956       ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
1957   ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
1958       ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
1959   ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
1960       ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
1961   ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
1962       ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
1963   ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
1964       ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
1965   ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
1966       ::tflite::BuiltinOperator_POW, OperatorType::kPow));
1967   ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
1968       ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
1969   ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
1970       ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
1971   ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
1972       ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
1973   ops.emplace_back(new SimpleOperator<FloorDivOperator>(
1974       ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
1975   ops.emplace_back(new SimpleOperator<FloorModOperator>(
1976       ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
1977   ops.emplace_back(new SimpleOperator<RangeOperator>(
1978       ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
1979   // Element-wise operator
1980   ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
1981       ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
1982   ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
1983       ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
1984   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
1985       ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
1986   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
1987       ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
1988   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
1989       ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
1990   ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
1991       ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
1992   ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
1993       ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
1994   ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
1995       ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
1996   ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
1997       ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
1998   ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
1999       ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2000   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2001       ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2002   ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2003       ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2004   return ops;
2005 }
2006 }  // namespace
2007 
2008 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2009 
BuildOperatorByTypeMap(bool enable_select_tf_ops)2010 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2011     bool enable_select_tf_ops) {
2012   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2013 
2014   std::vector<std::unique_ptr<BaseOperator>> ops =
2015       BuildOperatorList(enable_select_tf_ops);
2016   for (auto& op : ops) {
2017     result[op->type()] = std::move(op);
2018   }
2019 
2020   return result;
2021 }
2022 
BuildOperatorByNameMap(bool enable_select_tf_ops)2023 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2024     bool enable_select_tf_ops) {
2025   std::map<string, std::unique_ptr<BaseOperator>> result;
2026 
2027   std::vector<std::unique_ptr<BaseOperator>> ops =
2028       BuildOperatorList(enable_select_tf_ops);
2029   for (auto& op : ops) {
2030     result[op->name()] = std::move(op);
2031   }
2032 
2033   return result;
2034 }
2035 
ShouldExportAsFlexOp(bool enable_select_tf_ops,const string & tensorflow_op_name)2036 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2037                           const string& tensorflow_op_name) {
2038   // If Flex ops aren't allow at all, simply return false.
2039   if (!enable_select_tf_ops) {
2040     return false;
2041   }
2042   // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2043   // it and it has been whitelisted, export the op as an Flex op. Otherwise,
2044   // export it as a regular custom op.
2045   const tensorflow::OpDef* op_def = nullptr;
2046   if (!tensorflow::OpRegistry::Global()
2047            ->LookUpOpDef(tensorflow_op_name, &op_def)
2048            .ok()) {
2049     return false;
2050   }
2051 
2052   if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) {
2053     LOG(WARNING) << "Op " << tensorflow_op_name
2054                  << " is a valid TensorFlow op but has not been whitelisted for"
2055                     " the TensorFlow Lite flex op set.";
2056     return false;
2057   }
2058 
2059   return true;
2060 }
2061 
2062 }  // namespace tflite
2063 
2064 }  // namespace toco
2065