1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
16
17 #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
18 #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
19 #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
20 #include "tensorflow/contrib/lite/toco/tflite/types.h"
21
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24
25 namespace toco {
26
27 namespace tflite {
28
29 class AveragePool
30 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
31 ::tflite::BuiltinOptions_Pool2DOptions> {
32 public:
33 using BuiltinOperator::BuiltinOperator;
34
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const35 flatbuffers::Offset<TfLiteOptions> WriteOptions(
36 const TocoOperator& op,
37 flatbuffers::FlatBufferBuilder* builder) const override {
38 auto padding = Padding::Serialize(op.padding.type);
39 auto activation_function =
40 ActivationFunction::Serialize(op.fused_activation_function);
41 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
42 op.stride_height, op.kwidth,
43 op.kheight, activation_function);
44 }
45
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const46 void ReadOptions(const TfLiteOptions& options,
47 TocoOperator* op) const override {
48 op->padding.type = Padding::Deserialize(options.padding());
49 op->stride_width = options.stride_w();
50 op->stride_height = options.stride_h();
51 op->kwidth = options.filter_width();
52 op->kheight = options.filter_height();
53 op->fused_activation_function =
54 ActivationFunction::Deserialize(options.fused_activation_function());
55 }
56 };
57
58 class Convolution
59 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
60 ::tflite::BuiltinOptions_Conv2DOptions> {
61 public:
62 using BuiltinOperator::BuiltinOperator;
63
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const64 flatbuffers::Offset<TfLiteOptions> WriteOptions(
65 const TocoOperator& op,
66 flatbuffers::FlatBufferBuilder* builder) const override {
67 auto padding = Padding::Serialize(op.padding.type);
68 auto activation_function =
69 ActivationFunction::Serialize(op.fused_activation_function);
70 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
71 op.stride_height, activation_function);
72 }
73
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const74 void ReadOptions(const TfLiteOptions& options,
75 TocoOperator* op) const override {
76 op->padding.type = Padding::Deserialize(options.padding());
77 op->stride_width = options.stride_w();
78 op->stride_height = options.stride_h();
79 op->fused_activation_function =
80 ActivationFunction::Deserialize(options.fused_activation_function());
81 }
82 };
83
84 class DepthwiseConvolution
85 : public BuiltinOperator<DepthwiseConvOperator,
86 ::tflite::DepthwiseConv2DOptions,
87 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
88 public:
89 using BuiltinOperator::BuiltinOperator;
90
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const91 flatbuffers::Offset<TfLiteOptions> WriteOptions(
92 const TocoOperator& op,
93 flatbuffers::FlatBufferBuilder* builder) const override {
94 auto padding = Padding::Serialize(op.padding.type);
95 auto activation_function =
96 ActivationFunction::Serialize(op.fused_activation_function);
97 return ::tflite::CreateDepthwiseConv2DOptions(
98 *builder, padding, op.stride_width, op.stride_height,
99 op.depth_multiplier, activation_function);
100 }
101
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const102 void ReadOptions(const TfLiteOptions& options,
103 TocoOperator* op) const override {
104 op->padding.type = Padding::Deserialize(options.padding());
105 op->stride_width = options.stride_w();
106 op->stride_height = options.stride_h();
107 op->depth_multiplier = options.depth_multiplier();
108 op->fused_activation_function =
109 ActivationFunction::Deserialize(options.fused_activation_function());
110 }
111 };
112
113 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
114 ::tflite::BuiltinOptions_AddOptions> {
115 public:
116 using BuiltinOperator::BuiltinOperator;
117
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const118 flatbuffers::Offset<TfLiteOptions> WriteOptions(
119 const TocoOperator& op,
120 flatbuffers::FlatBufferBuilder* builder) const override {
121 auto activation_function =
122 ActivationFunction::Serialize(op.fused_activation_function);
123 return ::tflite::CreateAddOptions(*builder, activation_function);
124 }
125
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const126 void ReadOptions(const TfLiteOptions& options,
127 TocoOperator* op) const override {
128 op->fused_activation_function =
129 ActivationFunction::Deserialize(options.fused_activation_function());
130 }
131 };
132
133 class SpaceToBatchND
134 : public BuiltinOperator<SpaceToBatchNDOperator,
135 ::tflite::SpaceToBatchNDOptions,
136 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
137 public:
138 using BuiltinOperator::BuiltinOperator;
139
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const140 flatbuffers::Offset<TfLiteOptions> WriteOptions(
141 const TocoOperator& op,
142 flatbuffers::FlatBufferBuilder* builder) const override {
143 return ::tflite::CreateSpaceToBatchNDOptions(*builder);
144 }
145
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const146 void ReadOptions(const TfLiteOptions& options,
147 TocoOperator* op) const override {}
148 };
149
150 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
151 ::tflite::BuiltinOptions_SubOptions> {
152 public:
153 using BuiltinOperator::BuiltinOperator;
154
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const155 flatbuffers::Offset<TfLiteOptions> WriteOptions(
156 const TocoOperator& op,
157 flatbuffers::FlatBufferBuilder* builder) const override {
158 auto activation_function =
159 ActivationFunction::Serialize(op.fused_activation_function);
160 return ::tflite::CreateSubOptions(*builder, activation_function);
161 }
162
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const163 void ReadOptions(const TfLiteOptions& options,
164 TocoOperator* op) const override {
165 op->fused_activation_function =
166 ActivationFunction::Deserialize(options.fused_activation_function());
167 }
168 };
169
170 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
171 ::tflite::BuiltinOptions_DivOptions> {
172 public:
173 using BuiltinOperator::BuiltinOperator;
174
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const175 flatbuffers::Offset<TfLiteOptions> WriteOptions(
176 const TocoOperator& op,
177 flatbuffers::FlatBufferBuilder* builder) const override {
178 auto activation_function =
179 ActivationFunction::Serialize(op.fused_activation_function);
180 return ::tflite::CreateDivOptions(*builder, activation_function);
181 }
182
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const183 void ReadOptions(const TfLiteOptions& options,
184 TocoOperator* op) const override {
185 op->fused_activation_function =
186 ActivationFunction::Deserialize(options.fused_activation_function());
187 }
188 };
189
190 class BatchToSpaceND
191 : public BuiltinOperator<BatchToSpaceNDOperator,
192 ::tflite::BatchToSpaceNDOptions,
193 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
194 public:
195 using BuiltinOperator::BuiltinOperator;
196
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const197 flatbuffers::Offset<TfLiteOptions> WriteOptions(
198 const TocoOperator& op,
199 flatbuffers::FlatBufferBuilder* builder) const override {
200 return ::tflite::CreateBatchToSpaceNDOptions(*builder);
201 }
202
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const203 void ReadOptions(const TfLiteOptions& options,
204 TocoOperator* op) const override {}
205 };
206
207 class Cast : public CustomOperator<CastOperator> {
208 public:
209 using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const210 void WriteOptions(const TocoOperator& op,
211 flexbuffers::Builder* fbb) const override {
212 fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
213 fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
214 }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const215 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
216 op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
217 op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
218 }
219 };
220
221 class Concatenation
222 : public BuiltinOperator<ConcatenationOperator,
223 ::tflite::ConcatenationOptions,
224 ::tflite::BuiltinOptions_ConcatenationOptions> {
225 public:
226 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const227 flatbuffers::Offset<TfLiteOptions> WriteOptions(
228 const TocoOperator& op,
229 flatbuffers::FlatBufferBuilder* builder) const override {
230 return ::tflite::CreateConcatenationOptions(*builder, op.axis);
231 }
232
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const233 void ReadOptions(const TfLiteOptions& options,
234 TocoOperator* op) const override {
235 op->axis = options.axis();
236 }
237 };
238
239 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
240 public:
241 using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const242 void WriteOptions(const TocoOperator& op,
243 flexbuffers::Builder* fbb) const override {
244 fbb->Int("block_size", op.block_size);
245 }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const246 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
247 op->block_size = m["block_size"].AsInt64();
248 }
249 };
250
251 class FakeQuant : public CustomOperator<FakeQuantOperator> {
252 public:
253 using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const254 void WriteOptions(const TocoOperator& op,
255 flexbuffers::Builder* fbb) const override {
256 fbb->Float("min", op.minmax->min);
257 fbb->Float("max", op.minmax->max);
258 }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const259 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
260 auto* minmax = new MinMax;
261 minmax->min = m["min"].AsFloat();
262 minmax->max = m["max"].AsFloat();
263 op->minmax.reset(minmax);
264 }
265 };
266
267 class FullyConnected
268 : public BuiltinOperator<FullyConnectedOperator,
269 ::tflite::FullyConnectedOptions,
270 ::tflite::BuiltinOptions_FullyConnectedOptions> {
271 public:
272 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const273 flatbuffers::Offset<TfLiteOptions> WriteOptions(
274 const TocoOperator& op,
275 flatbuffers::FlatBufferBuilder* builder) const override {
276 auto activation_function =
277 ActivationFunction::Serialize(op.fused_activation_function);
278 return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
279 }
280
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const281 void ReadOptions(const TfLiteOptions& options,
282 TocoOperator* op) const override {
283 op->fused_activation_function =
284 ActivationFunction::Deserialize(options.fused_activation_function());
285 }
286 };
287
288 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
289 ::tflite::BuiltinOptions_GatherOptions> {
290 public:
291 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const292 flatbuffers::Offset<TfLiteOptions> WriteOptions(
293 const TocoOperator& op,
294 flatbuffers::FlatBufferBuilder* builder) const override {
295 return ::tflite::CreateGatherOptions(*builder, op.axis);
296 }
297
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const298 void ReadOptions(const TfLiteOptions& options,
299 TocoOperator* op) const override {
300 op->axis = options.axis();
301 }
302 };
303
304 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
305 ::tflite::BuiltinOptions_SVDFOptions> {
306 public:
307 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const308 flatbuffers::Offset<TfLiteOptions> WriteOptions(
309 const TocoOperator& op,
310 flatbuffers::FlatBufferBuilder* builder) const override {
311 auto activation_function =
312 ActivationFunction::Serialize(op.fused_activation_function);
313 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
314 }
315
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const316 void ReadOptions(const TfLiteOptions& options,
317 TocoOperator* op) const override {
318 op->fused_activation_function =
319 ActivationFunction::Deserialize(options.fused_activation_function());
320 op->rank = options.rank();
321 }
322 };
323
324 class L2Normalization
325 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
326 ::tflite::BuiltinOptions_L2NormOptions> {
327 public:
328 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const329 flatbuffers::Offset<TfLiteOptions> WriteOptions(
330 const TocoOperator& op,
331 flatbuffers::FlatBufferBuilder* builder) const override {
332 auto activation_function =
333 ActivationFunction::Serialize(op.fused_activation_function);
334 return ::tflite::CreateL2NormOptions(*builder, activation_function);
335 }
336
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const337 void ReadOptions(const TfLiteOptions& options,
338 TocoOperator* op) const override {
339 op->fused_activation_function =
340 ActivationFunction::Deserialize(options.fused_activation_function());
341 }
342 };
343
344 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
345 ::tflite::BuiltinOptions_Pool2DOptions> {
346 public:
347 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const348 flatbuffers::Offset<TfLiteOptions> WriteOptions(
349 const TocoOperator& op,
350 flatbuffers::FlatBufferBuilder* builder) const override {
351 auto padding = Padding::Serialize(op.padding.type);
352 auto activation_function =
353 ActivationFunction::Serialize(op.fused_activation_function);
354 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
355 op.stride_height, op.kwidth,
356 op.kheight, activation_function);
357 }
358
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const359 void ReadOptions(const TfLiteOptions& options,
360 TocoOperator* op) const override {
361 op->padding.type = Padding::Deserialize(options.padding());
362 op->stride_width = options.stride_w();
363 op->stride_height = options.stride_h();
364 op->kwidth = options.filter_width();
365 op->kheight = options.filter_height();
366 op->fused_activation_function =
367 ActivationFunction::Deserialize(options.fused_activation_function());
368 }
369 };
370
371 class LocalResponseNormalization
372 : public BuiltinOperator<
373 LocalResponseNormalizationOperator,
374 ::tflite::LocalResponseNormalizationOptions,
375 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
376 public:
377 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const378 flatbuffers::Offset<TfLiteOptions> WriteOptions(
379 const TocoOperator& op,
380 flatbuffers::FlatBufferBuilder* builder) const override {
381 return ::tflite::CreateLocalResponseNormalizationOptions(
382 *builder, op.range, op.bias, op.alpha, op.beta);
383 }
384
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const385 void ReadOptions(const TfLiteOptions& options,
386 TocoOperator* op) const override {
387 op->range = options.radius();
388 op->bias = options.bias();
389 op->alpha = options.alpha();
390 op->beta = options.beta();
391 }
392 };
393
394 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
395 ::tflite::BuiltinOptions_Pool2DOptions> {
396 public:
397 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const398 flatbuffers::Offset<TfLiteOptions> WriteOptions(
399 const TocoOperator& op,
400 flatbuffers::FlatBufferBuilder* builder) const override {
401 auto padding = Padding::Serialize(op.padding.type);
402 auto activation_function =
403 ActivationFunction::Serialize(op.fused_activation_function);
404 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
405 op.stride_height, op.kwidth,
406 op.kheight, activation_function);
407 }
408
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const409 void ReadOptions(const TfLiteOptions& options,
410 TocoOperator* op) const override {
411 op->padding.type = Padding::Deserialize(options.padding());
412 op->stride_width = options.stride_w();
413 op->stride_height = options.stride_h();
414 op->kwidth = options.filter_width();
415 op->kheight = options.filter_height();
416 op->fused_activation_function =
417 ActivationFunction::Deserialize(options.fused_activation_function());
418 }
419 };
420
421 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
422 ::tflite::BuiltinOptions_MulOptions> {
423 public:
424 using BuiltinOperator::BuiltinOperator;
425
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const426 flatbuffers::Offset<TfLiteOptions> WriteOptions(
427 const TocoOperator& op,
428 flatbuffers::FlatBufferBuilder* builder) const override {
429 auto activation_function =
430 ActivationFunction::Serialize(op.fused_activation_function);
431 return ::tflite::CreateMulOptions(*builder, activation_function);
432 }
433
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const434 void ReadOptions(const TfLiteOptions& options,
435 TocoOperator* op) const override {
436 op->fused_activation_function =
437 ActivationFunction::Deserialize(options.fused_activation_function());
438 }
439 };
440
441 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
442 ::tflite::BuiltinOptions_PadOptions> {
443 public:
444 using BuiltinOperator::BuiltinOperator;
445
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const446 flatbuffers::Offset<TfLiteOptions> WriteOptions(
447 const TocoOperator& op,
448 flatbuffers::FlatBufferBuilder* builder) const override {
449 return ::tflite::CreatePadOptions(*builder);
450 }
451
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const452 void ReadOptions(const TfLiteOptions& options,
453 TocoOperator* op) const override {}
454 };
455
456 class Reshape
457 : public BuiltinOperator<TensorFlowReshapeOperator,
458 ::tflite::ReshapeOptions,
459 ::tflite::BuiltinOptions_ReshapeOptions> {
460 public:
461 using BuiltinOperator::BuiltinOperator;
462
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const463 flatbuffers::Offset<TfLiteOptions> WriteOptions(
464 const TocoOperator& op,
465 flatbuffers::FlatBufferBuilder* builder) const override {
466 return ::tflite::CreateReshapeOptions(*builder,
467 builder->CreateVector(op.shape));
468 }
469
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const470 void ReadOptions(const TfLiteOptions& options,
471 TocoOperator* op) const override {
472 op->shape.insert(op->shape.end(), options.new_shape()->begin(),
473 options.new_shape()->end());
474 }
475 };
476
477 class Softmax
478 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
479 ::tflite::BuiltinOptions_SoftmaxOptions> {
480 public:
481 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const482 flatbuffers::Offset<TfLiteOptions> WriteOptions(
483 const TocoOperator& op,
484 flatbuffers::FlatBufferBuilder* builder) const override {
485 return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
486 }
487
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const488 void ReadOptions(const TfLiteOptions& options,
489 TocoOperator* op) const override {
490 op->beta = options.beta();
491 }
492 };
493
494 class SpaceToDepth
495 : public BuiltinOperator<SpaceToDepthOperator,
496 ::tflite::SpaceToDepthOptions,
497 ::tflite::BuiltinOptions_SpaceToDepthOptions> {
498 public:
499 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const500 flatbuffers::Offset<TfLiteOptions> WriteOptions(
501 const TocoOperator& op,
502 flatbuffers::FlatBufferBuilder* builder) const override {
503 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
504 }
505
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const506 void ReadOptions(const TfLiteOptions& options,
507 TocoOperator* op) const override {
508 op->block_size = options.block_size();
509 }
510 };
511
512 class Transpose
513 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
514 ::tflite::BuiltinOptions_TransposeOptions> {
515 public:
516 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const517 flatbuffers::Offset<TfLiteOptions> WriteOptions(
518 const TocoOperator& op,
519 flatbuffers::FlatBufferBuilder* builder) const override {
520 return ::tflite::CreateTransposeOptions(*builder);
521 }
522
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const523 void ReadOptions(const TfLiteOptions& options,
524 TocoOperator* op) const override {}
525 };
526
527 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
528 ::tflite::BuiltinOptions_LSTMOptions> {
529 public:
530 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const531 flatbuffers::Offset<TfLiteOptions> WriteOptions(
532 const TocoOperator& op,
533 flatbuffers::FlatBufferBuilder* builder) const override {
534 // Current toco converter only supports tanh, no clip.
535 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
536 ::tflite::ActivationFunctionType_TANH,
537 /*cell_clip=*/0.0,
538 /*proj_clip=*/0.0);
539 }
540
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const541 void ReadOptions(const TfLiteOptions& options,
542 TocoOperator* op) const override {
543 // Only support tanh activation, so check that tflite type is tanh.
544 CHECK(options.fused_activation_function() ==
545 ::tflite::ActivationFunctionType_TANH);
546 }
547 };
548
549 class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
550 ::tflite::BuiltinOptions_MeanOptions> {
551 public:
552 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const553 flatbuffers::Offset<TfLiteOptions> WriteOptions(
554 const TocoOperator& op,
555 flatbuffers::FlatBufferBuilder* builder) const override {
556 return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
557 }
558
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const559 void ReadOptions(const TfLiteOptions& options,
560 TocoOperator* op) const override {
561 op->keep_dims = options.keep_dims();
562 }
563 };
564
565 class ResizeBilinear
566 : public BuiltinOperator<ResizeBilinearOperator,
567 ::tflite::ResizeBilinearOptions,
568 ::tflite::BuiltinOptions_ResizeBilinearOptions> {
569 public:
570 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const571 flatbuffers::Offset<TfLiteOptions> WriteOptions(
572 const TocoOperator& op,
573 flatbuffers::FlatBufferBuilder* builder) const override {
574 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
575 }
576
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const577 void ReadOptions(const TfLiteOptions& options,
578 TocoOperator* op) const override {
579 op->align_corners = options.align_corners();
580 }
581 };
582
583 class Squeeze
584 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
585 ::tflite::BuiltinOptions_SqueezeOptions> {
586 public:
587 using BuiltinOperator::BuiltinOperator;
588
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const589 flatbuffers::Offset<TfLiteOptions> WriteOptions(
590 const TocoOperator& op,
591 flatbuffers::FlatBufferBuilder* builder) const override {
592 auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
593 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
594 }
595
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const596 void ReadOptions(const TfLiteOptions& options,
597 TocoOperator* op) const override {
598 op->squeeze_dims.insert(op->squeeze_dims.end(),
599 options.squeeze_dims()->begin(),
600 options.squeeze_dims()->end());
601 }
602 };
603
604 class Split
605 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
606 ::tflite::BuiltinOptions_SplitOptions> {
607 public:
608 using BuiltinOperator::BuiltinOperator;
609
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const610 flatbuffers::Offset<TfLiteOptions> WriteOptions(
611 const TocoOperator& op,
612 flatbuffers::FlatBufferBuilder* builder) const override {
613 return ::tflite::CreateSplitOptions(*builder, op.num_split);
614 }
615
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const616 void ReadOptions(const TfLiteOptions& options,
617 TocoOperator* op) const override {
618 op->num_split = options.num_splits();
619 }
620 };
621
622 class StridedSlice
623 : public BuiltinOperator<StridedSliceOperator,
624 ::tflite::StridedSliceOptions,
625 ::tflite::BuiltinOptions_StridedSliceOptions> {
626 public:
627 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const628 flatbuffers::Offset<TfLiteOptions> WriteOptions(
629 const TocoOperator& op,
630 flatbuffers::FlatBufferBuilder* builder) const override {
631 return ::tflite::CreateStridedSliceOptions(
632 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
633 op.new_axis_mask, op.shrink_axis_mask);
634 }
635
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const636 void ReadOptions(const TfLiteOptions& options,
637 TocoOperator* op) const override {
638 op->begin_mask = options.begin_mask();
639 op->end_mask = options.end_mask();
640 op->ellipsis_mask = options.ellipsis_mask();
641 op->new_axis_mask = options.new_axis_mask();
642 op->shrink_axis_mask = options.shrink_axis_mask();
643 }
644 };
645
646 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
647 ::tflite::BuiltinOptions_TopKV2Options> {
648 public:
649 using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const650 flatbuffers::Offset<TfLiteOptions> WriteOptions(
651 const TocoOperator& op,
652 flatbuffers::FlatBufferBuilder* builder) const override {
653 return ::tflite::CreateTopKV2Options(*builder);
654 }
655
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const656 void ReadOptions(const TfLiteOptions& options,
657 TocoOperator* op) const override {}
658 };
659
660 class TensorFlowUnsupported : public BaseOperator {
661 public:
662 using BaseOperator::BaseOperator;
663
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const664 Options Serialize(const Operator& op,
665 flatbuffers::FlatBufferBuilder* builder) const override {
666 auto fbb =
667 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
668 if (fbb) {
669 return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
670 } else {
671 return Options::Custom(0);
672 }
673 }
674
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const675 std::unique_ptr<Operator> Deserialize(
676 const BuiltinOptions* builtin_options,
677 const CustomOptions* custom_options) const override {
678 auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
679 if (custom_options) {
680 auto flexbuffer_map =
681 flexbuffers::GetRoot(custom_options->data(), custom_options->size())
682 .AsMap();
683 ReadOptions(flexbuffer_map, op.get());
684 }
685 return std::unique_ptr<Operator>(op.release());
686 }
687
WriteOptions(const TensorFlowUnsupportedOperator & op) const688 std::unique_ptr<flexbuffers::Builder> WriteOptions(
689 const TensorFlowUnsupportedOperator& op) const {
690 auto fbb = absl::make_unique<flexbuffers::Builder>();
691
692 ::tensorflow::NodeDef node_def;
693 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
694 LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
695 return std::unique_ptr<flexbuffers::Builder>();
696 }
697
698 bool has_valid_attr = false;
699 size_t map_start = fbb->StartMap();
700 for (const auto& pair : node_def.attr()) {
701 const char* key = pair.first.c_str();
702 const auto& attr = pair.second;
703 switch (attr.value_case()) {
704 case ::tensorflow::AttrValue::kS:
705 fbb->String(key, attr.s());
706 has_valid_attr = true;
707 break;
708 case ::tensorflow::AttrValue::kI:
709 fbb->Int(key, attr.i());
710 has_valid_attr = true;
711 break;
712 case ::tensorflow::AttrValue::kF:
713 fbb->Float(key, attr.f());
714 has_valid_attr = true;
715 break;
716 case ::tensorflow::AttrValue::kB:
717 fbb->Bool(key, attr.b());
718 has_valid_attr = true;
719 break;
720 default:
721 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
722 << key << "'";
723 break;
724 }
725 }
726 if (!has_valid_attr) {
727 return std::unique_ptr<flexbuffers::Builder>();
728 }
729 fbb->EndMap(map_start);
730 fbb->Finish();
731 return std::unique_ptr<flexbuffers::Builder>(fbb.release());
732 }
733
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const734 void ReadOptions(const flexbuffers::Map& m,
735 TensorFlowUnsupportedOperator* op) const {
736 ::tensorflow::NodeDef node_def;
737 auto attr = node_def.mutable_attr();
738
739 const auto& keys = m.Keys();
740 for (size_t i = 0; i < keys.size(); ++i) {
741 const auto key = keys[i].AsKey();
742 const auto& value = m[key];
743 switch (value.GetType()) {
744 case flexbuffers::TYPE_STRING:
745 (*attr)[key].set_s(value.AsString().c_str());
746 break;
747 case flexbuffers::TYPE_INT:
748 (*attr)[key].set_i(value.AsInt64());
749 break;
750 case flexbuffers::TYPE_FLOAT:
751 (*attr)[key].set_f(value.AsFloat());
752 break;
753 case flexbuffers::TYPE_BOOL:
754 (*attr)[key].set_b(value.AsBool());
755 break;
756 default:
757 LOG(WARNING) << "Ignoring unsupported attribute type with key '"
758 << key << "'";
759 break;
760 }
761 }
762 node_def.SerializeToString(&op->tensorflow_node_def);
763 }
764 };
765
766 namespace {
767 // Build a vector containing all the known operators.
BuildOperatorList()768 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
769 std::vector<std::unique_ptr<BaseOperator>> ops;
770
771 // Builtin Operators.
772 ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
773 ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
774 ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
775 ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
776 OperatorType::kAveragePool));
777 ops.emplace_back(
778 new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
779 OperatorType::kSpaceToBatchND));
780 ops.emplace_back(
781 new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
782 OperatorType::kBatchToSpaceND));
783 ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
784 OperatorType::kConcatenation));
785 ops.emplace_back(
786 new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
787 ops.emplace_back(
788 new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
789 OperatorType::kDepthwiseConv));
790 ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
791 OperatorType::kFullyConnected));
792 ops.emplace_back(
793 new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather));
794 ops.emplace_back(
795 new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
796 OperatorType::kL2Normalization));
797 ops.emplace_back(
798 new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
799 ops.emplace_back(new LocalResponseNormalization(
800 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
801 OperatorType::kLocalResponseNormalization));
802 ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
803 OperatorType::kMaxPool));
804 ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
805 ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
806 ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
807 OperatorType::kTensorFlowReshape));
808 ops.emplace_back(
809 new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
810 ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
811 OperatorType::kSpaceToDepth));
812 ops.emplace_back(
813 new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
814 ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
815 OperatorType::kTranspose));
816 ops.emplace_back(
817 new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
818 ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
819 OperatorType::kResizeBilinear));
820 ops.emplace_back(
821 new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
822 ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
823 OperatorType::kTensorFlowSplit));
824 ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
825 OperatorType::kStridedSlice));
826 ops.emplace_back(
827 new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
828 ops.emplace_back(
829 new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
830
831 // Custom Operators.
832 ops.emplace_back(new Cast("CAST", OperatorType::kCast));
833 ops.emplace_back(
834 new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
835 ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
836 ops.emplace_back(new TensorFlowUnsupported(
837 "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
838
839 // There operators are supported by Toco, but not by TF Lite, and has no
840 // attributes.
841 ops.emplace_back(
842 new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
843 ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
844 ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
845 "RSQRT", OperatorType::kTensorFlowRsqrt));
846 // Simple Operators.
847 ops.emplace_back(new SimpleOperator<DequantizeOperator>(
848 "DEQUANTIZE", OperatorType::kDequantize));
849 ops.emplace_back(
850 new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
851 ops.emplace_back(
852 new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
853 ops.emplace_back(
854 new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
855 ops.emplace_back(
856 new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
857 ops.emplace_back(new SimpleOperator<LogisticOperator>(
858 "LOGISTIC", OperatorType::kLogistic));
859 ops.emplace_back(
860 new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
861 ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
862
863 return ops;
864 }
865 } // namespace
866
BuildOperatorByTypeMap()867 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
868 std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
869
870 std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
871 for (auto& op : ops) {
872 result[op->type()] = std::move(op);
873 }
874
875 return result;
876 }
877
BuildOperatorByNameMap()878 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
879 std::map<string, std::unique_ptr<BaseOperator>> result;
880
881 std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
882 for (auto& op : ops) {
883 result[op->name()] = std::move(op);
884 }
885
886 return result;
887 }
888
889 } // namespace tflite
890
891 } // namespace toco
892