• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16 
17 #include <map>
18 
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/util/ptr_util.h"
24 
25 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
26 // graph_transformation module.
27 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
32 #include "tensorflow/lite/toco/tflite/custom_operator.h"
33 #include "tensorflow/lite/toco/tflite/simple_operator.h"
34 #include "tensorflow/lite/toco/tflite/types.h"
35 #include "tensorflow/lite/tools/versioning/op_version.h"
36 
37 namespace toco {
38 
39 namespace tflite {
40 
41 // LINT.IfChange
42 
GetTensorType(const ArrayDataType type)43 ::tflite::TensorType GetTensorType(const ArrayDataType type) {
44   const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
45       {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
46       {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
47       {ArrayDataType::kInt8, ::tflite::TensorType_INT8},
48       {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
49       {ArrayDataType::kInt16, ::tflite::TensorType_INT16},
50       {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
51       {ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
52       {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
53       {ArrayDataType::kUint64, ::tflite::TensorType_UINT64},
54       {ArrayDataType::kString, ::tflite::TensorType_STRING},
55       {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
56       {ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128},
57       {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
58       {ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
59 
60   auto it = tensor_type_map.find(type);
61   if (it != tensor_type_map.end()) {
62     return it->second;
63   }
64   return static_cast<::tflite::TensorType>(-1);
65 }
66 
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)67 ::tflite::OpSignature GetVersioningOpSig(
68     const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
69   std::vector<::tflite::TensorType> input_types, output_types;
70   for (const auto& input_name : op_signature.op->inputs) {
71     ::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
72     if (op_signature.model->HasArray(input_name)) {
73       const Array& input_array = op_signature.model->GetArray(input_name);
74       input_type = GetTensorType(input_array.data_type);
75     }
76     input_types.push_back(input_type);
77   }
78   for (const auto& output_name : op_signature.op->outputs) {
79     ::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
80     if (op_signature.model->HasArray(output_name)) {
81       const Array& output_array = op_signature.model->GetArray(output_name);
82       output_type = GetTensorType(output_array.data_type);
83     }
84     output_types.push_back(output_type);
85   }
86   return ::tflite::OpSignature{op, input_types, output_types};
87 }
88 
89 class AveragePool
90     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
91                              ::tflite::BuiltinOptions_Pool2DOptions> {
92  public:
93   using BuiltinOperator::BuiltinOperator;
94 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const95   flatbuffers::Offset<TfLiteOptions> WriteOptions(
96       const TocoOperator& op,
97       flatbuffers::FlatBufferBuilder* builder) const override {
98     auto padding = Padding::Serialize(op.padding.type);
99     auto activation_function =
100         ActivationFunction::Serialize(op.fused_activation_function);
101     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
102                                          op.stride_height, op.kwidth,
103                                          op.kheight, activation_function);
104   }
105 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const106   void ReadOptions(const TfLiteOptions& options,
107                    TocoOperator* op) const override {
108     op->padding.type = Padding::Deserialize(options.padding());
109     op->stride_width = options.stride_w();
110     op->stride_height = options.stride_h();
111     op->kwidth = options.filter_width();
112     op->kheight = options.filter_height();
113     op->fused_activation_function =
114         ActivationFunction::Deserialize(options.fused_activation_function());
115   }
116 };
117 
118 class Convolution
119     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
120                              ::tflite::BuiltinOptions_Conv2DOptions> {
121  public:
122   using BuiltinOperator::BuiltinOperator;
123 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const124   flatbuffers::Offset<TfLiteOptions> WriteOptions(
125       const TocoOperator& op,
126       flatbuffers::FlatBufferBuilder* builder) const override {
127     auto padding = Padding::Serialize(op.padding.type);
128     auto activation_function =
129         ActivationFunction::Serialize(op.fused_activation_function);
130     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
131                                          op.stride_height, activation_function,
132                                          op.dilation_width_factor,
133                                          op.dilation_height_factor);
134   }
135 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const136   void ReadOptions(const TfLiteOptions& options,
137                    TocoOperator* op) const override {
138     op->padding.type = Padding::Deserialize(options.padding());
139     op->stride_width = options.stride_w();
140     op->stride_height = options.stride_h();
141     op->dilation_width_factor = options.dilation_w_factor();
142     op->dilation_height_factor = options.dilation_h_factor();
143     op->fused_activation_function =
144         ActivationFunction::Deserialize(options.fused_activation_function());
145   }
146 };
147 
148 class DepthwiseConvolution
149     : public BuiltinOperator<DepthwiseConvOperator,
150                              ::tflite::DepthwiseConv2DOptions,
151                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
152  public:
153   using BuiltinOperator::BuiltinOperator;
154 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const155   flatbuffers::Offset<TfLiteOptions> WriteOptions(
156       const TocoOperator& op,
157       flatbuffers::FlatBufferBuilder* builder) const override {
158     auto padding = Padding::Serialize(op.padding.type);
159     auto activation_function =
160         ActivationFunction::Serialize(op.fused_activation_function);
161     return ::tflite::CreateDepthwiseConv2DOptions(
162         *builder, padding, op.stride_width, op.stride_height,
163         op.depth_multiplier, activation_function, op.dilation_width_factor,
164         op.dilation_height_factor);
165   }
166 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const167   void ReadOptions(const TfLiteOptions& options,
168                    TocoOperator* op) const override {
169     op->padding.type = Padding::Deserialize(options.padding());
170     op->stride_width = options.stride_w();
171     op->stride_height = options.stride_h();
172     op->depth_multiplier = options.depth_multiplier();
173     op->fused_activation_function =
174         ActivationFunction::Deserialize(options.fused_activation_function());
175     op->dilation_width_factor = options.dilation_w_factor();
176     op->dilation_height_factor = options.dilation_h_factor();
177   }
178 
GetVersion(const OperatorSignature & op_signature) const179   int GetVersion(const OperatorSignature& op_signature) const override {
180     const auto& conv_op =
181         static_cast<const DepthwiseConvOperator&>(*op_signature.op);
182     ::tflite::OpSignature op_sig =
183         GetVersioningOpSig(builtin_op(), op_signature);
184     op_sig.options.depthwise_conv_2d.dilation_w_factor =
185         conv_op.dilation_width_factor;
186     op_sig.options.depthwise_conv_2d.dilation_h_factor =
187         conv_op.dilation_height_factor;
188     return ::tflite::GetBuiltinOperatorVersion(op_sig);
189   }
190 };
191 
192 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
193                                    ::tflite::BuiltinOptions_AddOptions> {
194  public:
195   using BuiltinOperator::BuiltinOperator;
196 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const197   flatbuffers::Offset<TfLiteOptions> WriteOptions(
198       const TocoOperator& op,
199       flatbuffers::FlatBufferBuilder* builder) const override {
200     auto activation_function =
201         ActivationFunction::Serialize(op.fused_activation_function);
202     return ::tflite::CreateAddOptions(*builder, activation_function);
203   }
204 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const205   void ReadOptions(const TfLiteOptions& options,
206                    TocoOperator* op) const override {
207     op->fused_activation_function =
208         ActivationFunction::Deserialize(options.fused_activation_function());
209   }
210 };
211 
212 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
213                                     ::tflite::BuiltinOptions_AddNOptions> {
214  public:
215   using BuiltinOperator::BuiltinOperator;
216 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const217   flatbuffers::Offset<TfLiteOptions> WriteOptions(
218       const TocoOperator& op,
219       flatbuffers::FlatBufferBuilder* builder) const override {
220     return ::tflite::CreateAddNOptions(*builder);
221   }
222 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const223   void ReadOptions(const TfLiteOptions& options,
224                    TocoOperator* op) const override {}
225 };
226 
227 class SpaceToBatchND
228     : public BuiltinOperator<SpaceToBatchNDOperator,
229                              ::tflite::SpaceToBatchNDOptions,
230                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
231  public:
232   using BuiltinOperator::BuiltinOperator;
233 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const234   flatbuffers::Offset<TfLiteOptions> WriteOptions(
235       const TocoOperator& op,
236       flatbuffers::FlatBufferBuilder* builder) const override {
237     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
238   }
239 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const240   void ReadOptions(const TfLiteOptions& options,
241                    TocoOperator* op) const override {}
242 
GetVersion(const OperatorSignature & op_signature) const243   int GetVersion(const OperatorSignature& op_signature) const override {
244     const std::string& input_name = op_signature.op->inputs[0];
245     const Array& input_array = op_signature.model->GetArray(input_name);
246     ::tflite::OpSignature op_sig =
247         GetVersioningOpSig(builtin_op(), op_signature);
248     op_sig.options.single_input_op.num_dims =
249         input_array.shape().dimensions_count();
250     return ::tflite::GetBuiltinOperatorVersion(op_sig);
251   }
252 };
253 
254 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
255                                    ::tflite::BuiltinOptions_SubOptions> {
256  public:
257   using BuiltinOperator::BuiltinOperator;
258 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const259   flatbuffers::Offset<TfLiteOptions> WriteOptions(
260       const TocoOperator& op,
261       flatbuffers::FlatBufferBuilder* builder) const override {
262     auto activation_function =
263         ActivationFunction::Serialize(op.fused_activation_function);
264     return ::tflite::CreateSubOptions(*builder, activation_function);
265   }
266 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const267   void ReadOptions(const TfLiteOptions& options,
268                    TocoOperator* op) const override {
269     op->fused_activation_function =
270         ActivationFunction::Deserialize(options.fused_activation_function());
271   }
272 
GetVersion(const OperatorSignature & op_signature) const273   int GetVersion(const OperatorSignature& op_signature) const override {
274     const std::string& input1_name = op_signature.op->inputs[0];
275     const std::string& input2_name = op_signature.op->inputs[1];
276     const Array& input1_array = op_signature.model->GetArray(input1_name);
277     const Array& input2_array = op_signature.model->GetArray(input2_name);
278     ::tflite::OpSignature op_sig =
279         GetVersioningOpSig(builtin_op(), op_signature);
280     if (input1_array.has_shape() && input2_array.has_shape()) {
281       op_sig.options.addsub.num_dims =
282           std::max(input1_array.shape().dimensions_count(),
283                    input2_array.shape().dimensions_count());
284       op_sig.options.addsub.need_broadcast =
285           (input1_array.shape() != input2_array.shape());
286     }
287     return ::tflite::GetBuiltinOperatorVersion(op_sig);
288   }
289 };
290 
291 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
292                                    ::tflite::BuiltinOptions_DivOptions> {
293  public:
294   using BuiltinOperator::BuiltinOperator;
295 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const296   flatbuffers::Offset<TfLiteOptions> WriteOptions(
297       const TocoOperator& op,
298       flatbuffers::FlatBufferBuilder* builder) const override {
299     auto activation_function =
300         ActivationFunction::Serialize(op.fused_activation_function);
301     return ::tflite::CreateDivOptions(*builder, activation_function);
302   }
303 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const304   void ReadOptions(const TfLiteOptions& options,
305                    TocoOperator* op) const override {
306     op->fused_activation_function =
307         ActivationFunction::Deserialize(options.fused_activation_function());
308   }
309 
GetVersion(const OperatorSignature & op_signature) const310   int GetVersion(const OperatorSignature& op_signature) const override {
311     const std::string& input1_name = op_signature.op->inputs[0];
312     const std::string& input2_name = op_signature.op->inputs[1];
313     const Array& input1_array = op_signature.model->GetArray(input1_name);
314     const Array& input2_array = op_signature.model->GetArray(input2_name);
315     ::tflite::OpSignature op_sig =
316         GetVersioningOpSig(builtin_op(), op_signature);
317     if (input1_array.has_shape() && input2_array.has_shape()) {
318       op_sig.options.broadcast.num_dims =
319           std::max(input1_array.shape().dimensions_count(),
320                    input2_array.shape().dimensions_count());
321       op_sig.options.broadcast.need_broadcast =
322           (input1_array.shape() != input2_array.shape());
323     }
324     return ::tflite::GetBuiltinOperatorVersion(op_sig);
325   }
326 };
327 
328 class BatchToSpaceND
329     : public BuiltinOperator<BatchToSpaceNDOperator,
330                              ::tflite::BatchToSpaceNDOptions,
331                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
332  public:
333   using BuiltinOperator::BuiltinOperator;
334 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const335   flatbuffers::Offset<TfLiteOptions> WriteOptions(
336       const TocoOperator& op,
337       flatbuffers::FlatBufferBuilder* builder) const override {
338     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
339   }
340 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const341   void ReadOptions(const TfLiteOptions& options,
342                    TocoOperator* op) const override {}
343 
GetVersion(const OperatorSignature & op_signature) const344   int GetVersion(const OperatorSignature& op_signature) const override {
345     const std::string& input_name = op_signature.op->inputs[0];
346     const Array& input_array = op_signature.model->GetArray(input_name);
347     ::tflite::OpSignature op_sig =
348         GetVersioningOpSig(builtin_op(), op_signature);
349     op_sig.options.single_input_op.num_dims =
350         input_array.shape().dimensions_count();
351     return ::tflite::GetBuiltinOperatorVersion(op_sig);
352   }
353 };
354 
355 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
356                                     ::tflite::BuiltinOptions_CastOptions> {
357  public:
358   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const359   flatbuffers::Offset<TfLiteOptions> WriteOptions(
360       const TocoOperator& op,
361       flatbuffers::FlatBufferBuilder* builder) const override {
362     return ::tflite::CreateCastOptions(*builder,
363                                        DataType::Serialize(op.src_data_type),
364                                        DataType::Serialize(op.dst_data_type));
365   }
366 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const367   void ReadOptions(const TfLiteOptions& options,
368                    TocoOperator* op) const override {
369     op->src_data_type = DataType::Deserialize(options.in_data_type());
370     op->dst_data_type = DataType::Deserialize(options.out_data_type());
371   }
372 };
373 
374 class Concatenation
375     : public BuiltinOperator<ConcatenationOperator,
376                              ::tflite::ConcatenationOptions,
377                              ::tflite::BuiltinOptions_ConcatenationOptions> {
378  public:
379   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const380   flatbuffers::Offset<TfLiteOptions> WriteOptions(
381       const TocoOperator& op,
382       flatbuffers::FlatBufferBuilder* builder) const override {
383     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
384   }
385 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const386   void ReadOptions(const TfLiteOptions& options,
387                    TocoOperator* op) const override {
388     op->axis = options.axis();
389   }
390 };
391 
392 class DepthToSpace
393     : public BuiltinOperator<DepthToSpaceOperator,
394                              ::tflite::DepthToSpaceOptions,
395                              ::tflite::BuiltinOptions_DepthToSpaceOptions> {
396  public:
397   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const398   flatbuffers::Offset<TfLiteOptions> WriteOptions(
399       const TocoOperator& op,
400       flatbuffers::FlatBufferBuilder* builder) const override {
401     return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
402   }
403 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const404   void ReadOptions(const TfLiteOptions& options,
405                    TocoOperator* op) const override {
406     op->block_size = options.block_size();
407   }
408 };
409 
410 class FakeQuant
411     : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
412                              ::tflite::BuiltinOptions_FakeQuantOptions> {
413  public:
414   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const415   flatbuffers::Offset<TfLiteOptions> WriteOptions(
416       const TocoOperator& op,
417       flatbuffers::FlatBufferBuilder* builder) const override {
418     return ::tflite::CreateFakeQuantOptions(
419         *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
420   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const421   void ReadOptions(const TfLiteOptions& options,
422                    TocoOperator* op) const override {
423     auto* minmax = new MinMax;
424     minmax->min = options.min();
425     minmax->max = options.max();
426     op->minmax.reset(minmax);
427     op->num_bits = options.num_bits();
428     op->narrow_range = options.narrow_range();
429   }
GetVersion(const OperatorSignature & op_signature) const430   int GetVersion(const OperatorSignature& op_signature) const override {
431     const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
432     ::tflite::OpSignature op_sig =
433         GetVersioningOpSig(builtin_op(), op_signature);
434     op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
435     return ::tflite::GetBuiltinOperatorVersion(op_sig);
436   }
437 };
438 
439 class FullyConnected
440     : public BuiltinOperator<FullyConnectedOperator,
441                              ::tflite::FullyConnectedOptions,
442                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
443  public:
444   using BuiltinOperator::BuiltinOperator;
445 
GetWeightFormat(FullyConnectedWeightsFormat fmt) const446   ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
447       FullyConnectedWeightsFormat fmt) const {
448     switch (fmt) {
449       case FullyConnectedWeightsFormat::kDefault:
450         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
451       case FullyConnectedWeightsFormat::kShuffled4x16Int8:
452         return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
453       default:
454         LOG(ERROR) << "Unhandled FC weights format";
455         return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
456     }
457   }
458 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const459   flatbuffers::Offset<TfLiteOptions> WriteOptions(
460       const TocoOperator& op,
461       flatbuffers::FlatBufferBuilder* builder) const override {
462     auto activation_function =
463         ActivationFunction::Serialize(op.fused_activation_function);
464     return ::tflite::CreateFullyConnectedOptions(
465         *builder, activation_function, GetWeightFormat(op.weights_format));
466   }
467 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const468   void ReadOptions(const TfLiteOptions& options,
469                    TocoOperator* op) const override {
470     op->fused_activation_function =
471         ActivationFunction::Deserialize(options.fused_activation_function());
472     switch (options.weights_format()) {
473       case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
474         op->weights_format = FullyConnectedWeightsFormat::kDefault;
475         break;
476       case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
477         op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
478         break;
479       default:
480         LOG(ERROR) << "Unhandled FC weights format";
481         op->weights_format = FullyConnectedWeightsFormat::kDefault;
482     }
483   }
484 
GetVersion(const OperatorSignature & op_signature) const485   int GetVersion(const OperatorSignature& op_signature) const override {
486     const auto& fc_op =
487         static_cast<const FullyConnectedOperator&>(*op_signature.op);
488     ::tflite::OpSignature op_sig =
489         GetVersioningOpSig(builtin_op(), op_signature);
490     op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
491     op_sig.options.fully_connected.weights_format =
492         GetWeightFormat(fc_op.weights_format);
493     op_sig.options.fully_connected.sparse_weight = false;
494     return ::tflite::GetBuiltinOperatorVersion(op_sig);
495   }
496 };
497 
498 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
499                                       ::tflite::BuiltinOptions_GatherOptions> {
500  public:
501   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const502   flatbuffers::Offset<TfLiteOptions> WriteOptions(
503       const TocoOperator& op,
504       flatbuffers::FlatBufferBuilder* builder) const override {
505     int axis = op.axis ? op.axis.value() : 0;
506     return ::tflite::CreateGatherOptions(*builder, axis);
507   }
508 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const509   void ReadOptions(const TfLiteOptions& options,
510                    TocoOperator* op) const override {
511     op->axis = {options.axis()};
512   }
513 };
514 
515 class GatherNd
516     : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
517                              ::tflite::BuiltinOptions_GatherNdOptions> {
518  public:
519   using BuiltinOperator::BuiltinOperator;
520 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const521   flatbuffers::Offset<TfLiteOptions> WriteOptions(
522       const TocoOperator& op,
523       flatbuffers::FlatBufferBuilder* builder) const override {
524     return ::tflite::CreateGatherNdOptions(*builder);
525   }
526 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const527   void ReadOptions(const TfLiteOptions& options,
528                    TocoOperator* op) const override {}
529 };
530 
531 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
532                                     ::tflite::BuiltinOptions_SVDFOptions> {
533  public:
534   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const535   flatbuffers::Offset<TfLiteOptions> WriteOptions(
536       const TocoOperator& op,
537       flatbuffers::FlatBufferBuilder* builder) const override {
538     auto activation_function =
539         ActivationFunction::Serialize(op.fused_activation_function);
540     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
541   }
542 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const543   void ReadOptions(const TfLiteOptions& options,
544                    TocoOperator* op) const override {
545     op->fused_activation_function =
546         ActivationFunction::Deserialize(options.fused_activation_function());
547     op->rank = options.rank();
548   }
549 };
550 
551 class L2Normalization
552     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
553                              ::tflite::BuiltinOptions_L2NormOptions> {
554  public:
555   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const556   flatbuffers::Offset<TfLiteOptions> WriteOptions(
557       const TocoOperator& op,
558       flatbuffers::FlatBufferBuilder* builder) const override {
559     auto activation_function =
560         ActivationFunction::Serialize(op.fused_activation_function);
561     return ::tflite::CreateL2NormOptions(*builder, activation_function);
562   }
563 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const564   void ReadOptions(const TfLiteOptions& options,
565                    TocoOperator* op) const override {
566     op->fused_activation_function =
567         ActivationFunction::Deserialize(options.fused_activation_function());
568   }
569 };
570 
571 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
572                                       ::tflite::BuiltinOptions_Pool2DOptions> {
573  public:
574   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const575   flatbuffers::Offset<TfLiteOptions> WriteOptions(
576       const TocoOperator& op,
577       flatbuffers::FlatBufferBuilder* builder) const override {
578     auto padding = Padding::Serialize(op.padding.type);
579     auto activation_function =
580         ActivationFunction::Serialize(op.fused_activation_function);
581     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
582                                          op.stride_height, op.kwidth,
583                                          op.kheight, activation_function);
584   }
585 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const586   void ReadOptions(const TfLiteOptions& options,
587                    TocoOperator* op) const override {
588     op->padding.type = Padding::Deserialize(options.padding());
589     op->stride_width = options.stride_w();
590     op->stride_height = options.stride_h();
591     op->kwidth = options.filter_width();
592     op->kheight = options.filter_height();
593     op->fused_activation_function =
594         ActivationFunction::Deserialize(options.fused_activation_function());
595   }
596 };
597 
598 class LocalResponseNormalization
599     : public BuiltinOperator<
600           LocalResponseNormalizationOperator,
601           ::tflite::LocalResponseNormalizationOptions,
602           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
603  public:
604   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const605   flatbuffers::Offset<TfLiteOptions> WriteOptions(
606       const TocoOperator& op,
607       flatbuffers::FlatBufferBuilder* builder) const override {
608     return ::tflite::CreateLocalResponseNormalizationOptions(
609         *builder, op.range, op.bias, op.alpha, op.beta);
610   }
611 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const612   void ReadOptions(const TfLiteOptions& options,
613                    TocoOperator* op) const override {
614     op->range = options.radius();
615     op->bias = options.bias();
616     op->alpha = options.alpha();
617     op->beta = options.beta();
618   }
619 };
620 
621 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
622                                        ::tflite::BuiltinOptions_Pool2DOptions> {
623  public:
624   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const625   flatbuffers::Offset<TfLiteOptions> WriteOptions(
626       const TocoOperator& op,
627       flatbuffers::FlatBufferBuilder* builder) const override {
628     auto padding = Padding::Serialize(op.padding.type);
629     auto activation_function =
630         ActivationFunction::Serialize(op.fused_activation_function);
631     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
632                                          op.stride_height, op.kwidth,
633                                          op.kheight, activation_function);
634   }
635 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const636   void ReadOptions(const TfLiteOptions& options,
637                    TocoOperator* op) const override {
638     op->padding.type = Padding::Deserialize(options.padding());
639     op->stride_width = options.stride_w();
640     op->stride_height = options.stride_h();
641     op->kwidth = options.filter_width();
642     op->kheight = options.filter_height();
643     op->fused_activation_function =
644         ActivationFunction::Deserialize(options.fused_activation_function());
645   }
646 };
647 
648 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
649                                    ::tflite::BuiltinOptions_MulOptions> {
650  public:
651   using BuiltinOperator::BuiltinOperator;
652 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const653   flatbuffers::Offset<TfLiteOptions> WriteOptions(
654       const TocoOperator& op,
655       flatbuffers::FlatBufferBuilder* builder) const override {
656     auto activation_function =
657         ActivationFunction::Serialize(op.fused_activation_function);
658     return ::tflite::CreateMulOptions(*builder, activation_function);
659   }
660 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const661   void ReadOptions(const TfLiteOptions& options,
662                    TocoOperator* op) const override {
663     op->fused_activation_function =
664         ActivationFunction::Deserialize(options.fused_activation_function());
665   }
666 
GetVersion(const OperatorSignature & op_signature) const667   int GetVersion(const OperatorSignature& op_signature) const override {
668     const std::string& input1_name = op_signature.op->inputs[0];
669     const std::string& input2_name = op_signature.op->inputs[1];
670     const std::string& output_name = op_signature.op->outputs[0];
671     const Array& input1_array = op_signature.model->GetArray(input1_name);
672     const Array& input2_array = op_signature.model->GetArray(input2_name);
673     const Array& output_array = op_signature.model->GetArray(output_name);
674     const auto& input1_quant = input1_array.quantization_params;
675     const auto& input2_quant = input2_array.quantization_params;
676     const auto& output_quant = output_array.quantization_params;
677     const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
678     const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
679     const float output_scale = output_quant ? output_quant->scale : 0.0f;
680     ::tflite::OpSignature op_sig =
681         GetVersioningOpSig(builtin_op(), op_signature);
682     op_sig.options.mul.input1_scale = input1_scale;
683     op_sig.options.mul.input2_scale = input2_scale;
684     op_sig.options.mul.output_scale = output_scale;
685     return ::tflite::GetBuiltinOperatorVersion(op_sig);
686   }
687 };
688 
689 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
690                                    ::tflite::BuiltinOptions_PadOptions> {
691  public:
692   using BuiltinOperator::BuiltinOperator;
693 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const694   flatbuffers::Offset<TfLiteOptions> WriteOptions(
695       const TocoOperator& op,
696       flatbuffers::FlatBufferBuilder* builder) const override {
697     return ::tflite::CreatePadOptions(*builder);
698   }
699 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const700   void ReadOptions(const TfLiteOptions& options,
701                    TocoOperator* op) const override {}
702 };
703 
704 class Tile
705     : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
706                              ::tflite::BuiltinOptions_TileOptions> {
707   using BuiltinOperator::BuiltinOperator;
708 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const709   flatbuffers::Offset<TfLiteOptions> WriteOptions(
710       const TocoOperator& op,
711       flatbuffers::FlatBufferBuilder* builder) const override {
712     return ::tflite::CreateTileOptions(*builder);
713   }
714 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const715   void ReadOptions(const TfLiteOptions& options,
716                    TocoOperator* op) const override {}
717 };
718 
719 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
720                                      ::tflite::BuiltinOptions_PadV2Options> {
721  public:
722   using BuiltinOperator::BuiltinOperator;
723 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const724   flatbuffers::Offset<TfLiteOptions> WriteOptions(
725       const TocoOperator& op,
726       flatbuffers::FlatBufferBuilder* builder) const override {
727     return ::tflite::CreatePadV2Options(*builder);
728   }
729 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const730   void ReadOptions(const TfLiteOptions& options,
731                    TocoOperator* op) const override {}
732 };
733 
734 class Reshape
735     : public BuiltinOperator<TensorFlowReshapeOperator,
736                              ::tflite::ReshapeOptions,
737                              ::tflite::BuiltinOptions_ReshapeOptions> {
738  public:
739   using BuiltinOperator::BuiltinOperator;
740 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const741   flatbuffers::Offset<TfLiteOptions> WriteOptions(
742       const TocoOperator& op,
743       flatbuffers::FlatBufferBuilder* builder) const override {
744     return ::tflite::CreateReshapeOptions(*builder,
745                                           builder->CreateVector(op.shape));
746   }
747 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const748   void ReadOptions(const TfLiteOptions& options,
749                    TocoOperator* op) const override {
750     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
751                      options.new_shape()->end());
752   }
753 };
754 
755 class Softmax
756     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
757                              ::tflite::BuiltinOptions_SoftmaxOptions> {
758  public:
759   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const760   flatbuffers::Offset<TfLiteOptions> WriteOptions(
761       const TocoOperator& op,
762       flatbuffers::FlatBufferBuilder* builder) const override {
763     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
764   }
765 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const766   void ReadOptions(const TfLiteOptions& options,
767                    TocoOperator* op) const override {
768     op->beta = options.beta();
769   }
770 };
771 
772 class SpaceToDepth
773     : public BuiltinOperator<SpaceToDepthOperator,
774                              ::tflite::SpaceToDepthOptions,
775                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
776  public:
777   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const778   flatbuffers::Offset<TfLiteOptions> WriteOptions(
779       const TocoOperator& op,
780       flatbuffers::FlatBufferBuilder* builder) const override {
781     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
782   }
783 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const784   void ReadOptions(const TfLiteOptions& options,
785                    TocoOperator* op) const override {
786     op->block_size = options.block_size();
787   }
788 };
789 
790 class Transpose
791     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
792                              ::tflite::BuiltinOptions_TransposeOptions> {
793  public:
794   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const795   flatbuffers::Offset<TfLiteOptions> WriteOptions(
796       const TocoOperator& op,
797       flatbuffers::FlatBufferBuilder* builder) const override {
798     return ::tflite::CreateTransposeOptions(*builder);
799   }
800 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const801   void ReadOptions(const TfLiteOptions& options,
802                    TocoOperator* op) const override {}
803 };
804 
805 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
806                                     ::tflite::BuiltinOptions_LSTMOptions> {
807  public:
808   using BuiltinOperator::BuiltinOperator;
809 
GetKernelType(LstmCellOperator::KernelType type) const810   ::tflite::LSTMKernelType GetKernelType(
811       LstmCellOperator::KernelType type) const {
812     switch (type) {
813       case LstmCellOperator::KERNEL_BASIC:
814         return ::tflite::LSTMKernelType_BASIC;
815         break;
816       case LstmCellOperator::KERNEL_FULL:
817         return ::tflite::LSTMKernelType_FULL;
818         break;
819       default:
820         LOG(ERROR) << "Unhandled Kernel Type";
821         return static_cast<::tflite::LSTMKernelType>(-1);
822     }
823   }
824 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const825   flatbuffers::Offset<TfLiteOptions> WriteOptions(
826       const TocoOperator& op,
827       flatbuffers::FlatBufferBuilder* builder) const override {
828     ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
829 
830     // Current toco converter only supports tanh, no clip.
831     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
832                                        ::tflite::ActivationFunctionType_TANH,
833                                        /*cell_clip=*/0.0,
834                                        /*proj_clip=*/0.0, kernel_type);
835   }
836 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const837   void ReadOptions(const TfLiteOptions& options,
838                    TocoOperator* op) const override {
839     // Only support tanh activation, so check that tflite type is tanh.
840     CHECK(options.fused_activation_function() ==
841           ::tflite::ActivationFunctionType_TANH);
842 
843     switch (options.kernel_type()) {
844       case ::tflite::LSTMKernelType_BASIC:
845         op->kernel_type = LstmCellOperator::KERNEL_BASIC;
846         break;
847       case ::tflite::LSTMKernelType_FULL:
848         op->kernel_type = LstmCellOperator::KERNEL_FULL;
849         break;
850     }
851   }
852 
GetVersion(const OperatorSignature & op_signature) const853   int GetVersion(const OperatorSignature& op_signature) const override {
854     const auto& lstm_op =
855         static_cast<const LstmCellOperator&>(*op_signature.op);
856     ::tflite::OpSignature op_sig =
857         GetVersioningOpSig(builtin_op(), op_signature);
858     op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
859     return ::tflite::GetBuiltinOperatorVersion(op_sig);
860   }
861 
GetMutatingInputVariables(const Operator & op) const862   std::vector<bool> GetMutatingInputVariables(
863       const Operator& op) const override {
864     const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
865 
866     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
867     switch (lstm_op.kernel_type) {
868       case LstmCellOperator::KERNEL_FULL: {
869         mutating_input_variables[kInputActivationStateTensor] = true;
870         mutating_input_variables[kInputCellStateTensor] = true;
871         break;
872       }
873       case LstmCellOperator::KERNEL_BASIC: {
874         mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
875         mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
876         break;
877       }
878     }
879     return mutating_input_variables;
880   }
881 };
882 
883 class UnidirectionalSequenceLstm
884     : public BuiltinOperator<
885           UnidirectionalSequenceLstmOperator,
886           ::tflite::UnidirectionalSequenceLSTMOptions,
887           ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
888  public:
889   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const890   flatbuffers::Offset<TfLiteOptions> WriteOptions(
891       const TocoOperator& op,
892       flatbuffers::FlatBufferBuilder* builder) const override {
893     // Current toco converter only supports tanh, no clip.
894     return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
895         *builder, /*fused_activation_function=*/
896         ::tflite::ActivationFunctionType_TANH,
897         /*cell_clip=*/0.0,
898         /*proj_clip=*/0.0,
899         /*time_major=*/true);
900   }
901 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const902   void ReadOptions(const TfLiteOptions& options,
903                    TocoOperator* op) const override {
904     // Only support tanh activation, so check that tflite type is tanh.
905     DCHECK(options.fused_activation_function() ==
906            ::tflite::ActivationFunctionType_TANH);
907   }
908 
GetMutatingInputVariables(const Operator & op) const909   std::vector<bool> GetMutatingInputVariables(
910       const Operator& op) const override {
911     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
912     mutating_input_variables[kInputActivationStateTensor] = true;
913     mutating_input_variables[kInputCellStateTensor] = true;
914     return mutating_input_variables;
915   }
916 };
917 
918 class BidirectionalSequenceLstm
919     : public BuiltinOperator<
920           BidirectionalSequenceLstmOperator,
921           ::tflite::BidirectionalSequenceLSTMOptions,
922           ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
923  public:
924   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const925   flatbuffers::Offset<TfLiteOptions> WriteOptions(
926       const TocoOperator& op,
927       flatbuffers::FlatBufferBuilder* builder) const override {
928     // Current toco converter only supports tanh, no clip.
929     return ::tflite::CreateBidirectionalSequenceLSTMOptions(
930         *builder, /*fused_activation_function=*/
931         ::tflite::ActivationFunctionType_TANH,
932         /*cell_clip=*/0.0,
933         /*proj_clip=*/0.0,
934         /*merge_outputs=*/op.merge_outputs,
935         /*time_major=*/true);
936   }
937 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const938   void ReadOptions(const TfLiteOptions& options,
939                    TocoOperator* op) const override {
940     // Only support tanh activation, so check that tflite type is tanh.
941     DCHECK(options.fused_activation_function() ==
942            ::tflite::ActivationFunctionType_TANH);
943     op->merge_outputs = options.merge_outputs();
944   }
945 
GetMutatingInputVariables(const Operator & op) const946   std::vector<bool> GetMutatingInputVariables(
947       const Operator& op) const override {
948     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
949     // Forward input activation state.
950     mutating_input_variables[35] = true;
951     // Forward input cell state.
952     mutating_input_variables[36] = true;
953     // Backward input activation state.
954     mutating_input_variables[37] = true;
955     // Backward input cell state.
956     mutating_input_variables[38] = true;
957     return mutating_input_variables;
958   }
959 };
960 
961 class BidirectionalSequenceRnn
962     : public BuiltinOperator<
963           BidirectionalSequenceRnnOperator,
964           ::tflite::BidirectionalSequenceRNNOptions,
965           ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
966  public:
967   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const968   flatbuffers::Offset<TfLiteOptions> WriteOptions(
969       const TocoOperator& op,
970       flatbuffers::FlatBufferBuilder* builder) const override {
971     // Current toco converter only supports tanh, no clip.
972     return ::tflite::CreateBidirectionalSequenceRNNOptions(
973         *builder, /*time_major=*/true,
974         /*fused_activation_function=*/
975         ::tflite::ActivationFunctionType_TANH,
976         /*merge_outputs=*/op.merge_outputs);
977   }
978 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const979   void ReadOptions(const TfLiteOptions& options,
980                    TocoOperator* op) const override {
981     // Only support tanh activation, so check that tflite type is tanh.
982     DCHECK(options.fused_activation_function() ==
983            ::tflite::ActivationFunctionType_TANH);
984     op->merge_outputs = options.merge_outputs();
985   }
986 
GetMutatingInputVariables(const Operator & op) const987   std::vector<bool> GetMutatingInputVariables(
988       const Operator& op) const override {
989     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
990     // Forward hidden state.
991     mutating_input_variables[4] = true;
992     // Backward hidden state.
993     mutating_input_variables[8] = true;
994     return mutating_input_variables;
995   }
996 };
997 
998 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
999                                     ::tflite::BuiltinOptions_ReducerOptions> {
1000  public:
1001   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1002   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1003       const TocoOperator& op,
1004       flatbuffers::FlatBufferBuilder* builder) const override {
1005     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1006   }
1007 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1008   void ReadOptions(const TfLiteOptions& options,
1009                    TocoOperator* op) const override {
1010     op->keep_dims = options.keep_dims();
1011   }
1012 };
1013 
1014 class Sum
1015     : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1016                              ::tflite::BuiltinOptions_ReducerOptions> {
1017  public:
1018   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1019   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1020       const TocoOperator& op,
1021       flatbuffers::FlatBufferBuilder* builder) const override {
1022     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1023   }
1024 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1025   void ReadOptions(const TfLiteOptions& options,
1026                    TocoOperator* op) const override {
1027     op->keep_dims = options.keep_dims();
1028   }
1029 };
1030 
1031 class ReduceMax
1032     : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1033                              ::tflite::BuiltinOptions_ReducerOptions> {
1034  public:
1035   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1036   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1037       const TocoOperator& op,
1038       flatbuffers::FlatBufferBuilder* builder) const override {
1039     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1040   }
1041 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1042   void ReadOptions(const TfLiteOptions& options,
1043                    TocoOperator* op) const override {
1044     op->keep_dims = options.keep_dims();
1045   }
1046 };
1047 
1048 class ReduceMin
1049     : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1050                              ::tflite::BuiltinOptions_ReducerOptions> {
1051  public:
1052   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1053   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1054       const TocoOperator& op,
1055       flatbuffers::FlatBufferBuilder* builder) const override {
1056     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1057   }
1058 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1059   void ReadOptions(const TfLiteOptions& options,
1060                    TocoOperator* op) const override {
1061     op->keep_dims = options.keep_dims();
1062   }
1063 };
1064 
1065 class ReduceProd
1066     : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1067                              ::tflite::BuiltinOptions_ReducerOptions> {
1068  public:
1069   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1070   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1071       const TocoOperator& op,
1072       flatbuffers::FlatBufferBuilder* builder) const override {
1073     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1074   }
1075 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1076   void ReadOptions(const TfLiteOptions& options,
1077                    TocoOperator* op) const override {
1078     op->keep_dims = options.keep_dims();
1079   }
1080 };
1081 
1082 class ReduceAny
1083     : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1084                              ::tflite::BuiltinOptions_ReducerOptions> {
1085  public:
1086   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1087   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1088       const TocoOperator& op,
1089       flatbuffers::FlatBufferBuilder* builder) const override {
1090     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1091   }
1092 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1093   void ReadOptions(const TfLiteOptions& options,
1094                    TocoOperator* op) const override {
1095     op->keep_dims = options.keep_dims();
1096   }
1097 };
1098 
1099 class ResizeBilinear
1100     : public BuiltinOperator<ResizeBilinearOperator,
1101                              ::tflite::ResizeBilinearOptions,
1102                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1103  public:
1104   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1105   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1106       const TocoOperator& op,
1107       flatbuffers::FlatBufferBuilder* builder) const override {
1108     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1109                                                  op.half_pixel_centers);
1110   }
1111 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1112   void ReadOptions(const TfLiteOptions& options,
1113                    TocoOperator* op) const override {
1114     op->align_corners = options.align_corners();
1115     op->half_pixel_centers = options.half_pixel_centers();
1116   }
1117 
GetVersion(const OperatorSignature & op_signature) const1118   int GetVersion(const OperatorSignature& op_signature) const override {
1119     const auto& resize_bilinear_op =
1120         static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1121     ::tflite::OpSignature op_sig =
1122         GetVersioningOpSig(builtin_op(), op_signature);
1123     op_sig.options.resize.half_pixel_centers =
1124         resize_bilinear_op.half_pixel_centers;
1125     op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
1126     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1127   }
1128 };
1129 
1130 class ResizeNearestNeighbor
1131     : public BuiltinOperator<
1132           ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1133           ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1134  public:
1135   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1136   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1137       const TocoOperator& op,
1138       flatbuffers::FlatBufferBuilder* builder) const override {
1139     return ::tflite::CreateResizeNearestNeighborOptions(
1140         *builder, op.align_corners, op.half_pixel_centers);
1141   }
1142 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1143   void ReadOptions(const TfLiteOptions& options,
1144                    TocoOperator* op) const override {
1145     op->align_corners = options.align_corners();
1146     op->half_pixel_centers = options.half_pixel_centers();
1147   }
1148 
GetVersion(const OperatorSignature & op_signature) const1149   int GetVersion(const OperatorSignature& op_signature) const override {
1150     const auto& resize_nn_op =
1151         static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
1152     ::tflite::OpSignature op_sig =
1153         GetVersioningOpSig(builtin_op(), op_signature);
1154     op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
1155     op_sig.options.resize.align_corners = resize_nn_op.align_corners;
1156     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1157   }
1158 };
1159 
1160 class Squeeze
1161     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1162                              ::tflite::BuiltinOptions_SqueezeOptions> {
1163  public:
1164   using BuiltinOperator::BuiltinOperator;
1165 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1166   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1167       const TocoOperator& op,
1168       flatbuffers::FlatBufferBuilder* builder) const override {
1169     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1170     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1171   }
1172 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1173   void ReadOptions(const TfLiteOptions& options,
1174                    TocoOperator* op) const override {
1175     op->squeeze_dims.insert(op->squeeze_dims.end(),
1176                             options.squeeze_dims()->begin(),
1177                             options.squeeze_dims()->end());
1178   }
1179 };
1180 
1181 class Split
1182     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1183                              ::tflite::BuiltinOptions_SplitOptions> {
1184  public:
1185   using BuiltinOperator::BuiltinOperator;
1186 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1187   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1188       const TocoOperator& op,
1189       flatbuffers::FlatBufferBuilder* builder) const override {
1190     return ::tflite::CreateSplitOptions(*builder, op.num_split);
1191   }
1192 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1193   void ReadOptions(const TfLiteOptions& options,
1194                    TocoOperator* op) const override {
1195     op->num_split = options.num_splits();
1196   }
1197 };
1198 
1199 class SplitV
1200     : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1201                              ::tflite::BuiltinOptions_SplitVOptions> {
1202  public:
1203   using BuiltinOperator::BuiltinOperator;
1204 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1205   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1206       const TocoOperator& op,
1207       flatbuffers::FlatBufferBuilder* builder) const override {
1208     return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1209   }
1210 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1211   void ReadOptions(const TfLiteOptions& options,
1212                    TocoOperator* op) const override {
1213     op->num_split = options.num_splits();
1214   }
1215 };
1216 
1217 class StridedSlice
1218     : public BuiltinOperator<StridedSliceOperator,
1219                              ::tflite::StridedSliceOptions,
1220                              ::tflite::BuiltinOptions_StridedSliceOptions> {
1221  public:
1222   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1223   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1224       const TocoOperator& op,
1225       flatbuffers::FlatBufferBuilder* builder) const override {
1226     return ::tflite::CreateStridedSliceOptions(
1227         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1228         op.new_axis_mask, op.shrink_axis_mask);
1229   }
1230 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1231   void ReadOptions(const TfLiteOptions& options,
1232                    TocoOperator* op) const override {
1233     op->begin_mask = options.begin_mask();
1234     op->end_mask = options.end_mask();
1235     op->ellipsis_mask = options.ellipsis_mask();
1236     op->new_axis_mask = options.new_axis_mask();
1237     op->shrink_axis_mask = options.shrink_axis_mask();
1238   }
1239 
GetVersion(const OperatorSignature & op_signature) const1240   int GetVersion(const OperatorSignature& op_signature) const override {
1241     const auto& ss_op =
1242         static_cast<const StridedSliceOperator&>(*op_signature.op);
1243     ::tflite::OpSignature op_sig =
1244         GetVersioningOpSig(builtin_op(), op_signature);
1245     op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
1246     return ::tflite::GetBuiltinOperatorVersion(op_sig);
1247   }
1248 };
1249 
1250 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1251                                        ::tflite::BuiltinOptions_TopKV2Options> {
1252  public:
1253   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1254   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1255       const TocoOperator& op,
1256       flatbuffers::FlatBufferBuilder* builder) const override {
1257     return ::tflite::CreateTopKV2Options(*builder);
1258   }
1259 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1260   void ReadOptions(const TfLiteOptions& options,
1261                    TocoOperator* op) const override {}
1262 };
1263 
1264 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1265                                       ::tflite::BuiltinOptions_ArgMaxOptions> {
1266  public:
1267   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1268   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1269       const TocoOperator& op,
1270       flatbuffers::FlatBufferBuilder* builder) const override {
1271     return ::tflite::CreateArgMaxOptions(
1272         *builder, DataType::Serialize(op.output_data_type));
1273   }
1274 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1275   void ReadOptions(const TfLiteOptions& options,
1276                    TocoOperator* op) const override {
1277     op->output_data_type = DataType::Deserialize(options.output_type());
1278   }
1279 };
1280 
1281 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1282                                       ::tflite::BuiltinOptions_ArgMinOptions> {
1283  public:
1284   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1285   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1286       const TocoOperator& op,
1287       flatbuffers::FlatBufferBuilder* builder) const override {
1288     return ::tflite::CreateArgMinOptions(
1289         *builder, DataType::Serialize(op.output_data_type));
1290   }
1291 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1292   void ReadOptions(const TfLiteOptions& options,
1293                    TocoOperator* op) const override {
1294     op->output_data_type = DataType::Deserialize(options.output_type());
1295   }
1296 };
1297 
1298 class TransposeConv
1299     : public BuiltinOperator<TransposeConvOperator,
1300                              ::tflite::TransposeConvOptions,
1301                              ::tflite::BuiltinOptions_TransposeConvOptions> {
1302  public:
1303   using BuiltinOperator::BuiltinOperator;
1304 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1305   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1306       const TocoOperator& op,
1307       flatbuffers::FlatBufferBuilder* builder) const override {
1308     auto padding = Padding::Serialize(op.padding.type);
1309     return ::tflite::CreateTransposeConvOptions(
1310         *builder, padding, op.stride_width, op.stride_height);
1311   }
1312 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1313   void ReadOptions(const TfLiteOptions& options,
1314                    TocoOperator* op) const override {
1315     op->padding.type = Padding::Deserialize(options.padding());
1316     op->stride_width = options.stride_w();
1317     op->stride_height = options.stride_h();
1318   }
1319 };
1320 
1321 class SparseToDense
1322     : public BuiltinOperator<SparseToDenseOperator,
1323                              ::tflite::SparseToDenseOptions,
1324                              ::tflite::BuiltinOptions_SparseToDenseOptions> {
1325  public:
1326   using BuiltinOperator::BuiltinOperator;
1327 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1328   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1329       const TocoOperator& op,
1330       flatbuffers::FlatBufferBuilder* builder) const override {
1331     return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1332   }
1333 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1334   void ReadOptions(const TfLiteOptions& options,
1335                    TocoOperator* op) const override {
1336     op->validate_indices = options.validate_indices();
1337   }
1338 };
1339 
1340 class ExpandDims
1341     : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1342                              ::tflite::BuiltinOptions_ExpandDimsOptions> {
1343  public:
1344   using BuiltinOperator::BuiltinOperator;
1345 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1346   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1347       const TocoOperator& op,
1348       flatbuffers::FlatBufferBuilder* builder) const override {
1349     return ::tflite::CreateExpandDimsOptions(*builder);
1350   }
1351 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1352   void ReadOptions(const TfLiteOptions& options,
1353                    TocoOperator* op) const override {}
1354 };
1355 
1356 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1357                                     ::tflite::BuiltinOptions_PackOptions> {
1358  public:
1359   using BuiltinOperator::BuiltinOperator;
1360 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1361   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1362       const TocoOperator& op,
1363       flatbuffers::FlatBufferBuilder* builder) const override {
1364     return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1365   }
1366 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1367   void ReadOptions(const TfLiteOptions& options,
1368                    TocoOperator* op) const override {
1369     op->values_count = options.values_count();
1370     op->axis = options.axis();
1371   }
1372 };
1373 
1374 class Shape
1375     : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1376                              ::tflite::BuiltinOptions_ShapeOptions> {
1377  public:
1378   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1379   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1380       const TocoOperator& op,
1381       flatbuffers::FlatBufferBuilder* builder) const override {
1382     return ::tflite::CreateShapeOptions(
1383         *builder, DataType::Serialize(op.output_data_type));
1384   }
1385 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1386   void ReadOptions(const TfLiteOptions& options,
1387                    TocoOperator* op) const override {
1388     op->output_data_type = DataType::Deserialize(options.out_type());
1389   }
1390 };
1391 
1392 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1393                                       ::tflite::BuiltinOptions_OneHotOptions> {
1394  public:
1395   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1396   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1397       const TocoOperator& op,
1398       flatbuffers::FlatBufferBuilder* builder) const override {
1399     return ::tflite::CreateOneHotOptions(*builder, op.axis);
1400   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1401   void ReadOptions(const TfLiteOptions& options,
1402                    TocoOperator* op) const override {
1403     op->axis = options.axis();
1404   }
1405 };
1406 
1407 class CTCBeamSearchDecoder
1408     : public CustomOperator<CTCBeamSearchDecoderOperator> {
1409  public:
1410   using CustomOperator::CustomOperator;
1411 
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1412   void WriteOptions(const TocoOperator& op,
1413                     flexbuffers::Builder* fbb) const override {
1414     fbb->Int("beam_width", op.beam_width);
1415     fbb->Int("top_paths", op.top_paths);
1416     fbb->Bool("merge_repeated", op.merge_repeated);
1417   }
1418 
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1419   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1420     op->beam_width = m["beam_width"].AsInt32();
1421     op->top_paths = m["top_paths"].AsInt32();
1422     op->merge_repeated = m["merge_repeated"].AsBool();
1423   }
1424 
GetVersion(const OperatorSignature & op_signature) const1425   int GetVersion(const OperatorSignature& op_signature) const override {
1426     return 1;
1427   }
1428 };
1429 
1430 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1431                                       ::tflite::BuiltinOptions_UnpackOptions> {
1432  public:
1433   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1434   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1435       const TocoOperator& op,
1436       flatbuffers::FlatBufferBuilder* builder) const override {
1437     return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1438   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1439   void ReadOptions(const TfLiteOptions& options,
1440                    TocoOperator* op) const override {
1441     op->num = options.num();
1442     op->axis = options.axis();
1443   }
1444 
GetVersion(const OperatorSignature & op_signature) const1445   int GetVersion(const OperatorSignature& op_signature) const override {
1446     const std::string& input_name = op_signature.op->inputs[0];
1447     const Array& input_array = op_signature.model->GetArray(input_name);
1448     // If the op take int8/uint8 input, it is version 2.
1449     if (input_array.data_type == ArrayDataType::kInt8 ||
1450         input_array.data_type == ArrayDataType::kUint8) {
1451       return 2;
1452     }
1453     // If the op take bool input, it is version 3.
1454     if (input_array.data_type == ArrayDataType::kBool) {
1455       return 3;
1456     }
1457     return 1;
1458   }
1459 };
1460 
1461 class LeakyRelu
1462     : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1463                              ::tflite::BuiltinOptions_LeakyReluOptions> {
1464  public:
1465   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1466   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1467       const TocoOperator& op,
1468       flatbuffers::FlatBufferBuilder* builder) const override {
1469     return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1470   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1471   void ReadOptions(const TfLiteOptions& options,
1472                    TocoOperator* op) const override {
1473     op->alpha = options.alpha();
1474   }
1475 };
1476 
1477 class SquaredDifference
1478     : public BuiltinOperator<
1479           SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1480           ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1481  public:
1482   using BuiltinOperator::BuiltinOperator;
1483 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1484   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1485       const TocoOperator& op,
1486       flatbuffers::FlatBufferBuilder* builder) const override {
1487     return ::tflite::CreateSquaredDifferenceOptions(*builder);
1488   }
1489 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1490   void ReadOptions(const TfLiteOptions& options,
1491                    TocoOperator* op) const override {}
1492 };
1493 
1494 class MirrorPad
1495     : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1496                              ::tflite::BuiltinOptions_MirrorPadOptions> {
1497  public:
1498   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1499   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1500       const TocoOperator& op,
1501       flatbuffers::FlatBufferBuilder* builder) const override {
1502     return ::tflite::CreateMirrorPadOptions(
1503         *builder, op.mode == MirrorPadMode::kReflect
1504                       ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1505                       : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1506   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1507   void ReadOptions(const TfLiteOptions& options,
1508                    TocoOperator* op) const override {
1509     op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1510                    ? MirrorPadMode::kReflect
1511                    : MirrorPadMode::kSymmetric;
1512   }
1513 };
1514 
1515 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1516                                       ::tflite::BuiltinOptions_UniqueOptions> {
1517  public:
1518   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1519   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1520       const TocoOperator& op,
1521       flatbuffers::FlatBufferBuilder* builder) const override {
1522     const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1523     return ::tflite::CreateUniqueOptions(
1524         *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1525                       ? ::tflite::TensorType::TensorType_INT64
1526                       : ::tflite::TensorType_INT32);
1527   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1528   void ReadOptions(const TfLiteOptions& options,
1529                    TocoOperator* op) const override {
1530     UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1531     unique_op->idx_out_type =
1532         options.idx_out_type() == ::tflite::TensorType_INT64
1533             ? toco::ArrayDataType::kInt64
1534             : toco::ArrayDataType::kInt32;
1535   }
1536 };
1537 
1538 class UnidirectionalSequenceRnn
1539     : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1540                              ::tflite::SequenceRNNOptions,
1541                              ::tflite::BuiltinOptions_SequenceRNNOptions> {
1542  public:
1543   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1544   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1545       const TocoOperator& op,
1546       flatbuffers::FlatBufferBuilder* builder) const override {
1547     return ::tflite::CreateSequenceRNNOptions(
1548         *builder, /*time_major=*/true,
1549         /*fused_activation_function=*/
1550         ::tflite::ActivationFunctionType_TANH);
1551   }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1552   void ReadOptions(const TfLiteOptions& options,
1553                    TocoOperator* op) const override {
1554     // Only support tanh activation, so check that tflite type is tanh.
1555     DCHECK(options.fused_activation_function() ==
1556            ::tflite::ActivationFunctionType_TANH);
1557   }
1558 
GetMutatingInputVariables(const Operator & op) const1559   std::vector<bool> GetMutatingInputVariables(
1560       const Operator& op) const override {
1561     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1562     mutating_input_variables[4] = true;
1563     return mutating_input_variables;
1564   }
1565 };
1566 
1567 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1568                                      ::tflite::BuiltinOptions_WhereOptions> {
1569  public:
1570   using BuiltinOperator::BuiltinOperator;
1571 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1572   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1573       const TocoOperator& op,
1574       flatbuffers::FlatBufferBuilder* builder) const override {
1575     return ::tflite::CreateWhereOptions(*builder);
1576   }
1577 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1578   void ReadOptions(const TfLiteOptions& options,
1579                    TocoOperator* op) const override {}
1580 };
1581 
WriteFlexOpOptions(const std::string & tensorflow_node_def)1582 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1583     const std::string& tensorflow_node_def) {
1584   auto fbb = absl::make_unique<flexbuffers::Builder>();
1585 
1586   ::tensorflow::NodeDef node_def;
1587   if (!node_def.ParseFromString(tensorflow_node_def)) {
1588     LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1589     return {};
1590   }
1591 
1592   fbb->Vector([&]() {
1593     fbb->String(node_def.op());
1594     fbb->String(tensorflow_node_def);
1595   });
1596   fbb->Finish();
1597   LOG(INFO) << "Writing flex op: " << node_def.op();
1598   return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1599 }
1600 
1601 class TensorFlowUnsupported : public BaseOperator {
1602  public:
TensorFlowUnsupported(const std::string & name,OperatorType type,bool enable_select_tf_ops)1603   TensorFlowUnsupported(const std::string& name, OperatorType type,
1604                         bool enable_select_tf_ops)
1605       : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1606 
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1607   Options Serialize(const Operator& op,
1608                     flatbuffers::FlatBufferBuilder* builder) const override {
1609     auto fbb =
1610         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1611     if (fbb) {
1612       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1613     } else {
1614       return Options::Custom(0);
1615     }
1616   }
1617 
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1618   std::unique_ptr<Operator> Deserialize(
1619       const BuiltinOptions* builtin_options,
1620       const CustomOptions* custom_options) const override {
1621     // Deserializing Flex ops doesn't work now.
1622     // TODO(ycling): Revisit and decide if we should fix the flow for importing
1623     // TFLite models with Flex ops.
1624     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
1625     if (custom_options) {
1626       auto flexbuffer_map =
1627           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1628               .AsMap();
1629       ReadOptions(flexbuffer_map, op.get());
1630     }
1631     return std::unique_ptr<Operator>(op.release());
1632   }
1633 
WriteOptions(const TensorFlowUnsupportedOperator & op) const1634   std::unique_ptr<flexbuffers::Builder> WriteOptions(
1635       const TensorFlowUnsupportedOperator& op) const {
1636     if (enable_select_tf_ops_) {
1637       return WriteFlexOpOptions(op.tensorflow_node_def);
1638     }
1639     auto fbb = absl::make_unique<flexbuffers::Builder>();
1640 
1641     ::tensorflow::NodeDef node_def;
1642     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1643       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1644       return std::unique_ptr<flexbuffers::Builder>();
1645     }
1646 
1647     if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1648       fbb->Vector([&]() {
1649         fbb->String(node_def.op());
1650         fbb->String(op.tensorflow_node_def);
1651       });
1652       fbb->Finish();
1653       LOG(INFO) << "Writing flex op: " << node_def.op();
1654       return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1655     }
1656 
1657     bool has_valid_attr = false;
1658     size_t map_start = fbb->StartMap();
1659     for (const auto& pair : node_def.attr()) {
1660       const char* key = pair.first.c_str();
1661       const auto& attr = pair.second;
1662       switch (attr.value_case()) {
1663         case ::tensorflow::AttrValue::kS:
1664           fbb->String(key, attr.s());
1665           has_valid_attr = true;
1666           break;
1667         case ::tensorflow::AttrValue::kI:
1668           fbb->Int(key, attr.i());
1669           has_valid_attr = true;
1670           break;
1671         case ::tensorflow::AttrValue::kF:
1672           fbb->Float(key, attr.f());
1673           has_valid_attr = true;
1674           break;
1675         case ::tensorflow::AttrValue::kB:
1676           fbb->Bool(key, attr.b());
1677           has_valid_attr = true;
1678           break;
1679         case tensorflow::AttrValue::kList:
1680           if (attr.list().s_size() > 0) {
1681             auto start = fbb->StartVector(key);
1682             for (const std::string& v : attr.list().s()) {
1683               fbb->Add(v);
1684             }
1685             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1686             has_valid_attr = true;
1687           } else if (attr.list().i_size() > 0) {
1688             auto start = fbb->StartVector(key);
1689             for (const int64_t v : attr.list().i()) {
1690               fbb->Add(v);
1691             }
1692             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1693             has_valid_attr = true;
1694           } else if (attr.list().f_size() > 0) {
1695             auto start = fbb->StartVector(key);
1696             for (const float v : attr.list().f()) {
1697               fbb->Add(v);
1698             }
1699             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1700             has_valid_attr = true;
1701           } else {
1702             LOG(WARNING)
1703                 << "Ignoring unsupported type in list attribute with key '"
1704                 << key << "'";
1705           }
1706           break;
1707         default:
1708           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1709                        << key << "'";
1710           break;
1711       }
1712     }
1713     if (!has_valid_attr) {
1714       return std::unique_ptr<flexbuffers::Builder>();
1715     }
1716     fbb->EndMap(map_start);
1717     fbb->Finish();
1718     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1719   }
1720 
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1721   void ReadOptions(const flexbuffers::Map& m,
1722                    TensorFlowUnsupportedOperator* op) const {
1723     ::tensorflow::NodeDef node_def;
1724     auto attr = node_def.mutable_attr();
1725 
1726     const auto& keys = m.Keys();
1727     for (size_t i = 0; i < keys.size(); ++i) {
1728       const auto key = keys[i].AsKey();
1729       const auto& value = m[key];
1730       switch (value.GetType()) {
1731         case flexbuffers::FBT_STRING:
1732           (*attr)[key].set_s(value.AsString().c_str());
1733           break;
1734         case flexbuffers::FBT_INT:
1735           (*attr)[key].set_i(value.AsInt64());
1736           break;
1737         case flexbuffers::FBT_FLOAT:
1738           (*attr)[key].set_f(value.AsFloat());
1739           break;
1740         case flexbuffers::FBT_BOOL:
1741           (*attr)[key].set_b(value.AsBool());
1742           if (std::string(key) == "_output_quantized") {
1743             op->quantized = value.AsBool();
1744           }
1745           if (std::string(key) ==
1746               "_support_output_type_float_in_quantized_op") {
1747             op->support_output_type_float_in_quantized_op = value.AsBool();
1748           }
1749           break;
1750         case flexbuffers::FBT_VECTOR_INT: {
1751           auto* list = (*attr)[key].mutable_list();
1752           const auto& vector = value.AsTypedVector();
1753           for (size_t i = 0; i < vector.size(); i++) {
1754             list->add_i(vector[i].AsInt64());
1755           }
1756           break;
1757         }
1758         case flexbuffers::FBT_VECTOR_FLOAT: {
1759           auto* list = (*attr)[key].mutable_list();
1760           const auto& vector = value.AsTypedVector();
1761           for (size_t i = 0; i < vector.size(); i++) {
1762             list->add_f(vector[i].AsFloat());
1763           }
1764           break;
1765         }
1766         case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
1767           auto* list = (*attr)[key].mutable_list();
1768           const auto& vector = value.AsTypedVector();
1769           for (size_t i = 0; i < vector.size(); i++) {
1770             list->add_s(vector[i].AsString().str());
1771           }
1772           break;
1773         }
1774         default:
1775           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1776                        << key << "'";
1777           break;
1778       }
1779     }
1780     node_def.SerializeToString(&op->tensorflow_node_def);
1781   }
1782 
GetVersion(const OperatorSignature & op_signature) const1783   int GetVersion(const OperatorSignature& op_signature) const override {
1784     // TODO(ycling): Design and implement a way to plumb the version of
1785     // custom ops.
1786     return 1;
1787   }
1788 
1789  private:
1790   const bool enable_select_tf_ops_;
1791 };
1792 
1793 class Dequantize
1794     : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1795                              ::tflite::BuiltinOptions_DequantizeOptions> {
1796  public:
1797   using BuiltinOperator::BuiltinOperator;
1798 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1799   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1800       const TocoOperator& op,
1801       flatbuffers::FlatBufferBuilder* builder) const override {
1802     return ::tflite::CreateDequantizeOptions(*builder);
1803   }
1804 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1805   void ReadOptions(const TfLiteOptions& options,
1806                    TocoOperator* op) const override {}
1807 };
1808 
1809 class ReverseSequence
1810     : public BuiltinOperator<ReverseSequenceOperator,
1811                              ::tflite::ReverseSequenceOptions,
1812                              ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1813  public:
1814   using BuiltinOperator::BuiltinOperator;
1815 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1816   flatbuffers::Offset<TfLiteOptions> WriteOptions(
1817       const TocoOperator& op,
1818       flatbuffers::FlatBufferBuilder* builder) const override {
1819     return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1820                                                   op.batch_dim);
1821   }
1822 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1823   void ReadOptions(const TfLiteOptions& options,
1824                    TocoOperator* op) const override {
1825     op->seq_dim = options.seq_dim();
1826     op->batch_dim = options.batch_dim();
1827   }
1828 };
1829 
1830 namespace {
1831 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1832 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1833     bool enable_select_tf_ops = false) {
1834   std::vector<std::unique_ptr<BaseOperator>> ops;
1835   using tensorflow::MakeUnique;
1836   // Builtin Operators.
1837   ops.push_back(
1838       MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1839   ops.push_back(
1840       MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1841   ops.push_back(
1842       MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1843   ops.push_back(
1844       MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1845   ops.push_back(MakeUnique<AveragePool>(
1846       ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1847   ops.push_back(
1848       MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1849                                  OperatorType::kSpaceToBatchND));
1850   ops.push_back(
1851       MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1852                                  OperatorType::kBatchToSpaceND));
1853   ops.push_back(MakeUnique<Concatenation>(
1854       ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1855   ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1856                                         OperatorType::kConv));
1857   ops.push_back(MakeUnique<DepthwiseConvolution>(
1858       ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1859       OperatorType::kDepthwiseConv));
1860   ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1861                                        OperatorType::kDequantize));
1862   ops.push_back(
1863       MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1864                                  OperatorType::kFullyConnected));
1865   ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1866                                    OperatorType::kGather));
1867   ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1868                                      OperatorType::kGatherNd));
1869   ops.push_back(
1870       MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1871                                   OperatorType::kL2Normalization));
1872   ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1873                                    OperatorType::kL2Pool));
1874   ops.push_back(MakeUnique<LocalResponseNormalization>(
1875       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1876       OperatorType::kLocalResponseNormalization));
1877   ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1878                                     OperatorType::kMaxPool));
1879   ops.push_back(
1880       MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1881 
1882   ops.push_back(
1883       MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1884   ops.push_back(
1885       MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1886   ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1887                                     OperatorType::kReshape));
1888   ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1889                                     OperatorType::kSoftmax));
1890   ops.push_back(MakeUnique<SpaceToDepth>(
1891       ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1892   ops.push_back(MakeUnique<DepthToSpace>(
1893       ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1894   ops.push_back(
1895       MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1896   ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1897                                       OperatorType::kTranspose));
1898   ops.push_back(
1899       MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1900   ops.push_back(
1901       MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1902   ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1903                                        OperatorType::kReduceProd));
1904   ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1905                                       OperatorType::kReduceMax));
1906   ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1907                                       OperatorType::kReduceMin));
1908   ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1909                                       OperatorType::kAny));
1910   ops.push_back(
1911       MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1912                                  OperatorType::kResizeBilinear));
1913   ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1914       ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1915       OperatorType::kResizeNearestNeighbor));
1916   ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1917                                     OperatorType::kSqueeze));
1918   ops.push_back(
1919       MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1920   ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1921                                    OperatorType::kSplitV));
1922   ops.push_back(MakeUnique<StridedSlice>(
1923       ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1924   ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1925                                     OperatorType::kTopK_V2));
1926   ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1927                                  OperatorType::kLstmCell));
1928   ops.push_back(
1929       MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1930   ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1931                                    OperatorType::kArgMax));
1932   ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1933                                    OperatorType::kArgMin));
1934   ops.push_back(
1935       MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1936   ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1937                                        OperatorType::kExpandDims));
1938   ops.push_back(MakeUnique<TransposeConv>(
1939       ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1940   ops.push_back(MakeUnique<SparseToDense>(
1941       ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1942   ops.push_back(
1943       MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1944   ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1945                                       OperatorType::kFakeQuant));
1946   ops.push_back(
1947       MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1948   ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1949       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1950       OperatorType::kUnidirectionalSequenceLstm));
1951   ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1952       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1953       OperatorType::kBidirectionalSequenceLstm));
1954   ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1955       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1956       OperatorType::kBidirectionalSequenceRnn));
1957   ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1958                                    OperatorType::kOneHot));
1959   ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1960                                    OperatorType::kUnpack));
1961   ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1962                                       OperatorType::kLeakyRelu));
1963   ops.push_back(MakeUnique<SquaredDifference>(
1964       ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1965       OperatorType::kSquaredDifference));
1966   ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1967                                       OperatorType::kMirrorPad));
1968   ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1969                                    OperatorType::kUnique));
1970   ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1971       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1972       OperatorType::kUnidirectionalSequenceRnn));
1973   ops.push_back(
1974       MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1975   ops.push_back(
1976       MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1977                                   OperatorType::kReverseSequence));
1978   ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1979       ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1980   ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1981       ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1982   // Custom Operators.
1983   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1984       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1985   ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1986                                                   OperatorType::kUnsupported,
1987                                                   enable_select_tf_ops));
1988 
1989   // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1990   // been modified to also export builtins. As TOCO evolved we added warnings
1991   // when custom ops are exported but SimpleOperator bypasses thoses. To
1992   // prevent user confusion we are settling on using SimpleOperator only for
1993   // builtins.
1994   ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1995       ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1996   ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1997       ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1998   ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1999       ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
2000   ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
2001       ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
2002   ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
2003       ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
2004   ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2005       ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
2006   ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
2007       ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
2008   ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
2009       ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
2010   ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
2011       ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
2012   ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
2013       ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
2014   ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
2015       ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
2016   ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
2017       ::tflite::BuiltinOperator_COS, OperatorType::kCos));
2018   ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
2019       ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
2020   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
2021       ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
2022   ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
2023       ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
2024   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
2025       ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
2026   ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
2027       ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
2028   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
2029       ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
2030   ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
2031       ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
2032   ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
2033       ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
2034   ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
2035       ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
2036   ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
2037       ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
2038   ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
2039       ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
2040   ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
2041       ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
2042   ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
2043       ::tflite::BuiltinOperator_POW, OperatorType::kPow));
2044   ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2045       ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
2046   ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2047       ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
2048   ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2049       ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
2050   ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2051       ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
2052   ops.emplace_back(new SimpleOperator<FloorModOperator>(
2053       ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
2054   ops.emplace_back(new SimpleOperator<RangeOperator>(
2055       ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
2056   // Element-wise operator
2057   ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
2058       ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
2059   ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
2060       ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
2061   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2062       ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
2063   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2064       ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
2065   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2066       ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
2067   ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2068       ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
2069   ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
2070       ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
2071   ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
2072       ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
2073   ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
2074       ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
2075   ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2076       ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2077   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2078       ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2079   ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2080       ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2081   ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
2082       ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
2083   return ops;
2084 }
2085 }  // namespace
2086 
2087 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2088 
BuildOperatorByTypeMap(bool enable_select_tf_ops)2089 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2090     bool enable_select_tf_ops) {
2091   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2092 
2093   std::vector<std::unique_ptr<BaseOperator>> ops =
2094       BuildOperatorList(enable_select_tf_ops);
2095   for (auto& op : ops) {
2096     result[op->type()] = std::move(op);
2097   }
2098 
2099   return result;
2100 }
2101 
BuildOperatorByNameMap(bool enable_select_tf_ops)2102 std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2103     bool enable_select_tf_ops) {
2104   std::map<std::string, std::unique_ptr<BaseOperator>> result;
2105 
2106   std::vector<std::unique_ptr<BaseOperator>> ops =
2107       BuildOperatorList(enable_select_tf_ops);
2108   for (auto& op : ops) {
2109     result[op->name()] = std::move(op);
2110   }
2111 
2112   return result;
2113 }
2114 
ShouldExportAsFlexOp(bool enable_select_tf_ops,const std::string & tensorflow_op_name)2115 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2116                           const std::string& tensorflow_op_name) {
2117   // If Flex ops aren't allow at all, simply return false.
2118   if (!enable_select_tf_ops) {
2119     return false;
2120   }
2121   // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2122   // it and it has been allowlisted, export the op as an Flex op. Otherwise,
2123   // export it as a regular custom op.
2124   const tensorflow::OpDef* op_def = nullptr;
2125   if (!tensorflow::OpRegistry::Global()
2126            ->LookUpOpDef(tensorflow_op_name, &op_def)
2127            .ok()) {
2128     return false;
2129   }
2130 
2131   if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
2132     LOG(WARNING) << "Op " << tensorflow_op_name
2133                  << " is a valid TensorFlow op but has not been allowlisted for"
2134                     " the TensorFlow Lite flex op set.";
2135     return false;
2136   }
2137 
2138   return true;
2139 }
2140 
2141 }  // namespace tflite
2142 
2143 }  // namespace toco
2144