1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include <map>
18
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/util/ptr_util.h"
24
25 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
26 // graph_transformation module.
27 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
32 #include "tensorflow/lite/toco/tflite/custom_operator.h"
33 #include "tensorflow/lite/toco/tflite/simple_operator.h"
34 #include "tensorflow/lite/toco/tflite/types.h"
35 #include "tensorflow/lite/tools/versioning/op_version.h"
36
37 namespace toco {
38
39 namespace tflite {
40
41 // LINT.IfChange
42
GetTensorType(const ArrayDataType type)43 ::tflite::TensorType GetTensorType(const ArrayDataType type) {
44 const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
45 {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
46 {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
47 {ArrayDataType::kInt8, ::tflite::TensorType_INT8},
48 {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
49 {ArrayDataType::kInt16, ::tflite::TensorType_INT16},
50 {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
51 {ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
52 {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
53 {ArrayDataType::kUint64, ::tflite::TensorType_UINT64},
54 {ArrayDataType::kString, ::tflite::TensorType_STRING},
55 {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
56 {ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128},
57 {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
58 {ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
59
60 auto it = tensor_type_map.find(type);
61 if (it != tensor_type_map.end()) {
62 return it->second;
63 }
64 return static_cast<::tflite::TensorType>(-1);
65 }
66
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)67 ::tflite::OpSignature GetVersioningOpSig(
68 const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
69 std::vector<::tflite::TensorType> input_types, output_types;
70 for (const auto& input_name : op_signature.op->inputs) {
71 ::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
72 if (op_signature.model->HasArray(input_name)) {
73 const Array& input_array = op_signature.model->GetArray(input_name);
74 input_type = GetTensorType(input_array.data_type);
75 }
76 input_types.push_back(input_type);
77 }
78 for (const auto& output_name : op_signature.op->outputs) {
79 ::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
80 if (op_signature.model->HasArray(output_name)) {
81 const Array& output_array = op_signature.model->GetArray(output_name);
82 output_type = GetTensorType(output_array.data_type);
83 }
84 output_types.push_back(output_type);
85 }
86 return ::tflite::OpSignature{op, input_types, output_types};
87 }
88
89 class AveragePool
90 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
91 ::tflite::BuiltinOptions_Pool2DOptions> {
92 public:
93 using BuiltinOperator::BuiltinOperator;
94
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const95 flatbuffers::Offset<TfLiteOptions> WriteOptions(
96 const TocoOperator& op,
97 flatbuffers::FlatBufferBuilder* builder) const override {
98 auto padding = Padding::Serialize(op.padding.type);
99 auto activation_function =
100 ActivationFunction::Serialize(op.fused_activation_function);
101 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
102 op.stride_height, op.kwidth,
103 op.kheight, activation_function);
104 }
105
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const106 void ReadOptions(const TfLiteOptions& options,
107 TocoOperator* op) const override {
108 op->padding.type = Padding::Deserialize(options.padding());
109 op->stride_width = options.stride_w();
110 op->stride_height = options.stride_h();
111 op->kwidth = options.filter_width();
112 op->kheight = options.filter_height();
113 op->fused_activation_function =
114 ActivationFunction::Deserialize(options.fused_activation_function());
115 }
116 };
117
118 class Convolution
119 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
120 ::tflite::BuiltinOptions_Conv2DOptions> {
121 public:
122 using BuiltinOperator::BuiltinOperator;
123
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const124 flatbuffers::Offset<TfLiteOptions> WriteOptions(
125 const TocoOperator& op,
126 flatbuffers::FlatBufferBuilder* builder) const override {
127 auto padding = Padding::Serialize(op.padding.type);
128 auto activation_function =
129 ActivationFunction::Serialize(op.fused_activation_function);
130 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
131 op.stride_height, activation_function,
132 op.dilation_width_factor,
133 op.dilation_height_factor);
134 }
135
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const136 void ReadOptions(const TfLiteOptions& options,
137 TocoOperator* op) const override {
138 op->padding.type = Padding::Deserialize(options.padding());
139 op->stride_width = options.stride_w();
140 op->stride_height = options.stride_h();
141 op->dilation_width_factor = options.dilation_w_factor();
142 op->dilation_height_factor = options.dilation_h_factor();
143 op->fused_activation_function =
144 ActivationFunction::Deserialize(options.fused_activation_function());
145 }
146 };
147
148 class DepthwiseConvolution
149 : public BuiltinOperator<DepthwiseConvOperator,
150 ::tflite::DepthwiseConv2DOptions,
151 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
152 public:
153 using BuiltinOperator::BuiltinOperator;
154
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const155 flatbuffers::Offset<TfLiteOptions> WriteOptions(
156 const TocoOperator& op,
157 flatbuffers::FlatBufferBuilder* builder) const override {
158 auto padding = Padding::Serialize(op.padding.type);
159 auto activation_function =
160 ActivationFunction::Serialize(op.fused_activation_function);
161 return ::tflite::CreateDepthwiseConv2DOptions(
162 *builder, padding, op.stride_width, op.stride_height,
163 op.depth_multiplier, activation_function, op.dilation_width_factor,
164 op.dilation_height_factor);
165 }
166
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const167 void ReadOptions(const TfLiteOptions& options,
168 TocoOperator* op) const override {
169 op->padding.type = Padding::Deserialize(options.padding());
170 op->stride_width = options.stride_w();
171 op->stride_height = options.stride_h();
172 op->depth_multiplier = options.depth_multiplier();
173 op->fused_activation_function =
174 ActivationFunction::Deserialize(options.fused_activation_function());
175 op->dilation_width_factor = options.dilation_w_factor();
176 op->dilation_height_factor = options.dilation_h_factor();
177 }
178
GetVersion(const OperatorSignature & op_signature) const179 int GetVersion(const OperatorSignature& op_signature) const override {
180 const auto& conv_op =
181 static_cast<const DepthwiseConvOperator&>(*op_signature.op);
182 ::tflite::OpSignature op_sig =
183 GetVersioningOpSig(builtin_op(), op_signature);
184 op_sig.options.depthwise_conv_2d.dilation_w_factor =
185 conv_op.dilation_width_factor;
186 op_sig.options.depthwise_conv_2d.dilation_h_factor =
187 conv_op.dilation_height_factor;
188 return ::tflite::GetBuiltinOperatorVersion(op_sig);
189 }
190 };
191
192 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
193 ::tflite::BuiltinOptions_AddOptions> {
194 public:
195 using BuiltinOperator::BuiltinOperator;
196
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const197 flatbuffers::Offset<TfLiteOptions> WriteOptions(
198 const TocoOperator& op,
199 flatbuffers::FlatBufferBuilder* builder) const override {
200 auto activation_function =
201 ActivationFunction::Serialize(op.fused_activation_function);
202 return ::tflite::CreateAddOptions(*builder, activation_function);
203 }
204
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const205 void ReadOptions(const TfLiteOptions& options,
206 TocoOperator* op) const override {
207 op->fused_activation_function =
208 ActivationFunction::Deserialize(options.fused_activation_function());
209 }
210 };
211
212 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
213 ::tflite::BuiltinOptions_AddNOptions> {
214 public:
215 using BuiltinOperator::BuiltinOperator;
216
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const217 flatbuffers::Offset<TfLiteOptions> WriteOptions(
218 const TocoOperator& op,
219 flatbuffers::FlatBufferBuilder* builder) const override {
220 return ::tflite::CreateAddNOptions(*builder);
221 }
222
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const223 void ReadOptions(const TfLiteOptions& options,
224 TocoOperator* op) const override {}
225 };
226
227 class SpaceToBatchND
228 : public BuiltinOperator<SpaceToBatchNDOperator,
229 ::tflite::SpaceToBatchNDOptions,
230 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
231 public:
232 using BuiltinOperator::BuiltinOperator;
233
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const234 flatbuffers::Offset<TfLiteOptions> WriteOptions(
235 const TocoOperator& op,
236 flatbuffers::FlatBufferBuilder* builder) const override {
237 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
238 }
239
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const240 void ReadOptions(const TfLiteOptions& options,
241 TocoOperator* op) const override {}
242
GetVersion(const OperatorSignature & op_signature) const243 int GetVersion(const OperatorSignature& op_signature) const override {
244 const std::string& input_name = op_signature.op->inputs[0];
245 const Array& input_array = op_signature.model->GetArray(input_name);
246 ::tflite::OpSignature op_sig =
247 GetVersioningOpSig(builtin_op(), op_signature);
248 op_sig.options.single_input_op.num_dims =
249 input_array.shape().dimensions_count();
250 return ::tflite::GetBuiltinOperatorVersion(op_sig);
251 }
252 };
253
254 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
255 ::tflite::BuiltinOptions_SubOptions> {
256 public:
257 using BuiltinOperator::BuiltinOperator;
258
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const259 flatbuffers::Offset<TfLiteOptions> WriteOptions(
260 const TocoOperator& op,
261 flatbuffers::FlatBufferBuilder* builder) const override {
262 auto activation_function =
263 ActivationFunction::Serialize(op.fused_activation_function);
264 return ::tflite::CreateSubOptions(*builder, activation_function);
265 }
266
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const267 void ReadOptions(const TfLiteOptions& options,
268 TocoOperator* op) const override {
269 op->fused_activation_function =
270 ActivationFunction::Deserialize(options.fused_activation_function());
271 }
272
GetVersion(const OperatorSignature & op_signature) const273 int GetVersion(const OperatorSignature& op_signature) const override {
274 const std::string& input1_name = op_signature.op->inputs[0];
275 const std::string& input2_name = op_signature.op->inputs[1];
276 const Array& input1_array = op_signature.model->GetArray(input1_name);
277 const Array& input2_array = op_signature.model->GetArray(input2_name);
278 ::tflite::OpSignature op_sig =
279 GetVersioningOpSig(builtin_op(), op_signature);
280 if (input1_array.has_shape() && input2_array.has_shape()) {
281 op_sig.options.addsub.num_dims =
282 std::max(input1_array.shape().dimensions_count(),
283 input2_array.shape().dimensions_count());
284 op_sig.options.addsub.need_broadcast =
285 (input1_array.shape() != input2_array.shape());
286 }
287 return ::tflite::GetBuiltinOperatorVersion(op_sig);
288 }
289 };
290
291 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
292 ::tflite::BuiltinOptions_DivOptions> {
293 public:
294 using BuiltinOperator::BuiltinOperator;
295
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const296 flatbuffers::Offset<TfLiteOptions> WriteOptions(
297 const TocoOperator& op,
298 flatbuffers::FlatBufferBuilder* builder) const override {
299 auto activation_function =
300 ActivationFunction::Serialize(op.fused_activation_function);
301 return ::tflite::CreateDivOptions(*builder, activation_function);
302 }
303
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const304 void ReadOptions(const TfLiteOptions& options,
305 TocoOperator* op) const override {
306 op->fused_activation_function =
307 ActivationFunction::Deserialize(options.fused_activation_function());
308 }
309
GetVersion(const OperatorSignature & op_signature) const310 int GetVersion(const OperatorSignature& op_signature) const override {
311 const std::string& input1_name = op_signature.op->inputs[0];
312 const std::string& input2_name = op_signature.op->inputs[1];
313 const Array& input1_array = op_signature.model->GetArray(input1_name);
314 const Array& input2_array = op_signature.model->GetArray(input2_name);
315 ::tflite::OpSignature op_sig =
316 GetVersioningOpSig(builtin_op(), op_signature);
317 if (input1_array.has_shape() && input2_array.has_shape()) {
318 op_sig.options.broadcast.num_dims =
319 std::max(input1_array.shape().dimensions_count(),
320 input2_array.shape().dimensions_count());
321 op_sig.options.broadcast.need_broadcast =
322 (input1_array.shape() != input2_array.shape());
323 }
324 return ::tflite::GetBuiltinOperatorVersion(op_sig);
325 }
326 };
327
328 class BatchToSpaceND
329 : public BuiltinOperator<BatchToSpaceNDOperator,
330 ::tflite::BatchToSpaceNDOptions,
331 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
332 public:
333 using BuiltinOperator::BuiltinOperator;
334
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const335 flatbuffers::Offset<TfLiteOptions> WriteOptions(
336 const TocoOperator& op,
337 flatbuffers::FlatBufferBuilder* builder) const override {
338 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
339 }
340
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const341 void ReadOptions(const TfLiteOptions& options,
342 TocoOperator* op) const override {}
343
GetVersion(const OperatorSignature & op_signature) const344 int GetVersion(const OperatorSignature& op_signature) const override {
345 const std::string& input_name = op_signature.op->inputs[0];
346 const Array& input_array = op_signature.model->GetArray(input_name);
347 ::tflite::OpSignature op_sig =
348 GetVersioningOpSig(builtin_op(), op_signature);
349 op_sig.options.single_input_op.num_dims =
350 input_array.shape().dimensions_count();
351 return ::tflite::GetBuiltinOperatorVersion(op_sig);
352 }
353 };
354
355 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
356 ::tflite::BuiltinOptions_CastOptions> {
357 public:
358 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const359 flatbuffers::Offset<TfLiteOptions> WriteOptions(
360 const TocoOperator& op,
361 flatbuffers::FlatBufferBuilder* builder) const override {
362 return ::tflite::CreateCastOptions(*builder,
363 DataType::Serialize(op.src_data_type),
364 DataType::Serialize(op.dst_data_type));
365 }
366
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const367 void ReadOptions(const TfLiteOptions& options,
368 TocoOperator* op) const override {
369 op->src_data_type = DataType::Deserialize(options.in_data_type());
370 op->dst_data_type = DataType::Deserialize(options.out_data_type());
371 }
372 };
373
374 class Concatenation
375 : public BuiltinOperator<ConcatenationOperator,
376 ::tflite::ConcatenationOptions,
377 ::tflite::BuiltinOptions_ConcatenationOptions> {
378 public:
379 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const380 flatbuffers::Offset<TfLiteOptions> WriteOptions(
381 const TocoOperator& op,
382 flatbuffers::FlatBufferBuilder* builder) const override {
383 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
384 }
385
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const386 void ReadOptions(const TfLiteOptions& options,
387 TocoOperator* op) const override {
388 op->axis = options.axis();
389 }
390 };
391
392 class DepthToSpace
393 : public BuiltinOperator<DepthToSpaceOperator,
394 ::tflite::DepthToSpaceOptions,
395 ::tflite::BuiltinOptions_DepthToSpaceOptions> {
396 public:
397 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const398 flatbuffers::Offset<TfLiteOptions> WriteOptions(
399 const TocoOperator& op,
400 flatbuffers::FlatBufferBuilder* builder) const override {
401 return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
402 }
403
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const404 void ReadOptions(const TfLiteOptions& options,
405 TocoOperator* op) const override {
406 op->block_size = options.block_size();
407 }
408 };
409
410 class FakeQuant
411 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
412 ::tflite::BuiltinOptions_FakeQuantOptions> {
413 public:
414 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const415 flatbuffers::Offset<TfLiteOptions> WriteOptions(
416 const TocoOperator& op,
417 flatbuffers::FlatBufferBuilder* builder) const override {
418 return ::tflite::CreateFakeQuantOptions(
419 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
420 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const421 void ReadOptions(const TfLiteOptions& options,
422 TocoOperator* op) const override {
423 auto* minmax = new MinMax;
424 minmax->min = options.min();
425 minmax->max = options.max();
426 op->minmax.reset(minmax);
427 op->num_bits = options.num_bits();
428 op->narrow_range = options.narrow_range();
429 }
GetVersion(const OperatorSignature & op_signature) const430 int GetVersion(const OperatorSignature& op_signature) const override {
431 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
432 ::tflite::OpSignature op_sig =
433 GetVersioningOpSig(builtin_op(), op_signature);
434 op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
435 return ::tflite::GetBuiltinOperatorVersion(op_sig);
436 }
437 };
438
439 class FullyConnected
440 : public BuiltinOperator<FullyConnectedOperator,
441 ::tflite::FullyConnectedOptions,
442 ::tflite::BuiltinOptions_FullyConnectedOptions> {
443 public:
444 using BuiltinOperator::BuiltinOperator;
445
GetWeightFormat(FullyConnectedWeightsFormat fmt) const446 ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
447 FullyConnectedWeightsFormat fmt) const {
448 switch (fmt) {
449 case FullyConnectedWeightsFormat::kDefault:
450 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
451 case FullyConnectedWeightsFormat::kShuffled4x16Int8:
452 return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
453 default:
454 LOG(ERROR) << "Unhandled FC weights format";
455 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
456 }
457 }
458
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const459 flatbuffers::Offset<TfLiteOptions> WriteOptions(
460 const TocoOperator& op,
461 flatbuffers::FlatBufferBuilder* builder) const override {
462 auto activation_function =
463 ActivationFunction::Serialize(op.fused_activation_function);
464 return ::tflite::CreateFullyConnectedOptions(
465 *builder, activation_function, GetWeightFormat(op.weights_format));
466 }
467
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const468 void ReadOptions(const TfLiteOptions& options,
469 TocoOperator* op) const override {
470 op->fused_activation_function =
471 ActivationFunction::Deserialize(options.fused_activation_function());
472 switch (options.weights_format()) {
473 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
474 op->weights_format = FullyConnectedWeightsFormat::kDefault;
475 break;
476 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
477 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
478 break;
479 default:
480 LOG(ERROR) << "Unhandled FC weights format";
481 op->weights_format = FullyConnectedWeightsFormat::kDefault;
482 }
483 }
484
GetVersion(const OperatorSignature & op_signature) const485 int GetVersion(const OperatorSignature& op_signature) const override {
486 const auto& fc_op =
487 static_cast<const FullyConnectedOperator&>(*op_signature.op);
488 ::tflite::OpSignature op_sig =
489 GetVersioningOpSig(builtin_op(), op_signature);
490 op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
491 op_sig.options.fully_connected.weights_format =
492 GetWeightFormat(fc_op.weights_format);
493 op_sig.options.fully_connected.sparse_weight = false;
494 return ::tflite::GetBuiltinOperatorVersion(op_sig);
495 }
496 };
497
498 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
499 ::tflite::BuiltinOptions_GatherOptions> {
500 public:
501 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const502 flatbuffers::Offset<TfLiteOptions> WriteOptions(
503 const TocoOperator& op,
504 flatbuffers::FlatBufferBuilder* builder) const override {
505 int axis = op.axis ? op.axis.value() : 0;
506 return ::tflite::CreateGatherOptions(*builder, axis);
507 }
508
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const509 void ReadOptions(const TfLiteOptions& options,
510 TocoOperator* op) const override {
511 op->axis = {options.axis()};
512 }
513 };
514
515 class GatherNd
516 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
517 ::tflite::BuiltinOptions_GatherNdOptions> {
518 public:
519 using BuiltinOperator::BuiltinOperator;
520
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const521 flatbuffers::Offset<TfLiteOptions> WriteOptions(
522 const TocoOperator& op,
523 flatbuffers::FlatBufferBuilder* builder) const override {
524 return ::tflite::CreateGatherNdOptions(*builder);
525 }
526
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const527 void ReadOptions(const TfLiteOptions& options,
528 TocoOperator* op) const override {}
529 };
530
531 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
532 ::tflite::BuiltinOptions_SVDFOptions> {
533 public:
534 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const535 flatbuffers::Offset<TfLiteOptions> WriteOptions(
536 const TocoOperator& op,
537 flatbuffers::FlatBufferBuilder* builder) const override {
538 auto activation_function =
539 ActivationFunction::Serialize(op.fused_activation_function);
540 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
541 }
542
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const543 void ReadOptions(const TfLiteOptions& options,
544 TocoOperator* op) const override {
545 op->fused_activation_function =
546 ActivationFunction::Deserialize(options.fused_activation_function());
547 op->rank = options.rank();
548 }
549 };
550
551 class L2Normalization
552 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
553 ::tflite::BuiltinOptions_L2NormOptions> {
554 public:
555 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const556 flatbuffers::Offset<TfLiteOptions> WriteOptions(
557 const TocoOperator& op,
558 flatbuffers::FlatBufferBuilder* builder) const override {
559 auto activation_function =
560 ActivationFunction::Serialize(op.fused_activation_function);
561 return ::tflite::CreateL2NormOptions(*builder, activation_function);
562 }
563
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const564 void ReadOptions(const TfLiteOptions& options,
565 TocoOperator* op) const override {
566 op->fused_activation_function =
567 ActivationFunction::Deserialize(options.fused_activation_function());
568 }
569 };
570
571 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
572 ::tflite::BuiltinOptions_Pool2DOptions> {
573 public:
574 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const575 flatbuffers::Offset<TfLiteOptions> WriteOptions(
576 const TocoOperator& op,
577 flatbuffers::FlatBufferBuilder* builder) const override {
578 auto padding = Padding::Serialize(op.padding.type);
579 auto activation_function =
580 ActivationFunction::Serialize(op.fused_activation_function);
581 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
582 op.stride_height, op.kwidth,
583 op.kheight, activation_function);
584 }
585
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const586 void ReadOptions(const TfLiteOptions& options,
587 TocoOperator* op) const override {
588 op->padding.type = Padding::Deserialize(options.padding());
589 op->stride_width = options.stride_w();
590 op->stride_height = options.stride_h();
591 op->kwidth = options.filter_width();
592 op->kheight = options.filter_height();
593 op->fused_activation_function =
594 ActivationFunction::Deserialize(options.fused_activation_function());
595 }
596 };
597
598 class LocalResponseNormalization
599 : public BuiltinOperator<
600 LocalResponseNormalizationOperator,
601 ::tflite::LocalResponseNormalizationOptions,
602 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
603 public:
604 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const605 flatbuffers::Offset<TfLiteOptions> WriteOptions(
606 const TocoOperator& op,
607 flatbuffers::FlatBufferBuilder* builder) const override {
608 return ::tflite::CreateLocalResponseNormalizationOptions(
609 *builder, op.range, op.bias, op.alpha, op.beta);
610 }
611
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const612 void ReadOptions(const TfLiteOptions& options,
613 TocoOperator* op) const override {
614 op->range = options.radius();
615 op->bias = options.bias();
616 op->alpha = options.alpha();
617 op->beta = options.beta();
618 }
619 };
620
621 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
622 ::tflite::BuiltinOptions_Pool2DOptions> {
623 public:
624 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const625 flatbuffers::Offset<TfLiteOptions> WriteOptions(
626 const TocoOperator& op,
627 flatbuffers::FlatBufferBuilder* builder) const override {
628 auto padding = Padding::Serialize(op.padding.type);
629 auto activation_function =
630 ActivationFunction::Serialize(op.fused_activation_function);
631 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
632 op.stride_height, op.kwidth,
633 op.kheight, activation_function);
634 }
635
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const636 void ReadOptions(const TfLiteOptions& options,
637 TocoOperator* op) const override {
638 op->padding.type = Padding::Deserialize(options.padding());
639 op->stride_width = options.stride_w();
640 op->stride_height = options.stride_h();
641 op->kwidth = options.filter_width();
642 op->kheight = options.filter_height();
643 op->fused_activation_function =
644 ActivationFunction::Deserialize(options.fused_activation_function());
645 }
646 };
647
648 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
649 ::tflite::BuiltinOptions_MulOptions> {
650 public:
651 using BuiltinOperator::BuiltinOperator;
652
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const653 flatbuffers::Offset<TfLiteOptions> WriteOptions(
654 const TocoOperator& op,
655 flatbuffers::FlatBufferBuilder* builder) const override {
656 auto activation_function =
657 ActivationFunction::Serialize(op.fused_activation_function);
658 return ::tflite::CreateMulOptions(*builder, activation_function);
659 }
660
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const661 void ReadOptions(const TfLiteOptions& options,
662 TocoOperator* op) const override {
663 op->fused_activation_function =
664 ActivationFunction::Deserialize(options.fused_activation_function());
665 }
666
GetVersion(const OperatorSignature & op_signature) const667 int GetVersion(const OperatorSignature& op_signature) const override {
668 const std::string& input1_name = op_signature.op->inputs[0];
669 const std::string& input2_name = op_signature.op->inputs[1];
670 const std::string& output_name = op_signature.op->outputs[0];
671 const Array& input1_array = op_signature.model->GetArray(input1_name);
672 const Array& input2_array = op_signature.model->GetArray(input2_name);
673 const Array& output_array = op_signature.model->GetArray(output_name);
674 const auto& input1_quant = input1_array.quantization_params;
675 const auto& input2_quant = input2_array.quantization_params;
676 const auto& output_quant = output_array.quantization_params;
677 const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
678 const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
679 const float output_scale = output_quant ? output_quant->scale : 0.0f;
680 ::tflite::OpSignature op_sig =
681 GetVersioningOpSig(builtin_op(), op_signature);
682 op_sig.options.mul.input1_scale = input1_scale;
683 op_sig.options.mul.input2_scale = input2_scale;
684 op_sig.options.mul.output_scale = output_scale;
685 return ::tflite::GetBuiltinOperatorVersion(op_sig);
686 }
687 };
688
689 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
690 ::tflite::BuiltinOptions_PadOptions> {
691 public:
692 using BuiltinOperator::BuiltinOperator;
693
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const694 flatbuffers::Offset<TfLiteOptions> WriteOptions(
695 const TocoOperator& op,
696 flatbuffers::FlatBufferBuilder* builder) const override {
697 return ::tflite::CreatePadOptions(*builder);
698 }
699
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const700 void ReadOptions(const TfLiteOptions& options,
701 TocoOperator* op) const override {}
702 };
703
704 class Tile
705 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
706 ::tflite::BuiltinOptions_TileOptions> {
707 using BuiltinOperator::BuiltinOperator;
708
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const709 flatbuffers::Offset<TfLiteOptions> WriteOptions(
710 const TocoOperator& op,
711 flatbuffers::FlatBufferBuilder* builder) const override {
712 return ::tflite::CreateTileOptions(*builder);
713 }
714
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const715 void ReadOptions(const TfLiteOptions& options,
716 TocoOperator* op) const override {}
717 };
718
719 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
720 ::tflite::BuiltinOptions_PadV2Options> {
721 public:
722 using BuiltinOperator::BuiltinOperator;
723
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const724 flatbuffers::Offset<TfLiteOptions> WriteOptions(
725 const TocoOperator& op,
726 flatbuffers::FlatBufferBuilder* builder) const override {
727 return ::tflite::CreatePadV2Options(*builder);
728 }
729
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const730 void ReadOptions(const TfLiteOptions& options,
731 TocoOperator* op) const override {}
732 };
733
734 class Reshape
735 : public BuiltinOperator<TensorFlowReshapeOperator,
736 ::tflite::ReshapeOptions,
737 ::tflite::BuiltinOptions_ReshapeOptions> {
738 public:
739 using BuiltinOperator::BuiltinOperator;
740
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const741 flatbuffers::Offset<TfLiteOptions> WriteOptions(
742 const TocoOperator& op,
743 flatbuffers::FlatBufferBuilder* builder) const override {
744 return ::tflite::CreateReshapeOptions(*builder,
745 builder->CreateVector(op.shape));
746 }
747
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const748 void ReadOptions(const TfLiteOptions& options,
749 TocoOperator* op) const override {
750 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
751 options.new_shape()->end());
752 }
753 };
754
755 class Softmax
756 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
757 ::tflite::BuiltinOptions_SoftmaxOptions> {
758 public:
759 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const760 flatbuffers::Offset<TfLiteOptions> WriteOptions(
761 const TocoOperator& op,
762 flatbuffers::FlatBufferBuilder* builder) const override {
763 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
764 }
765
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const766 void ReadOptions(const TfLiteOptions& options,
767 TocoOperator* op) const override {
768 op->beta = options.beta();
769 }
770 };
771
772 class SpaceToDepth
773 : public BuiltinOperator<SpaceToDepthOperator,
774 ::tflite::SpaceToDepthOptions,
775 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
776 public:
777 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const778 flatbuffers::Offset<TfLiteOptions> WriteOptions(
779 const TocoOperator& op,
780 flatbuffers::FlatBufferBuilder* builder) const override {
781 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
782 }
783
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const784 void ReadOptions(const TfLiteOptions& options,
785 TocoOperator* op) const override {
786 op->block_size = options.block_size();
787 }
788 };
789
790 class Transpose
791 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
792 ::tflite::BuiltinOptions_TransposeOptions> {
793 public:
794 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const795 flatbuffers::Offset<TfLiteOptions> WriteOptions(
796 const TocoOperator& op,
797 flatbuffers::FlatBufferBuilder* builder) const override {
798 return ::tflite::CreateTransposeOptions(*builder);
799 }
800
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const801 void ReadOptions(const TfLiteOptions& options,
802 TocoOperator* op) const override {}
803 };
804
805 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
806 ::tflite::BuiltinOptions_LSTMOptions> {
807 public:
808 using BuiltinOperator::BuiltinOperator;
809
GetKernelType(LstmCellOperator::KernelType type) const810 ::tflite::LSTMKernelType GetKernelType(
811 LstmCellOperator::KernelType type) const {
812 switch (type) {
813 case LstmCellOperator::KERNEL_BASIC:
814 return ::tflite::LSTMKernelType_BASIC;
815 break;
816 case LstmCellOperator::KERNEL_FULL:
817 return ::tflite::LSTMKernelType_FULL;
818 break;
819 default:
820 LOG(ERROR) << "Unhandled Kernel Type";
821 return static_cast<::tflite::LSTMKernelType>(-1);
822 }
823 }
824
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const825 flatbuffers::Offset<TfLiteOptions> WriteOptions(
826 const TocoOperator& op,
827 flatbuffers::FlatBufferBuilder* builder) const override {
828 ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
829
830 // Current toco converter only supports tanh, no clip.
831 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
832 ::tflite::ActivationFunctionType_TANH,
833 /*cell_clip=*/0.0,
834 /*proj_clip=*/0.0, kernel_type);
835 }
836
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const837 void ReadOptions(const TfLiteOptions& options,
838 TocoOperator* op) const override {
839 // Only support tanh activation, so check that tflite type is tanh.
840 CHECK(options.fused_activation_function() ==
841 ::tflite::ActivationFunctionType_TANH);
842
843 switch (options.kernel_type()) {
844 case ::tflite::LSTMKernelType_BASIC:
845 op->kernel_type = LstmCellOperator::KERNEL_BASIC;
846 break;
847 case ::tflite::LSTMKernelType_FULL:
848 op->kernel_type = LstmCellOperator::KERNEL_FULL;
849 break;
850 }
851 }
852
GetVersion(const OperatorSignature & op_signature) const853 int GetVersion(const OperatorSignature& op_signature) const override {
854 const auto& lstm_op =
855 static_cast<const LstmCellOperator&>(*op_signature.op);
856 ::tflite::OpSignature op_sig =
857 GetVersioningOpSig(builtin_op(), op_signature);
858 op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
859 return ::tflite::GetBuiltinOperatorVersion(op_sig);
860 }
861
GetMutatingInputVariables(const Operator & op) const862 std::vector<bool> GetMutatingInputVariables(
863 const Operator& op) const override {
864 const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
865
866 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
867 switch (lstm_op.kernel_type) {
868 case LstmCellOperator::KERNEL_FULL: {
869 mutating_input_variables[kInputActivationStateTensor] = true;
870 mutating_input_variables[kInputCellStateTensor] = true;
871 break;
872 }
873 case LstmCellOperator::KERNEL_BASIC: {
874 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
875 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
876 break;
877 }
878 }
879 return mutating_input_variables;
880 }
881 };
882
883 class UnidirectionalSequenceLstm
884 : public BuiltinOperator<
885 UnidirectionalSequenceLstmOperator,
886 ::tflite::UnidirectionalSequenceLSTMOptions,
887 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
888 public:
889 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const890 flatbuffers::Offset<TfLiteOptions> WriteOptions(
891 const TocoOperator& op,
892 flatbuffers::FlatBufferBuilder* builder) const override {
893 // Current toco converter only supports tanh, no clip.
894 return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
895 *builder, /*fused_activation_function=*/
896 ::tflite::ActivationFunctionType_TANH,
897 /*cell_clip=*/0.0,
898 /*proj_clip=*/0.0,
899 /*time_major=*/true);
900 }
901
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const902 void ReadOptions(const TfLiteOptions& options,
903 TocoOperator* op) const override {
904 // Only support tanh activation, so check that tflite type is tanh.
905 DCHECK(options.fused_activation_function() ==
906 ::tflite::ActivationFunctionType_TANH);
907 }
908
GetMutatingInputVariables(const Operator & op) const909 std::vector<bool> GetMutatingInputVariables(
910 const Operator& op) const override {
911 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
912 mutating_input_variables[kInputActivationStateTensor] = true;
913 mutating_input_variables[kInputCellStateTensor] = true;
914 return mutating_input_variables;
915 }
916 };
917
918 class BidirectionalSequenceLstm
919 : public BuiltinOperator<
920 BidirectionalSequenceLstmOperator,
921 ::tflite::BidirectionalSequenceLSTMOptions,
922 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
923 public:
924 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const925 flatbuffers::Offset<TfLiteOptions> WriteOptions(
926 const TocoOperator& op,
927 flatbuffers::FlatBufferBuilder* builder) const override {
928 // Current toco converter only supports tanh, no clip.
929 return ::tflite::CreateBidirectionalSequenceLSTMOptions(
930 *builder, /*fused_activation_function=*/
931 ::tflite::ActivationFunctionType_TANH,
932 /*cell_clip=*/0.0,
933 /*proj_clip=*/0.0,
934 /*merge_outputs=*/op.merge_outputs,
935 /*time_major=*/true);
936 }
937
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const938 void ReadOptions(const TfLiteOptions& options,
939 TocoOperator* op) const override {
940 // Only support tanh activation, so check that tflite type is tanh.
941 DCHECK(options.fused_activation_function() ==
942 ::tflite::ActivationFunctionType_TANH);
943 op->merge_outputs = options.merge_outputs();
944 }
945
GetMutatingInputVariables(const Operator & op) const946 std::vector<bool> GetMutatingInputVariables(
947 const Operator& op) const override {
948 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
949 // Forward input activation state.
950 mutating_input_variables[35] = true;
951 // Forward input cell state.
952 mutating_input_variables[36] = true;
953 // Backward input activation state.
954 mutating_input_variables[37] = true;
955 // Backward input cell state.
956 mutating_input_variables[38] = true;
957 return mutating_input_variables;
958 }
959 };
960
961 class BidirectionalSequenceRnn
962 : public BuiltinOperator<
963 BidirectionalSequenceRnnOperator,
964 ::tflite::BidirectionalSequenceRNNOptions,
965 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
966 public:
967 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const968 flatbuffers::Offset<TfLiteOptions> WriteOptions(
969 const TocoOperator& op,
970 flatbuffers::FlatBufferBuilder* builder) const override {
971 // Current toco converter only supports tanh, no clip.
972 return ::tflite::CreateBidirectionalSequenceRNNOptions(
973 *builder, /*time_major=*/true,
974 /*fused_activation_function=*/
975 ::tflite::ActivationFunctionType_TANH,
976 /*merge_outputs=*/op.merge_outputs);
977 }
978
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const979 void ReadOptions(const TfLiteOptions& options,
980 TocoOperator* op) const override {
981 // Only support tanh activation, so check that tflite type is tanh.
982 DCHECK(options.fused_activation_function() ==
983 ::tflite::ActivationFunctionType_TANH);
984 op->merge_outputs = options.merge_outputs();
985 }
986
GetMutatingInputVariables(const Operator & op) const987 std::vector<bool> GetMutatingInputVariables(
988 const Operator& op) const override {
989 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
990 // Forward hidden state.
991 mutating_input_variables[4] = true;
992 // Backward hidden state.
993 mutating_input_variables[8] = true;
994 return mutating_input_variables;
995 }
996 };
997
998 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
999 ::tflite::BuiltinOptions_ReducerOptions> {
1000 public:
1001 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1002 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1003 const TocoOperator& op,
1004 flatbuffers::FlatBufferBuilder* builder) const override {
1005 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1006 }
1007
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1008 void ReadOptions(const TfLiteOptions& options,
1009 TocoOperator* op) const override {
1010 op->keep_dims = options.keep_dims();
1011 }
1012 };
1013
1014 class Sum
1015 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1016 ::tflite::BuiltinOptions_ReducerOptions> {
1017 public:
1018 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1019 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1020 const TocoOperator& op,
1021 flatbuffers::FlatBufferBuilder* builder) const override {
1022 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1023 }
1024
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1025 void ReadOptions(const TfLiteOptions& options,
1026 TocoOperator* op) const override {
1027 op->keep_dims = options.keep_dims();
1028 }
1029 };
1030
1031 class ReduceMax
1032 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1033 ::tflite::BuiltinOptions_ReducerOptions> {
1034 public:
1035 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1036 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1037 const TocoOperator& op,
1038 flatbuffers::FlatBufferBuilder* builder) const override {
1039 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1040 }
1041
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1042 void ReadOptions(const TfLiteOptions& options,
1043 TocoOperator* op) const override {
1044 op->keep_dims = options.keep_dims();
1045 }
1046 };
1047
1048 class ReduceMin
1049 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1050 ::tflite::BuiltinOptions_ReducerOptions> {
1051 public:
1052 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1053 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1054 const TocoOperator& op,
1055 flatbuffers::FlatBufferBuilder* builder) const override {
1056 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1057 }
1058
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1059 void ReadOptions(const TfLiteOptions& options,
1060 TocoOperator* op) const override {
1061 op->keep_dims = options.keep_dims();
1062 }
1063 };
1064
1065 class ReduceProd
1066 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1067 ::tflite::BuiltinOptions_ReducerOptions> {
1068 public:
1069 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1070 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1071 const TocoOperator& op,
1072 flatbuffers::FlatBufferBuilder* builder) const override {
1073 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1074 }
1075
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1076 void ReadOptions(const TfLiteOptions& options,
1077 TocoOperator* op) const override {
1078 op->keep_dims = options.keep_dims();
1079 }
1080 };
1081
1082 class ReduceAny
1083 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1084 ::tflite::BuiltinOptions_ReducerOptions> {
1085 public:
1086 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1087 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1088 const TocoOperator& op,
1089 flatbuffers::FlatBufferBuilder* builder) const override {
1090 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1091 }
1092
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1093 void ReadOptions(const TfLiteOptions& options,
1094 TocoOperator* op) const override {
1095 op->keep_dims = options.keep_dims();
1096 }
1097 };
1098
1099 class ResizeBilinear
1100 : public BuiltinOperator<ResizeBilinearOperator,
1101 ::tflite::ResizeBilinearOptions,
1102 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1103 public:
1104 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1105 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1106 const TocoOperator& op,
1107 flatbuffers::FlatBufferBuilder* builder) const override {
1108 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1109 op.half_pixel_centers);
1110 }
1111
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1112 void ReadOptions(const TfLiteOptions& options,
1113 TocoOperator* op) const override {
1114 op->align_corners = options.align_corners();
1115 op->half_pixel_centers = options.half_pixel_centers();
1116 }
1117
GetVersion(const OperatorSignature & op_signature) const1118 int GetVersion(const OperatorSignature& op_signature) const override {
1119 const auto& resize_bilinear_op =
1120 static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1121 ::tflite::OpSignature op_sig =
1122 GetVersioningOpSig(builtin_op(), op_signature);
1123 op_sig.options.resize.half_pixel_centers =
1124 resize_bilinear_op.half_pixel_centers;
1125 op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
1126 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1127 }
1128 };
1129
1130 class ResizeNearestNeighbor
1131 : public BuiltinOperator<
1132 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1133 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1134 public:
1135 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1136 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1137 const TocoOperator& op,
1138 flatbuffers::FlatBufferBuilder* builder) const override {
1139 return ::tflite::CreateResizeNearestNeighborOptions(
1140 *builder, op.align_corners, op.half_pixel_centers);
1141 }
1142
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1143 void ReadOptions(const TfLiteOptions& options,
1144 TocoOperator* op) const override {
1145 op->align_corners = options.align_corners();
1146 op->half_pixel_centers = options.half_pixel_centers();
1147 }
1148
GetVersion(const OperatorSignature & op_signature) const1149 int GetVersion(const OperatorSignature& op_signature) const override {
1150 const auto& resize_nn_op =
1151 static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
1152 ::tflite::OpSignature op_sig =
1153 GetVersioningOpSig(builtin_op(), op_signature);
1154 op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
1155 op_sig.options.resize.align_corners = resize_nn_op.align_corners;
1156 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1157 }
1158 };
1159
1160 class Squeeze
1161 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1162 ::tflite::BuiltinOptions_SqueezeOptions> {
1163 public:
1164 using BuiltinOperator::BuiltinOperator;
1165
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1166 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1167 const TocoOperator& op,
1168 flatbuffers::FlatBufferBuilder* builder) const override {
1169 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1170 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1171 }
1172
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1173 void ReadOptions(const TfLiteOptions& options,
1174 TocoOperator* op) const override {
1175 op->squeeze_dims.insert(op->squeeze_dims.end(),
1176 options.squeeze_dims()->begin(),
1177 options.squeeze_dims()->end());
1178 }
1179 };
1180
1181 class Split
1182 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1183 ::tflite::BuiltinOptions_SplitOptions> {
1184 public:
1185 using BuiltinOperator::BuiltinOperator;
1186
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1187 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1188 const TocoOperator& op,
1189 flatbuffers::FlatBufferBuilder* builder) const override {
1190 return ::tflite::CreateSplitOptions(*builder, op.num_split);
1191 }
1192
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1193 void ReadOptions(const TfLiteOptions& options,
1194 TocoOperator* op) const override {
1195 op->num_split = options.num_splits();
1196 }
1197 };
1198
1199 class SplitV
1200 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1201 ::tflite::BuiltinOptions_SplitVOptions> {
1202 public:
1203 using BuiltinOperator::BuiltinOperator;
1204
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1205 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1206 const TocoOperator& op,
1207 flatbuffers::FlatBufferBuilder* builder) const override {
1208 return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1209 }
1210
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1211 void ReadOptions(const TfLiteOptions& options,
1212 TocoOperator* op) const override {
1213 op->num_split = options.num_splits();
1214 }
1215 };
1216
1217 class StridedSlice
1218 : public BuiltinOperator<StridedSliceOperator,
1219 ::tflite::StridedSliceOptions,
1220 ::tflite::BuiltinOptions_StridedSliceOptions> {
1221 public:
1222 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1223 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1224 const TocoOperator& op,
1225 flatbuffers::FlatBufferBuilder* builder) const override {
1226 return ::tflite::CreateStridedSliceOptions(
1227 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1228 op.new_axis_mask, op.shrink_axis_mask);
1229 }
1230
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1231 void ReadOptions(const TfLiteOptions& options,
1232 TocoOperator* op) const override {
1233 op->begin_mask = options.begin_mask();
1234 op->end_mask = options.end_mask();
1235 op->ellipsis_mask = options.ellipsis_mask();
1236 op->new_axis_mask = options.new_axis_mask();
1237 op->shrink_axis_mask = options.shrink_axis_mask();
1238 }
1239
GetVersion(const OperatorSignature & op_signature) const1240 int GetVersion(const OperatorSignature& op_signature) const override {
1241 const auto& ss_op =
1242 static_cast<const StridedSliceOperator&>(*op_signature.op);
1243 ::tflite::OpSignature op_sig =
1244 GetVersioningOpSig(builtin_op(), op_signature);
1245 op_sig.options.single_input_op.num_dims = ss_op.start_indices.size();
1246 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1247 }
1248 };
1249
1250 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1251 ::tflite::BuiltinOptions_TopKV2Options> {
1252 public:
1253 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1254 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1255 const TocoOperator& op,
1256 flatbuffers::FlatBufferBuilder* builder) const override {
1257 return ::tflite::CreateTopKV2Options(*builder);
1258 }
1259
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1260 void ReadOptions(const TfLiteOptions& options,
1261 TocoOperator* op) const override {}
1262 };
1263
1264 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1265 ::tflite::BuiltinOptions_ArgMaxOptions> {
1266 public:
1267 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1268 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1269 const TocoOperator& op,
1270 flatbuffers::FlatBufferBuilder* builder) const override {
1271 return ::tflite::CreateArgMaxOptions(
1272 *builder, DataType::Serialize(op.output_data_type));
1273 }
1274
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1275 void ReadOptions(const TfLiteOptions& options,
1276 TocoOperator* op) const override {
1277 op->output_data_type = DataType::Deserialize(options.output_type());
1278 }
1279 };
1280
1281 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1282 ::tflite::BuiltinOptions_ArgMinOptions> {
1283 public:
1284 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1285 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1286 const TocoOperator& op,
1287 flatbuffers::FlatBufferBuilder* builder) const override {
1288 return ::tflite::CreateArgMinOptions(
1289 *builder, DataType::Serialize(op.output_data_type));
1290 }
1291
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1292 void ReadOptions(const TfLiteOptions& options,
1293 TocoOperator* op) const override {
1294 op->output_data_type = DataType::Deserialize(options.output_type());
1295 }
1296 };
1297
1298 class TransposeConv
1299 : public BuiltinOperator<TransposeConvOperator,
1300 ::tflite::TransposeConvOptions,
1301 ::tflite::BuiltinOptions_TransposeConvOptions> {
1302 public:
1303 using BuiltinOperator::BuiltinOperator;
1304
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1305 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1306 const TocoOperator& op,
1307 flatbuffers::FlatBufferBuilder* builder) const override {
1308 auto padding = Padding::Serialize(op.padding.type);
1309 return ::tflite::CreateTransposeConvOptions(
1310 *builder, padding, op.stride_width, op.stride_height);
1311 }
1312
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1313 void ReadOptions(const TfLiteOptions& options,
1314 TocoOperator* op) const override {
1315 op->padding.type = Padding::Deserialize(options.padding());
1316 op->stride_width = options.stride_w();
1317 op->stride_height = options.stride_h();
1318 }
1319 };
1320
1321 class SparseToDense
1322 : public BuiltinOperator<SparseToDenseOperator,
1323 ::tflite::SparseToDenseOptions,
1324 ::tflite::BuiltinOptions_SparseToDenseOptions> {
1325 public:
1326 using BuiltinOperator::BuiltinOperator;
1327
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1328 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1329 const TocoOperator& op,
1330 flatbuffers::FlatBufferBuilder* builder) const override {
1331 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1332 }
1333
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1334 void ReadOptions(const TfLiteOptions& options,
1335 TocoOperator* op) const override {
1336 op->validate_indices = options.validate_indices();
1337 }
1338 };
1339
1340 class ExpandDims
1341 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1342 ::tflite::BuiltinOptions_ExpandDimsOptions> {
1343 public:
1344 using BuiltinOperator::BuiltinOperator;
1345
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1346 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1347 const TocoOperator& op,
1348 flatbuffers::FlatBufferBuilder* builder) const override {
1349 return ::tflite::CreateExpandDimsOptions(*builder);
1350 }
1351
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1352 void ReadOptions(const TfLiteOptions& options,
1353 TocoOperator* op) const override {}
1354 };
1355
1356 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1357 ::tflite::BuiltinOptions_PackOptions> {
1358 public:
1359 using BuiltinOperator::BuiltinOperator;
1360
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1361 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1362 const TocoOperator& op,
1363 flatbuffers::FlatBufferBuilder* builder) const override {
1364 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1365 }
1366
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1367 void ReadOptions(const TfLiteOptions& options,
1368 TocoOperator* op) const override {
1369 op->values_count = options.values_count();
1370 op->axis = options.axis();
1371 }
1372 };
1373
1374 class Shape
1375 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1376 ::tflite::BuiltinOptions_ShapeOptions> {
1377 public:
1378 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1379 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1380 const TocoOperator& op,
1381 flatbuffers::FlatBufferBuilder* builder) const override {
1382 return ::tflite::CreateShapeOptions(
1383 *builder, DataType::Serialize(op.output_data_type));
1384 }
1385
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1386 void ReadOptions(const TfLiteOptions& options,
1387 TocoOperator* op) const override {
1388 op->output_data_type = DataType::Deserialize(options.out_type());
1389 }
1390 };
1391
1392 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1393 ::tflite::BuiltinOptions_OneHotOptions> {
1394 public:
1395 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1396 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1397 const TocoOperator& op,
1398 flatbuffers::FlatBufferBuilder* builder) const override {
1399 return ::tflite::CreateOneHotOptions(*builder, op.axis);
1400 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1401 void ReadOptions(const TfLiteOptions& options,
1402 TocoOperator* op) const override {
1403 op->axis = options.axis();
1404 }
1405 };
1406
1407 class CTCBeamSearchDecoder
1408 : public CustomOperator<CTCBeamSearchDecoderOperator> {
1409 public:
1410 using CustomOperator::CustomOperator;
1411
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1412 void WriteOptions(const TocoOperator& op,
1413 flexbuffers::Builder* fbb) const override {
1414 fbb->Int("beam_width", op.beam_width);
1415 fbb->Int("top_paths", op.top_paths);
1416 fbb->Bool("merge_repeated", op.merge_repeated);
1417 }
1418
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1419 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1420 op->beam_width = m["beam_width"].AsInt32();
1421 op->top_paths = m["top_paths"].AsInt32();
1422 op->merge_repeated = m["merge_repeated"].AsBool();
1423 }
1424
GetVersion(const OperatorSignature & op_signature) const1425 int GetVersion(const OperatorSignature& op_signature) const override {
1426 return 1;
1427 }
1428 };
1429
1430 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1431 ::tflite::BuiltinOptions_UnpackOptions> {
1432 public:
1433 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1434 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1435 const TocoOperator& op,
1436 flatbuffers::FlatBufferBuilder* builder) const override {
1437 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1438 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1439 void ReadOptions(const TfLiteOptions& options,
1440 TocoOperator* op) const override {
1441 op->num = options.num();
1442 op->axis = options.axis();
1443 }
1444
GetVersion(const OperatorSignature & op_signature) const1445 int GetVersion(const OperatorSignature& op_signature) const override {
1446 const std::string& input_name = op_signature.op->inputs[0];
1447 const Array& input_array = op_signature.model->GetArray(input_name);
1448 // If the op take int8/uint8 input, it is version 2.
1449 if (input_array.data_type == ArrayDataType::kInt8 ||
1450 input_array.data_type == ArrayDataType::kUint8) {
1451 return 2;
1452 }
1453 // If the op take bool input, it is version 3.
1454 if (input_array.data_type == ArrayDataType::kBool) {
1455 return 3;
1456 }
1457 return 1;
1458 }
1459 };
1460
1461 class LeakyRelu
1462 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1463 ::tflite::BuiltinOptions_LeakyReluOptions> {
1464 public:
1465 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1466 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1467 const TocoOperator& op,
1468 flatbuffers::FlatBufferBuilder* builder) const override {
1469 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1470 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1471 void ReadOptions(const TfLiteOptions& options,
1472 TocoOperator* op) const override {
1473 op->alpha = options.alpha();
1474 }
1475 };
1476
1477 class SquaredDifference
1478 : public BuiltinOperator<
1479 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1480 ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1481 public:
1482 using BuiltinOperator::BuiltinOperator;
1483
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1484 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1485 const TocoOperator& op,
1486 flatbuffers::FlatBufferBuilder* builder) const override {
1487 return ::tflite::CreateSquaredDifferenceOptions(*builder);
1488 }
1489
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1490 void ReadOptions(const TfLiteOptions& options,
1491 TocoOperator* op) const override {}
1492 };
1493
1494 class MirrorPad
1495 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1496 ::tflite::BuiltinOptions_MirrorPadOptions> {
1497 public:
1498 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1499 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1500 const TocoOperator& op,
1501 flatbuffers::FlatBufferBuilder* builder) const override {
1502 return ::tflite::CreateMirrorPadOptions(
1503 *builder, op.mode == MirrorPadMode::kReflect
1504 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1505 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1506 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1507 void ReadOptions(const TfLiteOptions& options,
1508 TocoOperator* op) const override {
1509 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1510 ? MirrorPadMode::kReflect
1511 : MirrorPadMode::kSymmetric;
1512 }
1513 };
1514
1515 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1516 ::tflite::BuiltinOptions_UniqueOptions> {
1517 public:
1518 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1519 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1520 const TocoOperator& op,
1521 flatbuffers::FlatBufferBuilder* builder) const override {
1522 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1523 return ::tflite::CreateUniqueOptions(
1524 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1525 ? ::tflite::TensorType::TensorType_INT64
1526 : ::tflite::TensorType_INT32);
1527 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1528 void ReadOptions(const TfLiteOptions& options,
1529 TocoOperator* op) const override {
1530 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1531 unique_op->idx_out_type =
1532 options.idx_out_type() == ::tflite::TensorType_INT64
1533 ? toco::ArrayDataType::kInt64
1534 : toco::ArrayDataType::kInt32;
1535 }
1536 };
1537
1538 class UnidirectionalSequenceRnn
1539 : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1540 ::tflite::SequenceRNNOptions,
1541 ::tflite::BuiltinOptions_SequenceRNNOptions> {
1542 public:
1543 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1544 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1545 const TocoOperator& op,
1546 flatbuffers::FlatBufferBuilder* builder) const override {
1547 return ::tflite::CreateSequenceRNNOptions(
1548 *builder, /*time_major=*/true,
1549 /*fused_activation_function=*/
1550 ::tflite::ActivationFunctionType_TANH);
1551 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1552 void ReadOptions(const TfLiteOptions& options,
1553 TocoOperator* op) const override {
1554 // Only support tanh activation, so check that tflite type is tanh.
1555 DCHECK(options.fused_activation_function() ==
1556 ::tflite::ActivationFunctionType_TANH);
1557 }
1558
GetMutatingInputVariables(const Operator & op) const1559 std::vector<bool> GetMutatingInputVariables(
1560 const Operator& op) const override {
1561 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1562 mutating_input_variables[4] = true;
1563 return mutating_input_variables;
1564 }
1565 };
1566
1567 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1568 ::tflite::BuiltinOptions_WhereOptions> {
1569 public:
1570 using BuiltinOperator::BuiltinOperator;
1571
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1572 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1573 const TocoOperator& op,
1574 flatbuffers::FlatBufferBuilder* builder) const override {
1575 return ::tflite::CreateWhereOptions(*builder);
1576 }
1577
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1578 void ReadOptions(const TfLiteOptions& options,
1579 TocoOperator* op) const override {}
1580 };
1581
WriteFlexOpOptions(const std::string & tensorflow_node_def)1582 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1583 const std::string& tensorflow_node_def) {
1584 auto fbb = absl::make_unique<flexbuffers::Builder>();
1585
1586 ::tensorflow::NodeDef node_def;
1587 if (!node_def.ParseFromString(tensorflow_node_def)) {
1588 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1589 return {};
1590 }
1591
1592 fbb->Vector([&]() {
1593 fbb->String(node_def.op());
1594 fbb->String(tensorflow_node_def);
1595 });
1596 fbb->Finish();
1597 LOG(INFO) << "Writing flex op: " << node_def.op();
1598 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1599 }
1600
1601 class TensorFlowUnsupported : public BaseOperator {
1602 public:
TensorFlowUnsupported(const std::string & name,OperatorType type,bool enable_select_tf_ops)1603 TensorFlowUnsupported(const std::string& name, OperatorType type,
1604 bool enable_select_tf_ops)
1605 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1606
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1607 Options Serialize(const Operator& op,
1608 flatbuffers::FlatBufferBuilder* builder) const override {
1609 auto fbb =
1610 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1611 if (fbb) {
1612 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1613 } else {
1614 return Options::Custom(0);
1615 }
1616 }
1617
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1618 std::unique_ptr<Operator> Deserialize(
1619 const BuiltinOptions* builtin_options,
1620 const CustomOptions* custom_options) const override {
1621 // Deserializing Flex ops doesn't work now.
1622 // TODO(ycling): Revisit and decide if we should fix the flow for importing
1623 // TFLite models with Flex ops.
1624 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
1625 if (custom_options) {
1626 auto flexbuffer_map =
1627 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1628 .AsMap();
1629 ReadOptions(flexbuffer_map, op.get());
1630 }
1631 return std::unique_ptr<Operator>(op.release());
1632 }
1633
WriteOptions(const TensorFlowUnsupportedOperator & op) const1634 std::unique_ptr<flexbuffers::Builder> WriteOptions(
1635 const TensorFlowUnsupportedOperator& op) const {
1636 if (enable_select_tf_ops_) {
1637 return WriteFlexOpOptions(op.tensorflow_node_def);
1638 }
1639 auto fbb = absl::make_unique<flexbuffers::Builder>();
1640
1641 ::tensorflow::NodeDef node_def;
1642 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1643 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1644 return std::unique_ptr<flexbuffers::Builder>();
1645 }
1646
1647 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1648 fbb->Vector([&]() {
1649 fbb->String(node_def.op());
1650 fbb->String(op.tensorflow_node_def);
1651 });
1652 fbb->Finish();
1653 LOG(INFO) << "Writing flex op: " << node_def.op();
1654 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1655 }
1656
1657 bool has_valid_attr = false;
1658 size_t map_start = fbb->StartMap();
1659 for (const auto& pair : node_def.attr()) {
1660 const char* key = pair.first.c_str();
1661 const auto& attr = pair.second;
1662 switch (attr.value_case()) {
1663 case ::tensorflow::AttrValue::kS:
1664 fbb->String(key, attr.s());
1665 has_valid_attr = true;
1666 break;
1667 case ::tensorflow::AttrValue::kI:
1668 fbb->Int(key, attr.i());
1669 has_valid_attr = true;
1670 break;
1671 case ::tensorflow::AttrValue::kF:
1672 fbb->Float(key, attr.f());
1673 has_valid_attr = true;
1674 break;
1675 case ::tensorflow::AttrValue::kB:
1676 fbb->Bool(key, attr.b());
1677 has_valid_attr = true;
1678 break;
1679 case tensorflow::AttrValue::kList:
1680 if (attr.list().s_size() > 0) {
1681 auto start = fbb->StartVector(key);
1682 for (const std::string& v : attr.list().s()) {
1683 fbb->Add(v);
1684 }
1685 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1686 has_valid_attr = true;
1687 } else if (attr.list().i_size() > 0) {
1688 auto start = fbb->StartVector(key);
1689 for (const int64_t v : attr.list().i()) {
1690 fbb->Add(v);
1691 }
1692 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1693 has_valid_attr = true;
1694 } else if (attr.list().f_size() > 0) {
1695 auto start = fbb->StartVector(key);
1696 for (const float v : attr.list().f()) {
1697 fbb->Add(v);
1698 }
1699 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1700 has_valid_attr = true;
1701 } else {
1702 LOG(WARNING)
1703 << "Ignoring unsupported type in list attribute with key '"
1704 << key << "'";
1705 }
1706 break;
1707 default:
1708 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1709 << key << "'";
1710 break;
1711 }
1712 }
1713 if (!has_valid_attr) {
1714 return std::unique_ptr<flexbuffers::Builder>();
1715 }
1716 fbb->EndMap(map_start);
1717 fbb->Finish();
1718 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1719 }
1720
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1721 void ReadOptions(const flexbuffers::Map& m,
1722 TensorFlowUnsupportedOperator* op) const {
1723 ::tensorflow::NodeDef node_def;
1724 auto attr = node_def.mutable_attr();
1725
1726 const auto& keys = m.Keys();
1727 for (size_t i = 0; i < keys.size(); ++i) {
1728 const auto key = keys[i].AsKey();
1729 const auto& value = m[key];
1730 switch (value.GetType()) {
1731 case flexbuffers::FBT_STRING:
1732 (*attr)[key].set_s(value.AsString().c_str());
1733 break;
1734 case flexbuffers::FBT_INT:
1735 (*attr)[key].set_i(value.AsInt64());
1736 break;
1737 case flexbuffers::FBT_FLOAT:
1738 (*attr)[key].set_f(value.AsFloat());
1739 break;
1740 case flexbuffers::FBT_BOOL:
1741 (*attr)[key].set_b(value.AsBool());
1742 if (std::string(key) == "_output_quantized") {
1743 op->quantized = value.AsBool();
1744 }
1745 if (std::string(key) ==
1746 "_support_output_type_float_in_quantized_op") {
1747 op->support_output_type_float_in_quantized_op = value.AsBool();
1748 }
1749 break;
1750 case flexbuffers::FBT_VECTOR_INT: {
1751 auto* list = (*attr)[key].mutable_list();
1752 const auto& vector = value.AsTypedVector();
1753 for (size_t i = 0; i < vector.size(); i++) {
1754 list->add_i(vector[i].AsInt64());
1755 }
1756 break;
1757 }
1758 case flexbuffers::FBT_VECTOR_FLOAT: {
1759 auto* list = (*attr)[key].mutable_list();
1760 const auto& vector = value.AsTypedVector();
1761 for (size_t i = 0; i < vector.size(); i++) {
1762 list->add_f(vector[i].AsFloat());
1763 }
1764 break;
1765 }
1766 case 15 /* TO_DO(wvo): flexbuffers::FBT_VECTOR_STRING_DEPRECATED*/: {
1767 auto* list = (*attr)[key].mutable_list();
1768 const auto& vector = value.AsTypedVector();
1769 for (size_t i = 0; i < vector.size(); i++) {
1770 list->add_s(vector[i].AsString().str());
1771 }
1772 break;
1773 }
1774 default:
1775 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1776 << key << "'";
1777 break;
1778 }
1779 }
1780 node_def.SerializeToString(&op->tensorflow_node_def);
1781 }
1782
GetVersion(const OperatorSignature & op_signature) const1783 int GetVersion(const OperatorSignature& op_signature) const override {
1784 // TODO(ycling): Design and implement a way to plumb the version of
1785 // custom ops.
1786 return 1;
1787 }
1788
1789 private:
1790 const bool enable_select_tf_ops_;
1791 };
1792
1793 class Dequantize
1794 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1795 ::tflite::BuiltinOptions_DequantizeOptions> {
1796 public:
1797 using BuiltinOperator::BuiltinOperator;
1798
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1799 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1800 const TocoOperator& op,
1801 flatbuffers::FlatBufferBuilder* builder) const override {
1802 return ::tflite::CreateDequantizeOptions(*builder);
1803 }
1804
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1805 void ReadOptions(const TfLiteOptions& options,
1806 TocoOperator* op) const override {}
1807 };
1808
1809 class ReverseSequence
1810 : public BuiltinOperator<ReverseSequenceOperator,
1811 ::tflite::ReverseSequenceOptions,
1812 ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1813 public:
1814 using BuiltinOperator::BuiltinOperator;
1815
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1816 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1817 const TocoOperator& op,
1818 flatbuffers::FlatBufferBuilder* builder) const override {
1819 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1820 op.batch_dim);
1821 }
1822
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1823 void ReadOptions(const TfLiteOptions& options,
1824 TocoOperator* op) const override {
1825 op->seq_dim = options.seq_dim();
1826 op->batch_dim = options.batch_dim();
1827 }
1828 };
1829
1830 namespace {
1831 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1832 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1833 bool enable_select_tf_ops = false) {
1834 std::vector<std::unique_ptr<BaseOperator>> ops;
1835 using tensorflow::MakeUnique;
1836 // Builtin Operators.
1837 ops.push_back(
1838 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1839 ops.push_back(
1840 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1841 ops.push_back(
1842 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1843 ops.push_back(
1844 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1845 ops.push_back(MakeUnique<AveragePool>(
1846 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1847 ops.push_back(
1848 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1849 OperatorType::kSpaceToBatchND));
1850 ops.push_back(
1851 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1852 OperatorType::kBatchToSpaceND));
1853 ops.push_back(MakeUnique<Concatenation>(
1854 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1855 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1856 OperatorType::kConv));
1857 ops.push_back(MakeUnique<DepthwiseConvolution>(
1858 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1859 OperatorType::kDepthwiseConv));
1860 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1861 OperatorType::kDequantize));
1862 ops.push_back(
1863 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1864 OperatorType::kFullyConnected));
1865 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1866 OperatorType::kGather));
1867 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1868 OperatorType::kGatherNd));
1869 ops.push_back(
1870 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1871 OperatorType::kL2Normalization));
1872 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1873 OperatorType::kL2Pool));
1874 ops.push_back(MakeUnique<LocalResponseNormalization>(
1875 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1876 OperatorType::kLocalResponseNormalization));
1877 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1878 OperatorType::kMaxPool));
1879 ops.push_back(
1880 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1881
1882 ops.push_back(
1883 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1884 ops.push_back(
1885 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1886 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1887 OperatorType::kReshape));
1888 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1889 OperatorType::kSoftmax));
1890 ops.push_back(MakeUnique<SpaceToDepth>(
1891 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1892 ops.push_back(MakeUnique<DepthToSpace>(
1893 ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1894 ops.push_back(
1895 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1896 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1897 OperatorType::kTranspose));
1898 ops.push_back(
1899 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1900 ops.push_back(
1901 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1902 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1903 OperatorType::kReduceProd));
1904 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1905 OperatorType::kReduceMax));
1906 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1907 OperatorType::kReduceMin));
1908 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1909 OperatorType::kAny));
1910 ops.push_back(
1911 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1912 OperatorType::kResizeBilinear));
1913 ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1914 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1915 OperatorType::kResizeNearestNeighbor));
1916 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1917 OperatorType::kSqueeze));
1918 ops.push_back(
1919 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1920 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1921 OperatorType::kSplitV));
1922 ops.push_back(MakeUnique<StridedSlice>(
1923 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1924 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1925 OperatorType::kTopK_V2));
1926 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1927 OperatorType::kLstmCell));
1928 ops.push_back(
1929 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1930 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1931 OperatorType::kArgMax));
1932 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1933 OperatorType::kArgMin));
1934 ops.push_back(
1935 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1936 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1937 OperatorType::kExpandDims));
1938 ops.push_back(MakeUnique<TransposeConv>(
1939 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1940 ops.push_back(MakeUnique<SparseToDense>(
1941 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1942 ops.push_back(
1943 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1944 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1945 OperatorType::kFakeQuant));
1946 ops.push_back(
1947 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1948 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1949 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1950 OperatorType::kUnidirectionalSequenceLstm));
1951 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1952 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1953 OperatorType::kBidirectionalSequenceLstm));
1954 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1955 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1956 OperatorType::kBidirectionalSequenceRnn));
1957 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1958 OperatorType::kOneHot));
1959 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1960 OperatorType::kUnpack));
1961 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1962 OperatorType::kLeakyRelu));
1963 ops.push_back(MakeUnique<SquaredDifference>(
1964 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1965 OperatorType::kSquaredDifference));
1966 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1967 OperatorType::kMirrorPad));
1968 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1969 OperatorType::kUnique));
1970 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1971 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1972 OperatorType::kUnidirectionalSequenceRnn));
1973 ops.push_back(
1974 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1975 ops.push_back(
1976 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1977 OperatorType::kReverseSequence));
1978 ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1979 ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1980 ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1981 ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1982 // Custom Operators.
1983 ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1984 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1985 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1986 OperatorType::kUnsupported,
1987 enable_select_tf_ops));
1988
1989 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1990 // been modified to also export builtins. As TOCO evolved we added warnings
1991 // when custom ops are exported but SimpleOperator bypasses thoses. To
1992 // prevent user confusion we are settling on using SimpleOperator only for
1993 // builtins.
1994 ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1995 ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1996 ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1997 ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1998 ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1999 ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
2000 ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
2001 ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
2002 ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
2003 ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
2004 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2005 ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
2006 ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
2007 ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
2008 ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
2009 ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
2010 ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
2011 ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
2012 ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
2013 ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
2014 ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
2015 ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
2016 ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
2017 ::tflite::BuiltinOperator_COS, OperatorType::kCos));
2018 ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
2019 ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
2020 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
2021 ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
2022 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
2023 ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
2024 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
2025 ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
2026 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
2027 ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
2028 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
2029 ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
2030 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
2031 ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
2032 ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
2033 ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
2034 ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
2035 ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
2036 ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
2037 ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
2038 ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
2039 ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
2040 ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
2041 ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
2042 ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
2043 ::tflite::BuiltinOperator_POW, OperatorType::kPow));
2044 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2045 ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
2046 ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2047 ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
2048 ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2049 ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
2050 ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2051 ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
2052 ops.emplace_back(new SimpleOperator<FloorModOperator>(
2053 ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
2054 ops.emplace_back(new SimpleOperator<RangeOperator>(
2055 ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
2056 // Element-wise operator
2057 ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
2058 ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
2059 ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
2060 ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
2061 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2062 ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
2063 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2064 ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
2065 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2066 ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
2067 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2068 ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
2069 ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
2070 ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
2071 ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
2072 ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
2073 ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
2074 ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
2075 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2076 ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2077 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2078 ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2079 ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2080 ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2081 ops.emplace_back(MakeUnique<SimpleOperator<ScatterNdOperator>>(
2082 ::tflite::BuiltinOperator_SCATTER_ND, OperatorType::kScatterNd));
2083 return ops;
2084 }
2085 } // namespace
2086
2087 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2088
BuildOperatorByTypeMap(bool enable_select_tf_ops)2089 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2090 bool enable_select_tf_ops) {
2091 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2092
2093 std::vector<std::unique_ptr<BaseOperator>> ops =
2094 BuildOperatorList(enable_select_tf_ops);
2095 for (auto& op : ops) {
2096 result[op->type()] = std::move(op);
2097 }
2098
2099 return result;
2100 }
2101
BuildOperatorByNameMap(bool enable_select_tf_ops)2102 std::map<std::string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2103 bool enable_select_tf_ops) {
2104 std::map<std::string, std::unique_ptr<BaseOperator>> result;
2105
2106 std::vector<std::unique_ptr<BaseOperator>> ops =
2107 BuildOperatorList(enable_select_tf_ops);
2108 for (auto& op : ops) {
2109 result[op->name()] = std::move(op);
2110 }
2111
2112 return result;
2113 }
2114
ShouldExportAsFlexOp(bool enable_select_tf_ops,const std::string & tensorflow_op_name)2115 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2116 const std::string& tensorflow_op_name) {
2117 // If Flex ops aren't allow at all, simply return false.
2118 if (!enable_select_tf_ops) {
2119 return false;
2120 }
2121 // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2122 // it and it has been allowlisted, export the op as an Flex op. Otherwise,
2123 // export it as a regular custom op.
2124 const tensorflow::OpDef* op_def = nullptr;
2125 if (!tensorflow::OpRegistry::Global()
2126 ->LookUpOpDef(tensorflow_op_name, &op_def)
2127 .ok()) {
2128 return false;
2129 }
2130
2131 if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
2132 LOG(WARNING) << "Op " << tensorflow_op_name
2133 << " is a valid TensorFlow op but has not been allowlisted for"
2134 " the TensorFlow Lite flex op set.";
2135 return false;
2136 }
2137
2138 return true;
2139 }
2140
2141 } // namespace tflite
2142
2143 } // namespace toco
2144