1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/operator.h"
16
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/util/ptr_util.h"
22
23 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
24 // graph_transformation module.
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
29 #include "tensorflow/lite/toco/tflite/custom_operator.h"
30 #include "tensorflow/lite/toco/tflite/simple_operator.h"
31 #include "tensorflow/lite/toco/tflite/types.h"
32 #include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
33
34 namespace toco {
35
36 namespace tflite {
37
38 class AveragePool
39 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
40 ::tflite::BuiltinOptions_Pool2DOptions> {
41 public:
42 using BuiltinOperator::BuiltinOperator;
43
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const44 flatbuffers::Offset<TfLiteOptions> WriteOptions(
45 const TocoOperator& op,
46 flatbuffers::FlatBufferBuilder* builder) const override {
47 auto padding = Padding::Serialize(op.padding.type);
48 auto activation_function =
49 ActivationFunction::Serialize(op.fused_activation_function);
50 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
51 op.stride_height, op.kwidth,
52 op.kheight, activation_function);
53 }
54
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const55 void ReadOptions(const TfLiteOptions& options,
56 TocoOperator* op) const override {
57 op->padding.type = Padding::Deserialize(options.padding());
58 op->stride_width = options.stride_w();
59 op->stride_height = options.stride_h();
60 op->kwidth = options.filter_width();
61 op->kheight = options.filter_height();
62 op->fused_activation_function =
63 ActivationFunction::Deserialize(options.fused_activation_function());
64 }
65
GetVersion(const OperatorSignature & op_signature) const66 int GetVersion(const OperatorSignature& op_signature) const override {
67 const string& input_name = op_signature.op->inputs[0];
68 const Array& input_array = op_signature.model->GetArray(input_name);
69 if (input_array.data_type == ArrayDataType::kInt8) {
70 return 2;
71 }
72 return 1;
73 }
74 };
75
76 class Convolution
77 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
78 ::tflite::BuiltinOptions_Conv2DOptions> {
79 public:
80 using BuiltinOperator::BuiltinOperator;
81
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const82 flatbuffers::Offset<TfLiteOptions> WriteOptions(
83 const TocoOperator& op,
84 flatbuffers::FlatBufferBuilder* builder) const override {
85 auto padding = Padding::Serialize(op.padding.type);
86 auto activation_function =
87 ActivationFunction::Serialize(op.fused_activation_function);
88 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
89 op.stride_height, activation_function,
90 op.dilation_width_factor,
91 op.dilation_height_factor);
92 }
93
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const94 void ReadOptions(const TfLiteOptions& options,
95 TocoOperator* op) const override {
96 op->padding.type = Padding::Deserialize(options.padding());
97 op->stride_width = options.stride_w();
98 op->stride_height = options.stride_h();
99 op->dilation_width_factor = options.dilation_w_factor();
100 op->dilation_height_factor = options.dilation_h_factor();
101 op->fused_activation_function =
102 ActivationFunction::Deserialize(options.fused_activation_function());
103 }
104
GetVersion(const OperatorSignature & op_signature) const105 int GetVersion(const OperatorSignature& op_signature) const override {
106 const string& input_name = op_signature.op->inputs[0];
107 const string& filter_name = op_signature.op->inputs[1];
108 const string& output_name = op_signature.op->outputs[0];
109 const Array& input_array = op_signature.model->GetArray(input_name);
110 const Array& filter_array = op_signature.model->GetArray(filter_name);
111 const Array& output_array = op_signature.model->GetArray(output_name);
112 // If the op has signed int8 inputs and outputs, its version 3.
113 if (input_array.data_type == ArrayDataType::kInt8 &&
114 filter_array.data_type == ArrayDataType::kInt8 &&
115 output_array.data_type == ArrayDataType::kInt8) {
116 return 3;
117 }
118 // If the op is a signed int8 hybrid operation, we need to return
119 // version 2.
120 if (input_array.data_type == ArrayDataType::kFloat &&
121 filter_array.data_type == ArrayDataType::kInt8 &&
122 output_array.data_type == ArrayDataType::kFloat) {
123 return 2;
124 }
125 return 1;
126 }
127 };
128
129 class DepthwiseConvolution
130 : public BuiltinOperator<DepthwiseConvOperator,
131 ::tflite::DepthwiseConv2DOptions,
132 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
133 public:
134 using BuiltinOperator::BuiltinOperator;
135
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const136 flatbuffers::Offset<TfLiteOptions> WriteOptions(
137 const TocoOperator& op,
138 flatbuffers::FlatBufferBuilder* builder) const override {
139 auto padding = Padding::Serialize(op.padding.type);
140 auto activation_function =
141 ActivationFunction::Serialize(op.fused_activation_function);
142 return ::tflite::CreateDepthwiseConv2DOptions(
143 *builder, padding, op.stride_width, op.stride_height,
144 op.depth_multiplier, activation_function, op.dilation_width_factor,
145 op.dilation_height_factor);
146 }
147
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const148 void ReadOptions(const TfLiteOptions& options,
149 TocoOperator* op) const override {
150 op->padding.type = Padding::Deserialize(options.padding());
151 op->stride_width = options.stride_w();
152 op->stride_height = options.stride_h();
153 op->depth_multiplier = options.depth_multiplier();
154 op->fused_activation_function =
155 ActivationFunction::Deserialize(options.fused_activation_function());
156 op->dilation_width_factor = options.dilation_w_factor();
157 op->dilation_height_factor = options.dilation_h_factor();
158 }
159
GetVersion(const OperatorSignature & op_signature) const160 int GetVersion(const OperatorSignature& op_signature) const override {
161 const auto& conv_op =
162 static_cast<const DepthwiseConvOperator&>(*op_signature.op);
163 const string& input_name = op_signature.op->inputs[0];
164 const string& filter_name = op_signature.op->inputs[1];
165 const string& output_name = op_signature.op->outputs[0];
166 const Array& input_array = op_signature.model->GetArray(input_name);
167 const Array& filter_array = op_signature.model->GetArray(filter_name);
168 const Array& output_array = op_signature.model->GetArray(output_name);
169 // If the op has signed int8 inputs and outputs, its version 3.
170 if (input_array.data_type == ArrayDataType::kInt8 &&
171 filter_array.data_type == ArrayDataType::kInt8 &&
172 output_array.data_type == ArrayDataType::kInt8) {
173 return 3;
174 }
175 if (conv_op.dilation_width_factor != 1 ||
176 conv_op.dilation_height_factor != 1) {
177 return 2;
178 }
179 return 1;
180 }
181 };
182
183 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
184 ::tflite::BuiltinOptions_AddOptions> {
185 public:
186 using BuiltinOperator::BuiltinOperator;
187
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const188 flatbuffers::Offset<TfLiteOptions> WriteOptions(
189 const TocoOperator& op,
190 flatbuffers::FlatBufferBuilder* builder) const override {
191 auto activation_function =
192 ActivationFunction::Serialize(op.fused_activation_function);
193 return ::tflite::CreateAddOptions(*builder, activation_function);
194 }
195
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const196 void ReadOptions(const TfLiteOptions& options,
197 TocoOperator* op) const override {
198 op->fused_activation_function =
199 ActivationFunction::Deserialize(options.fused_activation_function());
200 }
201
GetVersion(const OperatorSignature & op_signature) const202 int GetVersion(const OperatorSignature& op_signature) const override {
203 const string& input_name = op_signature.op->inputs[0];
204 const Array& input_array = op_signature.model->GetArray(input_name);
205 // Version 2 supports signed int8 input types.
206 if (input_array.data_type == ArrayDataType::kInt8) {
207 return 2;
208 }
209 return 1;
210 }
211 };
212
213 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
214 ::tflite::BuiltinOptions_AddNOptions> {
215 public:
216 using BuiltinOperator::BuiltinOperator;
217
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const218 flatbuffers::Offset<TfLiteOptions> WriteOptions(
219 const TocoOperator& op,
220 flatbuffers::FlatBufferBuilder* builder) const override {
221 return ::tflite::CreateAddNOptions(*builder);
222 }
223
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const224 void ReadOptions(const TfLiteOptions& options,
225 TocoOperator* op) const override {}
226
GetVersion(const OperatorSignature & op_signature) const227 int GetVersion(const OperatorSignature& op_signature) const override {
228 return 1;
229 }
230 };
231
232 class SpaceToBatchND
233 : public BuiltinOperator<SpaceToBatchNDOperator,
234 ::tflite::SpaceToBatchNDOptions,
235 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
236 public:
237 using BuiltinOperator::BuiltinOperator;
238
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const239 flatbuffers::Offset<TfLiteOptions> WriteOptions(
240 const TocoOperator& op,
241 flatbuffers::FlatBufferBuilder* builder) const override {
242 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
243 }
244
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const245 void ReadOptions(const TfLiteOptions& options,
246 TocoOperator* op) const override {}
247
GetVersion(const OperatorSignature & op_signature) const248 int GetVersion(const OperatorSignature& op_signature) const override {
249 const string& input_name = op_signature.op->inputs[0];
250 const Array& input_array = op_signature.model->GetArray(input_name);
251 // If the op take int8 input, it is version 2.
252 if (input_array.data_type == ArrayDataType::kInt8) {
253 return 2;
254 }
255 return 1;
256 }
257 };
258
259 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
260 ::tflite::BuiltinOptions_SubOptions> {
261 public:
262 using BuiltinOperator::BuiltinOperator;
263
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const264 flatbuffers::Offset<TfLiteOptions> WriteOptions(
265 const TocoOperator& op,
266 flatbuffers::FlatBufferBuilder* builder) const override {
267 auto activation_function =
268 ActivationFunction::Serialize(op.fused_activation_function);
269 return ::tflite::CreateSubOptions(*builder, activation_function);
270 }
271
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const272 void ReadOptions(const TfLiteOptions& options,
273 TocoOperator* op) const override {
274 op->fused_activation_function =
275 ActivationFunction::Deserialize(options.fused_activation_function());
276 }
277
GetVersion(const OperatorSignature & op_signature) const278 int GetVersion(const OperatorSignature& op_signature) const override {
279 const string& input_name = op_signature.op->inputs[0];
280 const Array& input_array = op_signature.model->GetArray(input_name);
281 // If the op take int8 input, it is version 2.
282 if (input_array.data_type == ArrayDataType::kInt8) {
283 return 2;
284 }
285 return 1;
286 }
287 };
288
289 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
290 ::tflite::BuiltinOptions_DivOptions> {
291 public:
292 using BuiltinOperator::BuiltinOperator;
293
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const294 flatbuffers::Offset<TfLiteOptions> WriteOptions(
295 const TocoOperator& op,
296 flatbuffers::FlatBufferBuilder* builder) const override {
297 auto activation_function =
298 ActivationFunction::Serialize(op.fused_activation_function);
299 return ::tflite::CreateDivOptions(*builder, activation_function);
300 }
301
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const302 void ReadOptions(const TfLiteOptions& options,
303 TocoOperator* op) const override {
304 op->fused_activation_function =
305 ActivationFunction::Deserialize(options.fused_activation_function());
306 }
307
GetVersion(const OperatorSignature & op_signature) const308 int GetVersion(const OperatorSignature& op_signature) const override {
309 return 1;
310 }
311 };
312
313 class BatchToSpaceND
314 : public BuiltinOperator<BatchToSpaceNDOperator,
315 ::tflite::BatchToSpaceNDOptions,
316 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
317 public:
318 using BuiltinOperator::BuiltinOperator;
319
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const320 flatbuffers::Offset<TfLiteOptions> WriteOptions(
321 const TocoOperator& op,
322 flatbuffers::FlatBufferBuilder* builder) const override {
323 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
324 }
325
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const326 void ReadOptions(const TfLiteOptions& options,
327 TocoOperator* op) const override {}
328
GetVersion(const OperatorSignature & op_signature) const329 int GetVersion(const OperatorSignature& op_signature) const override {
330 const string& input_name = op_signature.op->inputs[0];
331 const Array& input_array = op_signature.model->GetArray(input_name);
332 // If the op take int8 input, it is version 2.
333 if (input_array.data_type == ArrayDataType::kInt8) {
334 return 2;
335 }
336 return 1;
337 }
338 };
339
340 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
341 ::tflite::BuiltinOptions_CastOptions> {
342 public:
343 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const344 flatbuffers::Offset<TfLiteOptions> WriteOptions(
345 const TocoOperator& op,
346 flatbuffers::FlatBufferBuilder* builder) const override {
347 return ::tflite::CreateCastOptions(*builder,
348 DataType::Serialize(op.src_data_type),
349 DataType::Serialize(op.dst_data_type));
350 }
351
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const352 void ReadOptions(const TfLiteOptions& options,
353 TocoOperator* op) const override {
354 op->src_data_type = DataType::Deserialize(options.in_data_type());
355 op->dst_data_type = DataType::Deserialize(options.out_data_type());
356 }
357
GetVersion(const OperatorSignature & op_signature) const358 int GetVersion(const OperatorSignature& op_signature) const override {
359 return 1;
360 }
361 };
362
363 class Concatenation
364 : public BuiltinOperator<ConcatenationOperator,
365 ::tflite::ConcatenationOptions,
366 ::tflite::BuiltinOptions_ConcatenationOptions> {
367 public:
368 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const369 flatbuffers::Offset<TfLiteOptions> WriteOptions(
370 const TocoOperator& op,
371 flatbuffers::FlatBufferBuilder* builder) const override {
372 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
373 }
374
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const375 void ReadOptions(const TfLiteOptions& options,
376 TocoOperator* op) const override {
377 op->axis = options.axis();
378 }
379
GetVersion(const OperatorSignature & op_signature) const380 int GetVersion(const OperatorSignature& op_signature) const override {
381 const string& input_name = op_signature.op->inputs[0];
382 const Array& input_array = op_signature.model->GetArray(input_name);
383 // If the op take int8 input, it is version 2.
384 if (input_array.data_type == ArrayDataType::kInt8) {
385 return 2;
386 }
387 return 1;
388 }
389 };
390
391 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
392 public:
393 using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const394 void WriteOptions(const TocoOperator& op,
395 flexbuffers::Builder* fbb) const override {
396 fbb->Int("block_size", op.block_size);
397 }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const398 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
399 op->block_size = m["block_size"].AsInt64();
400 }
401
GetVersion(const OperatorSignature & op_signature) const402 int GetVersion(const OperatorSignature& op_signature) const override {
403 return 1;
404 }
405 };
406
407 class FakeQuant
408 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
409 ::tflite::BuiltinOptions_FakeQuantOptions> {
410 public:
411 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const412 flatbuffers::Offset<TfLiteOptions> WriteOptions(
413 const TocoOperator& op,
414 flatbuffers::FlatBufferBuilder* builder) const override {
415 return ::tflite::CreateFakeQuantOptions(
416 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
417 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const418 void ReadOptions(const TfLiteOptions& options,
419 TocoOperator* op) const override {
420 auto* minmax = new MinMax;
421 minmax->min = options.min();
422 minmax->max = options.max();
423 op->minmax.reset(minmax);
424 op->num_bits = options.num_bits();
425 op->narrow_range = options.narrow_range();
426 }
GetVersion(const OperatorSignature & op_signature) const427 int GetVersion(const OperatorSignature& op_signature) const override {
428 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
429 return fq_op.narrow_range ? 2 : 1;
430 }
431 };
432
433 class FullyConnected
434 : public BuiltinOperator<FullyConnectedOperator,
435 ::tflite::FullyConnectedOptions,
436 ::tflite::BuiltinOptions_FullyConnectedOptions> {
437 public:
438 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const439 flatbuffers::Offset<TfLiteOptions> WriteOptions(
440 const TocoOperator& op,
441 flatbuffers::FlatBufferBuilder* builder) const override {
442 auto activation_function =
443 ActivationFunction::Serialize(op.fused_activation_function);
444 ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
445 switch (op.weights_format) {
446 case FullyConnectedWeightsFormat::kDefault:
447 tflite_weights_format =
448 ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
449 break;
450 case FullyConnectedWeightsFormat::kShuffled4x16Int8:
451 tflite_weights_format =
452 ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
453 break;
454 default:
455 LOG(ERROR) << "Unhandled FC weights format";
456 tflite_weights_format =
457 ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
458 }
459 return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
460 tflite_weights_format);
461 }
462
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const463 void ReadOptions(const TfLiteOptions& options,
464 TocoOperator* op) const override {
465 op->fused_activation_function =
466 ActivationFunction::Deserialize(options.fused_activation_function());
467 switch (options.weights_format()) {
468 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
469 op->weights_format = FullyConnectedWeightsFormat::kDefault;
470 break;
471 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
472 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
473 break;
474 default:
475 LOG(ERROR) << "Unhandled FC weights format";
476 op->weights_format = FullyConnectedWeightsFormat::kDefault;
477 }
478 }
479
480 // +-----------------+--------------------+--------------------------+
481 // | | Weight::Default | Weight::Shuffled4x16Int8 |
482 // +-----------------+--------------------+--------------------------+
483 // | Float | 1 | 2 |
484 // | Quantized Uint8 | 1 | 2 |
485 // | Hybrid | 3 | 3 |
486 // | Quantized Int8 | 4 | 4 |
487 // +-----------------+--------------------+--------------------------+
GetVersion(const OperatorSignature & op_signature) const488 int GetVersion(const OperatorSignature& op_signature) const override {
489 const auto& fc_op =
490 static_cast<const FullyConnectedOperator&>(*op_signature.op);
491 const string& input_name = op_signature.op->inputs[0];
492 const string& weights_name = op_signature.op->inputs[1];
493 const string& output_name = op_signature.op->outputs[0];
494 const Array& input_array = op_signature.model->GetArray(input_name);
495 const Array& weights_array = op_signature.model->GetArray(weights_name);
496 const Array& output_array = op_signature.model->GetArray(output_name);
497 // Int8 fully fixed point kernel is at version 4.
498 if (input_array.data_type == ArrayDataType::kInt8 &&
499 weights_array.data_type == ArrayDataType::kInt8 &&
500 output_array.data_type == ArrayDataType::kInt8) {
501 return 4;
502 }
503 // If the op is a signed int8 hybrid operation, we need to return
504 // version 3.
505 if (input_array.data_type == ArrayDataType::kFloat &&
506 weights_array.data_type == ArrayDataType::kInt8 &&
507 output_array.data_type == ArrayDataType::kFloat) {
508 return 3;
509 }
510 // For float and uint8 fixed point kernels, if the weight is
511 // Shuffled4x16Int8, is is version 2.
512 if (fc_op.weights_format ==
513 FullyConnectedWeightsFormat::kShuffled4x16Int8) {
514 return 2;
515 }
516
517 // Otherwise (weight is default), the version is 1.
518 return 1;
519 }
520 };
521
522 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
523 ::tflite::BuiltinOptions_GatherOptions> {
524 public:
525 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const526 flatbuffers::Offset<TfLiteOptions> WriteOptions(
527 const TocoOperator& op,
528 flatbuffers::FlatBufferBuilder* builder) const override {
529 int axis = op.axis ? op.axis.value() : 0;
530 return ::tflite::CreateGatherOptions(*builder, axis);
531 }
532
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const533 void ReadOptions(const TfLiteOptions& options,
534 TocoOperator* op) const override {
535 op->axis = {options.axis()};
536 }
537
GetVersion(const OperatorSignature & op_signature) const538 int GetVersion(const OperatorSignature& op_signature) const override {
539 const string& input_name = op_signature.op->inputs[0];
540 const Array& input_array = op_signature.model->GetArray(input_name);
541 // If the op take int8 input, it is version 2.
542 if (input_array.data_type == ArrayDataType::kInt8) {
543 return 2;
544 }
545 return 1;
546 }
547 };
548
549 class GatherNd
550 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
551 ::tflite::BuiltinOptions_GatherNdOptions> {
552 public:
553 using BuiltinOperator::BuiltinOperator;
554
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const555 flatbuffers::Offset<TfLiteOptions> WriteOptions(
556 const TocoOperator& op,
557 flatbuffers::FlatBufferBuilder* builder) const override {
558 return ::tflite::CreateGatherNdOptions(*builder);
559 }
560
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const561 void ReadOptions(const TfLiteOptions& options,
562 TocoOperator* op) const override {}
563
GetVersion(const OperatorSignature & op_signature) const564 int GetVersion(const OperatorSignature& op_signature) const override {
565 return 1;
566 }
567 };
568
569 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
570 ::tflite::BuiltinOptions_SVDFOptions> {
571 public:
572 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const573 flatbuffers::Offset<TfLiteOptions> WriteOptions(
574 const TocoOperator& op,
575 flatbuffers::FlatBufferBuilder* builder) const override {
576 auto activation_function =
577 ActivationFunction::Serialize(op.fused_activation_function);
578 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
579 }
580
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const581 void ReadOptions(const TfLiteOptions& options,
582 TocoOperator* op) const override {
583 op->fused_activation_function =
584 ActivationFunction::Deserialize(options.fused_activation_function());
585 op->rank = options.rank();
586 }
587
GetVersion(const OperatorSignature & op_signature) const588 int GetVersion(const OperatorSignature& op_signature) const override {
589 const string& input_name = op_signature.op->inputs[0];
590 const string& weights_feature_name = op_signature.op->inputs[1];
591 const string& output_name = op_signature.op->outputs[0];
592 const Array& input_array = op_signature.model->GetArray(input_name);
593 const Array& weights_feature_array =
594 op_signature.model->GetArray(weights_feature_name);
595 const Array& output_array = op_signature.model->GetArray(output_name);
596 // If the op is a signed int8 hybrid operation, we need to return
597 // version 2.
598 if (input_array.data_type == ArrayDataType::kFloat &&
599 weights_feature_array.data_type == ArrayDataType::kInt8 &&
600 output_array.data_type == ArrayDataType::kFloat) {
601 return 2;
602 }
603 return 1;
604 }
605 };
606
607 class L2Normalization
608 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
609 ::tflite::BuiltinOptions_L2NormOptions> {
610 public:
611 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const612 flatbuffers::Offset<TfLiteOptions> WriteOptions(
613 const TocoOperator& op,
614 flatbuffers::FlatBufferBuilder* builder) const override {
615 auto activation_function =
616 ActivationFunction::Serialize(op.fused_activation_function);
617 return ::tflite::CreateL2NormOptions(*builder, activation_function);
618 }
619
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const620 void ReadOptions(const TfLiteOptions& options,
621 TocoOperator* op) const override {
622 op->fused_activation_function =
623 ActivationFunction::Deserialize(options.fused_activation_function());
624 }
625
GetVersion(const OperatorSignature & op_signature) const626 int GetVersion(const OperatorSignature& op_signature) const override {
627 const string& output_name = op_signature.op->outputs[0];
628 const Array& output_array = op_signature.model->GetArray(output_name);
629 // Version 2 supports signed int8 input types.
630 if (output_array.data_type == ArrayDataType::kInt8) {
631 return 2;
632 }
633 return 1;
634 }
635 };
636
637 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
638 ::tflite::BuiltinOptions_Pool2DOptions> {
639 public:
640 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const641 flatbuffers::Offset<TfLiteOptions> WriteOptions(
642 const TocoOperator& op,
643 flatbuffers::FlatBufferBuilder* builder) const override {
644 auto padding = Padding::Serialize(op.padding.type);
645 auto activation_function =
646 ActivationFunction::Serialize(op.fused_activation_function);
647 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
648 op.stride_height, op.kwidth,
649 op.kheight, activation_function);
650 }
651
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const652 void ReadOptions(const TfLiteOptions& options,
653 TocoOperator* op) const override {
654 op->padding.type = Padding::Deserialize(options.padding());
655 op->stride_width = options.stride_w();
656 op->stride_height = options.stride_h();
657 op->kwidth = options.filter_width();
658 op->kheight = options.filter_height();
659 op->fused_activation_function =
660 ActivationFunction::Deserialize(options.fused_activation_function());
661 }
662
GetVersion(const OperatorSignature & op_signature) const663 int GetVersion(const OperatorSignature& op_signature) const override {
664 return 1;
665 }
666 };
667
668 class LocalResponseNormalization
669 : public BuiltinOperator<
670 LocalResponseNormalizationOperator,
671 ::tflite::LocalResponseNormalizationOptions,
672 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
673 public:
674 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const675 flatbuffers::Offset<TfLiteOptions> WriteOptions(
676 const TocoOperator& op,
677 flatbuffers::FlatBufferBuilder* builder) const override {
678 return ::tflite::CreateLocalResponseNormalizationOptions(
679 *builder, op.range, op.bias, op.alpha, op.beta);
680 }
681
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const682 void ReadOptions(const TfLiteOptions& options,
683 TocoOperator* op) const override {
684 op->range = options.radius();
685 op->bias = options.bias();
686 op->alpha = options.alpha();
687 op->beta = options.beta();
688 }
689
GetVersion(const OperatorSignature & op_signature) const690 int GetVersion(const OperatorSignature& op_signature) const override {
691 return 1;
692 }
693 };
694
695 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
696 ::tflite::BuiltinOptions_Pool2DOptions> {
697 public:
698 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const699 flatbuffers::Offset<TfLiteOptions> WriteOptions(
700 const TocoOperator& op,
701 flatbuffers::FlatBufferBuilder* builder) const override {
702 auto padding = Padding::Serialize(op.padding.type);
703 auto activation_function =
704 ActivationFunction::Serialize(op.fused_activation_function);
705 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
706 op.stride_height, op.kwidth,
707 op.kheight, activation_function);
708 }
709
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const710 void ReadOptions(const TfLiteOptions& options,
711 TocoOperator* op) const override {
712 op->padding.type = Padding::Deserialize(options.padding());
713 op->stride_width = options.stride_w();
714 op->stride_height = options.stride_h();
715 op->kwidth = options.filter_width();
716 op->kheight = options.filter_height();
717 op->fused_activation_function =
718 ActivationFunction::Deserialize(options.fused_activation_function());
719 }
720
GetVersion(const OperatorSignature & op_signature) const721 int GetVersion(const OperatorSignature& op_signature) const override {
722 const string& input_name = op_signature.op->inputs[0];
723 const Array& input_array = op_signature.model->GetArray(input_name);
724 if (input_array.data_type == ArrayDataType::kInt8) {
725 return 2;
726 }
727 return 1;
728 }
729 };
730
731 class Maximum : public SimpleOperator<TensorFlowMaximumOperator> {
732 public:
Maximum()733 explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {}
GetVersion(const OperatorSignature & op_signature) const734 int GetVersion(const OperatorSignature& op_signature) const override {
735 const string& input_name = op_signature.op->inputs[0];
736 const Array& input_array = op_signature.model->GetArray(input_name);
737 // Version 2 supports signed int8 input types.
738 if (input_array.data_type == ArrayDataType::kInt8) {
739 return 2;
740 }
741 return 1;
742 }
743 };
744
745 class Minimum : public SimpleOperator<TensorFlowMinimumOperator> {
746 public:
Minimum()747 explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {}
GetVersion(const OperatorSignature & op_signature) const748 int GetVersion(const OperatorSignature& op_signature) const override {
749 const string& input_name = op_signature.op->inputs[0];
750 const Array& input_array = op_signature.model->GetArray(input_name);
751 // Version 2 supports signed int8 input types.
752 if (input_array.data_type == ArrayDataType::kInt8) {
753 return 2;
754 }
755 return 1;
756 }
757 };
758
759 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
760 ::tflite::BuiltinOptions_MulOptions> {
761 public:
762 using BuiltinOperator::BuiltinOperator;
763
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const764 flatbuffers::Offset<TfLiteOptions> WriteOptions(
765 const TocoOperator& op,
766 flatbuffers::FlatBufferBuilder* builder) const override {
767 auto activation_function =
768 ActivationFunction::Serialize(op.fused_activation_function);
769 return ::tflite::CreateMulOptions(*builder, activation_function);
770 }
771
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const772 void ReadOptions(const TfLiteOptions& options,
773 TocoOperator* op) const override {
774 op->fused_activation_function =
775 ActivationFunction::Deserialize(options.fused_activation_function());
776 }
777
GetVersion(const OperatorSignature & op_signature) const778 int GetVersion(const OperatorSignature& op_signature) const override {
779 const string& input_name = op_signature.op->inputs[0];
780 const Array& input_array = op_signature.model->GetArray(input_name);
781 // Version 2 supports signed int8 input types.
782 if (input_array.data_type == ArrayDataType::kInt8) {
783 return 2;
784 }
785 return 1;
786 }
787 };
788
789 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
790 ::tflite::BuiltinOptions_PadOptions> {
791 public:
792 using BuiltinOperator::BuiltinOperator;
793
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const794 flatbuffers::Offset<TfLiteOptions> WriteOptions(
795 const TocoOperator& op,
796 flatbuffers::FlatBufferBuilder* builder) const override {
797 return ::tflite::CreatePadOptions(*builder);
798 }
799
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const800 void ReadOptions(const TfLiteOptions& options,
801 TocoOperator* op) const override {}
802
GetVersion(const OperatorSignature & op_signature) const803 int GetVersion(const OperatorSignature& op_signature) const override {
804 const string& input_name = op_signature.op->inputs[0];
805 const Array& input_array = op_signature.model->GetArray(input_name);
806 // If the op take int8 input, it is version 2.
807 if (input_array.data_type == ArrayDataType::kInt8) {
808 return 2;
809 }
810 return 1;
811 }
812 };
813
814 class Tile
815 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
816 ::tflite::BuiltinOptions_TileOptions> {
817 using BuiltinOperator::BuiltinOperator;
818
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const819 flatbuffers::Offset<TfLiteOptions> WriteOptions(
820 const TocoOperator& op,
821 flatbuffers::FlatBufferBuilder* builder) const override {
822 return ::tflite::CreateTileOptions(*builder);
823 }
824
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const825 void ReadOptions(const TfLiteOptions& options,
826 TocoOperator* op) const override {}
GetVersion(const OperatorSignature & op_signature) const827 int GetVersion(const OperatorSignature& op_signature) const override {
828 return 1;
829 }
830 };
831
832 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
833 ::tflite::BuiltinOptions_PadV2Options> {
834 public:
835 using BuiltinOperator::BuiltinOperator;
836
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const837 flatbuffers::Offset<TfLiteOptions> WriteOptions(
838 const TocoOperator& op,
839 flatbuffers::FlatBufferBuilder* builder) const override {
840 return ::tflite::CreatePadV2Options(*builder);
841 }
842
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const843 void ReadOptions(const TfLiteOptions& options,
844 TocoOperator* op) const override {}
845
GetVersion(const OperatorSignature & op_signature) const846 int GetVersion(const OperatorSignature& op_signature) const override {
847 const string& input_name = op_signature.op->inputs[0];
848 const Array& input_array = op_signature.model->GetArray(input_name);
849 // If the op take int8 input, it is version 2.
850 if (input_array.data_type == ArrayDataType::kInt8) {
851 return 2;
852 }
853 return 1;
854 }
855 };
856
857 class Reshape
858 : public BuiltinOperator<TensorFlowReshapeOperator,
859 ::tflite::ReshapeOptions,
860 ::tflite::BuiltinOptions_ReshapeOptions> {
861 public:
862 using BuiltinOperator::BuiltinOperator;
863
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const864 flatbuffers::Offset<TfLiteOptions> WriteOptions(
865 const TocoOperator& op,
866 flatbuffers::FlatBufferBuilder* builder) const override {
867 return ::tflite::CreateReshapeOptions(*builder,
868 builder->CreateVector(op.shape));
869 }
870
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const871 void ReadOptions(const TfLiteOptions& options,
872 TocoOperator* op) const override {
873 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
874 options.new_shape()->end());
875 }
876
GetVersion(const OperatorSignature & op_signature) const877 int GetVersion(const OperatorSignature& op_signature) const override {
878 return 1;
879 }
880 };
881
882 class Softmax
883 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
884 ::tflite::BuiltinOptions_SoftmaxOptions> {
885 public:
886 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const887 flatbuffers::Offset<TfLiteOptions> WriteOptions(
888 const TocoOperator& op,
889 flatbuffers::FlatBufferBuilder* builder) const override {
890 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
891 }
892
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const893 void ReadOptions(const TfLiteOptions& options,
894 TocoOperator* op) const override {
895 op->beta = options.beta();
896 }
897
GetVersion(const OperatorSignature & op_signature) const898 int GetVersion(const OperatorSignature& op_signature) const override {
899 const string& input_name = op_signature.op->inputs[0];
900 const Array& input_array = op_signature.model->GetArray(input_name);
901 if (input_array.data_type == ArrayDataType::kInt8) {
902 return 2;
903 }
904 return 1;
905 }
906 };
907
908 class SpaceToDepth
909 : public BuiltinOperator<SpaceToDepthOperator,
910 ::tflite::SpaceToDepthOptions,
911 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
912 public:
913 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const914 flatbuffers::Offset<TfLiteOptions> WriteOptions(
915 const TocoOperator& op,
916 flatbuffers::FlatBufferBuilder* builder) const override {
917 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
918 }
919
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const920 void ReadOptions(const TfLiteOptions& options,
921 TocoOperator* op) const override {
922 op->block_size = options.block_size();
923 }
924
GetVersion(const OperatorSignature & op_signature) const925 int GetVersion(const OperatorSignature& op_signature) const override {
926 const string& input_name = op_signature.op->inputs[0];
927 const Array& input_array = op_signature.model->GetArray(input_name);
928 // If the op take int8 input, it is version 2.
929 if (input_array.data_type == ArrayDataType::kInt8) {
930 return 2;
931 }
932 return 1;
933 }
934 };
935
936 class Transpose
937 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
938 ::tflite::BuiltinOptions_TransposeOptions> {
939 public:
940 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const941 flatbuffers::Offset<TfLiteOptions> WriteOptions(
942 const TocoOperator& op,
943 flatbuffers::FlatBufferBuilder* builder) const override {
944 return ::tflite::CreateTransposeOptions(*builder);
945 }
946
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const947 void ReadOptions(const TfLiteOptions& options,
948 TocoOperator* op) const override {}
949
GetVersion(const OperatorSignature & op_signature) const950 int GetVersion(const OperatorSignature& op_signature) const override {
951 const string& input_name = op_signature.op->inputs[0];
952 const Array& input_array = op_signature.model->GetArray(input_name);
953 // If the op take int8 input, it is version 2.
954 if (input_array.data_type == ArrayDataType::kInt8) {
955 return 2;
956 }
957 return 1;
958 }
959 };
960
961 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
962 ::tflite::BuiltinOptions_LSTMOptions> {
963 public:
964 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const965 flatbuffers::Offset<TfLiteOptions> WriteOptions(
966 const TocoOperator& op,
967 flatbuffers::FlatBufferBuilder* builder) const override {
968 ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL;
969 switch (op.kernel_type) {
970 case LstmCellOperator::KERNEL_BASIC:
971 kernel_type = ::tflite::LSTMKernelType_BASIC;
972 break;
973 case LstmCellOperator::KERNEL_FULL:
974 kernel_type = ::tflite::LSTMKernelType_FULL;
975 break;
976 default:
977 return -1;
978 }
979
980 // Current toco converter only supports tanh, no clip.
981 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
982 ::tflite::ActivationFunctionType_TANH,
983 /*cell_clip=*/0.0,
984 /*proj_clip=*/0.0, kernel_type);
985 }
986
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const987 void ReadOptions(const TfLiteOptions& options,
988 TocoOperator* op) const override {
989 // Only support tanh activation, so check that tflite type is tanh.
990 CHECK(options.fused_activation_function() ==
991 ::tflite::ActivationFunctionType_TANH);
992
993 switch (options.kernel_type()) {
994 case ::tflite::LSTMKernelType_BASIC:
995 op->kernel_type = LstmCellOperator::KERNEL_BASIC;
996 break;
997 case ::tflite::LSTMKernelType_FULL:
998 op->kernel_type = LstmCellOperator::KERNEL_FULL;
999 break;
1000 }
1001 }
1002
GetVersion(const OperatorSignature & op_signature) const1003 int GetVersion(const OperatorSignature& op_signature) const override {
1004 const auto& lstm_op =
1005 static_cast<const LstmCellOperator&>(*op_signature.op);
1006 switch (lstm_op.kernel_type) {
1007 case LstmCellOperator::KERNEL_FULL: {
1008 // If the input tensor is float and a weight is int8, this is a version
1009 // 3 hybrid operation.
1010 const string& input_name = op_signature.op->inputs[0];
1011 const string& weights_name = op_signature.op->inputs[2];
1012 const string& output_name = op_signature.op->outputs[0];
1013 const Array& input_array = op_signature.model->GetArray(input_name);
1014 const Array& weights_array = op_signature.model->GetArray(weights_name);
1015 const Array& output_array = op_signature.model->GetArray(output_name);
1016 if (input_array.data_type == ArrayDataType::kFloat &&
1017 weights_array.data_type == ArrayDataType::kInt8 &&
1018 output_array.data_type == ArrayDataType::kFloat) {
1019 return 3;
1020 }
1021 return 1;
1022 }
1023 case LstmCellOperator::KERNEL_BASIC:
1024 // KERNEL_BASIC was added in version 2.
1025 return 2;
1026 }
1027 }
1028
GetMutatingInputVariables(const Operator & op) const1029 std::vector<bool> GetMutatingInputVariables(
1030 const Operator& op) const override {
1031 const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
1032
1033 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1034 switch (lstm_op.kernel_type) {
1035 case LstmCellOperator::KERNEL_FULL: {
1036 mutating_input_variables[kInputActivationStateTensor] = true;
1037 mutating_input_variables[kInputCellStateTensor] = true;
1038 break;
1039 }
1040 case LstmCellOperator::KERNEL_BASIC: {
1041 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
1042 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
1043 break;
1044 }
1045 }
1046 return mutating_input_variables;
1047 }
1048 };
1049
1050 class UnidirectionalSequenceLstm
1051 : public BuiltinOperator<
1052 UnidirectionalSequenceLstmOperator,
1053 ::tflite::UnidirectionalSequenceLSTMOptions,
1054 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
1055 public:
1056 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1057 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1058 const TocoOperator& op,
1059 flatbuffers::FlatBufferBuilder* builder) const override {
1060 // Current toco converter only supports tanh, no clip.
1061 return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
1062 *builder, /*fused_activation_function=*/
1063 ::tflite::ActivationFunctionType_TANH,
1064 /*cell_clip=*/0.0,
1065 /*proj_clip=*/0.0,
1066 /*time_major=*/true);
1067 }
1068
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1069 void ReadOptions(const TfLiteOptions& options,
1070 TocoOperator* op) const override {
1071 // Only support tanh activation, so check that tflite type is tanh.
1072 DCHECK(options.fused_activation_function() ==
1073 ::tflite::ActivationFunctionType_TANH);
1074 }
1075
GetVersion(const OperatorSignature & op_signature) const1076 int GetVersion(const OperatorSignature& op_signature) const override {
1077 // If the input tensor is float and a weight is int8, this is a version
1078 // 2 hybrid operation.
1079 const string& input_name = op_signature.op->inputs[0];
1080 const string& weights_name = op_signature.op->inputs[2];
1081 const string& output_name = op_signature.op->outputs[0];
1082 const Array& input_array = op_signature.model->GetArray(input_name);
1083 const Array& weights_array = op_signature.model->GetArray(weights_name);
1084 const Array& output_array = op_signature.model->GetArray(output_name);
1085 if (input_array.data_type == ArrayDataType::kFloat &&
1086 weights_array.data_type == ArrayDataType::kInt8 &&
1087 output_array.data_type == ArrayDataType::kFloat) {
1088 return 2;
1089 }
1090 return 1;
1091 }
1092
GetMutatingInputVariables(const Operator & op) const1093 std::vector<bool> GetMutatingInputVariables(
1094 const Operator& op) const override {
1095 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1096 mutating_input_variables[kInputActivationStateTensor] = true;
1097 mutating_input_variables[kInputCellStateTensor] = true;
1098 return mutating_input_variables;
1099 }
1100 };
1101
1102 class BidirectionalSequenceLstm
1103 : public BuiltinOperator<
1104 BidirectionalSequenceLstmOperator,
1105 ::tflite::BidirectionalSequenceLSTMOptions,
1106 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
1107 public:
1108 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1109 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1110 const TocoOperator& op,
1111 flatbuffers::FlatBufferBuilder* builder) const override {
1112 // Current toco converter only supports tanh, no clip.
1113 return ::tflite::CreateBidirectionalSequenceLSTMOptions(
1114 *builder, /*fused_activation_function=*/
1115 ::tflite::ActivationFunctionType_TANH,
1116 /*cell_clip=*/0.0,
1117 /*proj_clip=*/0.0,
1118 /*merge_outputs=*/op.merge_outputs,
1119 /*time_major=*/true);
1120 }
1121
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1122 void ReadOptions(const TfLiteOptions& options,
1123 TocoOperator* op) const override {
1124 // Only support tanh activation, so check that tflite type is tanh.
1125 DCHECK(options.fused_activation_function() ==
1126 ::tflite::ActivationFunctionType_TANH);
1127 op->merge_outputs = options.merge_outputs();
1128 }
1129
GetVersion(const OperatorSignature & op_signature) const1130 int GetVersion(const OperatorSignature& op_signature) const override {
1131 return 1;
1132 }
1133
GetMutatingInputVariables(const Operator & op) const1134 std::vector<bool> GetMutatingInputVariables(
1135 const Operator& op) const override {
1136 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1137 // Forward input activation state.
1138 mutating_input_variables[35] = true;
1139 // Forward input cell state.
1140 mutating_input_variables[36] = true;
1141 // Backward input activation state.
1142 mutating_input_variables[37] = true;
1143 // Backward input cell state.
1144 mutating_input_variables[38] = true;
1145 return mutating_input_variables;
1146 }
1147 };
1148
1149 class BidirectionalSequenceRnn
1150 : public BuiltinOperator<
1151 BidirectionalSequenceRnnOperator,
1152 ::tflite::BidirectionalSequenceRNNOptions,
1153 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
1154 public:
1155 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1156 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1157 const TocoOperator& op,
1158 flatbuffers::FlatBufferBuilder* builder) const override {
1159 // Current toco converter only supports tanh, no clip.
1160 return ::tflite::CreateBidirectionalSequenceRNNOptions(
1161 *builder, /*time_major=*/true,
1162 /*fused_activation_function=*/
1163 ::tflite::ActivationFunctionType_TANH,
1164 /*merge_outputs=*/op.merge_outputs);
1165 }
1166
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1167 void ReadOptions(const TfLiteOptions& options,
1168 TocoOperator* op) const override {
1169 // Only support tanh activation, so check that tflite type is tanh.
1170 DCHECK(options.fused_activation_function() ==
1171 ::tflite::ActivationFunctionType_TANH);
1172 op->merge_outputs = options.merge_outputs();
1173 }
1174
GetVersion(const OperatorSignature & op_signature) const1175 int GetVersion(const OperatorSignature& op_signature) const override {
1176 return 1;
1177 }
1178
GetMutatingInputVariables(const Operator & op) const1179 std::vector<bool> GetMutatingInputVariables(
1180 const Operator& op) const override {
1181 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1182 // Forward hidden state.
1183 mutating_input_variables[4] = true;
1184 // Backward hidden state.
1185 mutating_input_variables[8] = true;
1186 return mutating_input_variables;
1187 }
1188 };
1189
1190 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
1191 ::tflite::BuiltinOptions_ReducerOptions> {
1192 public:
1193 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1194 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1195 const TocoOperator& op,
1196 flatbuffers::FlatBufferBuilder* builder) const override {
1197 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1198 }
1199
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1200 void ReadOptions(const TfLiteOptions& options,
1201 TocoOperator* op) const override {
1202 op->keep_dims = options.keep_dims();
1203 }
1204
GetVersion(const OperatorSignature & op_signature) const1205 int GetVersion(const OperatorSignature& op_signature) const override {
1206 return 1;
1207 }
1208 };
1209
1210 class Sum
1211 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
1212 ::tflite::BuiltinOptions_ReducerOptions> {
1213 public:
1214 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1215 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1216 const TocoOperator& op,
1217 flatbuffers::FlatBufferBuilder* builder) const override {
1218 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1219 }
1220
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1221 void ReadOptions(const TfLiteOptions& options,
1222 TocoOperator* op) const override {
1223 op->keep_dims = options.keep_dims();
1224 }
1225
GetVersion(const OperatorSignature & op_signature) const1226 int GetVersion(const OperatorSignature& op_signature) const override {
1227 return 1;
1228 }
1229 };
1230
1231 class ReduceMax
1232 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
1233 ::tflite::BuiltinOptions_ReducerOptions> {
1234 public:
1235 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1236 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1237 const TocoOperator& op,
1238 flatbuffers::FlatBufferBuilder* builder) const override {
1239 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1240 }
1241
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1242 void ReadOptions(const TfLiteOptions& options,
1243 TocoOperator* op) const override {
1244 op->keep_dims = options.keep_dims();
1245 }
1246
GetVersion(const OperatorSignature & op_signature) const1247 int GetVersion(const OperatorSignature& op_signature) const override {
1248 const string& input_name = op_signature.op->inputs[0];
1249 const Array& input_array = op_signature.model->GetArray(input_name);
1250 // If the op take int8 input, it is version 2.
1251 if (input_array.data_type == ArrayDataType::kInt8) {
1252 return 2;
1253 }
1254 return 1;
1255 }
1256 };
1257
1258 class ReduceMin
1259 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
1260 ::tflite::BuiltinOptions_ReducerOptions> {
1261 public:
1262 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1263 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1264 const TocoOperator& op,
1265 flatbuffers::FlatBufferBuilder* builder) const override {
1266 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1267 }
1268
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1269 void ReadOptions(const TfLiteOptions& options,
1270 TocoOperator* op) const override {
1271 op->keep_dims = options.keep_dims();
1272 }
1273
GetVersion(const OperatorSignature & op_signature) const1274 int GetVersion(const OperatorSignature& op_signature) const override {
1275 const string& input_name = op_signature.op->inputs[0];
1276 const Array& input_array = op_signature.model->GetArray(input_name);
1277 // If the op take int8 input, it is version 2.
1278 if (input_array.data_type == ArrayDataType::kInt8) {
1279 return 2;
1280 }
1281 return 1;
1282 }
1283 };
1284
1285 class ReduceProd
1286 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
1287 ::tflite::BuiltinOptions_ReducerOptions> {
1288 public:
1289 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1290 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1291 const TocoOperator& op,
1292 flatbuffers::FlatBufferBuilder* builder) const override {
1293 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1294 }
1295
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1296 void ReadOptions(const TfLiteOptions& options,
1297 TocoOperator* op) const override {
1298 op->keep_dims = options.keep_dims();
1299 }
1300
GetVersion(const OperatorSignature & op_signature) const1301 int GetVersion(const OperatorSignature& op_signature) const override {
1302 return 1;
1303 }
1304 };
1305
1306 class ReduceAny
1307 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
1308 ::tflite::BuiltinOptions_ReducerOptions> {
1309 public:
1310 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1311 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1312 const TocoOperator& op,
1313 flatbuffers::FlatBufferBuilder* builder) const override {
1314 return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
1315 }
1316
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1317 void ReadOptions(const TfLiteOptions& options,
1318 TocoOperator* op) const override {
1319 op->keep_dims = options.keep_dims();
1320 }
1321
GetVersion(const OperatorSignature & op_signature) const1322 int GetVersion(const OperatorSignature& op_signature) const override {
1323 return 1;
1324 }
1325 };
1326
1327 class Relu6 : public SimpleOperator<Relu6Operator> {
1328 public:
Relu6()1329 explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {}
GetVersion(const OperatorSignature & op_signature) const1330 int GetVersion(const OperatorSignature& op_signature) const override {
1331 const string& input_name = op_signature.op->inputs[0];
1332 const Array& input_array = op_signature.model->GetArray(input_name);
1333 // Version 2 supports signed int8 input types.
1334 if (input_array.data_type == ArrayDataType::kInt8) {
1335 return 2;
1336 }
1337 return 1;
1338 }
1339 };
1340
1341 class ResizeBilinear
1342 : public BuiltinOperator<ResizeBilinearOperator,
1343 ::tflite::ResizeBilinearOptions,
1344 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
1345 public:
1346 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1347 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1348 const TocoOperator& op,
1349 flatbuffers::FlatBufferBuilder* builder) const override {
1350 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
1351 }
1352
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1353 void ReadOptions(const TfLiteOptions& options,
1354 TocoOperator* op) const override {
1355 op->align_corners = options.align_corners();
1356 }
1357
GetVersion(const OperatorSignature & op_signature) const1358 int GetVersion(const OperatorSignature& op_signature) const override {
1359 const string& input_name = op_signature.op->inputs[0];
1360 const Array& input_array = op_signature.model->GetArray(input_name);
1361 // If the op takes int8 input, it is version 2.
1362 if (input_array.data_type == ArrayDataType::kInt8) {
1363 return 2;
1364 }
1365 return 1;
1366 }
1367 };
1368
1369 class ResizeNearestNeighbor
1370 : public BuiltinOperator<
1371 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
1372 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
1373 public:
1374 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1375 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1376 const TocoOperator& op,
1377 flatbuffers::FlatBufferBuilder* builder) const override {
1378 return ::tflite::CreateResizeNearestNeighborOptions(*builder,
1379 op.align_corners);
1380 }
1381
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1382 void ReadOptions(const TfLiteOptions& options,
1383 TocoOperator* op) const override {
1384 op->align_corners = options.align_corners();
1385 }
1386
GetVersion(const OperatorSignature & op_signature) const1387 int GetVersion(const OperatorSignature& op_signature) const override {
1388 const string& input_name = op_signature.op->inputs[0];
1389 const Array& input_array = op_signature.model->GetArray(input_name);
1390 // Version 2 supports signed int8 input types.
1391 if (input_array.data_type == ArrayDataType::kInt8) {
1392 return 2;
1393 }
1394 return 1;
1395 }
1396 };
1397
1398 class Squeeze
1399 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
1400 ::tflite::BuiltinOptions_SqueezeOptions> {
1401 public:
1402 using BuiltinOperator::BuiltinOperator;
1403
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1404 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1405 const TocoOperator& op,
1406 flatbuffers::FlatBufferBuilder* builder) const override {
1407 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
1408 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
1409 }
1410
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1411 void ReadOptions(const TfLiteOptions& options,
1412 TocoOperator* op) const override {
1413 op->squeeze_dims.insert(op->squeeze_dims.end(),
1414 options.squeeze_dims()->begin(),
1415 options.squeeze_dims()->end());
1416 }
1417
GetVersion(const OperatorSignature & op_signature) const1418 int GetVersion(const OperatorSignature& op_signature) const override {
1419 return 1;
1420 }
1421 };
1422
1423 class Split
1424 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
1425 ::tflite::BuiltinOptions_SplitOptions> {
1426 public:
1427 using BuiltinOperator::BuiltinOperator;
1428
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1429 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1430 const TocoOperator& op,
1431 flatbuffers::FlatBufferBuilder* builder) const override {
1432 return ::tflite::CreateSplitOptions(*builder, op.num_split);
1433 }
1434
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1435 void ReadOptions(const TfLiteOptions& options,
1436 TocoOperator* op) const override {
1437 op->num_split = options.num_splits();
1438 }
1439
GetVersion(const OperatorSignature & op_signature) const1440 int GetVersion(const OperatorSignature& op_signature) const override {
1441 const string& input_name = op_signature.op->inputs[0];
1442 const Array& input_array = op_signature.model->GetArray(input_name);
1443 // If the op take int8 input, it is version 2, for int32 it's version 3.
1444 if (input_array.data_type == ArrayDataType::kInt8) {
1445 return 2;
1446 } else if (input_array.data_type == ArrayDataType::kInt32) {
1447 return 3;
1448 }
1449 return 1;
1450 }
1451 };
1452
1453 class SplitV
1454 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
1455 ::tflite::BuiltinOptions_SplitVOptions> {
1456 public:
1457 using BuiltinOperator::BuiltinOperator;
1458
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1459 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1460 const TocoOperator& op,
1461 flatbuffers::FlatBufferBuilder* builder) const override {
1462 return ::tflite::CreateSplitVOptions(*builder, op.num_split);
1463 }
1464
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1465 void ReadOptions(const TfLiteOptions& options,
1466 TocoOperator* op) const override {
1467 op->num_split = options.num_splits();
1468 }
1469
GetVersion(const OperatorSignature & op_signature) const1470 int GetVersion(const OperatorSignature& op_signature) const override {
1471 return 1;
1472 }
1473 };
1474
1475 class StridedSlice
1476 : public BuiltinOperator<StridedSliceOperator,
1477 ::tflite::StridedSliceOptions,
1478 ::tflite::BuiltinOptions_StridedSliceOptions> {
1479 public:
1480 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1481 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1482 const TocoOperator& op,
1483 flatbuffers::FlatBufferBuilder* builder) const override {
1484 return ::tflite::CreateStridedSliceOptions(
1485 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
1486 op.new_axis_mask, op.shrink_axis_mask);
1487 }
1488
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1489 void ReadOptions(const TfLiteOptions& options,
1490 TocoOperator* op) const override {
1491 op->begin_mask = options.begin_mask();
1492 op->end_mask = options.end_mask();
1493 op->ellipsis_mask = options.ellipsis_mask();
1494 op->new_axis_mask = options.new_axis_mask();
1495 op->shrink_axis_mask = options.shrink_axis_mask();
1496 }
1497
GetVersion(const OperatorSignature & op_signature) const1498 int GetVersion(const OperatorSignature& op_signature) const override {
1499 const string& input_name = op_signature.op->inputs[0];
1500 const Array& input_array = op_signature.model->GetArray(input_name);
1501 // If the op take int8 input, it is version 2.
1502 if (input_array.data_type == ArrayDataType::kInt8) {
1503 return 2;
1504 }
1505 return 1;
1506 }
1507 };
1508
1509 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
1510 ::tflite::BuiltinOptions_TopKV2Options> {
1511 public:
1512 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1513 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1514 const TocoOperator& op,
1515 flatbuffers::FlatBufferBuilder* builder) const override {
1516 return ::tflite::CreateTopKV2Options(*builder);
1517 }
1518
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1519 void ReadOptions(const TfLiteOptions& options,
1520 TocoOperator* op) const override {}
1521
GetVersion(const OperatorSignature & op_signature) const1522 int GetVersion(const OperatorSignature& op_signature) const override {
1523 const string& input_name = op_signature.op->inputs[0];
1524 const Array& input_array = op_signature.model->GetArray(input_name);
1525 if (input_array.data_type == ArrayDataType::kInt8) {
1526 return 2;
1527 }
1528 return 1;
1529 }
1530 };
1531
1532 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
1533 ::tflite::BuiltinOptions_ArgMaxOptions> {
1534 public:
1535 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1536 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1537 const TocoOperator& op,
1538 flatbuffers::FlatBufferBuilder* builder) const override {
1539 return ::tflite::CreateArgMaxOptions(
1540 *builder, DataType::Serialize(op.output_data_type));
1541 }
1542
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1543 void ReadOptions(const TfLiteOptions& options,
1544 TocoOperator* op) const override {
1545 op->output_data_type = DataType::Deserialize(options.output_type());
1546 }
1547
GetVersion(const OperatorSignature & op_signature) const1548 int GetVersion(const OperatorSignature& op_signature) const override {
1549 const string& input_name = op_signature.op->inputs[0];
1550 const Array& input_array = op_signature.model->GetArray(input_name);
1551 if (input_array.data_type == ArrayDataType::kInt8) {
1552 return 2;
1553 }
1554
1555 return 1;
1556 }
1557 };
1558
1559 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
1560 ::tflite::BuiltinOptions_ArgMinOptions> {
1561 public:
1562 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1563 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1564 const TocoOperator& op,
1565 flatbuffers::FlatBufferBuilder* builder) const override {
1566 return ::tflite::CreateArgMinOptions(
1567 *builder, DataType::Serialize(op.output_data_type));
1568 }
1569
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1570 void ReadOptions(const TfLiteOptions& options,
1571 TocoOperator* op) const override {
1572 op->output_data_type = DataType::Deserialize(options.output_type());
1573 }
1574
GetVersion(const OperatorSignature & op_signature) const1575 int GetVersion(const OperatorSignature& op_signature) const override {
1576 const string& input_name = op_signature.op->inputs[0];
1577 const Array& input_array = op_signature.model->GetArray(input_name);
1578 if (input_array.data_type == ArrayDataType::kInt8) {
1579 return 2;
1580 }
1581
1582 return 1;
1583 }
1584 };
1585
1586 class TransposeConv
1587 : public BuiltinOperator<TransposeConvOperator,
1588 ::tflite::TransposeConvOptions,
1589 ::tflite::BuiltinOptions_TransposeConvOptions> {
1590 public:
1591 using BuiltinOperator::BuiltinOperator;
1592
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1593 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1594 const TocoOperator& op,
1595 flatbuffers::FlatBufferBuilder* builder) const override {
1596 auto padding = Padding::Serialize(op.padding.type);
1597 return ::tflite::CreateTransposeConvOptions(
1598 *builder, padding, op.stride_width, op.stride_height);
1599 }
1600
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1601 void ReadOptions(const TfLiteOptions& options,
1602 TocoOperator* op) const override {
1603 op->padding.type = Padding::Deserialize(options.padding());
1604 op->stride_width = options.stride_w();
1605 op->stride_height = options.stride_h();
1606 }
1607
GetVersion(const OperatorSignature & op_signature) const1608 int GetVersion(const OperatorSignature& op_signature) const override {
1609 return 1;
1610 }
1611 };
1612
1613 class SparseToDense
1614 : public BuiltinOperator<SparseToDenseOperator,
1615 ::tflite::SparseToDenseOptions,
1616 ::tflite::BuiltinOptions_SparseToDenseOptions> {
1617 public:
1618 using BuiltinOperator::BuiltinOperator;
1619
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1620 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1621 const TocoOperator& op,
1622 flatbuffers::FlatBufferBuilder* builder) const override {
1623 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
1624 }
1625
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1626 void ReadOptions(const TfLiteOptions& options,
1627 TocoOperator* op) const override {
1628 op->validate_indices = options.validate_indices();
1629 }
1630
GetVersion(const OperatorSignature & op_signature) const1631 int GetVersion(const OperatorSignature& op_signature) const override {
1632 return 1;
1633 }
1634 };
1635
1636 class ExpandDims
1637 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
1638 ::tflite::BuiltinOptions_ExpandDimsOptions> {
1639 public:
1640 using BuiltinOperator::BuiltinOperator;
1641
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1642 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1643 const TocoOperator& op,
1644 flatbuffers::FlatBufferBuilder* builder) const override {
1645 return ::tflite::CreateExpandDimsOptions(*builder);
1646 }
1647
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1648 void ReadOptions(const TfLiteOptions& options,
1649 TocoOperator* op) const override {}
1650
GetVersion(const OperatorSignature & op_signature) const1651 int GetVersion(const OperatorSignature& op_signature) const override {
1652 return 1;
1653 }
1654 };
1655
1656 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
1657 ::tflite::BuiltinOptions_PackOptions> {
1658 public:
1659 using BuiltinOperator::BuiltinOperator;
1660
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1661 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1662 const TocoOperator& op,
1663 flatbuffers::FlatBufferBuilder* builder) const override {
1664 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
1665 }
1666
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1667 void ReadOptions(const TfLiteOptions& options,
1668 TocoOperator* op) const override {
1669 op->values_count = options.values_count();
1670 op->axis = options.axis();
1671 }
1672
GetVersion(const OperatorSignature & op_signature) const1673 int GetVersion(const OperatorSignature& op_signature) const override {
1674 const string& input_name = op_signature.op->inputs[0];
1675 const Array& input_array = op_signature.model->GetArray(input_name);
1676 // If the op take int8 input, it is version 2.
1677 if (input_array.data_type == ArrayDataType::kInt8) {
1678 return 2;
1679 }
1680 return 1;
1681 }
1682 };
1683
1684 class Shape
1685 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
1686 ::tflite::BuiltinOptions_ShapeOptions> {
1687 public:
1688 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1689 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1690 const TocoOperator& op,
1691 flatbuffers::FlatBufferBuilder* builder) const override {
1692 return ::tflite::CreateShapeOptions(
1693 *builder, DataType::Serialize(op.output_data_type));
1694 }
1695
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1696 void ReadOptions(const TfLiteOptions& options,
1697 TocoOperator* op) const override {
1698 op->output_data_type = DataType::Deserialize(options.out_type());
1699 }
1700
GetVersion(const OperatorSignature & op_signature) const1701 int GetVersion(const OperatorSignature& op_signature) const override {
1702 return 1;
1703 }
1704 };
1705
1706 class Slice : public SimpleOperator<SliceOperator> {
1707 public:
Slice()1708 explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {}
GetVersion(const OperatorSignature & op_signature) const1709 int GetVersion(const OperatorSignature& op_signature) const override {
1710 const string& input_name = op_signature.op->inputs[0];
1711 const Array& input_array = op_signature.model->GetArray(input_name);
1712 // Version 2 supports signed int8 input types.
1713 if (input_array.data_type == ArrayDataType::kInt8) {
1714 return 2;
1715 }
1716 return 1;
1717 }
1718 };
1719
1720 class Tanh : public SimpleOperator<TanhOperator> {
1721 public:
Tanh()1722 explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {}
GetVersion(const OperatorSignature & op_signature) const1723 int GetVersion(const OperatorSignature& op_signature) const override {
1724 const string& input_name = op_signature.op->inputs[0];
1725 const Array& input_array = op_signature.model->GetArray(input_name);
1726 // Version 2 supports signed int8 input types.
1727 if (input_array.data_type == ArrayDataType::kInt8) {
1728 return 2;
1729 }
1730 return 1;
1731 }
1732 };
1733
1734 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
1735 ::tflite::BuiltinOptions_OneHotOptions> {
1736 public:
1737 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1738 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1739 const TocoOperator& op,
1740 flatbuffers::FlatBufferBuilder* builder) const override {
1741 return ::tflite::CreateOneHotOptions(*builder, op.axis);
1742 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1743 void ReadOptions(const TfLiteOptions& options,
1744 TocoOperator* op) const override {
1745 op->axis = options.axis();
1746 }
1747
GetVersion(const OperatorSignature & op_signature) const1748 int GetVersion(const OperatorSignature& op_signature) const override {
1749 return 1;
1750 }
1751 };
1752
1753 class CTCBeamSearchDecoder
1754 : public CustomOperator<CTCBeamSearchDecoderOperator> {
1755 public:
1756 using CustomOperator::CustomOperator;
1757
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const1758 void WriteOptions(const TocoOperator& op,
1759 flexbuffers::Builder* fbb) const override {
1760 fbb->Int("beam_width", op.beam_width);
1761 fbb->Int("top_paths", op.top_paths);
1762 fbb->Bool("merge_repeated", op.merge_repeated);
1763 }
1764
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const1765 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
1766 op->beam_width = m["beam_width"].AsInt32();
1767 op->top_paths = m["top_paths"].AsInt32();
1768 op->merge_repeated = m["merge_repeated"].AsBool();
1769 }
1770
GetVersion(const OperatorSignature & op_signature) const1771 int GetVersion(const OperatorSignature& op_signature) const override {
1772 return 1;
1773 }
1774 };
1775
1776 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
1777 ::tflite::BuiltinOptions_UnpackOptions> {
1778 public:
1779 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1780 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1781 const TocoOperator& op,
1782 flatbuffers::FlatBufferBuilder* builder) const override {
1783 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
1784 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1785 void ReadOptions(const TfLiteOptions& options,
1786 TocoOperator* op) const override {
1787 op->num = options.num();
1788 op->axis = options.axis();
1789 }
1790
GetVersion(const OperatorSignature & op_signature) const1791 int GetVersion(const OperatorSignature& op_signature) const override {
1792 return 1;
1793 }
1794 };
1795
1796 class LeakyRelu
1797 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
1798 ::tflite::BuiltinOptions_LeakyReluOptions> {
1799 public:
1800 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1801 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1802 const TocoOperator& op,
1803 flatbuffers::FlatBufferBuilder* builder) const override {
1804 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
1805 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1806 void ReadOptions(const TfLiteOptions& options,
1807 TocoOperator* op) const override {
1808 op->alpha = options.alpha();
1809 }
1810
GetVersion(const OperatorSignature & op_signature) const1811 int GetVersion(const OperatorSignature& op_signature) const override {
1812 return 1;
1813 }
1814 };
1815
1816 class Logistic : public SimpleOperator<LogisticOperator> {
1817 public:
Logistic()1818 explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {}
GetVersion(const OperatorSignature & op_signature) const1819 int GetVersion(const OperatorSignature& op_signature) const override {
1820 const string& input_name = op_signature.op->inputs[0];
1821 const Array& input_array = op_signature.model->GetArray(input_name);
1822 // Version 2 supports signed int8 input types.
1823 if (input_array.data_type == ArrayDataType::kInt8) {
1824 return 2;
1825 }
1826 return 1;
1827 }
1828 };
1829
1830 class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
1831 public:
LogSoftmax()1832 explicit LogSoftmax()
1833 : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
GetVersion(const OperatorSignature & op_signature) const1834 int GetVersion(const OperatorSignature& op_signature) const override {
1835 const string& input_name = op_signature.op->inputs[0];
1836 const Array& input_array = op_signature.model->GetArray(input_name);
1837 // Version 2 supports signed int8 input types.
1838 if (input_array.data_type == ArrayDataType::kInt8) {
1839 return 2;
1840 }
1841 return 1;
1842 }
1843 };
1844
1845 class SquaredDifference
1846 : public BuiltinOperator<
1847 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
1848 ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
1849 public:
1850 using BuiltinOperator::BuiltinOperator;
1851
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1852 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1853 const TocoOperator& op,
1854 flatbuffers::FlatBufferBuilder* builder) const override {
1855 return ::tflite::CreateSquaredDifferenceOptions(*builder);
1856 }
1857
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1858 void ReadOptions(const TfLiteOptions& options,
1859 TocoOperator* op) const override {}
1860
GetVersion(const OperatorSignature & op_signature) const1861 int GetVersion(const OperatorSignature& op_signature) const override {
1862 return 1;
1863 }
1864 };
1865
1866 class MirrorPad
1867 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
1868 ::tflite::BuiltinOptions_MirrorPadOptions> {
1869 public:
1870 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1871 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1872 const TocoOperator& op,
1873 flatbuffers::FlatBufferBuilder* builder) const override {
1874 return ::tflite::CreateMirrorPadOptions(
1875 *builder, op.mode == MirrorPadMode::kReflect
1876 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1877 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
1878 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1879 void ReadOptions(const TfLiteOptions& options,
1880 TocoOperator* op) const override {
1881 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
1882 ? MirrorPadMode::kReflect
1883 : MirrorPadMode::kSymmetric;
1884 }
1885
GetVersion(const OperatorSignature & op) const1886 int GetVersion(const OperatorSignature& op) const override { return 1; }
1887 };
1888
1889 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
1890 ::tflite::BuiltinOptions_UniqueOptions> {
1891 public:
1892 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1893 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1894 const TocoOperator& op,
1895 flatbuffers::FlatBufferBuilder* builder) const override {
1896 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
1897 return ::tflite::CreateUniqueOptions(
1898 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
1899 ? ::tflite::TensorType::TensorType_INT64
1900 : ::tflite::TensorType_INT32);
1901 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1902 void ReadOptions(const TfLiteOptions& options,
1903 TocoOperator* op) const override {
1904 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
1905 unique_op->idx_out_type =
1906 options.idx_out_type() == ::tflite::TensorType_INT64
1907 ? toco::ArrayDataType::kInt64
1908 : toco::ArrayDataType::kInt32;
1909 }
1910
GetVersion(const OperatorSignature & op_signature) const1911 int GetVersion(const OperatorSignature& op_signature) const override {
1912 return 1;
1913 }
1914 };
1915
1916 class UnidirectionalSequenceRnn
1917 : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
1918 ::tflite::SequenceRNNOptions,
1919 ::tflite::BuiltinOptions_SequenceRNNOptions> {
1920 public:
1921 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1922 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1923 const TocoOperator& op,
1924 flatbuffers::FlatBufferBuilder* builder) const override {
1925 return ::tflite::CreateSequenceRNNOptions(
1926 *builder, /*time_major=*/true,
1927 /*fused_activation_function=*/
1928 ::tflite::ActivationFunctionType_TANH);
1929 }
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1930 void ReadOptions(const TfLiteOptions& options,
1931 TocoOperator* op) const override {
1932 // Only support tanh activation, so check that tflite type is tanh.
1933 DCHECK(options.fused_activation_function() ==
1934 ::tflite::ActivationFunctionType_TANH);
1935 }
1936
GetVersion(const OperatorSignature & op_signature) const1937 int GetVersion(const OperatorSignature& op_signature) const override {
1938 return 1;
1939 }
1940
GetMutatingInputVariables(const Operator & op) const1941 std::vector<bool> GetMutatingInputVariables(
1942 const Operator& op) const override {
1943 std::vector<bool> mutating_input_variables(op.inputs.size(), false);
1944 mutating_input_variables[4] = true;
1945 return mutating_input_variables;
1946 }
1947 };
1948
1949 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
1950 ::tflite::BuiltinOptions_WhereOptions> {
1951 public:
1952 using BuiltinOperator::BuiltinOperator;
1953
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const1954 flatbuffers::Offset<TfLiteOptions> WriteOptions(
1955 const TocoOperator& op,
1956 flatbuffers::FlatBufferBuilder* builder) const override {
1957 return ::tflite::CreateWhereOptions(*builder);
1958 }
1959
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const1960 void ReadOptions(const TfLiteOptions& options,
1961 TocoOperator* op) const override {}
1962
GetVersion(const OperatorSignature & op_signature) const1963 int GetVersion(const OperatorSignature& op_signature) const override {
1964 return 1;
1965 }
1966 };
1967
WriteFlexOpOptions(const string & tensorflow_node_def)1968 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
1969 const string& tensorflow_node_def) {
1970 auto fbb = absl::make_unique<flexbuffers::Builder>();
1971
1972 ::tensorflow::NodeDef node_def;
1973 if (!node_def.ParseFromString(tensorflow_node_def)) {
1974 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
1975 return {};
1976 }
1977
1978 fbb->Vector([&]() {
1979 fbb->String(node_def.op());
1980 fbb->String(tensorflow_node_def);
1981 });
1982 fbb->Finish();
1983 LOG(INFO) << "Writing flex op: " << node_def.op();
1984 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
1985 }
1986
1987 class TensorFlowUnsupported : public BaseOperator {
1988 public:
TensorFlowUnsupported(const string & name,OperatorType type,bool enable_select_tf_ops)1989 TensorFlowUnsupported(const string& name, OperatorType type,
1990 bool enable_select_tf_ops)
1991 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
1992
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const1993 Options Serialize(const Operator& op,
1994 flatbuffers::FlatBufferBuilder* builder) const override {
1995 auto fbb =
1996 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
1997 if (fbb) {
1998 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
1999 } else {
2000 return Options::Custom(0);
2001 }
2002 }
2003
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const2004 std::unique_ptr<Operator> Deserialize(
2005 const BuiltinOptions* builtin_options,
2006 const CustomOptions* custom_options) const override {
2007 // Deserializing Flex ops doesn't work now.
2008 // TODO(ycling): Revisit and decide if we should fix the flow for importing
2009 // TFLite models with Flex ops.
2010 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
2011 if (custom_options) {
2012 auto flexbuffer_map =
2013 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
2014 .AsMap();
2015 ReadOptions(flexbuffer_map, op.get());
2016 }
2017 return std::unique_ptr<Operator>(op.release());
2018 }
2019
WriteOptions(const TensorFlowUnsupportedOperator & op) const2020 std::unique_ptr<flexbuffers::Builder> WriteOptions(
2021 const TensorFlowUnsupportedOperator& op) const {
2022 if (enable_select_tf_ops_) {
2023 return WriteFlexOpOptions(op.tensorflow_node_def);
2024 }
2025 auto fbb = absl::make_unique<flexbuffers::Builder>();
2026
2027 ::tensorflow::NodeDef node_def;
2028 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
2029 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
2030 return std::unique_ptr<flexbuffers::Builder>();
2031 }
2032
2033 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
2034 fbb->Vector([&]() {
2035 fbb->String(node_def.op());
2036 fbb->String(op.tensorflow_node_def);
2037 });
2038 fbb->Finish();
2039 LOG(INFO) << "Writing flex op: " << node_def.op();
2040 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
2041 }
2042
2043 bool has_valid_attr = false;
2044 size_t map_start = fbb->StartMap();
2045 for (const auto& pair : node_def.attr()) {
2046 const char* key = pair.first.c_str();
2047 const auto& attr = pair.second;
2048 switch (attr.value_case()) {
2049 case ::tensorflow::AttrValue::kS:
2050 fbb->String(key, attr.s());
2051 has_valid_attr = true;
2052 break;
2053 case ::tensorflow::AttrValue::kI:
2054 fbb->Int(key, attr.i());
2055 has_valid_attr = true;
2056 break;
2057 case ::tensorflow::AttrValue::kF:
2058 fbb->Float(key, attr.f());
2059 has_valid_attr = true;
2060 break;
2061 case ::tensorflow::AttrValue::kB:
2062 fbb->Bool(key, attr.b());
2063 has_valid_attr = true;
2064 break;
2065 case tensorflow::AttrValue::kList:
2066 if (attr.list().s_size() > 0) {
2067 auto start = fbb->StartVector(key);
2068 for (const string& v : attr.list().s()) {
2069 fbb->Add(v);
2070 }
2071 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2072 has_valid_attr = true;
2073 } else if (attr.list().i_size() > 0) {
2074 auto start = fbb->StartVector(key);
2075 for (const int64_t v : attr.list().i()) {
2076 fbb->Add(v);
2077 }
2078 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2079 has_valid_attr = true;
2080 } else if (attr.list().f_size() > 0) {
2081 auto start = fbb->StartVector(key);
2082 for (const float v : attr.list().f()) {
2083 fbb->Add(v);
2084 }
2085 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
2086 has_valid_attr = true;
2087 } else {
2088 LOG(WARNING)
2089 << "Ignoring unsupported type in list attribute with key '"
2090 << key << "'";
2091 }
2092 break;
2093 default:
2094 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
2095 << key << "'";
2096 break;
2097 }
2098 }
2099 if (!has_valid_attr) {
2100 return std::unique_ptr<flexbuffers::Builder>();
2101 }
2102 fbb->EndMap(map_start);
2103 fbb->Finish();
2104 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
2105 }
2106
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const2107 void ReadOptions(const flexbuffers::Map& m,
2108 TensorFlowUnsupportedOperator* op) const {
2109 ::tensorflow::NodeDef node_def;
2110 auto attr = node_def.mutable_attr();
2111
2112 const auto& keys = m.Keys();
2113 for (size_t i = 0; i < keys.size(); ++i) {
2114 const auto key = keys[i].AsKey();
2115 const auto& value = m[key];
2116 // TODO(wvo): hack to make this code compile with 2 different API
2117 // versions.
2118 // Please remove once OS/internal versions are in sync.
2119 // See hardcoded values in the switch below.
2120 switch (value.GetType()) {
2121 case 5: // flexbuffers::FBT_STRING:
2122 (*attr)[key].set_s(value.AsString().c_str());
2123 break;
2124 case 1: // flexbuffers::FBT_INT:
2125 (*attr)[key].set_i(value.AsInt64());
2126 break;
2127 case 3: // flexbuffers::FBT_FLOAT:
2128 (*attr)[key].set_f(value.AsFloat());
2129 break;
2130 case 26: // flexbuffers::FBT_BOOL:
2131 (*attr)[key].set_b(value.AsBool());
2132 if (string(key) == "_output_quantized") {
2133 op->quantized = value.AsBool();
2134 }
2135 if (string(key) == "_support_output_type_float_in_quantized_op") {
2136 op->support_output_type_float_in_quantized_op = value.AsBool();
2137 }
2138 break;
2139 case 11: { // flexbuffers::FBT_VECTOR_INT: {
2140 auto* list = (*attr)[key].mutable_list();
2141 const auto& vector = value.AsTypedVector();
2142 for (size_t i = 0; i < vector.size(); i++) {
2143 list->add_i(vector[i].AsInt64());
2144 }
2145 break;
2146 }
2147 case 13: { // flexbuffers::FBT_VECTOR_FLOAT: {
2148 auto* list = (*attr)[key].mutable_list();
2149 const auto& vector = value.AsTypedVector();
2150 for (size_t i = 0; i < vector.size(); i++) {
2151 list->add_f(vector[i].AsFloat());
2152 }
2153 break;
2154 }
2155 case 15: { // flexbuffers::FBT_VECTOR_STRING: {
2156 auto* list = (*attr)[key].mutable_list();
2157 const auto& vector = value.AsTypedVector();
2158 for (size_t i = 0; i < vector.size(); i++) {
2159 list->add_s(vector[i].AsString().str());
2160 }
2161 break;
2162 }
2163 default:
2164 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
2165 << key << "'";
2166 break;
2167 }
2168 }
2169 node_def.SerializeToString(&op->tensorflow_node_def);
2170 }
2171
GetVersion(const OperatorSignature & op_signature) const2172 int GetVersion(const OperatorSignature& op_signature) const override {
2173 // TODO(ycling): Design and implement a way to plumb the version of
2174 // custom ops.
2175 return 1;
2176 }
2177
2178 private:
2179 const bool enable_select_tf_ops_;
2180 };
2181
2182 class Dequantize
2183 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
2184 ::tflite::BuiltinOptions_DequantizeOptions> {
2185 public:
2186 using BuiltinOperator::BuiltinOperator;
2187
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const2188 flatbuffers::Offset<TfLiteOptions> WriteOptions(
2189 const TocoOperator& op,
2190 flatbuffers::FlatBufferBuilder* builder) const override {
2191 return ::tflite::CreateDequantizeOptions(*builder);
2192 }
2193
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const2194 void ReadOptions(const TfLiteOptions& options,
2195 TocoOperator* op) const override {}
2196
GetVersion(const OperatorSignature & op_signature) const2197 int GetVersion(const OperatorSignature& op_signature) const override {
2198 const string& input_name = op_signature.op->inputs[0];
2199 const Array& input_array = op_signature.model->GetArray(input_name);
2200 // Version 2 supports signed int8 input types.
2201 if (input_array.data_type == ArrayDataType::kInt8) {
2202 return 2;
2203 }
2204 return 1;
2205 }
2206 };
2207
2208 class ReverseSequence
2209 : public BuiltinOperator<ReverseSequenceOperator,
2210 ::tflite::ReverseSequenceOptions,
2211 ::tflite::BuiltinOptions_ReverseSequenceOptions> {
2212 public:
2213 using BuiltinOperator::BuiltinOperator;
2214
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const2215 flatbuffers::Offset<TfLiteOptions> WriteOptions(
2216 const TocoOperator& op,
2217 flatbuffers::FlatBufferBuilder* builder) const override {
2218 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
2219 op.batch_dim);
2220 }
2221
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const2222 void ReadOptions(const TfLiteOptions& options,
2223 TocoOperator* op) const override {
2224 op->seq_dim = options.seq_dim();
2225 op->batch_dim = options.batch_dim();
2226 }
2227
GetVersion(const OperatorSignature & op_signature) const2228 int GetVersion(const OperatorSignature& op_signature) const override {
2229 return 1;
2230 }
2231 };
2232
2233 class Equal : public SimpleOperator<TensorFlowEqualOperator> {
2234 public:
Equal()2235 explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {}
GetVersion(const OperatorSignature & op_signature) const2236 int GetVersion(const OperatorSignature& op_signature) const override {
2237 const string& input_name = op_signature.op->inputs[0];
2238 const Array& input_array = op_signature.model->GetArray(input_name);
2239 // Version 2 supports signed int8 input types.
2240 if (input_array.data_type == ArrayDataType::kInt8) {
2241 return 2;
2242 }
2243 return 1;
2244 }
2245 };
2246
2247 class NotEqual : public SimpleOperator<TensorFlowNotEqualOperator> {
2248 public:
NotEqual()2249 explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {}
GetVersion(const OperatorSignature & op_signature) const2250 int GetVersion(const OperatorSignature& op_signature) const override {
2251 const string& input_name = op_signature.op->inputs[0];
2252 const Array& input_array = op_signature.model->GetArray(input_name);
2253 // Version 2 supports signed int8 input types.
2254 if (input_array.data_type == ArrayDataType::kInt8) {
2255 return 2;
2256 }
2257 return 1;
2258 }
2259 };
2260
2261 class Greater : public SimpleOperator<TensorFlowGreaterOperator> {
2262 public:
Greater()2263 explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {}
GetVersion(const OperatorSignature & op_signature) const2264 int GetVersion(const OperatorSignature& op_signature) const override {
2265 const string& input_name = op_signature.op->inputs[0];
2266 const Array& input_array = op_signature.model->GetArray(input_name);
2267 // Version 2 supports signed int8 input types.
2268 if (input_array.data_type == ArrayDataType::kInt8) {
2269 return 2;
2270 }
2271 return 1;
2272 }
2273 };
2274
2275 class GreaterEqual : public SimpleOperator<TensorFlowGreaterEqualOperator> {
2276 public:
GreaterEqual()2277 explicit GreaterEqual()
2278 : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {}
GetVersion(const OperatorSignature & op_signature) const2279 int GetVersion(const OperatorSignature& op_signature) const override {
2280 const string& input_name = op_signature.op->inputs[0];
2281 const Array& input_array = op_signature.model->GetArray(input_name);
2282 // Version 2 supports signed int8 input types.
2283 if (input_array.data_type == ArrayDataType::kInt8) {
2284 return 2;
2285 }
2286 return 1;
2287 }
2288 };
2289
2290 class Less : public SimpleOperator<TensorFlowLessOperator> {
2291 public:
Less()2292 explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {}
GetVersion(const OperatorSignature & op_signature) const2293 int GetVersion(const OperatorSignature& op_signature) const override {
2294 const string& input_name = op_signature.op->inputs[0];
2295 const Array& input_array = op_signature.model->GetArray(input_name);
2296 // Version 2 supports signed int8 input types.
2297 if (input_array.data_type == ArrayDataType::kInt8) {
2298 return 2;
2299 }
2300 return 1;
2301 }
2302 };
2303
2304 class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
2305 public:
LessEqual()2306 explicit LessEqual()
2307 : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {}
GetVersion(const OperatorSignature & op_signature) const2308 int GetVersion(const OperatorSignature& op_signature) const override {
2309 const string& input_name = op_signature.op->inputs[0];
2310 const Array& input_array = op_signature.model->GetArray(input_name);
2311 // Version 2 supports signed int8 input types.
2312 if (input_array.data_type == ArrayDataType::kInt8) {
2313 return 2;
2314 }
2315 return 1;
2316 }
2317 };
2318
2319 class Select : public SimpleOperator<SelectOperator> {
2320 public:
Select()2321 explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
GetVersion(const OperatorSignature & op_signature) const2322 int GetVersion(const OperatorSignature& op_signature) const override {
2323 const string& input_name = op_signature.op->inputs[0];
2324 const Array& input_array = op_signature.model->GetArray(input_name);
2325 // Version 2 supports signed int8 input types.
2326 if (input_array.data_type == ArrayDataType::kInt8) {
2327 return 2;
2328 }
2329 return 1;
2330 }
2331 };
2332
2333 namespace {
2334 // Build a vector containing all the known operators.
BuildOperatorList(bool enable_select_tf_ops=false)2335 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
2336 bool enable_select_tf_ops = false) {
2337 std::vector<std::unique_ptr<BaseOperator>> ops;
2338 using tensorflow::MakeUnique;
2339 // Builtin Operators.
2340 ops.push_back(
2341 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
2342 ops.push_back(
2343 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
2344 ops.push_back(
2345 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
2346 ops.push_back(
2347 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
2348 ops.push_back(MakeUnique<AveragePool>(
2349 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
2350 ops.push_back(
2351 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
2352 OperatorType::kSpaceToBatchND));
2353 ops.push_back(
2354 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
2355 OperatorType::kBatchToSpaceND));
2356 ops.push_back(MakeUnique<Concatenation>(
2357 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
2358 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
2359 OperatorType::kConv));
2360 ops.push_back(MakeUnique<DepthwiseConvolution>(
2361 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
2362 OperatorType::kDepthwiseConv));
2363 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
2364 OperatorType::kDequantize));
2365 ops.push_back(
2366 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
2367 OperatorType::kFullyConnected));
2368 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
2369 OperatorType::kGather));
2370 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
2371 OperatorType::kGatherNd));
2372 ops.push_back(
2373 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
2374 OperatorType::kL2Normalization));
2375 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
2376 OperatorType::kL2Pool));
2377 ops.push_back(MakeUnique<LocalResponseNormalization>(
2378 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
2379 OperatorType::kLocalResponseNormalization));
2380 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
2381 OperatorType::kMaxPool));
2382 ops.push_back(
2383 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
2384
2385 ops.push_back(
2386 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
2387 ops.push_back(
2388 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
2389 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
2390 OperatorType::kReshape));
2391 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
2392 OperatorType::kSoftmax));
2393 ops.push_back(MakeUnique<SpaceToDepth>(
2394 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
2395 ops.push_back(
2396 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
2397 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
2398 OperatorType::kTranspose));
2399 ops.push_back(
2400 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
2401 ops.push_back(
2402 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
2403 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
2404 OperatorType::kReduceProd));
2405 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
2406 OperatorType::kReduceMax));
2407 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
2408 OperatorType::kReduceMin));
2409 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
2410 OperatorType::kAny));
2411 ops.push_back(
2412 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
2413 OperatorType::kResizeBilinear));
2414 ops.push_back(MakeUnique<ResizeNearestNeighbor>(
2415 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
2416 OperatorType::kResizeNearestNeighbor));
2417 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
2418 OperatorType::kSqueeze));
2419 ops.push_back(
2420 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
2421 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
2422 OperatorType::kSplitV));
2423 ops.push_back(MakeUnique<StridedSlice>(
2424 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
2425 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
2426 OperatorType::kTopK_V2));
2427 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
2428 OperatorType::kLstmCell));
2429 ops.push_back(
2430 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
2431 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
2432 OperatorType::kArgMax));
2433 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
2434 OperatorType::kArgMin));
2435 ops.push_back(
2436 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
2437 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
2438 OperatorType::kExpandDims));
2439 ops.push_back(MakeUnique<TransposeConv>(
2440 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
2441 ops.push_back(MakeUnique<SparseToDense>(
2442 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
2443 ops.push_back(
2444 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
2445 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
2446 OperatorType::kFakeQuant));
2447 ops.push_back(
2448 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
2449 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
2450 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
2451 OperatorType::kUnidirectionalSequenceLstm));
2452 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
2453 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
2454 OperatorType::kBidirectionalSequenceLstm));
2455 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
2456 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
2457 OperatorType::kBidirectionalSequenceRnn));
2458 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
2459 OperatorType::kOneHot));
2460 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
2461 OperatorType::kUnpack));
2462 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
2463 OperatorType::kLeakyRelu));
2464 ops.push_back(MakeUnique<SquaredDifference>(
2465 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
2466 OperatorType::kSquaredDifference));
2467 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
2468 OperatorType::kMirrorPad));
2469 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
2470 OperatorType::kUnique));
2471 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
2472 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
2473 OperatorType::kUnidirectionalSequenceRnn));
2474 ops.push_back(
2475 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
2476 ops.push_back(
2477 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
2478 OperatorType::kReverseSequence));
2479
2480 // Custom Operators.
2481 ops.push_back(
2482 MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
2483 ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
2484 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
2485 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
2486 OperatorType::kUnsupported,
2487 enable_select_tf_ops));
2488
2489 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
2490 // been modified to also export builtins. As TOCO evolved we added warnings
2491 // when custom ops are exported but SimpleOperator bypasses thoses. To
2492 // prevent user confusion we are settling on using SimpleOperator only for
2493 // builtins.
2494 ops.push_back(
2495 MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
2496 ops.push_back(
2497 MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
2498 ops.push_back(
2499 MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
2500 ops.push_back(
2501 MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
2502 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
2503 "RELU_N1_TO_1", OperatorType::kRelu1));
2504 ops.push_back(MakeUnique<Relu6>());
2505 ops.push_back(
2506 MakeUnique<SimpleOperator<PReluOperator>>("PRELU", OperatorType::kPRelu));
2507 ops.push_back(MakeUnique<Logistic>());
2508 ops.push_back(MakeUnique<Tanh>());
2509 ops.push_back(
2510 MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
2511 ops.push_back(
2512 MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
2513 ops.push_back(MakeUnique<LogSoftmax>());
2514 ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum
2515 ops.push_back(MakeUnique<Minimum>()); // Element-wise Minimum
2516 ops.push_back(MakeUnique<Greater>());
2517 ops.push_back(MakeUnique<GreaterEqual>());
2518 ops.push_back(MakeUnique<Less>());
2519 ops.push_back(MakeUnique<LessEqual>());
2520 ops.push_back(MakeUnique<Equal>());
2521 ops.push_back(MakeUnique<NotEqual>());
2522 ops.push_back(
2523 MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
2524 ops.push_back(MakeUnique<Select>());
2525 ops.push_back(MakeUnique<Slice>());
2526 ops.push_back(
2527 MakeUnique<SimpleOperator<PowOperator>>("POW", OperatorType::kPow));
2528 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
2529 "LOGICAL_OR", OperatorType::kLogicalOr));
2530 ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
2531 "LOGICAL_AND", OperatorType::kLogicalAnd));
2532 ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
2533 "LOGICAL_NOT", OperatorType::kLogicalNot));
2534 ops.emplace_back(new SimpleOperator<FloorDivOperator>(
2535 "FLOOR_DIV", OperatorType::kFloorDiv));
2536 ops.emplace_back(new SimpleOperator<FloorModOperator>(
2537 "FLOOR_MOD", OperatorType::kFloorMod));
2538 ops.emplace_back(
2539 new SimpleOperator<RangeOperator>("RANGE", OperatorType::kRange));
2540 // Element-wise operator
2541 ops.push_back(
2542 MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
2543 ops.push_back(
2544 MakeUnique<SimpleOperator<LogOperator>>("LOG", OperatorType::kLog));
2545 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
2546 "SQRT", OperatorType::kSqrt));
2547 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
2548 "RSQRT", OperatorType::kRsqrt));
2549 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
2550 "SQUARE", OperatorType::kSquare));
2551 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
2552 "ZEROS_LIKE", OperatorType::kZerosLike));
2553 ops.push_back(
2554 MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
2555 ops.push_back(
2556 MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
2557 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
2558 "REVERSE_V2", OperatorType::kReverseV2));
2559 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
2560 "RANK", OperatorType::kRank));
2561 return ops;
2562 }
2563 } // namespace
2564
BuildOperatorByTypeMap(bool enable_select_tf_ops)2565 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
2566 bool enable_select_tf_ops) {
2567 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
2568
2569 std::vector<std::unique_ptr<BaseOperator>> ops =
2570 BuildOperatorList(enable_select_tf_ops);
2571 for (auto& op : ops) {
2572 result[op->type()] = std::move(op);
2573 }
2574
2575 return result;
2576 }
2577
BuildOperatorByNameMap(bool enable_select_tf_ops)2578 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
2579 bool enable_select_tf_ops) {
2580 std::map<string, std::unique_ptr<BaseOperator>> result;
2581
2582 std::vector<std::unique_ptr<BaseOperator>> ops =
2583 BuildOperatorList(enable_select_tf_ops);
2584 for (auto& op : ops) {
2585 result[op->name()] = std::move(op);
2586 }
2587
2588 return result;
2589 }
2590
ShouldExportAsFlexOp(bool enable_select_tf_ops,const string & tensorflow_op_name)2591 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
2592 const string& tensorflow_op_name) {
2593 // If Flex ops aren't allow at all, simply return false.
2594 if (!enable_select_tf_ops) {
2595 return false;
2596 }
2597 // Check if we can find the `OpDef` for the TensorFlow op. If we can find
2598 // it and it has been whitelisted, export the op as an Flex op. Otherwise,
2599 // export it as a regular custom op.
2600 const tensorflow::OpDef* op_def = nullptr;
2601 if (!tensorflow::OpRegistry::Global()
2602 ->LookUpOpDef(tensorflow_op_name, &op_def)
2603 .ok()) {
2604 return false;
2605 }
2606
2607 if (!IsWhitelistedFlexOp(tensorflow_op_name)) {
2608 LOG(WARNING) << "Op " << tensorflow_op_name
2609 << " is a valid TensorFlow op but has not been whitelisted for"
2610 " the TensorFlow Lite flex op set.";
2611 return false;
2612 }
2613
2614 return true;
2615 }
2616
2617 } // namespace tflite
2618
2619 } // namespace toco
2620