1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include <map>
18
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/util/ptr_util.h"
24
25 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
26 // graph_transformation module.
27 #include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
32 #include "tensorflow/lite/toco/tflite/custom_operator.h"
33 #include "tensorflow/lite/toco/tflite/simple_operator.h"
34 #include "tensorflow/lite/toco/tflite/types.h"
35 #include "tensorflow/lite/tools/versioning/op_version.h"
36
37 namespace toco {
38
39 namespace tflite {
40
41 // LINT.IfChange
42
GetTensorType(const ArrayDataType type)43 ::tflite::TensorType GetTensorType(const ArrayDataType type) {
44 const std::map<ArrayDataType, ::tflite::TensorType> tensor_type_map = {
45 {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
46 {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
47 {ArrayDataType::kInt8, ::tflite::TensorType_INT8},
48 {ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
49 {ArrayDataType::kInt16, ::tflite::TensorType_INT16},
50 {ArrayDataType::kInt32, ::tflite::TensorType_INT32},
51 {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
52 {ArrayDataType::kString, ::tflite::TensorType_STRING},
53 {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
54 {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}};
55
56 auto it = tensor_type_map.find(type);
57 if (it != tensor_type_map.end()) {
58 return it->second;
59 }
60 return static_cast<::tflite::TensorType>(-1);
61 }
62
GetVersioningOpSig(const::tflite::BuiltinOperator op,const OperatorSignature & op_signature)63 ::tflite::OpSignature GetVersioningOpSig(
64 const ::tflite::BuiltinOperator op, const OperatorSignature& op_signature) {
65 std::vector<::tflite::TensorType> input_types, output_types;
66 for (auto input_name : op_signature.op->inputs) {
67 ::tflite::TensorType input_type = static_cast<::tflite::TensorType>(-1);
68 if (op_signature.model->HasArray(input_name)) {
69 const Array& input_array = op_signature.model->GetArray(input_name);
70 input_type = GetTensorType(input_array.data_type);
71 }
72 input_types.push_back(input_type);
73 }
74 for (auto output_name : op_signature.op->outputs) {
75 ::tflite::TensorType output_type = static_cast<::tflite::TensorType>(-1);
76 if (op_signature.model->HasArray(output_name)) {
77 const Array& output_array = op_signature.model->GetArray(output_name);
78 output_type = GetTensorType(output_array.data_type);
79 }
80 output_types.push_back(output_type);
81 }
82 return ::tflite::OpSignature{op, input_types, output_types};
83 }
84
85 class AveragePool
86 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
87 ::tflite::BuiltinOptions_Pool2DOptions> {
88 public:
89 using BuiltinOperator::BuiltinOperator;
90
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const91 flatbuffers::Offset<TfLiteOptions> WriteOptions(
92 const TocoOperator& op,
93 flatbuffers::FlatBufferBuilder* builder) const override {
94 auto padding = Padding::Serialize(op.padding.type);
95 auto activation_function =
96 ActivationFunction::Serialize(op.fused_activation_function);
97 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
98 op.stride_height, op.kwidth,
99 op.kheight, activation_function);
100 }
101
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const102 void ReadOptions(const TfLiteOptions& options,
103 TocoOperator* op) const override {
104 op->padding.type = Padding::Deserialize(options.padding());
105 op->stride_width = options.stride_w();
106 op->stride_height = options.stride_h();
107 op->kwidth = options.filter_width();
108 op->kheight = options.filter_height();
109 op->fused_activation_function =
110 ActivationFunction::Deserialize(options.fused_activation_function());
111 }
112 };
113
114 class Convolution
115 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
116 ::tflite::BuiltinOptions_Conv2DOptions> {
117 public:
118 using BuiltinOperator::BuiltinOperator;
119
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const120 flatbuffers::Offset<TfLiteOptions> WriteOptions(
121 const TocoOperator& op,
122 flatbuffers::FlatBufferBuilder* builder) const override {
123 auto padding = Padding::Serialize(op.padding.type);
124 auto activation_function =
125 ActivationFunction::Serialize(op.fused_activation_function);
126 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
127 op.stride_height, activation_function,
128 op.dilation_width_factor,
129 op.dilation_height_factor);
130 }
131
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const132 void ReadOptions(const TfLiteOptions& options,
133 TocoOperator* op) const override {
134 op->padding.type = Padding::Deserialize(options.padding());
135 op->stride_width = options.stride_w();
136 op->stride_height = options.stride_h();
137 op->dilation_width_factor = options.dilation_w_factor();
138 op->dilation_height_factor = options.dilation_h_factor();
139 op->fused_activation_function =
140 ActivationFunction::Deserialize(options.fused_activation_function());
141 }
142 };
143
144 class DepthwiseConvolution
145 : public BuiltinOperator<DepthwiseConvOperator,
146 ::tflite::DepthwiseConv2DOptions,
147 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
148 public:
149 using BuiltinOperator::BuiltinOperator;
150
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const151 flatbuffers::Offset<TfLiteOptions> WriteOptions(
152 const TocoOperator& op,
153 flatbuffers::FlatBufferBuilder* builder) const override {
154 auto padding = Padding::Serialize(op.padding.type);
155 auto activation_function =
156 ActivationFunction::Serialize(op.fused_activation_function);
157 return ::tflite::CreateDepthwiseConv2DOptions(
158 *builder, padding, op.stride_width, op.stride_height,
159 op.depth_multiplier, activation_function, op.dilation_width_factor,
160 op.dilation_height_factor);
161 }
162
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const163 void ReadOptions(const TfLiteOptions& options,
164 TocoOperator* op) const override {
165 op->padding.type = Padding::Deserialize(options.padding());
166 op->stride_width = options.stride_w();
167 op->stride_height = options.stride_h();
168 op->depth_multiplier = options.depth_multiplier();
169 op->fused_activation_function =
170 ActivationFunction::Deserialize(options.fused_activation_function());
171 op->dilation_width_factor = options.dilation_w_factor();
172 op->dilation_height_factor = options.dilation_h_factor();
173 }
174
GetVersion(const OperatorSignature & op_signature) const175 int GetVersion(const OperatorSignature& op_signature) const override {
176 const auto& conv_op =
177 static_cast<const DepthwiseConvOperator&>(*op_signature.op);
178 ::tflite::OpSignature op_sig =
179 GetVersioningOpSig(builtin_op(), op_signature);
180 op_sig.options.depthwise_conv_2d.dilation_w_factor =
181 conv_op.dilation_width_factor;
182 op_sig.options.depthwise_conv_2d.dilation_h_factor =
183 conv_op.dilation_height_factor;
184 return ::tflite::GetBuiltinOperatorVersion(op_sig);
185 }
186 };
187
188 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
189 ::tflite::BuiltinOptions_AddOptions> {
190 public:
191 using BuiltinOperator::BuiltinOperator;
192
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const193 flatbuffers::Offset<TfLiteOptions> WriteOptions(
194 const TocoOperator& op,
195 flatbuffers::FlatBufferBuilder* builder) const override {
196 auto activation_function =
197 ActivationFunction::Serialize(op.fused_activation_function);
198 return ::tflite::CreateAddOptions(*builder, activation_function);
199 }
200
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const201 void ReadOptions(const TfLiteOptions& options,
202 TocoOperator* op) const override {
203 op->fused_activation_function =
204 ActivationFunction::Deserialize(options.fused_activation_function());
205 }
206 };
207
208 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
209 ::tflite::BuiltinOptions_AddNOptions> {
210 public:
211 using BuiltinOperator::BuiltinOperator;
212
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const213 flatbuffers::Offset<TfLiteOptions> WriteOptions(
214 const TocoOperator& op,
215 flatbuffers::FlatBufferBuilder* builder) const override {
216 return ::tflite::CreateAddNOptions(*builder);
217 }
218
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const219 void ReadOptions(const TfLiteOptions& options,
220 TocoOperator* op) const override {}
221 };
222
223 class SpaceToBatchND
224 : public BuiltinOperator<SpaceToBatchNDOperator,
225 ::tflite::SpaceToBatchNDOptions,
226 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
227 public:
228 using BuiltinOperator::BuiltinOperator;
229
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const230 flatbuffers::Offset<TfLiteOptions> WriteOptions(
231 const TocoOperator& op,
232 flatbuffers::FlatBufferBuilder* builder) const override {
233 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
234 }
235
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const236 void ReadOptions(const TfLiteOptions& options,
237 TocoOperator* op) const override {}
238 };
239
240 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
241 ::tflite::BuiltinOptions_SubOptions> {
242 public:
243 using BuiltinOperator::BuiltinOperator;
244
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const245 flatbuffers::Offset<TfLiteOptions> WriteOptions(
246 const TocoOperator& op,
247 flatbuffers::FlatBufferBuilder* builder) const override {
248 auto activation_function =
249 ActivationFunction::Serialize(op.fused_activation_function);
250 return ::tflite::CreateSubOptions(*builder, activation_function);
251 }
252
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const253 void ReadOptions(const TfLiteOptions& options,
254 TocoOperator* op) const override {
255 op->fused_activation_function =
256 ActivationFunction::Deserialize(options.fused_activation_function());
257 }
258 };
259
260 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
261 ::tflite::BuiltinOptions_DivOptions> {
262 public:
263 using BuiltinOperator::BuiltinOperator;
264
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const265 flatbuffers::Offset<TfLiteOptions> WriteOptions(
266 const TocoOperator& op,
267 flatbuffers::FlatBufferBuilder* builder) const override {
268 auto activation_function =
269 ActivationFunction::Serialize(op.fused_activation_function);
270 return ::tflite::CreateDivOptions(*builder, activation_function);
271 }
272
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const273 void ReadOptions(const TfLiteOptions& options,
274 TocoOperator* op) const override {
275 op->fused_activation_function =
276 ActivationFunction::Deserialize(options.fused_activation_function());
277 }
278 };
279
280 class BatchToSpaceND
281 : public BuiltinOperator<BatchToSpaceNDOperator,
282 ::tflite::BatchToSpaceNDOptions,
283 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
284 public:
285 using BuiltinOperator::BuiltinOperator;
286
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const287 flatbuffers::Offset<TfLiteOptions> WriteOptions(
288 const TocoOperator& op,
289 flatbuffers::FlatBufferBuilder* builder) const override {
290 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
291 }
292
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const293 void ReadOptions(const TfLiteOptions& options,
294 TocoOperator* op) const override {}
295 };
296
297 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
298 ::tflite::BuiltinOptions_CastOptions> {
299 public:
300 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const301 flatbuffers::Offset<TfLiteOptions> WriteOptions(
302 const TocoOperator& op,
303 flatbuffers::FlatBufferBuilder* builder) const override {
304 return ::tflite::CreateCastOptions(*builder,
305 DataType::Serialize(op.src_data_type),
306 DataType::Serialize(op.dst_data_type));
307 }
308
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const309 void ReadOptions(const TfLiteOptions& options,
310 TocoOperator* op) const override {
311 op->src_data_type = DataType::Deserialize(options.in_data_type());
312 op->dst_data_type = DataType::Deserialize(options.out_data_type());
313 }
314 };
315
316 class Concatenation
317 : public BuiltinOperator<ConcatenationOperator,
318 ::tflite::ConcatenationOptions,
319 ::tflite::BuiltinOptions_ConcatenationOptions> {
320 public:
321 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const322 flatbuffers::Offset<TfLiteOptions> WriteOptions(
323 const TocoOperator& op,
324 flatbuffers::FlatBufferBuilder* builder) const override {
325 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
326 }
327
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const328 void ReadOptions(const TfLiteOptions& options,
329 TocoOperator* op) const override {
330 op->axis = options.axis();
331 }
332 };
333
334 class DepthToSpace
335 : public BuiltinOperator<DepthToSpaceOperator,
336 ::tflite::DepthToSpaceOptions,
337 ::tflite::BuiltinOptions_DepthToSpaceOptions> {
338 public:
339 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const340 flatbuffers::Offset<TfLiteOptions> WriteOptions(
341 const TocoOperator& op,
342 flatbuffers::FlatBufferBuilder* builder) const override {
343 return ::tflite::CreateDepthToSpaceOptions(*builder, op.block_size);
344 }
345
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const346 void ReadOptions(const TfLiteOptions& options,
347 TocoOperator* op) const override {
348 op->block_size = options.block_size();
349 }
350 };
351
352 class FakeQuant
353 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
354 ::tflite::BuiltinOptions_FakeQuantOptions> {
355 public:
356 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const357 flatbuffers::Offset<TfLiteOptions> WriteOptions(
358 const TocoOperator& op,
359 flatbuffers::FlatBufferBuilder* builder) const override {
360 return ::tflite::CreateFakeQuantOptions(
361 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
362 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const363 void ReadOptions(const TfLiteOptions& options,
364 TocoOperator* op) const override {
365 auto* minmax = new MinMax;
366 minmax->min = options.min();
367 minmax->max = options.max();
368 op->minmax.reset(minmax);
369 op->num_bits = options.num_bits();
370 op->narrow_range = options.narrow_range();
371 }
GetVersion(const OperatorSignature & op_signature) const372 int GetVersion(const OperatorSignature& op_signature) const override {
373 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
374 ::tflite::OpSignature op_sig =
375 GetVersioningOpSig(builtin_op(), op_signature);
376 op_sig.options.fakequant.narrow_range = fq_op.narrow_range;
377 return ::tflite::GetBuiltinOperatorVersion(op_sig);
378 }
379 };
380
381 class FullyConnected
382 : public BuiltinOperator<FullyConnectedOperator,
383 ::tflite::FullyConnectedOptions,
384 ::tflite::BuiltinOptions_FullyConnectedOptions> {
385 public:
386 using BuiltinOperator::BuiltinOperator;
387
GetWeightFormat(FullyConnectedWeightsFormat fmt) const388 ::tflite::FullyConnectedOptionsWeightsFormat GetWeightFormat(
389 FullyConnectedWeightsFormat fmt) const {
390 switch (fmt) {
391 case FullyConnectedWeightsFormat::kDefault:
392 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
393 case FullyConnectedWeightsFormat::kShuffled4x16Int8:
394 return ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
395 default:
396 LOG(ERROR) << "Unhandled FC weights format";
397 return ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
398 }
399 }
400
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const401 flatbuffers::Offset<TfLiteOptions> WriteOptions(
402 const TocoOperator& op,
403 flatbuffers::FlatBufferBuilder* builder) const override {
404 auto activation_function =
405 ActivationFunction::Serialize(op.fused_activation_function);
406 return ::tflite::CreateFullyConnectedOptions(
407 *builder, activation_function, GetWeightFormat(op.weights_format));
408 }
409
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const410 void ReadOptions(const TfLiteOptions& options,
411 TocoOperator* op) const override {
412 op->fused_activation_function =
413 ActivationFunction::Deserialize(options.fused_activation_function());
414 switch (options.weights_format()) {
415 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
416 op->weights_format = FullyConnectedWeightsFormat::kDefault;
417 break;
418 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
419 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
420 break;
421 default:
422 LOG(ERROR) << "Unhandled FC weights format";
423 op->weights_format = FullyConnectedWeightsFormat::kDefault;
424 }
425 }
426
GetVersion(const OperatorSignature & op_signature) const427 int GetVersion(const OperatorSignature& op_signature) const override {
428 const auto& fc_op =
429 static_cast<const FullyConnectedOperator&>(*op_signature.op);
430 ::tflite::OpSignature op_sig =
431 GetVersioningOpSig(builtin_op(), op_signature);
432 op_sig.options.fully_connected.keep_num_dims = fc_op.keep_num_dims;
433 op_sig.options.fully_connected.weights_format =
434 GetWeightFormat(fc_op.weights_format);
435 return ::tflite::GetBuiltinOperatorVersion(op_sig);
436 }
437 };
438
439 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
440 ::tflite::BuiltinOptions_GatherOptions> {
441 public:
442 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const443 flatbuffers::Offset<TfLiteOptions> WriteOptions(
444 const TocoOperator& op,
445 flatbuffers::FlatBufferBuilder* builder) const override {
446 int axis = op.axis ? op.axis.value() : 0;
447 return ::tflite::CreateGatherOptions(*builder, axis);
448 }
449
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const450 void ReadOptions(const TfLiteOptions& options,
451 TocoOperator* op) const override {
452 op->axis = {options.axis()};
453 }
454 };
455
456 class GatherNd
457 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
458 ::tflite::BuiltinOptions_GatherNdOptions> {
459 public:
460 using BuiltinOperator::BuiltinOperator;
461
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const462 flatbuffers::Offset<TfLiteOptions> WriteOptions(
463 const TocoOperator& op,
464 flatbuffers::FlatBufferBuilder* builder) const override {
465 return ::tflite::CreateGatherNdOptions(*builder);
466 }
467
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const468 void ReadOptions(const TfLiteOptions& options,
469 TocoOperator* op) const override {}
470 };
471
472 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
473 ::tflite::BuiltinOptions_SVDFOptions> {
474 public:
475 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const476 flatbuffers::Offset<TfLiteOptions> WriteOptions(
477 const TocoOperator& op,
478 flatbuffers::FlatBufferBuilder* builder) const override {
479 auto activation_function =
480 ActivationFunction::Serialize(op.fused_activation_function);
481 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
482 }
483
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const484 void ReadOptions(const TfLiteOptions& options,
485 TocoOperator* op) const override {
486 op->fused_activation_function =
487 ActivationFunction::Deserialize(options.fused_activation_function());
488 op->rank = options.rank();
489 }
490 };
491
492 class L2Normalization
493 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
494 ::tflite::BuiltinOptions_L2NormOptions> {
495 public:
496 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const497 flatbuffers::Offset<TfLiteOptions> WriteOptions(
498 const TocoOperator& op,
499 flatbuffers::FlatBufferBuilder* builder) const override {
500 auto activation_function =
501 ActivationFunction::Serialize(op.fused_activation_function);
502 return ::tflite::CreateL2NormOptions(*builder, activation_function);
503 }
504
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const505 void ReadOptions(const TfLiteOptions& options,
506 TocoOperator* op) const override {
507 op->fused_activation_function =
508 ActivationFunction::Deserialize(options.fused_activation_function());
509 }
510 };
511
512 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
513 ::tflite::BuiltinOptions_Pool2DOptions> {
514 public:
515 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const516 flatbuffers::Offset<TfLiteOptions> WriteOptions(
517 const TocoOperator& op,
518 flatbuffers::FlatBufferBuilder* builder) const override {
519 auto padding = Padding::Serialize(op.padding.type);
520 auto activation_function =
521 ActivationFunction::Serialize(op.fused_activation_function);
522 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
523 op.stride_height, op.kwidth,
524 op.kheight, activation_function);
525 }
526
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const527 void ReadOptions(const TfLiteOptions& options,
528 TocoOperator* op) const override {
529 op->padding.type = Padding::Deserialize(options.padding());
530 op->stride_width = options.stride_w();
531 op->stride_height = options.stride_h();
532 op->kwidth = options.filter_width();
533 op->kheight = options.filter_height();
534 op->fused_activation_function =
535 ActivationFunction::Deserialize(options.fused_activation_function());
536 }
537 };
538
539 class LocalResponseNormalization
540 : public BuiltinOperator<
541 LocalResponseNormalizationOperator,
542 ::tflite::LocalResponseNormalizationOptions,
543 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
544 public:
545 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const546 flatbuffers::Offset<TfLiteOptions> WriteOptions(
547 const TocoOperator& op,
548 flatbuffers::FlatBufferBuilder* builder) const override {
549 return ::tflite::CreateLocalResponseNormalizationOptions(
550 *builder, op.range, op.bias, op.alpha, op.beta);
551 }
552
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const553 void ReadOptions(const TfLiteOptions& options,
554 TocoOperator* op) const override {
555 op->range = options.radius();
556 op->bias = options.bias();
557 op->alpha = options.alpha();
558 op->beta = options.beta();
559 }
560 };
561
562 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
563 ::tflite::BuiltinOptions_Pool2DOptions> {
564 public:
565 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const566 flatbuffers::Offset<TfLiteOptions> WriteOptions(
567 const TocoOperator& op,
568 flatbuffers::FlatBufferBuilder* builder) const override {
569 auto padding = Padding::Serialize(op.padding.type);
570 auto activation_function =
571 ActivationFunction::Serialize(op.fused_activation_function);
572 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
573 op.stride_height, op.kwidth,
574 op.kheight, activation_function);
575 }
576
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const577 void ReadOptions(const TfLiteOptions& options,
578 TocoOperator* op) const override {
579 op->padding.type = Padding::Deserialize(options.padding());
580 op->stride_width = options.stride_w();
581 op->stride_height = options.stride_h();
582 op->kwidth = options.filter_width();
583 op->kheight = options.filter_height();
584 op->fused_activation_function =
585 ActivationFunction::Deserialize(options.fused_activation_function());
586 }
587 };
588
589 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
590 ::tflite::BuiltinOptions_MulOptions> {
591 public:
592 using BuiltinOperator::BuiltinOperator;
593
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const594 flatbuffers::Offset<TfLiteOptions> WriteOptions(
595 const TocoOperator& op,
596 flatbuffers::FlatBufferBuilder* builder) const override {
597 auto activation_function =
598 ActivationFunction::Serialize(op.fused_activation_function);
599 return ::tflite::CreateMulOptions(*builder, activation_function);
600 }
601
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const602 void ReadOptions(const TfLiteOptions& options,
603 TocoOperator* op) const override {
604 op->fused_activation_function =
605 ActivationFunction::Deserialize(options.fused_activation_function());
606 }
607
GetVersion(const OperatorSignature & op_signature) const608 int GetVersion(const OperatorSignature& op_signature) const override {
609 const string& input1_name = op_signature.op->inputs[0];
610 const string& input2_name = op_signature.op->inputs[1];
611 const string& output_name = op_signature.op->outputs[0];
612 const Array& input1_array = op_signature.model->GetArray(input1_name);
613 const Array& input2_array = op_signature.model->GetArray(input2_name);
614 const Array& output_array = op_signature.model->GetArray(output_name);
615 const auto& input1_quant = input1_array.quantization_params;
616 const auto& input2_quant = input2_array.quantization_params;
617 const auto& output_quant = output_array.quantization_params;
618 const float input1_scale = input1_quant ? input1_quant->scale : 0.0f;
619 const float input2_scale = input2_quant ? input2_quant->scale : 0.0f;
620 const float output_scale = output_quant ? output_quant->scale : 0.0f;
621 ::tflite::OpSignature op_sig =
622 GetVersioningOpSig(builtin_op(), op_signature);
623 op_sig.options.mul.input1_scale = input1_scale;
624 op_sig.options.mul.input2_scale = input2_scale;
625 op_sig.options.mul.output_scale = output_scale;
626 return ::tflite::GetBuiltinOperatorVersion(op_sig);
627 }
628 };
629
630 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
631 ::tflite::BuiltinOptions_PadOptions> {
632 public:
633 using BuiltinOperator::BuiltinOperator;
634
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const635 flatbuffers::Offset<TfLiteOptions> WriteOptions(
636 const TocoOperator& op,
637 flatbuffers::FlatBufferBuilder* builder) const override {
638 return ::tflite::CreatePadOptions(*builder);
639 }
640
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const641 void ReadOptions(const TfLiteOptions& options,
642 TocoOperator* op) const override {}
643 };
644
645 class Tile
646 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
647 ::tflite::BuiltinOptions_TileOptions> {
648 using BuiltinOperator::BuiltinOperator;
649
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const650 flatbuffers::Offset<TfLiteOptions> WriteOptions(
651 const TocoOperator& op,
652 flatbuffers::FlatBufferBuilder* builder) const override {
653 return ::tflite::CreateTileOptions(*builder);
654 }
655
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const656 void ReadOptions(const TfLiteOptions& options,
657 TocoOperator* op) const override {}
658 };
659
660 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
661 ::tflite::BuiltinOptions_PadV2Options> {
662 public:
663 using BuiltinOperator::BuiltinOperator;
664
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const665 flatbuffers::Offset<TfLiteOptions> WriteOptions(
666 const TocoOperator& op,
667 flatbuffers::FlatBufferBuilder* builder) const override {
668 return ::tflite::CreatePadV2Options(*builder);
669 }
670
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const671 void ReadOptions(const TfLiteOptions& options,
672 TocoOperator* op) const override {}
673 };
674
675 class Reshape
676 : public BuiltinOperator<TensorFlowReshapeOperator,
677 ::tflite::ReshapeOptions,
678 ::tflite::BuiltinOptions_ReshapeOptions> {
679 public:
680 using BuiltinOperator::BuiltinOperator;
681
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const682 flatbuffers::Offset<TfLiteOptions> WriteOptions(
683 const TocoOperator& op,
684 flatbuffers::FlatBufferBuilder* builder) const override {
685 return ::tflite::CreateReshapeOptions(*builder,
686 builder->CreateVector(op.shape));
687 }
688
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const689 void ReadOptions(const TfLiteOptions& options,
690 TocoOperator* op) const override {
691 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
692 options.new_shape()->end());
693 }
694 };
695
696 class Softmax
697 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
698 ::tflite::BuiltinOptions_SoftmaxOptions> {
699 public:
700 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const701 flatbuffers::Offset<TfLiteOptions> WriteOptions(
702 const TocoOperator& op,
703 flatbuffers::FlatBufferBuilder* builder) const override {
704 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
705 }
706
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const707 void ReadOptions(const TfLiteOptions& options,
708 TocoOperator* op) const override {
709 op->beta = options.beta();
710 }
711 };
712
713 class SpaceToDepth
714 : public BuiltinOperator<SpaceToDepthOperator,
715 ::tflite::SpaceToDepthOptions,
716 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
717 public:
718 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const719 flatbuffers::Offset<TfLiteOptions> WriteOptions(
720 const TocoOperator& op,
721 flatbuffers::FlatBufferBuilder* builder) const override {
722 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
723 }
724
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const725 void ReadOptions(const TfLiteOptions& options,
726 TocoOperator* op) const override {
727 op->block_size = options.block_size();
728 }
729 };
730
731 class Transpose
732 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
733 ::tflite::BuiltinOptions_TransposeOptions> {
734 public:
735 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const736 flatbuffers::Offset<TfLiteOptions> WriteOptions(
737 const TocoOperator& op,
738 flatbuffers::FlatBufferBuilder* builder) const override {
739 return ::tflite::CreateTransposeOptions(*builder);
740 }
741
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const742 void ReadOptions(const TfLiteOptions& options,
743 TocoOperator* op) const override {}
744 };
745
746 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
747 ::tflite::BuiltinOptions_LSTMOptions> {
748 public:
749 using BuiltinOperator::BuiltinOperator;
750
GetKernelType(LstmCellOperator::KernelType type) const751 ::tflite::LSTMKernelType GetKernelType(
752 LstmCellOperator::KernelType type) const {
753 switch (type) {
754 case LstmCellOperator::KERNEL_BASIC:
755 return ::tflite::LSTMKernelType_BASIC;
756 break;
757 case LstmCellOperator::KERNEL_FULL:
758 return ::tflite::LSTMKernelType_FULL;
759 break;
760 default:
761 LOG(ERROR) << "Unhandled Kernel Type";
762 return static_cast<::tflite::LSTMKernelType>(-1);
763 }
764 }
765
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const766 flatbuffers::Offset<TfLiteOptions> WriteOptions(
767 const TocoOperator& op,
768 flatbuffers::FlatBufferBuilder* builder) const override {
769 ::tflite::LSTMKernelType kernel_type = GetKernelType(op.kernel_type);
770
771 // Current toco converter only supports tanh, no clip.
772 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
773 ::tflite::ActivationFunctionType_TANH,
774 /*cell_clip=*/0.0,
775 /*proj_clip=*/0.0, kernel_type);
776 }
777
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const778 void ReadOptions(const TfLiteOptions& options,
779 TocoOperator* op) const override {
780 // Only support tanh activation, so check that tflite type is tanh.
781 CHECK(options.fused_activation_function() ==
782 ::tflite::ActivationFunctionType_TANH);
783
784 switch (options.kernel_type()) {
785 case ::tflite::LSTMKernelType_BASIC:
786 op->kernel_type = LstmCellOperator::KERNEL_BASIC;
787 break;
788 case ::tflite::LSTMKernelType_FULL:
789 op->kernel_type = LstmCellOperator::KERNEL_FULL;
790 break;
791 }
792 }
793
GetVersion(const OperatorSignature & op_signature) const794 int GetVersion(const OperatorSignature& op_signature) const override {
795 const auto& lstm_op =
796 static_cast<const LstmCellOperator&>(*op_signature.op);
797 ::tflite::OpSignature op_sig =
798 GetVersioningOpSig(builtin_op(), op_signature);
799 op_sig.options.lstm.kernel_type = GetKernelType(lstm_op.kernel_type);
800 return ::tflite::GetBuiltinOperatorVersion(op_sig);
801 }
802
GetMutatingInputVariables(const Operator & op) const803 std::vector<bool> GetMutatingInputVariables(
804 const Operator& op) const override {
805 const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
806
807 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
808 switch (lstm_op.kernel_type) {
809 case LstmCellOperator::KERNEL_FULL: {
810 mutating_input_variables[kInputActivationStateTensor] = true;
811 mutating_input_variables[kInputCellStateTensor] = true;
812 break;
813 }
814 case LstmCellOperator::KERNEL_BASIC: {
815 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
816 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
817 break;
818 }
819 }
820 return mutating_input_variables;
821 }
822 };
823
824 class UnidirectionalSequenceLstm
825 : public BuiltinOperator<
826 UnidirectionalSequenceLstmOperator,
827 ::tflite::UnidirectionalSequenceLSTMOptions,
828 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
829 public:
830 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const831 flatbuffers::Offset<TfLiteOptions> WriteOptions(
832 const TocoOperator& op,
833 flatbuffers::FlatBufferBuilder* builder) const override {
834 // Current toco converter only supports tanh, no clip.
835 return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
836 *builder, /*fused_activation_function=*/
837 ::tflite::ActivationFunctionType_TANH,
838 /*cell_clip=*/0.0,
839 /*proj_clip=*/0.0,
840 /*time_major=*/true);
841 }
842
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const843 void ReadOptions(const TfLiteOptions& options,
844 TocoOperator* op) const override {
845 // Only support tanh activation, so check that tflite type is tanh.
846 DCHECK(options.fused_activation_function() ==
847 ::tflite::ActivationFunctionType_TANH);
848 }
849
GetMutatingInputVariables(const Operator & op) const850 std::vector<bool> GetMutatingInputVariables(
851 const Operator& op) const override {
852 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
853 mutating_input_variables[kInputActivationStateTensor] = true;
854 mutating_input_variables[kInputCellStateTensor] = true;
855 return mutating_input_variables;
856 }
857 };
858
859 class BidirectionalSequenceLstm
860 : public BuiltinOperator<
861 BidirectionalSequenceLstmOperator,
862 ::tflite::BidirectionalSequenceLSTMOptions,
863 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
864 public:
865 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const866 flatbuffers::Offset<TfLiteOptions> WriteOptions(
867 const TocoOperator& op,
868 flatbuffers::FlatBufferBuilder* builder) const override {
869 // Current toco converter only supports tanh, no clip.
870 return ::tflite::CreateBidirectionalSequenceLSTMOptions(
871 *builder, /*fused_activation_function=*/
872 ::tflite::ActivationFunctionType_TANH,
873 /*cell_clip=*/0.0,
874 /*proj_clip=*/0.0,
875 /*merge_outputs=*/op.merge_outputs,
876 /*time_major=*/true);
877 }
878
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const879 void ReadOptions(const TfLiteOptions& options,
880 TocoOperator* op) const override {
881 // Only support tanh activation, so check that tflite type is tanh.
882 DCHECK(options.fused_activation_function() ==
883 ::tflite::ActivationFunctionType_TANH);
884 op->merge_outputs = options.merge_outputs();
885 }
886
GetMutatingInputVariables(const Operator & op) const887 std::vector<bool> GetMutatingInputVariables(
888 const Operator& op) const override {
889 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
890 // Forward input activation state.
891 mutating_input_variables[35] = true;
892 // Forward input cell state.
893 mutating_input_variables[36] = true;
894 // Backward input activation state.
895 mutating_input_variables[37] = true;
896 // Backward input cell state.
897 mutating_input_variables[38] = true;
898 return mutating_input_variables;
899 }
900 };
901
902 class BidirectionalSequenceRnn
903 : public BuiltinOperator<
904 BidirectionalSequenceRnnOperator,
905 ::tflite::BidirectionalSequenceRNNOptions,
906 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
907 public:
908 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const909 flatbuffers::Offset<TfLiteOptions> WriteOptions(
910 const TocoOperator& op,
911 flatbuffers::FlatBufferBuilder* builder) const override {
912 // Current toco converter only supports tanh, no clip.
913 return ::tflite::CreateBidirectionalSequenceRNNOptions(
914 *builder, /*time_major=*/true,
915 /*fused_activation_function=*/
916 ::tflite::ActivationFunctionType_TANH,
917 /*merge_outputs=*/op.merge_outputs);
918 }
919
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const920 void ReadOptions(const TfLiteOptions& options,
921 TocoOperator* op) const override {
922 // Only support tanh activation, so check that tflite type is tanh.
923 DCHECK(options.fused_activation_function() ==
924 ::tflite::ActivationFunctionType_TANH);
925 op->merge_outputs = options.merge_outputs();
926 }
927
GetMutatingInputVariables(const Operator & op) const928 std::vector<bool> GetMutatingInputVariables(
929 const Operator& op) const override {
930 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
931 // Forward hidden state.
932 mutating_input_variables[4] = true;
933 // Backward hidden state.
934 mutating_input_variables[8] = true;
935 return mutating_input_variables;
936 }
937 };
938
939 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
940 ::tflite::BuiltinOptions_ReducerOptions> {
941 public:
942 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const943 flatbuffers::Offset<TfLiteOptions> WriteOptions(
944 const TocoOperator& op,
945 flatbuffers::FlatBufferBuilder* builder) const override {
946 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
947 }
948
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const949 void ReadOptions(const TfLiteOptions& options,
950 TocoOperator* op) const override {
951 op->keep_dims = options.keep_dims();
952 }
953 };
954
955 class Sum
956 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
957 ::tflite::BuiltinOptions_ReducerOptions> {
958 public:
959 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const960 flatbuffers::Offset<TfLiteOptions> WriteOptions(
961 const TocoOperator& op,
962 flatbuffers::FlatBufferBuilder* builder) const override {
963 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
964 }
965
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const966 void ReadOptions(const TfLiteOptions& options,
967 TocoOperator* op) const override {
968 op->keep_dims = options.keep_dims();
969 }
970 };
971
972 class ReduceMax
973 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
974 ::tflite::BuiltinOptions_ReducerOptions> {
975 public:
976 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const977 flatbuffers::Offset<TfLiteOptions> WriteOptions(
978 const TocoOperator& op,
979 flatbuffers::FlatBufferBuilder* builder) const override {
980 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
981 }
982
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const983 void ReadOptions(const TfLiteOptions& options,
984 TocoOperator* op) const override {
985 op->keep_dims = options.keep_dims();
986 }
987 };
988
989 class ReduceMin
990 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
991 ::tflite::BuiltinOptions_ReducerOptions> {
992 public:
993 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const994 flatbuffers::Offset<TfLiteOptions> WriteOptions(
995 const TocoOperator& op,
996 flatbuffers::FlatBufferBuilder* builder) const override {
997 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
998 }
999
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1000 void ReadOptions(const TfLiteOptions& options,
1001 TocoOperator* op) const override {
1002 op->keep_dims = options.keep_dims();
1003 }
1004 };
1005
1006 class ReduceProd
1007 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1008 ::tflite::BuiltinOptions_ReducerOptions> {
1009 public:
1010 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1011 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1012 const TocoOperator& op,
1013 flatbuffers::FlatBufferBuilder* builder) const override {
1014 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1015 }
1016
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1017 void ReadOptions(const TfLiteOptions& options,
1018 TocoOperator* op) const override {
1019 op->keep_dims = options.keep_dims();
1020 }
1021 };
1022
1023 class ReduceAny
1024 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1025 ::tflite::BuiltinOptions_ReducerOptions> {
1026 public:
1027 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1028 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1029 const TocoOperator& op,
1030 flatbuffers::FlatBufferBuilder* builder) const override {
1031 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1032 }
1033
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1034 void ReadOptions(const TfLiteOptions& options,
1035 TocoOperator* op) const override {
1036 op->keep_dims = options.keep_dims();
1037 }
1038 };
1039
1040 class ResizeBilinear
1041 : public BuiltinOperator<ResizeBilinearOperator,
1042 ::tflite::ResizeBilinearOptions,
1043 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1044 public:
1045 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1046 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1047 const TocoOperator& op,
1048 flatbuffers::FlatBufferBuilder* builder) const override {
1049 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners,
1050 op.half_pixel_centers);
1051 }
1052
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1053 void ReadOptions(const TfLiteOptions& options,
1054 TocoOperator* op) const override {
1055 op->align_corners = options.align_corners();
1056 op->half_pixel_centers = options.half_pixel_centers();
1057 }
1058
GetVersion(const OperatorSignature & op_signature) const1059 int GetVersion(const OperatorSignature& op_signature) const override {
1060 const auto& resize_bilinear_op =
1061 static_cast<const ResizeBilinearOperator&>(*op_signature.op);
1062 ::tflite::OpSignature op_sig =
1063 GetVersioningOpSig(builtin_op(), op_signature);
1064 op_sig.options.resize_bilinear.half_pixel_centers =
1065 resize_bilinear_op.half_pixel_centers;
1066 return ::tflite::GetBuiltinOperatorVersion(op_sig);
1067 }
1068 };
1069
1070 class ResizeNearestNeighbor
1071 : public BuiltinOperator<
1072 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1073 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1074 public:
1075 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1076 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1077 const TocoOperator& op,
1078 flatbuffers::FlatBufferBuilder* builder) const override {
1079 return ::tflite::CreateResizeNearestNeighborOptions(*builder,
1080 op.align_corners);
1081 }
1082
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1083 void ReadOptions(const TfLiteOptions& options,
1084 TocoOperator* op) const override {
1085 op->align_corners = options.align_corners();
1086 }
1087 };
1088
1089 class Squeeze
1090 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1091 ::tflite::BuiltinOptions_SqueezeOptions> {
1092 public:
1093 using BuiltinOperator::BuiltinOperator;
1094
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1095 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1096 const TocoOperator& op,
1097 flatbuffers::FlatBufferBuilder* builder) const override {
1098 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1099 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1100 }
1101
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1102 void ReadOptions(const TfLiteOptions& options,
1103 TocoOperator* op) const override {
1104 op->squeeze_dims.insert(op->squeeze_dims.end(),
1105 options.squeeze_dims()->begin(),
1106 options.squeeze_dims()->end());
1107 }
1108 };
1109
1110 class Split
1111 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1112 ::tflite::BuiltinOptions_SplitOptions> {
1113 public:
1114 using BuiltinOperator::BuiltinOperator;
1115
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1116 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1117 const TocoOperator& op,
1118 flatbuffers::FlatBufferBuilder* builder) const override {
1119 return ::tflite::CreateSplitOptions(*builder, op.num_split);
1120 }
1121
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1122 void ReadOptions(const TfLiteOptions& options,
1123 TocoOperator* op) const override {
1124 op->num_split = options.num_splits();
1125 }
1126 };
1127
1128 class SplitV
1129 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1130 ::tflite::BuiltinOptions_SplitVOptions> {
1131 public:
1132 using BuiltinOperator::BuiltinOperator;
1133
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1134 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1135 const TocoOperator& op,
1136 flatbuffers::FlatBufferBuilder* builder) const override {
1137 return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1138 }
1139
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1140 void ReadOptions(const TfLiteOptions& options,
1141 TocoOperator* op) const override {
1142 op->num_split = options.num_splits();
1143 }
1144 };
1145
1146 class StridedSlice
1147 : public BuiltinOperator<StridedSliceOperator,
1148 ::tflite::StridedSliceOptions,
1149 ::tflite::BuiltinOptions_StridedSliceOptions> {
1150 public:
1151 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1152 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1153 const TocoOperator& op,
1154 flatbuffers::FlatBufferBuilder* builder) const override {
1155 return ::tflite::CreateStridedSliceOptions(
1156 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1157 op.new_axis_mask, op.shrink_axis_mask);
1158 }
1159
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1160 void ReadOptions(const TfLiteOptions& options,
1161 TocoOperator* op) const override {
1162 op->begin_mask = options.begin_mask();
1163 op->end_mask = options.end_mask();
1164 op->ellipsis_mask = options.ellipsis_mask();
1165 op->new_axis_mask = options.new_axis_mask();
1166 op->shrink_axis_mask = options.shrink_axis_mask();
1167 }
1168 };
1169
1170 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1171 ::tflite::BuiltinOptions_TopKV2Options> {
1172 public:
1173 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1174 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1175 const TocoOperator& op,
1176 flatbuffers::FlatBufferBuilder* builder) const override {
1177 return ::tflite::CreateTopKV2Options(*builder);
1178 }
1179
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1180 void ReadOptions(const TfLiteOptions& options,
1181 TocoOperator* op) const override {}
1182 };
1183
1184 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1185 ::tflite::BuiltinOptions_ArgMaxOptions> {
1186 public:
1187 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1188 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1189 const TocoOperator& op,
1190 flatbuffers::FlatBufferBuilder* builder) const override {
1191 return ::tflite::CreateArgMaxOptions(
1192 *builder, DataType::Serialize(op.output_data_type));
1193 }
1194
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1195 void ReadOptions(const TfLiteOptions& options,
1196 TocoOperator* op) const override {
1197 op->output_data_type = DataType::Deserialize(options.output_type());
1198 }
1199 };
1200
1201 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1202 ::tflite::BuiltinOptions_ArgMinOptions> {
1203 public:
1204 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1205 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1206 const TocoOperator& op,
1207 flatbuffers::FlatBufferBuilder* builder) const override {
1208 return ::tflite::CreateArgMinOptions(
1209 *builder, DataType::Serialize(op.output_data_type));
1210 }
1211
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1212 void ReadOptions(const TfLiteOptions& options,
1213 TocoOperator* op) const override {
1214 op->output_data_type = DataType::Deserialize(options.output_type());
1215 }
1216 };
1217
1218 class TransposeConv
1219 : public BuiltinOperator<TransposeConvOperator,
1220 ::tflite::TransposeConvOptions,
1221 ::tflite::BuiltinOptions_TransposeConvOptions> {
1222 public:
1223 using BuiltinOperator::BuiltinOperator;
1224
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1225 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1226 const TocoOperator& op,
1227 flatbuffers::FlatBufferBuilder* builder) const override {
1228 auto padding = Padding::Serialize(op.padding.type);
1229 return ::tflite::CreateTransposeConvOptions(
1230 *builder, padding, op.stride_width, op.stride_height);
1231 }
1232
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1233 void ReadOptions(const TfLiteOptions& options,
1234 TocoOperator* op) const override {
1235 op->padding.type = Padding::Deserialize(options.padding());
1236 op->stride_width = options.stride_w();
1237 op->stride_height = options.stride_h();
1238 }
1239 };
1240
1241 class SparseToDense
1242 : public BuiltinOperator<SparseToDenseOperator,
1243 ::tflite::SparseToDenseOptions,
1244 ::tflite::BuiltinOptions_SparseToDenseOptions> {
1245 public:
1246 using BuiltinOperator::BuiltinOperator;
1247
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1248 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1249 const TocoOperator& op,
1250 flatbuffers::FlatBufferBuilder* builder) const override {
1251 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1252 }
1253
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1254 void ReadOptions(const TfLiteOptions& options,
1255 TocoOperator* op) const override {
1256 op->validate_indices = options.validate_indices();
1257 }
1258 };
1259
1260 class ExpandDims
1261 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1262 ::tflite::BuiltinOptions_ExpandDimsOptions> {
1263 public:
1264 using BuiltinOperator::BuiltinOperator;
1265
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1266 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1267 const TocoOperator& op,
1268 flatbuffers::FlatBufferBuilder* builder) const override {
1269 return ::tflite::CreateExpandDimsOptions(*builder);
1270 }
1271
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1272 void ReadOptions(const TfLiteOptions& options,
1273 TocoOperator* op) const override {}
1274 };
1275
1276 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1277 ::tflite::BuiltinOptions_PackOptions> {
1278 public:
1279 using BuiltinOperator::BuiltinOperator;
1280
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1281 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1282 const TocoOperator& op,
1283 flatbuffers::FlatBufferBuilder* builder) const override {
1284 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1285 }
1286
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1287 void ReadOptions(const TfLiteOptions& options,
1288 TocoOperator* op) const override {
1289 op->values_count = options.values_count();
1290 op->axis = options.axis();
1291 }
1292 };
1293
1294 class Shape
1295 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1296 ::tflite::BuiltinOptions_ShapeOptions> {
1297 public:
1298 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1299 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1300 const TocoOperator& op,
1301 flatbuffers::FlatBufferBuilder* builder) const override {
1302 return ::tflite::CreateShapeOptions(
1303 *builder, DataType::Serialize(op.output_data_type));
1304 }
1305
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1306 void ReadOptions(const TfLiteOptions& options,
1307 TocoOperator* op) const override {
1308 op->output_data_type = DataType::Deserialize(options.out_type());
1309 }
1310 };
1311
1312 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1313 ::tflite::BuiltinOptions_OneHotOptions> {
1314 public:
1315 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1316 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1317 const TocoOperator& op,
1318 flatbuffers::FlatBufferBuilder* builder) const override {
1319 return ::tflite::CreateOneHotOptions(*builder, op.axis);
1320 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1321 void ReadOptions(const TfLiteOptions& options,
1322 TocoOperator* op) const override {
1323 op->axis = options.axis();
1324 }
1325 };
1326
1327 class CTCBeamSearchDecoder
1328 : public CustomOperator<CTCBeamSearchDecoderOperator> {
1329 public:
1330 using CustomOperator::CustomOperator;
1331
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1332 void WriteOptions(const TocoOperator& op,
1333 flexbuffers::Builder* fbb) const override {
1334 fbb->Int("beam_width", op.beam_width);
1335 fbb->Int("top_paths", op.top_paths);
1336 fbb->Bool("merge_repeated", op.merge_repeated);
1337 }
1338
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1339 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1340 op->beam_width = m["beam_width"].AsInt32();
1341 op->top_paths = m["top_paths"].AsInt32();
1342 op->merge_repeated = m["merge_repeated"].AsBool();
1343 }
1344
GetVersion(const OperatorSignature & op_signature) const1345 int GetVersion(const OperatorSignature& op_signature) const override {
1346 return 1;
1347 }
1348 };
1349
1350 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1351 ::tflite::BuiltinOptions_UnpackOptions> {
1352 public:
1353 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1354 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1355 const TocoOperator& op,
1356 flatbuffers::FlatBufferBuilder* builder) const override {
1357 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1358 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1359 void ReadOptions(const TfLiteOptions& options,
1360 TocoOperator* op) const override {
1361 op->num = options.num();
1362 op->axis = options.axis();
1363 }
1364
GetVersion(const OperatorSignature & op_signature) const1365 int GetVersion(const OperatorSignature& op_signature) const override {
1366 const string& input_name = op_signature.op->inputs[0];
1367 const Array& input_array = op_signature.model->GetArray(input_name);
1368 // If the op take int8/uint8 input, it is version 2.
1369 if (input_array.data_type == ArrayDataType::kInt8 ||
1370 input_array.data_type == ArrayDataType::kUint8) {
1371 return 2;
1372 }
1373 // If the op take bool input, it is version 3.
1374 if (input_array.data_type == ArrayDataType::kBool) {
1375 return 3;
1376 }
1377 return 1;
1378 }
1379 };
1380
1381 class LeakyRelu
1382 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1383 ::tflite::BuiltinOptions_LeakyReluOptions> {
1384 public:
1385 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1386 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1387 const TocoOperator& op,
1388 flatbuffers::FlatBufferBuilder* builder) const override {
1389 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1390 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1391 void ReadOptions(const TfLiteOptions& options,
1392 TocoOperator* op) const override {
1393 op->alpha = options.alpha();
1394 }
1395 };
1396
1397 class SquaredDifference
1398 : public BuiltinOperator<
1399 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1400 ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1401 public:
1402 using BuiltinOperator::BuiltinOperator;
1403
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1404 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1405 const TocoOperator& op,
1406 flatbuffers::FlatBufferBuilder* builder) const override {
1407 return ::tflite::CreateSquaredDifferenceOptions(*builder);
1408 }
1409
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1410 void ReadOptions(const TfLiteOptions& options,
1411 TocoOperator* op) const override {}
1412 };
1413
1414 class MirrorPad
1415 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1416 ::tflite::BuiltinOptions_MirrorPadOptions> {
1417 public:
1418 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1419 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1420 const TocoOperator& op,
1421 flatbuffers::FlatBufferBuilder* builder) const override {
1422 return ::tflite::CreateMirrorPadOptions(
1423 *builder, op.mode == MirrorPadMode::kReflect
1424 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1425 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1426 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1427 void ReadOptions(const TfLiteOptions& options,
1428 TocoOperator* op) const override {
1429 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1430 ? MirrorPadMode::kReflect
1431 : MirrorPadMode::kSymmetric;
1432 }
1433 };
1434
1435 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1436 ::tflite::BuiltinOptions_UniqueOptions> {
1437 public:
1438 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1439 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1440 const TocoOperator& op,
1441 flatbuffers::FlatBufferBuilder* builder) const override {
1442 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1443 return ::tflite::CreateUniqueOptions(
1444 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1445 ? ::tflite::TensorType::TensorType_INT64
1446 : ::tflite::TensorType_INT32);
1447 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1448 void ReadOptions(const TfLiteOptions& options,
1449 TocoOperator* op) const override {
1450 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1451 unique_op->idx_out_type =
1452 options.idx_out_type() == ::tflite::TensorType_INT64
1453 ? toco::ArrayDataType::kInt64
1454 : toco::ArrayDataType::kInt32;
1455 }
1456 };
1457
1458 class UnidirectionalSequenceRnn
1459 : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1460 ::tflite::SequenceRNNOptions,
1461 ::tflite::BuiltinOptions_SequenceRNNOptions> {
1462 public:
1463 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1464 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1465 const TocoOperator& op,
1466 flatbuffers::FlatBufferBuilder* builder) const override {
1467 return ::tflite::CreateSequenceRNNOptions(
1468 *builder, /*time_major=*/true,
1469 /*fused_activation_function=*/
1470 ::tflite::ActivationFunctionType_TANH);
1471 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1472 void ReadOptions(const TfLiteOptions& options,
1473 TocoOperator* op) const override {
1474 // Only support tanh activation, so check that tflite type is tanh.
1475 DCHECK(options.fused_activation_function() ==
1476 ::tflite::ActivationFunctionType_TANH);
1477 }
1478
GetMutatingInputVariables(const Operator & op) const1479 std::vector<bool> GetMutatingInputVariables(
1480 const Operator& op) const override {
1481 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1482 mutating_input_variables[4] = true;
1483 return mutating_input_variables;
1484 }
1485 };
1486
1487 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1488 ::tflite::BuiltinOptions_WhereOptions> {
1489 public:
1490 using BuiltinOperator::BuiltinOperator;
1491
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1492 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1493 const TocoOperator& op,
1494 flatbuffers::FlatBufferBuilder* builder) const override {
1495 return ::tflite::CreateWhereOptions(*builder);
1496 }
1497
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1498 void ReadOptions(const TfLiteOptions& options,
1499 TocoOperator* op) const override {}
1500 };
1501
WriteFlexOpOptions(const string & tensorflow_node_def)1502 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1503 const string& tensorflow_node_def) {
1504 auto fbb = absl::make_unique<flexbuffers::Builder>();
1505
1506 ::tensorflow::NodeDef node_def;
1507 if (!node_def.ParseFromString(tensorflow_node_def)) {
1508 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1509 return {};
1510 }
1511
1512 fbb->Vector([&]() {
1513 fbb->String(node_def.op());
1514 fbb->String(tensorflow_node_def);
1515 });
1516 fbb->Finish();
1517 LOG(INFO) << "Writing flex op: " << node_def.op();
1518 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1519 }
1520
1521 class TensorFlowUnsupported : public BaseOperator {
1522 public:
TensorFlowUnsupported(const string & name,OperatorType type,bool enable_select_tf_ops)1523 TensorFlowUnsupported(const string& name, OperatorType type,
1524 bool enable_select_tf_ops)
1525 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1526
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1527 Options Serialize(const Operator& op,
1528 flatbuffers::FlatBufferBuilder* builder) const override {
1529 auto fbb =
1530 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1531 if (fbb) {
1532 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1533 } else {
1534 return Options::Custom(0);
1535 }
1536 }
1537
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const1538 std::unique_ptr<Operator> Deserialize(
1539 const BuiltinOptions* builtin_options,
1540 const CustomOptions* custom_options) const override {
1541 // Deserializing Flex ops doesn't work now.
1542 // TODO(ycling): Revisit and decide if we should fix the flow for importing
1543 // TFLite models with Flex ops.
1544 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
1545 if (custom_options) {
1546 auto flexbuffer_map =
1547 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
1548 .AsMap();
1549 ReadOptions(flexbuffer_map, op.get());
1550 }
1551 return std::unique_ptr<Operator>(op.release());
1552 }
1553
WriteOptions(const TensorFlowUnsupportedOperator & op) const1554 std::unique_ptr<flexbuffers::Builder> WriteOptions(
1555 const TensorFlowUnsupportedOperator& op) const {
1556 if (enable_select_tf_ops_) {
1557 return WriteFlexOpOptions(op.tensorflow_node_def);
1558 }
1559 auto fbb = absl::make_unique<flexbuffers::Builder>();
1560
1561 ::tensorflow::NodeDef node_def;
1562 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
1563 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1564 return std::unique_ptr<flexbuffers::Builder>();
1565 }
1566
1567 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
1568 fbb->Vector([&]() {
1569 fbb->String(node_def.op());
1570 fbb->String(op.tensorflow_node_def);
1571 });
1572 fbb->Finish();
1573 LOG(INFO) << "Writing flex op: " << node_def.op();
1574 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1575 }
1576
1577 bool has_valid_attr = false;
1578 size_t map_start = fbb->StartMap();
1579 for (const auto& pair : node_def.attr()) {
1580 const char* key = pair.first.c_str();
1581 const auto& attr = pair.second;
1582 switch (attr.value_case()) {
1583 case ::tensorflow::AttrValue::kS:
1584 fbb->String(key, attr.s());
1585 has_valid_attr = true;
1586 break;
1587 case ::tensorflow::AttrValue::kI:
1588 fbb->Int(key, attr.i());
1589 has_valid_attr = true;
1590 break;
1591 case ::tensorflow::AttrValue::kF:
1592 fbb->Float(key, attr.f());
1593 has_valid_attr = true;
1594 break;
1595 case ::tensorflow::AttrValue::kB:
1596 fbb->Bool(key, attr.b());
1597 has_valid_attr = true;
1598 break;
1599 case tensorflow::AttrValue::kList:
1600 if (attr.list().s_size() > 0) {
1601 auto start = fbb->StartVector(key);
1602 for (const string& v : attr.list().s()) {
1603 fbb->Add(v);
1604 }
1605 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1606 has_valid_attr = true;
1607 } else if (attr.list().i_size() > 0) {
1608 auto start = fbb->StartVector(key);
1609 for (const int64_t v : attr.list().i()) {
1610 fbb->Add(v);
1611 }
1612 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1613 has_valid_attr = true;
1614 } else if (attr.list().f_size() > 0) {
1615 auto start = fbb->StartVector(key);
1616 for (const float v : attr.list().f()) {
1617 fbb->Add(v);
1618 }
1619 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
1620 has_valid_attr = true;
1621 } else {
1622 LOG(WARNING)
1623 << "Ignoring unsupported type in list attribute with key '"
1624 << key << "'";
1625 }
1626 break;
1627 default:
1628 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1629 << key << "'";
1630 break;
1631 }
1632 }
1633 if (!has_valid_attr) {
1634 return std::unique_ptr<flexbuffers::Builder>();
1635 }
1636 fbb->EndMap(map_start);
1637 fbb->Finish();
1638 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1639 }
1640
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const1641 void ReadOptions(const flexbuffers::Map& m,
1642 TensorFlowUnsupportedOperator* op) const {
1643 ::tensorflow::NodeDef node_def;
1644 auto attr = node_def.mutable_attr();
1645
1646 const auto& keys = m.Keys();
1647 for (size_t i = 0; i < keys.size(); ++i) {
1648 const auto key = keys[i].AsKey();
1649 const auto& value = m[key];
1650 // TODO(wvo): hack to make this code compile with 2 different API
1651 // versions.
1652 // Please remove once OS/internal versions are in sync.
1653 // See hardcoded values in the switch below.
1654 switch (value.GetType()) {
1655 case 5: // flexbuffers::FBT_STRING:
1656 (*attr)[key].set_s(value.AsString().c_str());
1657 break;
1658 case 1: // flexbuffers::FBT_INT:
1659 (*attr)[key].set_i(value.AsInt64());
1660 break;
1661 case 3: // flexbuffers::FBT_FLOAT:
1662 (*attr)[key].set_f(value.AsFloat());
1663 break;
1664 case 26: // flexbuffers::FBT_BOOL:
1665 (*attr)[key].set_b(value.AsBool());
1666 if (string(key) == "_output_quantized") {
1667 op->quantized = value.AsBool();
1668 }
1669 if (string(key) == "_support_output_type_float_in_quantized_op") {
1670 op->support_output_type_float_in_quantized_op = value.AsBool();
1671 }
1672 break;
1673 case 11: { // flexbuffers::FBT_VECTOR_INT: {
1674 auto* list = (*attr)[key].mutable_list();
1675 const auto& vector = value.AsTypedVector();
1676 for (size_t i = 0; i < vector.size(); i++) {
1677 list->add_i(vector[i].AsInt64());
1678 }
1679 break;
1680 }
1681 case 13: { // flexbuffers::FBT_VECTOR_FLOAT: {
1682 auto* list = (*attr)[key].mutable_list();
1683 const auto& vector = value.AsTypedVector();
1684 for (size_t i = 0; i < vector.size(); i++) {
1685 list->add_f(vector[i].AsFloat());
1686 }
1687 break;
1688 }
1689 case 15: { // flexbuffers::FBT_VECTOR_STRING: {
1690 auto* list = (*attr)[key].mutable_list();
1691 const auto& vector = value.AsTypedVector();
1692 for (size_t i = 0; i < vector.size(); i++) {
1693 list->add_s(vector[i].AsString().str());
1694 }
1695 break;
1696 }
1697 default:
1698 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
1699 << key << "'";
1700 break;
1701 }
1702 }
1703 node_def.SerializeToString(&op->tensorflow_node_def);
1704 }
1705
GetVersion(const OperatorSignature & op_signature) const1706 int GetVersion(const OperatorSignature& op_signature) const override {
1707 // TODO(ycling): Design and implement a way to plumb the version of
1708 // custom ops.
1709 return 1;
1710 }
1711
1712 private:
1713 const bool enable_select_tf_ops_;
1714 };
1715
1716 class Dequantize
1717 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
1718 ::tflite::BuiltinOptions_DequantizeOptions> {
1719 public:
1720 using BuiltinOperator::BuiltinOperator;
1721
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1722 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1723 const TocoOperator& op,
1724 flatbuffers::FlatBufferBuilder* builder) const override {
1725 return ::tflite::CreateDequantizeOptions(*builder);
1726 }
1727
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1728 void ReadOptions(const TfLiteOptions& options,
1729 TocoOperator* op) const override {}
1730 };
1731
1732 class ReverseSequence
1733 : public BuiltinOperator<ReverseSequenceOperator,
1734 ::tflite::ReverseSequenceOptions,
1735 ::tflite::BuiltinOptions_ReverseSequenceOptions> {
1736 public:
1737 using BuiltinOperator::BuiltinOperator;
1738
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1739 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1740 const TocoOperator& op,
1741 flatbuffers::FlatBufferBuilder* builder) const override {
1742 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
1743 op.batch_dim);
1744 }
1745
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1746 void ReadOptions(const TfLiteOptions& options,
1747 TocoOperator* op) const override {
1748 op->seq_dim = options.seq_dim();
1749 op->batch_dim = options.batch_dim();
1750 }
1751 };
1752
1753 namespace {
1754 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)1755 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
1756 bool enable_select_tf_ops = false) {
1757 std::vector<std::unique_ptr<BaseOperator>> ops;
1758 using tensorflow::MakeUnique;
1759 // Builtin Operators.
1760 ops.push_back(
1761 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
1762 ops.push_back(
1763 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
1764 ops.push_back(
1765 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
1766 ops.push_back(
1767 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
1768 ops.push_back(MakeUnique<AveragePool>(
1769 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
1770 ops.push_back(
1771 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
1772 OperatorType::kSpaceToBatchND));
1773 ops.push_back(
1774 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
1775 OperatorType::kBatchToSpaceND));
1776 ops.push_back(MakeUnique<Concatenation>(
1777 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
1778 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
1779 OperatorType::kConv));
1780 ops.push_back(MakeUnique<DepthwiseConvolution>(
1781 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
1782 OperatorType::kDepthwiseConv));
1783 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
1784 OperatorType::kDequantize));
1785 ops.push_back(
1786 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
1787 OperatorType::kFullyConnected));
1788 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
1789 OperatorType::kGather));
1790 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
1791 OperatorType::kGatherNd));
1792 ops.push_back(
1793 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
1794 OperatorType::kL2Normalization));
1795 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
1796 OperatorType::kL2Pool));
1797 ops.push_back(MakeUnique<LocalResponseNormalization>(
1798 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1799 OperatorType::kLocalResponseNormalization));
1800 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
1801 OperatorType::kMaxPool));
1802 ops.push_back(
1803 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
1804
1805 ops.push_back(
1806 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
1807 ops.push_back(
1808 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
1809 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
1810 OperatorType::kReshape));
1811 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
1812 OperatorType::kSoftmax));
1813 ops.push_back(MakeUnique<SpaceToDepth>(
1814 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
1815 ops.push_back(MakeUnique<DepthToSpace>(
1816 ::tflite::BuiltinOperator_DEPTH_TO_SPACE, OperatorType::kDepthToSpace));
1817 ops.push_back(
1818 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
1819 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
1820 OperatorType::kTranspose));
1821 ops.push_back(
1822 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
1823 ops.push_back(
1824 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
1825 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
1826 OperatorType::kReduceProd));
1827 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
1828 OperatorType::kReduceMax));
1829 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
1830 OperatorType::kReduceMin));
1831 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
1832 OperatorType::kAny));
1833 ops.push_back(
1834 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
1835 OperatorType::kResizeBilinear));
1836 ops.push_back(MakeUnique<ResizeNearestNeighbor>(
1837 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
1838 OperatorType::kResizeNearestNeighbor));
1839 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
1840 OperatorType::kSqueeze));
1841 ops.push_back(
1842 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
1843 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
1844 OperatorType::kSplitV));
1845 ops.push_back(MakeUnique<StridedSlice>(
1846 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
1847 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
1848 OperatorType::kTopK_V2));
1849 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
1850 OperatorType::kLstmCell));
1851 ops.push_back(
1852 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
1853 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
1854 OperatorType::kArgMax));
1855 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
1856 OperatorType::kArgMin));
1857 ops.push_back(
1858 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
1859 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
1860 OperatorType::kExpandDims));
1861 ops.push_back(MakeUnique<TransposeConv>(
1862 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
1863 ops.push_back(MakeUnique<SparseToDense>(
1864 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
1865 ops.push_back(
1866 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
1867 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
1868 OperatorType::kFakeQuant));
1869 ops.push_back(
1870 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
1871 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
1872 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
1873 OperatorType::kUnidirectionalSequenceLstm));
1874 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
1875 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
1876 OperatorType::kBidirectionalSequenceLstm));
1877 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
1878 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
1879 OperatorType::kBidirectionalSequenceRnn));
1880 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
1881 OperatorType::kOneHot));
1882 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
1883 OperatorType::kUnpack));
1884 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
1885 OperatorType::kLeakyRelu));
1886 ops.push_back(MakeUnique<SquaredDifference>(
1887 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
1888 OperatorType::kSquaredDifference));
1889 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
1890 OperatorType::kMirrorPad));
1891 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
1892 OperatorType::kUnique));
1893 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
1894 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
1895 OperatorType::kUnidirectionalSequenceRnn));
1896 ops.push_back(
1897 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
1898 ops.push_back(
1899 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
1900 OperatorType::kReverseSequence));
1901 ops.push_back(MakeUnique<SimpleOperator<MatrixDiagOperator>>(
1902 ::tflite::BuiltinOperator_MATRIX_DIAG, OperatorType::kMatrixDiag));
1903 ops.push_back(MakeUnique<SimpleOperator<MatrixSetDiagOperator>>(
1904 ::tflite::BuiltinOperator_MATRIX_SET_DIAG, OperatorType::kMatrixSetDiag));
1905 // Custom Operators.
1906 ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
1907 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
1908 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
1909 OperatorType::kUnsupported,
1910 enable_select_tf_ops));
1911
1912 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
1913 // been modified to also export builtins. As TOCO evolved we added warnings
1914 // when custom ops are exported but SimpleOperator bypasses thoses. To
1915 // prevent user confusion we are settling on using SimpleOperator only for
1916 // builtins.
1917 ops.push_back(MakeUnique<SimpleOperator<FloorOperator>>(
1918 ::tflite::BuiltinOperator_FLOOR, OperatorType::kFloor));
1919 ops.push_back(MakeUnique<SimpleOperator<CeilOperator>>(
1920 ::tflite::BuiltinOperator_CEIL, OperatorType::kCeil));
1921 ops.push_back(MakeUnique<SimpleOperator<EluOperator>>(
1922 ::tflite::BuiltinOperator_ELU, OperatorType::kElu));
1923 ops.push_back(MakeUnique<SimpleOperator<RoundOperator>>(
1924 ::tflite::BuiltinOperator_ROUND, OperatorType::kRound));
1925 ops.push_back(MakeUnique<SimpleOperator<ReluOperator>>(
1926 ::tflite::BuiltinOperator_RELU, OperatorType::kRelu));
1927 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
1928 ::tflite::BuiltinOperator_RELU_N1_TO_1, OperatorType::kRelu1));
1929 ops.push_back(MakeUnique<SimpleOperator<Relu6Operator>>(
1930 ::tflite::BuiltinOperator_RELU6, OperatorType::kRelu6));
1931 ops.push_back(MakeUnique<SimpleOperator<PReluOperator>>(
1932 ::tflite::BuiltinOperator_PRELU, OperatorType::kPRelu));
1933 ops.push_back(MakeUnique<SimpleOperator<LogisticOperator>>(
1934 ::tflite::BuiltinOperator_LOGISTIC, OperatorType::kLogistic));
1935 ops.push_back(MakeUnique<SimpleOperator<TanhOperator>>(
1936 ::tflite::BuiltinOperator_TANH, OperatorType::kTanh));
1937 ops.push_back(MakeUnique<SimpleOperator<ExpOperator>>(
1938 ::tflite::BuiltinOperator_EXP, OperatorType::kExp));
1939 ops.push_back(MakeUnique<SimpleOperator<CosOperator>>(
1940 ::tflite::BuiltinOperator_COS, OperatorType::kCos));
1941 ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
1942 ::tflite::BuiltinOperator_LOG_SOFTMAX, OperatorType::kLogSoftmax));
1943 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMaximumOperator>>(
1944 ::tflite::BuiltinOperator_MAXIMUM, OperatorType::kMaximum));
1945 ops.push_back(MakeUnique<SimpleOperator<TensorFlowMinimumOperator>>(
1946 ::tflite::BuiltinOperator_MINIMUM, OperatorType::kMinimum));
1947 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterOperator>>(
1948 ::tflite::BuiltinOperator_GREATER, OperatorType::kGreater));
1949 ops.push_back(MakeUnique<SimpleOperator<TensorFlowGreaterEqualOperator>>(
1950 ::tflite::BuiltinOperator_GREATER_EQUAL, OperatorType::kGreaterEqual));
1951 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessOperator>>(
1952 ::tflite::BuiltinOperator_LESS, OperatorType::kLess));
1953 ops.push_back(MakeUnique<SimpleOperator<TensorFlowLessEqualOperator>>(
1954 ::tflite::BuiltinOperator_LESS_EQUAL, OperatorType::kLessEqual));
1955 ops.push_back(MakeUnique<SimpleOperator<TensorFlowEqualOperator>>(
1956 ::tflite::BuiltinOperator_EQUAL, OperatorType::kEqual));
1957 ops.push_back(MakeUnique<SimpleOperator<TensorFlowNotEqualOperator>>(
1958 ::tflite::BuiltinOperator_NOT_EQUAL, OperatorType::kNotEqual));
1959 ops.push_back(MakeUnique<SimpleOperator<NegOperator>>(
1960 ::tflite::BuiltinOperator_NEG, OperatorType::kNeg));
1961 ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
1962 ::tflite::BuiltinOperator_SELECT, OperatorType::kSelect));
1963 ops.push_back(MakeUnique<SimpleOperator<SliceOperator>>(
1964 ::tflite::BuiltinOperator_SLICE, OperatorType::kSlice));
1965 ops.push_back(MakeUnique<SimpleOperator<PowOperator>>(
1966 ::tflite::BuiltinOperator_POW, OperatorType::kPow));
1967 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
1968 ::tflite::BuiltinOperator_LOGICAL_OR, OperatorType::kLogicalOr));
1969 ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
1970 ::tflite::BuiltinOperator_LOGICAL_AND, OperatorType::kLogicalAnd));
1971 ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
1972 ::tflite::BuiltinOperator_LOGICAL_NOT, OperatorType::kLogicalNot));
1973 ops.emplace_back(new SimpleOperator<FloorDivOperator>(
1974 ::tflite::BuiltinOperator_FLOOR_DIV, OperatorType::kFloorDiv));
1975 ops.emplace_back(new SimpleOperator<FloorModOperator>(
1976 ::tflite::BuiltinOperator_FLOOR_MOD, OperatorType::kFloorMod));
1977 ops.emplace_back(new SimpleOperator<RangeOperator>(
1978 ::tflite::BuiltinOperator_RANGE, OperatorType::kRange));
1979 // Element-wise operator
1980 ops.push_back(MakeUnique<SimpleOperator<SinOperator>>(
1981 ::tflite::BuiltinOperator_SIN, OperatorType::kSin));
1982 ops.push_back(MakeUnique<SimpleOperator<LogOperator>>(
1983 ::tflite::BuiltinOperator_LOG, OperatorType::kLog));
1984 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
1985 ::tflite::BuiltinOperator_SQRT, OperatorType::kSqrt));
1986 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
1987 ::tflite::BuiltinOperator_RSQRT, OperatorType::kRsqrt));
1988 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
1989 ::tflite::BuiltinOperator_SQUARE, OperatorType::kSquare));
1990 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
1991 ::tflite::BuiltinOperator_ZEROS_LIKE, OperatorType::kZerosLike));
1992 ops.push_back(MakeUnique<SimpleOperator<AbsOperator>>(
1993 ::tflite::BuiltinOperator_ABS, OperatorType::kAbs));
1994 ops.push_back(MakeUnique<SimpleOperator<HardSwishOperator>>(
1995 ::tflite::BuiltinOperator_HARD_SWISH, OperatorType::kHardSwish));
1996 ops.push_back(MakeUnique<SimpleOperator<FillOperator>>(
1997 ::tflite::BuiltinOperator_FILL, OperatorType::kFill));
1998 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
1999 ::tflite::BuiltinOperator_REVERSE_V2, OperatorType::kReverseV2));
2000 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2001 ::tflite::BuiltinOperator_RANK, OperatorType::kRank));
2002 ops.emplace_back(new SimpleOperator<SegmentSumOperator>(
2003 ::tflite::BuiltinOperator_SEGMENT_SUM, OperatorType::kSegmentSum));
2004 return ops;
2005 }
2006 } // namespace
2007
2008 // LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)
2009
BuildOperatorByTypeMap(bool enable_select_tf_ops)2010 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2011 bool enable_select_tf_ops) {
2012 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2013
2014 std::vector<std::unique_ptr<BaseOperator>> ops =
2015 BuildOperatorList(enable_select_tf_ops);
2016 for (auto& op : ops) {
2017 result[op->type()] = std::move(op);
2018 }
2019
2020 return result;
2021 }
2022
BuildOperatorByNameMap(bool enable_select_tf_ops)2023 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2024 bool enable_select_tf_ops) {
2025 std::map<string, std::unique_ptr<BaseOperator>> result;
2026
2027 std::vector<std::unique_ptr<BaseOperator>> ops =
2028 BuildOperatorList(enable_select_tf_ops);
2029 for (auto& op : ops) {
2030 result[op->name()] = std::move(op);
2031 }
2032
2033 return result;
2034 }
2035
ShouldExportAsFlexOp(bool enable_select_tf_ops,const string & tensorflow_op_name)2036 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2037 const string& tensorflow_op_name) {
2038 // If Flex ops aren't allow at all, simply return false.
2039 if (!enable_select_tf_ops) {
2040 return false;
2041 }
2042 // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2043 // it and it has been whitelisted, export the op as an Flex op. Otherwise,
2044 // export it as a regular custom op.
2045 const tensorflow::OpDef* op_def = nullptr;
2046 if (!tensorflow::OpRegistry::Global()
2047 ->LookUpOpDef(tensorflow_op_name, &op_def)
2048 .ok()) {
2049 return false;
2050 }
2051
2052 if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) {
2053 LOG(WARNING) << "Op " << tensorflow_op_name
2054 << " is a valid TensorFlow op but has not been whitelisted for"
2055 " the TensorFlow Lite flex op set.";
2056 return false;
2057 }
2058
2059 return true;
2060 }
2061
2062 } // namespace tflite
2063
2064 } // namespace toco
2065