• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
16 
17 #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
18 #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
19 #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
20 #include "tensorflow/contrib/lite/toco/tflite/types.h"
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 
25 namespace toco {
26 
27 namespace tflite {
28 
29 class AveragePool
30     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
31                              ::tflite::BuiltinOptions_Pool2DOptions> {
32  public:
33   using BuiltinOperator::BuiltinOperator;
34 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const35   flatbuffers::Offset<TfLiteOptions> WriteOptions(
36       const TocoOperator& op,
37       flatbuffers::FlatBufferBuilder* builder) const override {
38     auto padding = Padding::Serialize(op.padding.type);
39     auto activation_function =
40         ActivationFunction::Serialize(op.fused_activation_function);
41     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
42                                          op.stride_height, op.kwidth,
43                                          op.kheight, activation_function);
44   }
45 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const46   void ReadOptions(const TfLiteOptions& options,
47                    TocoOperator* op) const override {
48     op->padding.type = Padding::Deserialize(options.padding());
49     op->stride_width = options.stride_w();
50     op->stride_height = options.stride_h();
51     op->kwidth = options.filter_width();
52     op->kheight = options.filter_height();
53     op->fused_activation_function =
54         ActivationFunction::Deserialize(options.fused_activation_function());
55   }
56 };
57 
58 class Convolution
59     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
60                              ::tflite::BuiltinOptions_Conv2DOptions> {
61  public:
62   using BuiltinOperator::BuiltinOperator;
63 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const64   flatbuffers::Offset<TfLiteOptions> WriteOptions(
65       const TocoOperator& op,
66       flatbuffers::FlatBufferBuilder* builder) const override {
67     auto padding = Padding::Serialize(op.padding.type);
68     auto activation_function =
69         ActivationFunction::Serialize(op.fused_activation_function);
70     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
71                                          op.stride_height, activation_function);
72   }
73 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const74   void ReadOptions(const TfLiteOptions& options,
75                    TocoOperator* op) const override {
76     op->padding.type = Padding::Deserialize(options.padding());
77     op->stride_width = options.stride_w();
78     op->stride_height = options.stride_h();
79     op->fused_activation_function =
80         ActivationFunction::Deserialize(options.fused_activation_function());
81   }
82 };
83 
84 class DepthwiseConvolution
85     : public BuiltinOperator<DepthwiseConvOperator,
86                              ::tflite::DepthwiseConv2DOptions,
87                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
88  public:
89   using BuiltinOperator::BuiltinOperator;
90 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const91   flatbuffers::Offset<TfLiteOptions> WriteOptions(
92       const TocoOperator& op,
93       flatbuffers::FlatBufferBuilder* builder) const override {
94     auto padding = Padding::Serialize(op.padding.type);
95     auto activation_function =
96         ActivationFunction::Serialize(op.fused_activation_function);
97     return ::tflite::CreateDepthwiseConv2DOptions(
98         *builder, padding, op.stride_width, op.stride_height,
99         op.depth_multiplier, activation_function);
100   }
101 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const102   void ReadOptions(const TfLiteOptions& options,
103                    TocoOperator* op) const override {
104     op->padding.type = Padding::Deserialize(options.padding());
105     op->stride_width = options.stride_w();
106     op->stride_height = options.stride_h();
107     op->depth_multiplier = options.depth_multiplier();
108     op->fused_activation_function =
109         ActivationFunction::Deserialize(options.fused_activation_function());
110   }
111 };
112 
113 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
114                                    ::tflite::BuiltinOptions_AddOptions> {
115  public:
116   using BuiltinOperator::BuiltinOperator;
117 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const118   flatbuffers::Offset<TfLiteOptions> WriteOptions(
119       const TocoOperator& op,
120       flatbuffers::FlatBufferBuilder* builder) const override {
121     auto activation_function =
122         ActivationFunction::Serialize(op.fused_activation_function);
123     return ::tflite::CreateAddOptions(*builder, activation_function);
124   }
125 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const126   void ReadOptions(const TfLiteOptions& options,
127                    TocoOperator* op) const override {
128     op->fused_activation_function =
129         ActivationFunction::Deserialize(options.fused_activation_function());
130   }
131 };
132 
133 class SpaceToBatchND
134     : public BuiltinOperator<SpaceToBatchNDOperator,
135                              ::tflite::SpaceToBatchNDOptions,
136                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
137  public:
138   using BuiltinOperator::BuiltinOperator;
139 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const140   flatbuffers::Offset<TfLiteOptions> WriteOptions(
141       const TocoOperator& op,
142       flatbuffers::FlatBufferBuilder* builder) const override {
143     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
144   }
145 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const146   void ReadOptions(const TfLiteOptions& options,
147                    TocoOperator* op) const override {}
148 };
149 
150 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
151                                    ::tflite::BuiltinOptions_SubOptions> {
152  public:
153   using BuiltinOperator::BuiltinOperator;
154 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const155   flatbuffers::Offset<TfLiteOptions> WriteOptions(
156       const TocoOperator& op,
157       flatbuffers::FlatBufferBuilder* builder) const override {
158     auto activation_function =
159         ActivationFunction::Serialize(op.fused_activation_function);
160     return ::tflite::CreateSubOptions(*builder, activation_function);
161   }
162 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const163   void ReadOptions(const TfLiteOptions& options,
164                    TocoOperator* op) const override {
165     op->fused_activation_function =
166         ActivationFunction::Deserialize(options.fused_activation_function());
167   }
168 };
169 
170 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
171                                    ::tflite::BuiltinOptions_DivOptions> {
172  public:
173   using BuiltinOperator::BuiltinOperator;
174 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const175   flatbuffers::Offset<TfLiteOptions> WriteOptions(
176       const TocoOperator& op,
177       flatbuffers::FlatBufferBuilder* builder) const override {
178     auto activation_function =
179         ActivationFunction::Serialize(op.fused_activation_function);
180     return ::tflite::CreateDivOptions(*builder, activation_function);
181   }
182 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const183   void ReadOptions(const TfLiteOptions& options,
184                    TocoOperator* op) const override {
185     op->fused_activation_function =
186         ActivationFunction::Deserialize(options.fused_activation_function());
187   }
188 };
189 
190 class BatchToSpaceND
191     : public BuiltinOperator<BatchToSpaceNDOperator,
192                              ::tflite::BatchToSpaceNDOptions,
193                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
194  public:
195   using BuiltinOperator::BuiltinOperator;
196 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const197   flatbuffers::Offset<TfLiteOptions> WriteOptions(
198       const TocoOperator& op,
199       flatbuffers::FlatBufferBuilder* builder) const override {
200     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
201   }
202 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const203   void ReadOptions(const TfLiteOptions& options,
204                    TocoOperator* op) const override {}
205 };
206 
207 class Cast : public CustomOperator<CastOperator> {
208  public:
209   using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const210   void WriteOptions(const TocoOperator& op,
211                     flexbuffers::Builder* fbb) const override {
212     fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
213     fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
214   }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const215   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
216     op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
217     op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
218   }
219 };
220 
221 class Concatenation
222     : public BuiltinOperator<ConcatenationOperator,
223                              ::tflite::ConcatenationOptions,
224                              ::tflite::BuiltinOptions_ConcatenationOptions> {
225  public:
226   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const227   flatbuffers::Offset<TfLiteOptions> WriteOptions(
228       const TocoOperator& op,
229       flatbuffers::FlatBufferBuilder* builder) const override {
230     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
231   }
232 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const233   void ReadOptions(const TfLiteOptions& options,
234                    TocoOperator* op) const override {
235     op->axis = options.axis();
236   }
237 };
238 
239 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
240  public:
241   using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const242   void WriteOptions(const TocoOperator& op,
243                     flexbuffers::Builder* fbb) const override {
244     fbb->Int("block_size", op.block_size);
245   }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const246   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
247     op->block_size = m["block_size"].AsInt64();
248   }
249 };
250 
251 class FakeQuant : public CustomOperator<FakeQuantOperator> {
252  public:
253   using CustomOperator::CustomOperator;
WriteOptions(const TocoOperator & op,flexbuffers::Builder * fbb) const254   void WriteOptions(const TocoOperator& op,
255                     flexbuffers::Builder* fbb) const override {
256     fbb->Float("min", op.minmax->min);
257     fbb->Float("max", op.minmax->max);
258   }
ReadOptions(const flexbuffers::Map & m,TocoOperator * op) const259   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
260     auto* minmax = new MinMax;
261     minmax->min = m["min"].AsFloat();
262     minmax->max = m["max"].AsFloat();
263     op->minmax.reset(minmax);
264   }
265 };
266 
267 class FullyConnected
268     : public BuiltinOperator<FullyConnectedOperator,
269                              ::tflite::FullyConnectedOptions,
270                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
271  public:
272   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const273   flatbuffers::Offset<TfLiteOptions> WriteOptions(
274       const TocoOperator& op,
275       flatbuffers::FlatBufferBuilder* builder) const override {
276     auto activation_function =
277         ActivationFunction::Serialize(op.fused_activation_function);
278     return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
279   }
280 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const281   void ReadOptions(const TfLiteOptions& options,
282                    TocoOperator* op) const override {
283     op->fused_activation_function =
284         ActivationFunction::Deserialize(options.fused_activation_function());
285   }
286 };
287 
288 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
289                                       ::tflite::BuiltinOptions_GatherOptions> {
290  public:
291   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const292   flatbuffers::Offset<TfLiteOptions> WriteOptions(
293       const TocoOperator& op,
294       flatbuffers::FlatBufferBuilder* builder) const override {
295     return ::tflite::CreateGatherOptions(*builder, op.axis);
296   }
297 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const298   void ReadOptions(const TfLiteOptions& options,
299                    TocoOperator* op) const override {
300     op->axis = options.axis();
301   }
302 };
303 
304 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
305                                     ::tflite::BuiltinOptions_SVDFOptions> {
306  public:
307   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const308   flatbuffers::Offset<TfLiteOptions> WriteOptions(
309       const TocoOperator& op,
310       flatbuffers::FlatBufferBuilder* builder) const override {
311     auto activation_function =
312         ActivationFunction::Serialize(op.fused_activation_function);
313     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
314   }
315 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const316   void ReadOptions(const TfLiteOptions& options,
317                    TocoOperator* op) const override {
318     op->fused_activation_function =
319         ActivationFunction::Deserialize(options.fused_activation_function());
320     op->rank = options.rank();
321   }
322 };
323 
324 class L2Normalization
325     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
326                              ::tflite::BuiltinOptions_L2NormOptions> {
327  public:
328   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const329   flatbuffers::Offset<TfLiteOptions> WriteOptions(
330       const TocoOperator& op,
331       flatbuffers::FlatBufferBuilder* builder) const override {
332     auto activation_function =
333         ActivationFunction::Serialize(op.fused_activation_function);
334     return ::tflite::CreateL2NormOptions(*builder, activation_function);
335   }
336 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const337   void ReadOptions(const TfLiteOptions& options,
338                    TocoOperator* op) const override {
339     op->fused_activation_function =
340         ActivationFunction::Deserialize(options.fused_activation_function());
341   }
342 };
343 
344 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
345                                       ::tflite::BuiltinOptions_Pool2DOptions> {
346  public:
347   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const348   flatbuffers::Offset<TfLiteOptions> WriteOptions(
349       const TocoOperator& op,
350       flatbuffers::FlatBufferBuilder* builder) const override {
351     auto padding = Padding::Serialize(op.padding.type);
352     auto activation_function =
353         ActivationFunction::Serialize(op.fused_activation_function);
354     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
355                                          op.stride_height, op.kwidth,
356                                          op.kheight, activation_function);
357   }
358 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const359   void ReadOptions(const TfLiteOptions& options,
360                    TocoOperator* op) const override {
361     op->padding.type = Padding::Deserialize(options.padding());
362     op->stride_width = options.stride_w();
363     op->stride_height = options.stride_h();
364     op->kwidth = options.filter_width();
365     op->kheight = options.filter_height();
366     op->fused_activation_function =
367         ActivationFunction::Deserialize(options.fused_activation_function());
368   }
369 };
370 
371 class LocalResponseNormalization
372     : public BuiltinOperator<
373           LocalResponseNormalizationOperator,
374           ::tflite::LocalResponseNormalizationOptions,
375           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
376  public:
377   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const378   flatbuffers::Offset<TfLiteOptions> WriteOptions(
379       const TocoOperator& op,
380       flatbuffers::FlatBufferBuilder* builder) const override {
381     return ::tflite::CreateLocalResponseNormalizationOptions(
382         *builder, op.range, op.bias, op.alpha, op.beta);
383   }
384 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const385   void ReadOptions(const TfLiteOptions& options,
386                    TocoOperator* op) const override {
387     op->range = options.radius();
388     op->bias = options.bias();
389     op->alpha = options.alpha();
390     op->beta = options.beta();
391   }
392 };
393 
394 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
395                                        ::tflite::BuiltinOptions_Pool2DOptions> {
396  public:
397   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const398   flatbuffers::Offset<TfLiteOptions> WriteOptions(
399       const TocoOperator& op,
400       flatbuffers::FlatBufferBuilder* builder) const override {
401     auto padding = Padding::Serialize(op.padding.type);
402     auto activation_function =
403         ActivationFunction::Serialize(op.fused_activation_function);
404     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
405                                          op.stride_height, op.kwidth,
406                                          op.kheight, activation_function);
407   }
408 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const409   void ReadOptions(const TfLiteOptions& options,
410                    TocoOperator* op) const override {
411     op->padding.type = Padding::Deserialize(options.padding());
412     op->stride_width = options.stride_w();
413     op->stride_height = options.stride_h();
414     op->kwidth = options.filter_width();
415     op->kheight = options.filter_height();
416     op->fused_activation_function =
417         ActivationFunction::Deserialize(options.fused_activation_function());
418   }
419 };
420 
421 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
422                                    ::tflite::BuiltinOptions_MulOptions> {
423  public:
424   using BuiltinOperator::BuiltinOperator;
425 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const426   flatbuffers::Offset<TfLiteOptions> WriteOptions(
427       const TocoOperator& op,
428       flatbuffers::FlatBufferBuilder* builder) const override {
429     auto activation_function =
430         ActivationFunction::Serialize(op.fused_activation_function);
431     return ::tflite::CreateMulOptions(*builder, activation_function);
432   }
433 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const434   void ReadOptions(const TfLiteOptions& options,
435                    TocoOperator* op) const override {
436     op->fused_activation_function =
437         ActivationFunction::Deserialize(options.fused_activation_function());
438   }
439 };
440 
441 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
442                                    ::tflite::BuiltinOptions_PadOptions> {
443  public:
444   using BuiltinOperator::BuiltinOperator;
445 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const446   flatbuffers::Offset<TfLiteOptions> WriteOptions(
447       const TocoOperator& op,
448       flatbuffers::FlatBufferBuilder* builder) const override {
449     return ::tflite::CreatePadOptions(*builder);
450   }
451 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const452   void ReadOptions(const TfLiteOptions& options,
453                    TocoOperator* op) const override {}
454 };
455 
456 class Reshape
457     : public BuiltinOperator<TensorFlowReshapeOperator,
458                              ::tflite::ReshapeOptions,
459                              ::tflite::BuiltinOptions_ReshapeOptions> {
460  public:
461   using BuiltinOperator::BuiltinOperator;
462 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const463   flatbuffers::Offset<TfLiteOptions> WriteOptions(
464       const TocoOperator& op,
465       flatbuffers::FlatBufferBuilder* builder) const override {
466     return ::tflite::CreateReshapeOptions(*builder,
467                                           builder->CreateVector(op.shape));
468   }
469 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const470   void ReadOptions(const TfLiteOptions& options,
471                    TocoOperator* op) const override {
472     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
473                      options.new_shape()->end());
474   }
475 };
476 
477 class Softmax
478     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
479                              ::tflite::BuiltinOptions_SoftmaxOptions> {
480  public:
481   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const482   flatbuffers::Offset<TfLiteOptions> WriteOptions(
483       const TocoOperator& op,
484       flatbuffers::FlatBufferBuilder* builder) const override {
485     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
486   }
487 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const488   void ReadOptions(const TfLiteOptions& options,
489                    TocoOperator* op) const override {
490     op->beta = options.beta();
491   }
492 };
493 
494 class SpaceToDepth
495     : public BuiltinOperator<SpaceToDepthOperator,
496                              ::tflite::SpaceToDepthOptions,
497                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
498  public:
499   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const500   flatbuffers::Offset<TfLiteOptions> WriteOptions(
501       const TocoOperator& op,
502       flatbuffers::FlatBufferBuilder* builder) const override {
503     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
504   }
505 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const506   void ReadOptions(const TfLiteOptions& options,
507                    TocoOperator* op) const override {
508     op->block_size = options.block_size();
509   }
510 };
511 
512 class Transpose
513     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
514                              ::tflite::BuiltinOptions_TransposeOptions> {
515  public:
516   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const517   flatbuffers::Offset<TfLiteOptions> WriteOptions(
518       const TocoOperator& op,
519       flatbuffers::FlatBufferBuilder* builder) const override {
520     return ::tflite::CreateTransposeOptions(*builder);
521   }
522 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const523   void ReadOptions(const TfLiteOptions& options,
524                    TocoOperator* op) const override {}
525 };
526 
527 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
528                                     ::tflite::BuiltinOptions_LSTMOptions> {
529  public:
530   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const531   flatbuffers::Offset<TfLiteOptions> WriteOptions(
532       const TocoOperator& op,
533       flatbuffers::FlatBufferBuilder* builder) const override {
534     // Current toco converter only supports tanh, no clip.
535     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
536                                        ::tflite::ActivationFunctionType_TANH,
537                                        /*cell_clip=*/0.0,
538                                        /*proj_clip=*/0.0);
539   }
540 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const541   void ReadOptions(const TfLiteOptions& options,
542                    TocoOperator* op) const override {
543     // Only support tanh activation, so check that tflite type is tanh.
544     CHECK(options.fused_activation_function() ==
545           ::tflite::ActivationFunctionType_TANH);
546   }
547 };
548 
549 class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
550                                     ::tflite::BuiltinOptions_MeanOptions> {
551  public:
552   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const553   flatbuffers::Offset<TfLiteOptions> WriteOptions(
554       const TocoOperator& op,
555       flatbuffers::FlatBufferBuilder* builder) const override {
556     return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
557   }
558 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const559   void ReadOptions(const TfLiteOptions& options,
560                    TocoOperator* op) const override {
561     op->keep_dims = options.keep_dims();
562   }
563 };
564 
565 class ResizeBilinear
566     : public BuiltinOperator<ResizeBilinearOperator,
567                              ::tflite::ResizeBilinearOptions,
568                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
569  public:
570   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const571   flatbuffers::Offset<TfLiteOptions> WriteOptions(
572       const TocoOperator& op,
573       flatbuffers::FlatBufferBuilder* builder) const override {
574     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
575   }
576 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const577   void ReadOptions(const TfLiteOptions& options,
578                    TocoOperator* op) const override {
579     op->align_corners = options.align_corners();
580   }
581 };
582 
583 class Squeeze
584     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
585                              ::tflite::BuiltinOptions_SqueezeOptions> {
586  public:
587   using BuiltinOperator::BuiltinOperator;
588 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const589   flatbuffers::Offset<TfLiteOptions> WriteOptions(
590       const TocoOperator& op,
591       flatbuffers::FlatBufferBuilder* builder) const override {
592     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
593     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
594   }
595 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const596   void ReadOptions(const TfLiteOptions& options,
597                    TocoOperator* op) const override {
598     op->squeeze_dims.insert(op->squeeze_dims.end(),
599                             options.squeeze_dims()->begin(),
600                             options.squeeze_dims()->end());
601   }
602 };
603 
604 class Split
605     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
606                              ::tflite::BuiltinOptions_SplitOptions> {
607  public:
608   using BuiltinOperator::BuiltinOperator;
609 
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const610   flatbuffers::Offset<TfLiteOptions> WriteOptions(
611       const TocoOperator& op,
612       flatbuffers::FlatBufferBuilder* builder) const override {
613     return ::tflite::CreateSplitOptions(*builder, op.num_split);
614   }
615 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const616   void ReadOptions(const TfLiteOptions& options,
617                    TocoOperator* op) const override {
618     op->num_split = options.num_splits();
619   }
620 };
621 
622 class StridedSlice
623     : public BuiltinOperator<StridedSliceOperator,
624                              ::tflite::StridedSliceOptions,
625                              ::tflite::BuiltinOptions_StridedSliceOptions> {
626  public:
627   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const628   flatbuffers::Offset<TfLiteOptions> WriteOptions(
629       const TocoOperator& op,
630       flatbuffers::FlatBufferBuilder* builder) const override {
631     return ::tflite::CreateStridedSliceOptions(
632         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
633         op.new_axis_mask, op.shrink_axis_mask);
634   }
635 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const636   void ReadOptions(const TfLiteOptions& options,
637                    TocoOperator* op) const override {
638     op->begin_mask = options.begin_mask();
639     op->end_mask = options.end_mask();
640     op->ellipsis_mask = options.ellipsis_mask();
641     op->new_axis_mask = options.new_axis_mask();
642     op->shrink_axis_mask = options.shrink_axis_mask();
643   }
644 };
645 
646 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
647                                        ::tflite::BuiltinOptions_TopKV2Options> {
648  public:
649   using BuiltinOperator::BuiltinOperator;
WriteOptions(const TocoOperator & op,flatbuffers::FlatBufferBuilder * builder) const650   flatbuffers::Offset<TfLiteOptions> WriteOptions(
651       const TocoOperator& op,
652       flatbuffers::FlatBufferBuilder* builder) const override {
653     return ::tflite::CreateTopKV2Options(*builder);
654   }
655 
ReadOptions(const TfLiteOptions & options,TocoOperator * op) const656   void ReadOptions(const TfLiteOptions& options,
657                    TocoOperator* op) const override {}
658 };
659 
660 class TensorFlowUnsupported : public BaseOperator {
661  public:
662   using BaseOperator::BaseOperator;
663 
Serialize(const Operator & op,flatbuffers::FlatBufferBuilder * builder) const664   Options Serialize(const Operator& op,
665                     flatbuffers::FlatBufferBuilder* builder) const override {
666     auto fbb =
667         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
668     if (fbb) {
669       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
670     } else {
671       return Options::Custom(0);
672     }
673   }
674 
Deserialize(const BuiltinOptions * builtin_options,const CustomOptions * custom_options) const675   std::unique_ptr<Operator> Deserialize(
676       const BuiltinOptions* builtin_options,
677       const CustomOptions* custom_options) const override {
678     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
679     if (custom_options) {
680       auto flexbuffer_map =
681           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
682               .AsMap();
683       ReadOptions(flexbuffer_map, op.get());
684     }
685     return std::unique_ptr<Operator>(op.release());
686   }
687 
WriteOptions(const TensorFlowUnsupportedOperator & op) const688   std::unique_ptr<flexbuffers::Builder> WriteOptions(
689       const TensorFlowUnsupportedOperator& op) const {
690     auto fbb = absl::make_unique<flexbuffers::Builder>();
691 
692     ::tensorflow::NodeDef node_def;
693     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
694       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
695       return std::unique_ptr<flexbuffers::Builder>();
696     }
697 
698     bool has_valid_attr = false;
699     size_t map_start = fbb->StartMap();
700     for (const auto& pair : node_def.attr()) {
701       const char* key = pair.first.c_str();
702       const auto& attr = pair.second;
703       switch (attr.value_case()) {
704         case ::tensorflow::AttrValue::kS:
705           fbb->String(key, attr.s());
706           has_valid_attr = true;
707           break;
708         case ::tensorflow::AttrValue::kI:
709           fbb->Int(key, attr.i());
710           has_valid_attr = true;
711           break;
712         case ::tensorflow::AttrValue::kF:
713           fbb->Float(key, attr.f());
714           has_valid_attr = true;
715           break;
716         case ::tensorflow::AttrValue::kB:
717           fbb->Bool(key, attr.b());
718           has_valid_attr = true;
719           break;
720         default:
721           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
722                        << key << "'";
723           break;
724       }
725     }
726     if (!has_valid_attr) {
727       return std::unique_ptr<flexbuffers::Builder>();
728     }
729     fbb->EndMap(map_start);
730     fbb->Finish();
731     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
732   }
733 
ReadOptions(const flexbuffers::Map & m,TensorFlowUnsupportedOperator * op) const734   void ReadOptions(const flexbuffers::Map& m,
735                    TensorFlowUnsupportedOperator* op) const {
736     ::tensorflow::NodeDef node_def;
737     auto attr = node_def.mutable_attr();
738 
739     const auto& keys = m.Keys();
740     for (size_t i = 0; i < keys.size(); ++i) {
741       const auto key = keys[i].AsKey();
742       const auto& value = m[key];
743       switch (value.GetType()) {
744         case flexbuffers::TYPE_STRING:
745           (*attr)[key].set_s(value.AsString().c_str());
746           break;
747         case flexbuffers::TYPE_INT:
748           (*attr)[key].set_i(value.AsInt64());
749           break;
750         case flexbuffers::TYPE_FLOAT:
751           (*attr)[key].set_f(value.AsFloat());
752           break;
753         case flexbuffers::TYPE_BOOL:
754           (*attr)[key].set_b(value.AsBool());
755           break;
756         default:
757           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
758                        << key << "'";
759           break;
760       }
761     }
762     node_def.SerializeToString(&op->tensorflow_node_def);
763   }
764 };
765 
766 namespace {
767 // Build a vector containing all the known operators.
BuildOperatorList()768 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
769   std::vector<std::unique_ptr<BaseOperator>> ops;
770 
771   // Builtin Operators.
772   ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
773   ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
774   ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
775   ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
776                                    OperatorType::kAveragePool));
777   ops.emplace_back(
778       new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
779                          OperatorType::kSpaceToBatchND));
780   ops.emplace_back(
781       new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
782                          OperatorType::kBatchToSpaceND));
783   ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
784                                      OperatorType::kConcatenation));
785   ops.emplace_back(
786       new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
787   ops.emplace_back(
788       new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
789                                OperatorType::kDepthwiseConv));
790   ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
791                                       OperatorType::kFullyConnected));
792   ops.emplace_back(
793       new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather));
794   ops.emplace_back(
795       new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
796                           OperatorType::kL2Normalization));
797   ops.emplace_back(
798       new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
799   ops.emplace_back(new LocalResponseNormalization(
800       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
801       OperatorType::kLocalResponseNormalization));
802   ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
803                                OperatorType::kMaxPool));
804   ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
805   ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
806   ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
807                                OperatorType::kTensorFlowReshape));
808   ops.emplace_back(
809       new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
810   ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
811                                     OperatorType::kSpaceToDepth));
812   ops.emplace_back(
813       new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
814   ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
815                                  OperatorType::kTranspose));
816   ops.emplace_back(
817       new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
818   ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
819                                       OperatorType::kResizeBilinear));
820   ops.emplace_back(
821       new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
822   ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
823                              OperatorType::kTensorFlowSplit));
824   ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
825                                     OperatorType::kStridedSlice));
826   ops.emplace_back(
827       new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
828   ops.emplace_back(
829       new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
830 
831   // Custom Operators.
832   ops.emplace_back(new Cast("CAST", OperatorType::kCast));
833   ops.emplace_back(
834       new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
835   ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
836   ops.emplace_back(new TensorFlowUnsupported(
837       "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
838 
839   // There operators are supported by Toco, but not by TF Lite, and has no
840   // attributes.
841   ops.emplace_back(
842       new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
843   ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
844   ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
845       "RSQRT", OperatorType::kTensorFlowRsqrt));
846   // Simple Operators.
847   ops.emplace_back(new SimpleOperator<DequantizeOperator>(
848       "DEQUANTIZE", OperatorType::kDequantize));
849   ops.emplace_back(
850       new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
851   ops.emplace_back(
852       new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
853   ops.emplace_back(
854       new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
855   ops.emplace_back(
856       new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
857   ops.emplace_back(new SimpleOperator<LogisticOperator>(
858       "LOGISTIC", OperatorType::kLogistic));
859   ops.emplace_back(
860       new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
861   ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
862 
863   return ops;
864 }
865 }  // namespace
866 
BuildOperatorByTypeMap()867 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
868   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
869 
870   std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
871   for (auto& op : ops) {
872     result[op->type()] = std::move(op);
873   }
874 
875   return result;
876 }
877 
BuildOperatorByNameMap()878 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
879   std::map<string, std::unique_ptr<BaseOperator>> result;
880 
881   std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
882   for (auto& op : ops) {
883     result[op->name()] = std::move(op);
884   }
885 
886   return result;
887 }
888 
889 }  // namespace tflite
890 
891 }  // namespace toco
892