• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "google/protobuf/map.h"
22 #include "google/protobuf/text_format.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/model.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/runtime/types.h"
35 #include "tensorflow/lite/toco/tensorflow_util.h"
36 #include "tensorflow/lite/toco/tooling_util.h"
37 
38 using tensorflow::DT_BOOL;
39 using tensorflow::DT_COMPLEX64;
40 using tensorflow::DT_FLOAT;
41 using tensorflow::DT_INT16;
42 using tensorflow::DT_INT32;
43 using tensorflow::DT_INT64;
44 using tensorflow::DT_UINT32;
45 using tensorflow::DT_UINT8;
46 using tensorflow::GraphDef;
47 using tensorflow::TensorProto;
48 
49 namespace toco {
50 namespace {
51 
GetTensorFlowDataType(ArrayDataType data_type,const std::string & error_location)52 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
53                                            const std::string& error_location) {
54   switch (data_type) {
55     case ArrayDataType::kBool:
56       return tensorflow::DT_BOOL;
57     case ArrayDataType::kFloat:
58       return tensorflow::DT_FLOAT;
59     case ArrayDataType::kUint8:
60       return tensorflow::DT_UINT8;
61     case ArrayDataType::kInt32:
62       return tensorflow::DT_INT32;
63     case ArrayDataType::kUint32:
64       return tensorflow::DT_UINT32;
65     case ArrayDataType::kInt64:
66       return tensorflow::DT_INT64;
67     case ArrayDataType::kString:
68       return tensorflow::DT_STRING;
69     case ArrayDataType::kComplex64:
70       return tensorflow::DT_COMPLEX64;
71     default:
72     case ArrayDataType::kNone:
73       LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
74                  << "' in " << error_location;
75       return tensorflow::DT_INVALID;
76   }
77 }
78 
GetTensorFlowDataTypeForOp(ArrayDataType data_type,const std::string & op_name)79 tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
80                                                 const std::string& op_name) {
81   return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
82 }
83 
GetTensorFlowDataType(const Model & model,const std::string & array_name)84 tensorflow::DataType GetTensorFlowDataType(const Model& model,
85                                            const std::string& array_name) {
86   return GetTensorFlowDataType(model.GetArray(array_name).data_type,
87                                "array '" + array_name + "'");
88 }
89 
90 // TensorFlow sometimes forbids what it calls "legacy scalars",
91 // which are 1-D shapes where the unique shape size is 1.
92 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
93 // For that reason, we generally avoid creating legacy scalars,
94 // by detecting the case where a 1-D shape would be of size 1 and
95 // replacing that by a 0-D shape.
96 // However, there is a special circumstance where we must not do that
97 // and must unconditionally create a 1-D shape even if it is going to
98 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
99 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
100 // a depth of 1, that would be a legacy scalar, so in that case we
101 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
102 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
103 
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)104 void ExportFloatArray(const Shape& input_shape, const float* input_data,
105                       TensorProto* output_tensor,
106                       LegacyScalarPolicy legacy_scalar_policy) {
107   output_tensor->set_dtype(DT_FLOAT);
108   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
109   auto* shape = output_tensor->mutable_tensor_shape();
110 
111   const int kDims = input_shape.dimensions_count();
112   if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
113       kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
114     for (int i = 0; i < kDims; ++i) {
115       shape->add_dim()->set_size(input_shape.dims(i));
116     }
117   }
118   output_tensor->set_tensor_content(
119       std::string(reinterpret_cast<const char*>(input_data),
120                   sizeof(*input_data) * input_flat_size));
121 }
122 
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)123 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
124                       const float* input_data, AxesOrder output_axes_order,
125                       TensorProto* output_tensor,
126                       LegacyScalarPolicy legacy_scalar_policy) {
127   CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
128   output_tensor->set_dtype(DT_FLOAT);
129   CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
130   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
131 
132   Shape shuffled_shape;
133   ShuffleDims(input_shape, input_axes_order, output_axes_order,
134               &shuffled_shape);
135   std::vector<float> shuffled_data(input_flat_size);
136   ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
137                input_data, shuffled_data.data());
138 
139   ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
140                    legacy_scalar_policy);
141 }
142 
HasAlreadyExportedConst(const std::string & name,const GraphDef & tensorflow_graph)143 bool HasAlreadyExportedConst(const std::string& name,
144                              const GraphDef& tensorflow_graph) {
145   for (const auto& node : tensorflow_graph.node()) {
146     if (node.op() == "Const" && node.name() == name) {
147       return true;
148     }
149   }
150   return false;
151 }
152 
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)153 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
154                              const float* input_data,
155                              AxesOrder input_axes_order,
156                              AxesOrder output_axes_order,
157                              GraphDef* tensorflow_graph,
158                              LegacyScalarPolicy legacy_scalar_policy) {
159   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
160     return;
161   }
162   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
163   const_op->set_op("Const");
164   const_op->set_name(name);
165   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
166   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
167   ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
168                    tensor, legacy_scalar_policy);
169 }
170 
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)171 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
172                              const float* input_data,
173                              AxesOrder input_axes_order,
174                              AxesOrder output_axes_order,
175                              GraphDef* tensorflow_graph) {
176   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
177                           output_axes_order, tensorflow_graph,
178                           LegacyScalarPolicy::kAvoidLegacyScalars);
179 }
180 
ConvertFloatTensorConst(const Model & model,const std::string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)181 void ConvertFloatTensorConst(const Model& model, const std::string& name,
182                              AxesOrder input_axes_order,
183                              AxesOrder output_axes_order,
184                              GraphDef* tensorflow_graph) {
185   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
186     return;
187   }
188   CHECK(model.HasArray(name));
189   const auto& input_array = model.GetArray(name);
190   const auto& input_shape = input_array.shape();
191   CHECK(input_array.buffer);
192   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
193   const float* input_data =
194       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
195   ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
196                           output_axes_order, tensorflow_graph);
197 }
198 
ConvertFloatTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)199 void ConvertFloatTensorConst(const Model& model, const std::string& name,
200                              GraphDef* tensorflow_graph) {
201   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
202     return;
203   }
204   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
205   const_op->set_op("Const");
206   const_op->set_name(name);
207   (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
208   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
209   CHECK(model.HasArray(name));
210   const auto& input_array = model.GetArray(name);
211   const auto& input_shape = input_array.shape();
212   CHECK(input_array.buffer);
213   CHECK(input_array.buffer->type == ArrayDataType::kFloat);
214   const float* input_data =
215       input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
216   ExportFloatArray(input_shape, input_data, tensor,
217                    LegacyScalarPolicy::kAvoidLegacyScalars);
218 }
219 
ConvertBoolTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)220 void ConvertBoolTensorConst(const Model& model, const std::string& name,
221                             GraphDef* tensorflow_graph) {
222   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
223     return;
224   }
225   CHECK(model.HasArray(name));
226   const auto& array = model.GetArray(name);
227   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
228   const_op->set_op("Const");
229   const_op->set_name(name);
230   (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
231   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
232   tensor->set_dtype(DT_BOOL);
233   const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
234   for (auto index : data) {
235     tensor->add_bool_val(index);
236   }
237   const auto& array_shape = array.shape();
238   auto* shape = tensor->mutable_tensor_shape();
239   for (int i = 0; i < array_shape.dimensions_count(); i++) {
240     shape->add_dim()->set_size(array_shape.dims(i));
241   }
242 }
243 
ConvertIntTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)244 void ConvertIntTensorConst(const Model& model, const std::string& name,
245                            GraphDef* tensorflow_graph) {
246   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
247     return;
248   }
249   CHECK(model.HasArray(name));
250   const auto& array = model.GetArray(name);
251   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
252   const_op->set_op("Const");
253   const_op->set_name(name);
254   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
255   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
256   tensor->set_dtype(DT_INT32);
257   const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
258   for (auto index : data) {
259     tensor->add_int_val(index);
260   }
261   const auto& array_shape = array.shape();
262   auto* shape = tensor->mutable_tensor_shape();
263   for (int i = 0; i < array_shape.dimensions_count(); i++) {
264     shape->add_dim()->set_size(array_shape.dims(i));
265   }
266 }
267 
CreateIntTensorConst(const std::string & name,const std::vector<int32> & data,const std::vector<int32> & shape,GraphDef * tensorflow_graph)268 void CreateIntTensorConst(const std::string& name,
269                           const std::vector<int32>& data,
270                           const std::vector<int32>& shape,
271                           GraphDef* tensorflow_graph) {
272   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
273     return;
274   }
275   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
276   const_op->set_op("Const");
277   const_op->set_name(name);
278   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
279   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
280   tensor->set_dtype(DT_INT32);
281   for (auto index : data) {
282     tensor->add_int_val(index);
283   }
284   auto* tensor_shape = tensor->mutable_tensor_shape();
285   int num_elements = 1;
286   for (int size : shape) {
287     tensor_shape->add_dim()->set_size(size);
288     num_elements *= size;
289   }
290   CHECK_EQ(num_elements, data.size());
291 }
292 
ConvertComplex64TensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)293 void ConvertComplex64TensorConst(const Model& model, const std::string& name,
294                                  GraphDef* tensorflow_graph) {
295   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
296     return;
297   }
298   CHECK(model.HasArray(name));
299   const auto& array = model.GetArray(name);
300   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
301   const_op->set_op("Const");
302   const_op->set_name(name);
303   (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
304   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
305   tensor->set_dtype(DT_COMPLEX64);
306   const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
307   for (auto index : data) {
308     tensor->add_scomplex_val(std::real(index));
309     tensor->add_scomplex_val(std::imag(index));
310   }
311   const auto& array_shape = array.shape();
312   auto* shape = tensor->mutable_tensor_shape();
313   for (int i = 0; i < array_shape.dimensions_count(); i++) {
314     shape->add_dim()->set_size(array_shape.dims(i));
315   }
316 }
317 
CreateMatrixShapeTensorConst(const std::string & name,int rows,int cols,GraphDef * tensorflow_graph)318 void CreateMatrixShapeTensorConst(const std::string& name, int rows, int cols,
319                                   GraphDef* tensorflow_graph) {
320   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
321     return;
322   }
323   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
324   const_op->set_op("Const");
325   const_op->set_name(name);
326   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
327   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
328   tensor->set_dtype(DT_INT32);
329   const int32 data[2] = {cols, rows};
330   tensor->set_tensor_content(
331       std::string(reinterpret_cast<const char*>(data), sizeof(data)));
332   auto* shape = tensor->mutable_tensor_shape();
333   shape->add_dim()->set_size(2);
334 }
335 
CreateDummyConcatDimTensorConst(const std::string & name,int dim,GraphDef * tensorflow_graph)336 void CreateDummyConcatDimTensorConst(const std::string& name, int dim,
337                                      GraphDef* tensorflow_graph) {
338   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
339     return;
340   }
341   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
342   const_op->set_op("Const");
343   const_op->set_name(name);
344   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
345   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
346   tensor->set_dtype(DT_INT32);
347   tensor->add_int_val(dim);
348 }
349 
CreateReshapeShapeTensorConst(const std::string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)350 void CreateReshapeShapeTensorConst(const std::string& name,
351                                    const std::vector<int32>& shape,
352                                    GraphDef* tensorflow_graph) {
353   if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
354     return;
355   }
356   tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
357   const_op->set_op("Const");
358   const_op->set_name(name);
359   (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
360   auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
361   tensor->set_dtype(DT_INT32);
362   for (auto s : shape) {
363     tensor->add_int_val(s);
364   }
365   // TensorFlow sometimes forbids what it calls "legacy scalars",
366   // which are shapes of size 1 where the unique shape size is 1.
367   // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
368   if (shape.size() > 1) {
369     auto* tensor_shape = tensor->mutable_tensor_shape();
370     tensor_shape->add_dim()->set_size(shape.size());
371   }
372 }
373 
WalkUpToConstantArray(const Model & model,const std::string & name)374 std::string WalkUpToConstantArray(const Model& model, const std::string& name) {
375   const Array& original_array = model.GetArray(name);
376   if (original_array.buffer) {
377     return name;
378   }
379   const auto* op = GetOpWithOutput(model, name);
380   CHECK(op);
381   CHECK(op->type == OperatorType::kFakeQuant);
382   const std::string& input_of_fakequant_name = op->inputs[0];
383   const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
384   CHECK(input_of_fakequant.buffer);
385   return input_of_fakequant_name;
386 }
387 
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)388 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
389                          GraphDef* tensorflow_graph) {
390   const bool has_bias = src_op.inputs.size() >= 3;
391   std::string conv_output = src_op.outputs[0];
392   if (has_bias) {
393     conv_output += "/conv";
394   }
395 
396   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
397   conv2d_op->set_op("Conv2D");
398   conv2d_op->set_name(conv_output);
399   *conv2d_op->add_input() = src_op.inputs[0];
400   *conv2d_op->add_input() = src_op.inputs[1];
401   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
402   const std::string& weights_array_name =
403       WalkUpToConstantArray(model, src_op.inputs[1]);
404   const auto& weights_array = model.GetArray(weights_array_name);
405   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
406   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
407                           AxesOrder::kHWIO, tensorflow_graph);
408   auto& strides = (*conv2d_op->mutable_attr())["strides"];
409   strides.mutable_list()->add_i(1);
410   strides.mutable_list()->add_i(src_op.stride_height);
411   strides.mutable_list()->add_i(src_op.stride_width);
412   strides.mutable_list()->add_i(1);
413   if ((src_op.dilation_width_factor != 1) ||
414       (src_op.dilation_height_factor != 1)) {
415     auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
416     dilations.mutable_list()->add_i(1);
417     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
418     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
419     dilations.mutable_list()->add_i(1);
420   }
421   std::string padding;
422   if (src_op.padding.type == PaddingType::kSame) {
423     padding = "SAME";
424   } else if (src_op.padding.type == PaddingType::kValid) {
425     padding = "VALID";
426   } else {
427     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
428   }
429   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
430 
431   if (has_bias) {
432     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
433     biasadd_op->set_op("BiasAdd");
434     biasadd_op->set_name(src_op.outputs[0]);
435     biasadd_op->add_input(conv_output);
436     biasadd_op->add_input(src_op.inputs[2]);
437     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
438     CHECK(model.HasArray(src_op.inputs[2]));
439     const std::string& bias_array_name =
440         WalkUpToConstantArray(model, src_op.inputs[2]);
441     const auto& bias_array = model.GetArray(bias_array_name);
442     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
443     Shape bias_shape_1d = bias_array.shape();
444     UnextendShape(&bias_shape_1d, 1);
445     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
446     const float* bias_data =
447         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
448     ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
449                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
450                             tensorflow_graph,
451                             LegacyScalarPolicy::kDoCreateLegacyScalars);
452   }
453 }
454 
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)455 void ConvertDepthwiseConvOperator(const Model& model,
456                                   const DepthwiseConvOperator& src_op,
457                                   GraphDef* tensorflow_graph) {
458   const bool has_bias = src_op.inputs.size() >= 3;
459   std::string conv_output = src_op.outputs[0];
460   if (has_bias) {
461     conv_output += "/conv";
462   }
463 
464   tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
465   dc2d_op->set_op("DepthwiseConv2dNative");
466   dc2d_op->set_name(conv_output);
467   *dc2d_op->add_input() = src_op.inputs[0];
468   *dc2d_op->add_input() = src_op.inputs[1];
469   (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
470 
471   // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
472   // We need to convert that to H x W x InputDepth x Multiplier.
473   // That's only a matter of constructing a Dims object; the actual
474   // array layout is the same.
475   CHECK(model.HasArray(src_op.inputs[1]));
476   const std::string& src_weights_name =
477       WalkUpToConstantArray(model, src_op.inputs[1]);
478   const auto& src_weights_array = model.GetArray(src_weights_name);
479   const auto& src_weights_shape = src_weights_array.shape();
480   CHECK_EQ(src_weights_shape.dimensions_count(), 4);
481   const Shape dst_weights_shape =
482       Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
483              src_weights_shape.dims(3) / src_op.depth_multiplier,
484              src_op.depth_multiplier});
485   CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
486   CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
487         src_weights_shape.dims(3));
488   CHECK_EQ(src_weights_shape.dims(0), 1);
489 
490   CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
491   const float* src_weights_data =
492       src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
493   ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
494                           AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
495 
496   auto& strides = (*dc2d_op->mutable_attr())["strides"];
497   strides.mutable_list()->add_i(1);
498   strides.mutable_list()->add_i(src_op.stride_height);
499   strides.mutable_list()->add_i(src_op.stride_width);
500   strides.mutable_list()->add_i(1);
501   // TODO(b/116063589): To return a working TF GraphDef, we should be returning
502   // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
503   // the conv since TF doesn't support dilations.
504   if ((src_op.dilation_width_factor != 1) ||
505       (src_op.dilation_height_factor != 1)) {
506     auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
507     dilations.mutable_list()->add_i(1);
508     dilations.mutable_list()->add_i(src_op.dilation_height_factor);
509     dilations.mutable_list()->add_i(src_op.dilation_width_factor);
510     dilations.mutable_list()->add_i(1);
511   }
512   std::string padding;
513   if (src_op.padding.type == PaddingType::kSame) {
514     padding = "SAME";
515   } else if (src_op.padding.type == PaddingType::kValid) {
516     padding = "VALID";
517   } else {
518     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
519   }
520   (*dc2d_op->mutable_attr())["padding"].set_s(padding);
521 
522   if (has_bias) {
523     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
524     biasadd_op->set_op("BiasAdd");
525     biasadd_op->set_name(src_op.outputs[0]);
526     biasadd_op->add_input(conv_output);
527     biasadd_op->add_input(src_op.inputs[2]);
528     (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
529     CHECK(model.HasArray(src_op.inputs[2]));
530     const std::string& bias_name =
531         WalkUpToConstantArray(model, src_op.inputs[2]);
532     const auto& bias_array = model.GetArray(bias_name);
533     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
534     Shape bias_shape_1d = bias_array.shape();
535     UnextendShape(&bias_shape_1d, 1);
536     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
537     const float* bias_data =
538         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
539     ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
540                             AxesOrder::kOneAxis, AxesOrder::kOneAxis,
541                             tensorflow_graph,
542                             LegacyScalarPolicy::kDoCreateLegacyScalars);
543   }
544 }
545 
ConvertTransposeConvOperator(const Model & model,const TransposeConvOperator & src_op,GraphDef * tensorflow_graph)546 void ConvertTransposeConvOperator(const Model& model,
547                                   const TransposeConvOperator& src_op,
548                                   GraphDef* tensorflow_graph) {
549   tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
550   conv2d_op->set_op("Conv2DBackpropInput");
551   conv2d_op->set_name(src_op.outputs[0]);
552   *conv2d_op->add_input() = src_op.inputs[0];
553   *conv2d_op->add_input() = src_op.inputs[1];
554   *conv2d_op->add_input() = src_op.inputs[2];
555   (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
556   const std::string& weights_array_name = WalkUpToConstantArray(
557       model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
558   const auto& weights_array = model.GetArray(weights_array_name);
559   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
560   ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
561                           AxesOrder::kHWOI, tensorflow_graph);
562   auto& strides = (*conv2d_op->mutable_attr())["strides"];
563   strides.mutable_list()->add_i(1);
564   strides.mutable_list()->add_i(src_op.stride_height);
565   strides.mutable_list()->add_i(src_op.stride_width);
566   strides.mutable_list()->add_i(1);
567   std::string padding;
568   if (src_op.padding.type == PaddingType::kSame) {
569     padding = "SAME";
570   } else if (src_op.padding.type == PaddingType::kValid) {
571     padding = "VALID";
572   } else {
573     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
574   }
575   (*conv2d_op->mutable_attr())["padding"].set_s(padding);
576 }
577 
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)578 void ConvertDepthToSpaceOperator(const Model& model,
579                                  const DepthToSpaceOperator& src_op,
580                                  GraphDef* tensorflow_graph) {
581   tensorflow::NodeDef* op = tensorflow_graph->add_node();
582   op->set_op("DepthToSpace");
583   op->set_name(src_op.outputs[0]);
584   *op->add_input() = src_op.inputs[0];
585   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
586   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
587 }
588 
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)589 void ConvertSpaceToDepthOperator(const Model& model,
590                                  const SpaceToDepthOperator& src_op,
591                                  GraphDef* tensorflow_graph) {
592   tensorflow::NodeDef* op = tensorflow_graph->add_node();
593   op->set_op("SpaceToDepth");
594   op->set_name(src_op.outputs[0]);
595   *op->add_input() = src_op.inputs[0];
596   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
597   (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
598 }
599 
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)600 void ConvertFullyConnectedOperator(const Model& model,
601                                    const FullyConnectedOperator& src_op,
602                                    GraphDef* tensorflow_graph) {
603   // Reshape input activations to have the shape expected by the MatMul.
604   const std::string reshape_output =
605       AvailableArrayName(model, src_op.outputs[0] + "/reshape");
606   const std::string reshape_shape =
607       AvailableArrayName(model, reshape_output + "/shape");
608   const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
609   const auto& fc_weights_shape = fc_weights_array.shape();
610   CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
611   CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
612                                tensorflow_graph);
613   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
614   reshape_op->set_op("Reshape");
615   reshape_op->set_name(reshape_output);
616   reshape_op->add_input(src_op.inputs[0]);
617   reshape_op->add_input(reshape_shape);
618   (*reshape_op->mutable_attr())["T"].set_type(
619       GetTensorFlowDataType(model, src_op.inputs[0]));
620 
621   const bool has_bias = src_op.inputs.size() >= 3;
622   std::string matmul_output = src_op.outputs[0];
623   if (has_bias) {
624     matmul_output += "/matmul";
625   }
626 
627   // Transpose the RHS input from column-major to row-major to match TensorFlow
628   // expectations. This is the inverse of the transpose we do during
629   // ResolveTensorFlowMatMul.
630   const std::string transpose_output =
631       AvailableArrayName(model, matmul_output + "/transpose_weights");
632   const std::string transpose_perm =
633       AvailableArrayName(model, transpose_output + "/perm");
634   CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
635   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
636   transpose_op->set_op("Transpose");
637   transpose_op->set_name(transpose_output);
638   *transpose_op->add_input() = src_op.inputs[1];
639   *transpose_op->add_input() = transpose_perm;
640   (*transpose_op->mutable_attr())["T"].set_type(
641       GetTensorFlowDataType(model, src_op.inputs[1]));
642   (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
643 
644   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
645   matmul_op->set_op("MatMul");
646   matmul_op->set_name(matmul_output);
647   *matmul_op->add_input() = reshape_output;
648   *matmul_op->add_input() = transpose_op->name();
649   (*matmul_op->mutable_attr())["T"].set_type(
650       GetTensorFlowDataType(model, src_op.inputs[0]));
651   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
652   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
653   CHECK(model.HasArray(src_op.inputs[1]));
654 
655   // Add the bias, if it exists.
656   if (has_bias) {
657     tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
658     biasadd_op->set_op("BiasAdd");
659     biasadd_op->set_name(src_op.outputs[0]);
660     biasadd_op->add_input(matmul_output);
661     biasadd_op->add_input(src_op.inputs[2]);
662     (*biasadd_op->mutable_attr())["T"].set_type(
663         GetTensorFlowDataType(model, src_op.inputs[0]));
664     CHECK(model.HasArray(src_op.inputs[2]));
665     const auto& bias_array = model.GetArray(src_op.inputs[2]);
666     // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
667     Shape bias_shape_1d = bias_array.shape();
668     UnextendShape(&bias_shape_1d, 1);
669     CHECK(bias_array.buffer);
670     CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
671     const float* bias_data =
672         bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
673     ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
674                             bias_shape_1d, bias_data, AxesOrder::kOneAxis,
675                             AxesOrder::kOneAxis, tensorflow_graph,
676                             LegacyScalarPolicy::kDoCreateLegacyScalars);
677   }
678 }
679 
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)680 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
681                         GraphDef* tensorflow_graph) {
682   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
683   add_op->set_op("Add");
684   add_op->set_name(src_op.outputs[0]);
685   CHECK_EQ(src_op.inputs.size(), 2);
686   *add_op->add_input() = src_op.inputs[0];
687   *add_op->add_input() = src_op.inputs[1];
688   (*add_op->mutable_attr())["T"].set_type(
689       GetTensorFlowDataType(model, src_op.outputs[0]));
690 }
691 
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)692 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
693                          GraphDef* tensorflow_graph) {
694   tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
695   add_op->set_op("AddN");
696   add_op->set_name(src_op.outputs[0]);
697   for (const auto& input : src_op.inputs) {
698     *add_op->add_input() = input;
699   }
700   (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
701   (*add_op->mutable_attr())["T"].set_type(
702       GetTensorFlowDataType(model, src_op.outputs[0]));
703 }
704 
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)705 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
706                         GraphDef* tensorflow_graph) {
707   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
708   mul_op->set_op("Mul");
709   mul_op->set_name(src_op.outputs[0]);
710   CHECK_EQ(src_op.inputs.size(), 2);
711   *mul_op->add_input() = src_op.inputs[0];
712   *mul_op->add_input() = src_op.inputs[1];
713   (*mul_op->mutable_attr())["T"].set_type(
714       GetTensorFlowDataType(model, src_op.outputs[0]));
715 }
716 
ConvertDivOperator(const Model & model,const DivOperator & src_op,GraphDef * tensorflow_graph)717 void ConvertDivOperator(const Model& model, const DivOperator& src_op,
718                         GraphDef* tensorflow_graph) {
719   tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
720   div_op->set_op("Div");
721   div_op->set_name(src_op.outputs[0]);
722   CHECK_EQ(src_op.inputs.size(), 2);
723   *div_op->add_input() = src_op.inputs[0];
724   *div_op->add_input() = src_op.inputs[1];
725   (*div_op->mutable_attr())["T"].set_type(
726       GetTensorFlowDataType(model, src_op.outputs[0]));
727 }
728 
ConvertReluOperator(const Model & model,const ReluOperator & src_op,GraphDef * tensorflow_graph)729 void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
730                          GraphDef* tensorflow_graph) {
731   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
732   relu_op->set_op("Relu");
733   relu_op->set_name(src_op.outputs[0]);
734   *relu_op->add_input() = src_op.inputs[0];
735   (*relu_op->mutable_attr())["T"].set_type(
736       GetTensorFlowDataType(model, src_op.outputs[0]));
737 }
738 
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)739 void ConvertRelu1Operator(const Relu1Operator& src_op,
740                           GraphDef* tensorflow_graph) {
741   const std::string max_bounds = src_op.outputs[0] + "/max_bounds";
742   const std::string min_bounds = src_op.outputs[0] + "/min_bounds";
743   const std::string max_output = src_op.outputs[0] + "/max_output";
744 
745   tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
746   max_bounds_const_op->set_op("Const");
747   max_bounds_const_op->set_name(max_bounds);
748   (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
749   auto* max_bounds_const_op_tensor =
750       (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
751   max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
752   max_bounds_const_op_tensor->add_float_val(-1.0f);
753 
754   tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
755   min_bounds_const_op->set_op("Const");
756   min_bounds_const_op->set_name(min_bounds);
757   (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
758   auto* min_bounds_const_op_tensor =
759       (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
760   min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
761   min_bounds_const_op_tensor->add_float_val(1.0f);
762 
763   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
764   max_op->set_op("Maximum");
765   max_op->set_name(max_output);
766   *max_op->add_input() = src_op.inputs[0];
767   *max_op->add_input() = max_bounds;
768   (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
769 
770   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
771   min_op->set_op("Minimum");
772   min_op->set_name(src_op.outputs[0]);
773   *min_op->add_input() = max_output;
774   *min_op->add_input() = min_bounds;
775   (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
776 }
777 
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)778 void ConvertRelu6Operator(const Relu6Operator& src_op,
779                           GraphDef* tensorflow_graph) {
780   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
781   relu_op->set_op("Relu6");
782   relu_op->set_name(src_op.outputs[0]);
783   *relu_op->add_input() = src_op.inputs[0];
784   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
785 }
786 
ConvertLogOperator(const LogOperator & src_op,GraphDef * tensorflow_graph)787 void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
788   tensorflow::NodeDef* op = tensorflow_graph->add_node();
789   op->set_op("Log");
790   op->set_name(src_op.outputs[0]);
791   CHECK_EQ(src_op.inputs.size(), 1);
792   *op->add_input() = src_op.inputs[0];
793   (*op->mutable_attr())["T"].set_type(DT_FLOAT);
794 }
795 
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)796 void ConvertLogisticOperator(const LogisticOperator& src_op,
797                              GraphDef* tensorflow_graph) {
798   tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
799   relu_op->set_op("Sigmoid");
800   relu_op->set_name(src_op.outputs[0]);
801   *relu_op->add_input() = src_op.inputs[0];
802   (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
803 }
804 
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)805 void ConvertTanhOperator(const TanhOperator& src_op,
806                          GraphDef* tensorflow_graph) {
807   tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
808   tanh_op->set_op("Tanh");
809   tanh_op->set_name(src_op.outputs[0]);
810   *tanh_op->add_input() = src_op.inputs[0];
811   (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
812 }
813 
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)814 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
815                             GraphDef* tensorflow_graph) {
816   std::string softmax_input;
817   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
818   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
819     softmax_input = src_op.inputs[0];
820   } else {
821     // Insert a reshape operator that reduces the dimensions down to the 2 that
822     // are required for TensorFlow Logits.
823     const std::string reshape_output =
824         src_op.outputs[0] + "/softmax_insert_reshape";
825     const std::string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
826     softmax_input = reshape_output;
827 
828     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
829     reshape_op->set_op("Reshape");
830     reshape_op->set_name(reshape_output);
831     *reshape_op->add_input() = src_op.inputs[0];
832     *reshape_op->add_input() = softmax_size;
833     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
834 
835     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
836     int32 flattened_size = 1;
837     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
838       flattened_size *= input_shape.dims(i);
839     }
840     const std::vector<int32> shape_data = {
841         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
842     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
843   }
844 
845   tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
846   softmax_op->set_op("Softmax");
847   softmax_op->set_name(src_op.outputs[0]);
848   *softmax_op->add_input() = softmax_input;
849   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
850   CHECK_EQ(src_op.beta, 1.f);
851   (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
852 }
853 
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)854 void ConvertLogSoftmaxOperator(const Model& model,
855                                const LogSoftmaxOperator& src_op,
856                                GraphDef* tensorflow_graph) {
857   std::string softmax_input;
858   Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
859   if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
860     softmax_input = src_op.inputs[0];
861   } else {
862     // Insert a reshape operator that reduces the dimensions down to the 2 that
863     // are required for TensorFlow Logits.
864     const std::string reshape_output =
865         src_op.outputs[0] + "/log_softmax_insert_reshape";
866     const std::string softmax_size =
867         src_op.outputs[0] + "/log_softmax_insert_size";
868     softmax_input = reshape_output;
869 
870     tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
871     reshape_op->set_op("Reshape");
872     reshape_op->set_name(reshape_output);
873     *reshape_op->add_input() = src_op.inputs[0];
874     *reshape_op->add_input() = softmax_size;
875     (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
876 
877     const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
878     int32 flattened_size = 1;
879     for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
880       flattened_size *= input_shape.dims(i);
881     }
882     const std::vector<int32> shape_data = {
883         flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
884     CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
885   }
886 
887   tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
888   log_softmax_op->set_op("LogSoftmax");
889   log_softmax_op->set_name(src_op.outputs[0]);
890   *log_softmax_op->add_input() = softmax_input;
891   (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
892 }
893 
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)894 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
895                                     GraphDef* tensorflow_graph) {
896   const std::string square_output = src_op.outputs[0] + "/square";
897   const std::string sum_reduction_indices =
898       src_op.outputs[0] + "/reduction_indices";
899   const std::string sum_output = src_op.outputs[0] + "/sum";
900   const std::string rsqrt_output = src_op.outputs[0] + "/rsqrt";
901   const std::string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
902 
903   tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
904   sum_reduction_indices_op->set_op("Const");
905   sum_reduction_indices_op->set_name(sum_reduction_indices);
906   (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
907   auto* sum_reduction_indices_tensor =
908       (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
909   sum_reduction_indices_tensor->set_dtype(DT_INT32);
910   auto* sum_reduction_indices_shape =
911       sum_reduction_indices_tensor->mutable_tensor_shape();
912   auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
913   sum_reduction_indices_dim->set_size(2);
914   sum_reduction_indices_tensor->add_int_val(0);
915   sum_reduction_indices_tensor->add_int_val(1);
916 
917   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
918   square_op->set_op("Square");
919   square_op->set_name(square_output);
920   *square_op->add_input() = src_op.inputs[0];
921   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
922 
923   tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
924   sum_op->set_op("Sum");
925   sum_op->set_name(sum_output);
926   *sum_op->add_input() = square_output;
927   *sum_op->add_input() = sum_reduction_indices;
928   (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
929 
930   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
931   rsqrt_op->set_op("Rsqrt");
932   rsqrt_op->set_name(rsqrt_output);
933   *rsqrt_op->add_input() = sum_output;
934   (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
935 
936   tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
937   mul_op->set_op("Mul");
938   mul_op->set_name(src_op.outputs[0]);
939   *mul_op->add_input() = src_op.inputs[0];
940   *mul_op->add_input() = rsqrt_output;
941   (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
942 }
943 
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)944 void ConvertLocalResponseNormalizationOperator(
945     const LocalResponseNormalizationOperator& src_op,
946     GraphDef* tensorflow_graph) {
947   tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
948   lrn_op->set_op("LRN");
949   lrn_op->set_name(src_op.outputs[0]);
950   *lrn_op->add_input() = src_op.inputs[0];
951   (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
952   (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
953   (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
954   (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
955 }
956 
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)957 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
958                               GraphDef* tensorflow_graph) {
959   tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
960   fakequant_op->set_op("FakeQuantWithMinMaxArgs");
961   fakequant_op->set_name(src_op.outputs[0]);
962   CHECK_EQ(src_op.inputs.size(), 1);
963   *fakequant_op->add_input() = src_op.inputs[0];
964   CHECK(src_op.minmax);
965   (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
966   (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
967   if (src_op.num_bits) {
968     (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
969   }
970   if (src_op.narrow_range) {
971     (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
972   }
973 }
974 
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)975 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
976                             GraphDef* tensorflow_graph) {
977   tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
978   maxpool_op->set_op("MaxPool");
979   maxpool_op->set_name(src_op.outputs[0]);
980   *maxpool_op->add_input() = src_op.inputs[0];
981   auto& strides = (*maxpool_op->mutable_attr())["strides"];
982   strides.mutable_list()->add_i(1);
983   strides.mutable_list()->add_i(src_op.stride_height);
984   strides.mutable_list()->add_i(src_op.stride_width);
985   strides.mutable_list()->add_i(1);
986   std::string padding;
987   if (src_op.padding.type == PaddingType::kSame) {
988     padding = "SAME";
989   } else if (src_op.padding.type == PaddingType::kValid) {
990     padding = "VALID";
991   } else {
992     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
993   }
994   (*maxpool_op->mutable_attr())["padding"].set_s(padding);
995   (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
996   auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
997   ksize.mutable_list()->add_i(1);
998   ksize.mutable_list()->add_i(src_op.kheight);
999   ksize.mutable_list()->add_i(src_op.kwidth);
1000   ksize.mutable_list()->add_i(1);
1001 }
1002 
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)1003 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
1004                                 GraphDef* tensorflow_graph) {
1005   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1006   avgpool_op->set_op("AvgPool");
1007   avgpool_op->set_name(src_op.outputs[0]);
1008   *avgpool_op->add_input() = src_op.inputs[0];
1009   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1010   strides.mutable_list()->add_i(1);
1011   strides.mutable_list()->add_i(src_op.stride_height);
1012   strides.mutable_list()->add_i(src_op.stride_width);
1013   strides.mutable_list()->add_i(1);
1014   std::string padding;
1015   if (src_op.padding.type == PaddingType::kSame) {
1016     padding = "SAME";
1017   } else if (src_op.padding.type == PaddingType::kValid) {
1018     padding = "VALID";
1019   } else {
1020     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1021   }
1022   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1023   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1024   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1025   ksize.mutable_list()->add_i(1);
1026   ksize.mutable_list()->add_i(src_op.kheight);
1027   ksize.mutable_list()->add_i(src_op.kwidth);
1028   ksize.mutable_list()->add_i(1);
1029 }
1030 
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)1031 void ConvertConcatenationOperator(const Model& model,
1032                                   const ConcatenationOperator& src_op,
1033                                   GraphDef* tensorflow_graph) {
1034   tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1035   dc_op->set_op("ConcatV2");
1036   dc_op->set_name(src_op.outputs[0]);
1037   const std::string dummy_axis = src_op.outputs[0] + "/axis";
1038   CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1039   for (const auto& input : src_op.inputs) {
1040     *dc_op->add_input() = input;
1041   }
1042   *dc_op->add_input() = dummy_axis;
1043   (*dc_op->mutable_attr())["T"].set_type(
1044       GetTensorFlowDataType(model, src_op.inputs[0]));
1045   (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1046   (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1047 }
1048 
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)1049 void ConvertTensorFlowReshapeOperator(const Model& model,
1050                                       const TensorFlowReshapeOperator& src_op,
1051                                       GraphDef* tensorflow_graph) {
1052   tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1053   reshape_op->set_op("Reshape");
1054   reshape_op->set_name(src_op.outputs[0]);
1055   CHECK_EQ(src_op.inputs.size(), 2);
1056   *reshape_op->add_input() = src_op.inputs[0];
1057   *reshape_op->add_input() = src_op.inputs[1];
1058   (*reshape_op->mutable_attr())["T"].set_type(
1059       GetTensorFlowDataType(model, src_op.outputs[0]));
1060   const auto& shape_array = model.GetArray(src_op.inputs[1]);
1061   QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1062       << "Only int32 shape is supported.";
1063   QCHECK(shape_array.buffer != nullptr)
1064       << "Shape inferred at runtime is not supported.";
1065   const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1066   CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1067 }
1068 
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)1069 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1070                            GraphDef* tensorflow_graph) {
1071   const std::string square_output = src_op.outputs[0] + "/square";
1072   const std::string avgpool_output = src_op.outputs[0] + "/avgpool";
1073 
1074   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1075   square_op->set_op("Square");
1076   square_op->set_name(square_output);
1077   *square_op->add_input() = src_op.inputs[0];
1078   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1079 
1080   std::string padding;
1081   if (src_op.padding.type == PaddingType::kSame) {
1082     padding = "SAME";
1083   } else if (src_op.padding.type == PaddingType::kValid) {
1084     padding = "VALID";
1085   } else {
1086     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1087   }
1088 
1089   tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1090   avgpool_op->set_op("AvgPool");
1091   avgpool_op->set_name(avgpool_output);
1092   *avgpool_op->add_input() = square_output;
1093   auto& strides = (*avgpool_op->mutable_attr())["strides"];
1094   strides.mutable_list()->add_i(1);
1095   strides.mutable_list()->add_i(src_op.stride_height);
1096   strides.mutable_list()->add_i(src_op.stride_width);
1097   strides.mutable_list()->add_i(1);
1098 
1099   (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1100   (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1101   auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1102   ksize.mutable_list()->add_i(1);
1103   ksize.mutable_list()->add_i(src_op.kheight);
1104   ksize.mutable_list()->add_i(src_op.kwidth);
1105   ksize.mutable_list()->add_i(1);
1106 
1107   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1108   sqrt_op->set_op("Sqrt");
1109   sqrt_op->set_name(src_op.outputs[0]);
1110   *sqrt_op->add_input() = avgpool_output;
1111   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1112 }
1113 
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)1114 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1115                            GraphDef* tensorflow_graph) {
1116   tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1117   square_op->set_op("Square");
1118   square_op->set_name(src_op.outputs[0]);
1119   CHECK_EQ(src_op.inputs.size(), 1);
1120   *square_op->add_input() = src_op.inputs[0];
1121   (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1122 }
1123 
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)1124 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1125                          GraphDef* tensorflow_graph) {
1126   tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1127   sqrt_op->set_op("Sqrt");
1128   sqrt_op->set_name(src_op.outputs[0]);
1129   CHECK_EQ(src_op.inputs.size(), 1);
1130   *sqrt_op->add_input() = src_op.inputs[0];
1131   (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1132 }
1133 
ConvertRsqrtOperator(const Model & model,const TensorFlowRsqrtOperator & src_op,GraphDef * tensorflow_graph)1134 void ConvertRsqrtOperator(const Model& model,
1135                           const TensorFlowRsqrtOperator& src_op,
1136                           GraphDef* tensorflow_graph) {
1137   tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1138   rsqrt_op->set_op("Rsqrt");
1139   rsqrt_op->set_name(src_op.outputs[0]);
1140   CHECK_EQ(src_op.inputs.size(), 1);
1141   *rsqrt_op->add_input() = src_op.inputs[0];
1142   const tensorflow::DataType data_type =
1143       GetTensorFlowDataType(model, src_op.inputs[0]);
1144   (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1145 }
1146 
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)1147 void ConvertSplitOperator(const Model& model,
1148                           const TensorFlowSplitOperator& src_op,
1149                           GraphDef* tensorflow_graph) {
1150   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1151   split_op->set_op("Split");
1152   split_op->set_name(src_op.outputs[0]);
1153   for (const auto& input : src_op.inputs) {
1154     *split_op->add_input() = input;
1155   }
1156   (*split_op->mutable_attr())["T"].set_type(
1157       GetTensorFlowDataType(model, src_op.outputs[0]));
1158   (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1159   const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1160   CHECK(split_dim_array.buffer);
1161   CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1162   const auto& split_dim_data =
1163       split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1164   CHECK_EQ(split_dim_data.size(), 1);
1165   const int split_dim = split_dim_data[0];
1166   CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1167                                   tensorflow_graph);
1168 }
1169 
ConvertSplitVOperator(const Model & model,const TensorFlowSplitVOperator & src_op,GraphDef * tensorflow_graph)1170 void ConvertSplitVOperator(const Model& model,
1171                            const TensorFlowSplitVOperator& src_op,
1172                            GraphDef* tensorflow_graph) {
1173   tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1174   split_v_op->set_op("SplitV");
1175   split_v_op->set_name(src_op.outputs[0]);
1176   for (const auto& input : src_op.inputs) {
1177     *split_v_op->add_input() = input;
1178   }
1179   (*split_v_op->mutable_attr())["T"].set_type(
1180       GetTensorFlowDataType(model, src_op.outputs[0]));
1181   (*split_v_op->mutable_attr())["Tlen"].set_type(
1182       GetTensorFlowDataType(model, src_op.inputs[1]));
1183   (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1184   ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
1185 }
1186 
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1187 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1188                          GraphDef* tensorflow_graph) {
1189   tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1190   cast_op->set_op("Cast");
1191   cast_op->set_name(src_op.outputs[0]);
1192   CHECK_EQ(src_op.inputs.size(), 1);
1193   *cast_op->add_input() = src_op.inputs[0];
1194 
1195   (*cast_op->mutable_attr())["DstT"].set_type(
1196       GetTensorFlowDataType(model, src_op.outputs[0]));
1197   (*cast_op->mutable_attr())["SrcT"].set_type(
1198       GetTensorFlowDataType(model, src_op.inputs[0]));
1199 }
1200 
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1201 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1202                           GraphDef* tensorflow_graph) {
1203   tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1204   floor_op->set_op("Floor");
1205   floor_op->set_name(src_op.outputs[0]);
1206   CHECK_EQ(src_op.inputs.size(), 1);
1207   *floor_op->add_input() = src_op.inputs[0];
1208   (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1209 }
1210 
ConvertCeilOperator(const Model & model,const CeilOperator & src_op,GraphDef * tensorflow_graph)1211 void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1212                          GraphDef* tensorflow_graph) {
1213   tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1214   ceil_op->set_op("Ceil");
1215   ceil_op->set_name(src_op.outputs[0]);
1216   CHECK_EQ(src_op.inputs.size(), 1);
1217   *ceil_op->add_input() = src_op.inputs[0];
1218   (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1219 }
1220 
ConvertRoundOperator(const Model & model,const RoundOperator & src_op,GraphDef * tensorflow_graph)1221 void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
1222                           GraphDef* tensorflow_graph) {
1223   tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
1224   round_op->set_op("Round");
1225   round_op->set_name(src_op.outputs[0]);
1226   CHECK_EQ(src_op.inputs.size(), 1);
1227   *round_op->add_input() = src_op.inputs[0];
1228   (*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
1229 }
1230 
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1231 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1232                            GraphDef* tensorflow_graph) {
1233   tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1234   gather_op->set_op("GatherV2");
1235   gather_op->set_name(src_op.outputs[0]);
1236   *gather_op->add_input() = src_op.inputs[0];
1237   *gather_op->add_input() = src_op.inputs[1];
1238 
1239   if (!src_op.axis) {
1240     // Dynamic axis.
1241     CHECK_EQ(src_op.inputs.size(), 3);
1242     *gather_op->add_input() = src_op.inputs[2];
1243   } else {
1244     // Constant axis.
1245     CHECK_EQ(src_op.inputs.size(), 2);
1246     const std::string gather_axis =
1247         AvailableArrayName(model, gather_op->name() + "/axis");
1248     CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1249                          tensorflow_graph);
1250     *gather_op->add_input() = gather_axis;
1251   }
1252 
1253   (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1254   (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1255   const tensorflow::DataType params_type =
1256       GetTensorFlowDataType(model, src_op.inputs[0]);
1257   (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1258 }
1259 
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1260 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1261                            GraphDef* tensorflow_graph) {
1262   tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1263   argmax_op->set_op("ArgMax");
1264   argmax_op->set_name(src_op.outputs[0]);
1265   CHECK_EQ(src_op.inputs.size(), 2);
1266   *argmax_op->add_input() = src_op.inputs[0];
1267   *argmax_op->add_input() = src_op.inputs[1];
1268   (*argmax_op->mutable_attr())["T"].set_type(
1269       GetTensorFlowDataType(model, src_op.inputs[0]));
1270   (*argmax_op->mutable_attr())["Tidx"].set_type(
1271       GetTensorFlowDataType(model, src_op.inputs[1]));
1272   (*argmax_op->mutable_attr())["output_type"].set_type(
1273       GetTensorFlowDataType(model, src_op.outputs[0]));
1274 }
1275 
ConvertArgMinOperator(const Model & model,const ArgMinOperator & src_op,GraphDef * tensorflow_graph)1276 void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1277                            GraphDef* tensorflow_graph) {
1278   tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1279   argmin_op->set_op("ArgMin");
1280   argmin_op->set_name(src_op.outputs[0]);
1281   CHECK_EQ(src_op.inputs.size(), 2);
1282   *argmin_op->add_input() = src_op.inputs[0];
1283   *argmin_op->add_input() = src_op.inputs[1];
1284   (*argmin_op->mutable_attr())["T"].set_type(
1285       GetTensorFlowDataType(model, src_op.inputs[0]));
1286   (*argmin_op->mutable_attr())["Tidx"].set_type(
1287       GetTensorFlowDataType(model, src_op.inputs[1]));
1288   (*argmin_op->mutable_attr())["output_type"].set_type(
1289       GetTensorFlowDataType(model, src_op.outputs[0]));
1290 }
1291 
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1292 void ConvertTransposeOperator(const Model& model,
1293                               const TransposeOperator& src_op,
1294                               GraphDef* tensorflow_graph) {
1295   tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1296   transpose_op->set_op("Transpose");
1297   transpose_op->set_name(src_op.outputs[0]);
1298   CHECK_EQ(src_op.inputs.size(), 2);
1299   *transpose_op->add_input() = src_op.inputs[0];
1300   *transpose_op->add_input() = src_op.inputs[1];
1301   (*transpose_op->mutable_attr())["T"].set_type(
1302       GetTensorFlowDataType(model, src_op.inputs[0]));
1303   (*transpose_op->mutable_attr())["Tperm"].set_type(
1304       GetTensorFlowDataType(model, src_op.inputs[1]));
1305 }
1306 
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1307 void ConvertTensorFlowShapeOperator(const Model& model,
1308                                     const TensorFlowShapeOperator& src_op,
1309                                     GraphDef* tensorflow_graph) {
1310   tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1311   shape_op->set_op("Shape");
1312   shape_op->set_name(src_op.outputs[0]);
1313   CHECK_EQ(src_op.inputs.size(), 1);
1314   *shape_op->add_input() = src_op.inputs[0];
1315   (*shape_op->mutable_attr())["T"].set_type(
1316       GetTensorFlowDataType(model, src_op.inputs[0]));
1317   (*shape_op->mutable_attr())["out_type"].set_type(
1318       GetTensorFlowDataType(model, src_op.outputs[0]));
1319 }
1320 
ConvertRankOperator(const Model & model,const TensorFlowRankOperator & src_op,GraphDef * tensorflow_graph)1321 void ConvertRankOperator(const Model& model,
1322                          const TensorFlowRankOperator& src_op,
1323                          GraphDef* tensorflow_graph) {
1324   tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1325   rank_op->set_op("Rank");
1326   rank_op->set_name(src_op.outputs[0]);
1327   CHECK_EQ(src_op.inputs.size(), 1);
1328   *rank_op->add_input() = src_op.inputs[0];
1329   (*rank_op->mutable_attr())["T"].set_type(
1330       GetTensorFlowDataType(model, src_op.inputs[0]));
1331 }
1332 
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1333 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1334                           GraphDef* tensorflow_graph) {
1335   tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1336   range_op->set_op("Range");
1337   range_op->set_name(src_op.outputs[0]);
1338   CHECK_EQ(src_op.inputs.size(), 3);
1339   *range_op->add_input() = src_op.inputs[0];
1340   *range_op->add_input() = src_op.inputs[1];
1341   *range_op->add_input() = src_op.inputs[2];
1342   (*range_op->mutable_attr())["Tidx"].set_type(
1343       GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1344 }
1345 
ConvertPackOperator(const Model & model,const PackOperator & src_op,GraphDef * tensorflow_graph)1346 void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1347                          GraphDef* tensorflow_graph) {
1348   tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1349   pack_op->set_op("Pack");
1350   pack_op->set_name(src_op.outputs[0]);
1351   for (const auto& input : src_op.inputs) {
1352     *pack_op->add_input() = input;
1353   }
1354   (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1355   (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1356   (*pack_op->mutable_attr())["T"].set_type(
1357       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1358 }
1359 
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1360 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1361                          GraphDef* tensorflow_graph) {
1362   tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1363   fill_op->set_op("Fill");
1364   fill_op->set_name(src_op.outputs[0]);
1365   CHECK_EQ(src_op.inputs.size(), 2);
1366   *fill_op->add_input() = src_op.inputs[0];
1367   *fill_op->add_input() = src_op.inputs[1];
1368   (*fill_op->mutable_attr())["index_type"].set_type(
1369       GetTensorFlowDataType(model, src_op.inputs[0]));
1370   (*fill_op->mutable_attr())["T"].set_type(
1371       GetTensorFlowDataType(model, src_op.inputs[1]));
1372 }
1373 
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1374 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1375                              GraphDef* tensorflow_graph) {
1376   tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1377   floor_div_op->set_op("FloorDiv");
1378   floor_div_op->set_name(src_op.outputs[0]);
1379   CHECK_EQ(src_op.inputs.size(), 2);
1380   *floor_div_op->add_input() = src_op.inputs[0];
1381   *floor_div_op->add_input() = src_op.inputs[1];
1382   (*floor_div_op->mutable_attr())["T"].set_type(
1383       GetTensorFlowDataType(model, src_op.inputs[0]));
1384 }
1385 
ConvertFloorModOperator(const Model & model,const FloorModOperator & src_op,GraphDef * tensorflow_graph)1386 void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1387                              GraphDef* tensorflow_graph) {
1388   tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1389   floor_mod_op->set_op("FloorMod");
1390   floor_mod_op->set_name(src_op.outputs[0]);
1391   DCHECK_EQ(src_op.inputs.size(), 2);
1392   *floor_mod_op->add_input() = src_op.inputs[0];
1393   *floor_mod_op->add_input() = src_op.inputs[1];
1394   (*floor_mod_op->mutable_attr())["T"].set_type(
1395       GetTensorFlowDataType(model, src_op.inputs[0]));
1396 }
1397 
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1398 void ConvertExpandDimsOperator(const Model& model,
1399                                const ExpandDimsOperator& src_op,
1400                                GraphDef* tensorflow_graph) {
1401   tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1402   expand_dims_op->set_op("ExpandDims");
1403   expand_dims_op->set_name(src_op.outputs[0]);
1404   CHECK_EQ(src_op.inputs.size(), 2);
1405   *expand_dims_op->add_input() = src_op.inputs[0];
1406   *expand_dims_op->add_input() = src_op.inputs[1];
1407   (*expand_dims_op->mutable_attr())["T"].set_type(
1408       GetTensorFlowDataType(model, src_op.inputs[0]));
1409   (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1410       GetTensorFlowDataType(model, src_op.inputs[1]));
1411 }
1412 
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1413 void ConvertResizeBilinearOperator(const Model& model,
1414                                    const ResizeBilinearOperator& src_op,
1415                                    GraphDef* tensorflow_graph) {
1416   tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1417   resize_op->set_op("ResizeBilinear");
1418   resize_op->set_name(src_op.outputs[0]);
1419   CHECK_EQ(src_op.inputs.size(), 2);
1420   *resize_op->add_input() = src_op.inputs[0];
1421   *resize_op->add_input() = src_op.inputs[1];
1422   (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1423   (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1424   (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1425       src_op.half_pixel_centers);
1426 }
1427 
ConvertResizeNearestNeighborOperator(const Model & model,const ResizeNearestNeighborOperator & src_op,GraphDef * tensorflow_graph)1428 void ConvertResizeNearestNeighborOperator(
1429     const Model& model, const ResizeNearestNeighborOperator& src_op,
1430     GraphDef* tensorflow_graph) {
1431   tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1432   resize_op->set_op("ResizeNearestNeighbor");
1433   resize_op->set_name(src_op.outputs[0]);
1434   CHECK_EQ(src_op.inputs.size(), 2);
1435   *resize_op->add_input() = src_op.inputs[0];
1436   *resize_op->add_input() = src_op.inputs[1];
1437   (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1438   (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1439   (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1440       src_op.half_pixel_centers);
1441 }
1442 
ConvertOneHotOperator(const Model & model,const OneHotOperator & src_op,GraphDef * tensorflow_graph)1443 void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1444                            GraphDef* tensorflow_graph) {
1445   tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1446   onehot_op->set_op("OneHot");
1447   onehot_op->set_name(src_op.outputs[0]);
1448   CHECK_EQ(src_op.inputs.size(), 4);
1449   for (const auto& input : src_op.inputs) {
1450     *onehot_op->add_input() = input;
1451   }
1452   (*onehot_op->mutable_attr())["T"].set_type(
1453       GetTensorFlowDataType(model, src_op.outputs[0]));
1454   (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1455 }
1456 
1457 namespace {
1458 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1459 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1460                                           absl::string_view b) {
1461   if (a.empty() || b.empty()) return absl::string_view();
1462 
1463   const char* pa = a.data();
1464   const char* pb = b.data();
1465   std::string::difference_type count = 0;
1466   const std::string::difference_type limit = std::min(a.size(), b.size());
1467   while (count < limit && *pa == *pb) {
1468     ++pa;
1469     ++pb;
1470     ++count;
1471   }
1472 
1473   return absl::string_view(a.data(), count);
1474 }
1475 }  // namespace
1476 
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1477 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1478                              GraphDef* tensorflow_graph) {
1479   // Find the base name
1480   const std::string base(
1481       FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1482                               src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1483 
1484   // Concatenate inputs
1485   const std::string concat_output = base + "basic_lstm_cell/concat";
1486   // Op names have been chosen to match the tf.slim LSTM naming
1487   // as closely as possible.
1488   const int axis =
1489       model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1490           .shape()
1491           .dimensions_count() -
1492       1;
1493   // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1494   // works the same since the tensor has the same underlying data layout.
1495   const std::string axis_output = concat_output + "/axis";
1496   CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1497   tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1498   concat_op->set_op("ConcatV2");
1499   concat_op->set_name(concat_output);
1500   *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1501   *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1502   *concat_op->add_input() = axis_output;
1503   (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1504   (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1505   (*concat_op->mutable_attr())["N"].set_i(2);  // Number of inputs
1506 
1507   // Write weights
1508   const std::string weights_output = base + "weights";
1509   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1510   const std::string weights_name = WalkUpToConstantArray(
1511       model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1512   const auto& weights_array = model.GetArray(weights_name);
1513   // Convert 4D FullyConnected weights into 2D matrix
1514   const auto& weights_shape = weights_array.shape();
1515   CHECK_EQ(weights_shape.dimensions_count(), 2);
1516   CHECK(weights_array.buffer);
1517   CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1518   const float* weights_data =
1519       weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1520   ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1521                           AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1522 
1523   // Fully connected matrix multiply
1524   const std::string matmul_output = base + "MatMul";
1525   tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1526   matmul_op->set_op("MatMul");
1527   matmul_op->set_name(matmul_output);
1528   *matmul_op->add_input() = concat_output;
1529   *matmul_op->add_input() = weights_output;
1530   (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1531   (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1532   (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1533 
1534   // Write biases
1535   const std::string biases_output = base + "biases";
1536   CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1537   const std::string bias_name = WalkUpToConstantArray(
1538       model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1539   const auto& bias_array = model.GetArray(bias_name);
1540   // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1541   Shape bias_shape_1d = bias_array.shape();
1542   UnextendShape(&bias_shape_1d, 1);
1543   CHECK(bias_array.buffer);
1544   CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1545   const float* bias_data =
1546       bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1547   ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1548                           AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1549                           tensorflow_graph,
1550                           LegacyScalarPolicy::kDoCreateLegacyScalars);
1551 
1552   // Add biases
1553   std::string biasadd_output = base + "BiasAdd";
1554   tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1555   biasadd_op->set_op("BiasAdd");
1556   biasadd_op->set_name(biasadd_output);
1557   biasadd_op->add_input(matmul_output);
1558   biasadd_op->add_input(biases_output);
1559   (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1560   (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1561 
1562   // Split
1563   std::string split_dim_output = base + "split/split_dim";
1564   // The dimension is the same as the concatenation dimension
1565   CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1566   std::string split_output = base + "split";
1567   tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1568   split_op->set_op("Split");
1569   split_op->set_name(split_output);
1570   *split_op->add_input() = split_dim_output;
1571   *split_op->add_input() = biasadd_output;
1572   (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1573   (*split_op->mutable_attr())["num_split"].set_i(4);  // Split into four outputs
1574 
1575   // Activation functions and memory computations
1576   const std::string tanh_0_output = base + "Tanh";
1577   tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1578   tanh_0_op->set_op("Tanh");
1579   tanh_0_op->set_name(tanh_0_output);
1580   *tanh_0_op->add_input() = split_output + ":1";
1581   (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1582 
1583   const std::string sigmoid_1_output = base + "Sigmoid_1";
1584   tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1585   logistic_1_op->set_op("Sigmoid");
1586   logistic_1_op->set_name(sigmoid_1_output);
1587   *logistic_1_op->add_input() = split_output;
1588   (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1589 
1590   const std::string mul_1_output = base + "mul_1";
1591   tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1592   mul_1_op->set_op("Mul");
1593   mul_1_op->set_name(mul_1_output);
1594   *mul_1_op->add_input() = sigmoid_1_output;
1595   *mul_1_op->add_input() = tanh_0_output;
1596   (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1597 
1598   const std::string sigmoid_0_output = base + "Sigmoid";
1599   tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1600   logistic_2_op->set_op("Sigmoid");
1601   logistic_2_op->set_name(sigmoid_0_output);
1602   *logistic_2_op->add_input() = split_output + ":2";
1603   (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1604 
1605   const std::string sigmoid_2_output = base + "Sigmoid_2";
1606   tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1607   logistic_3_op->set_op("Sigmoid");
1608   logistic_3_op->set_name(sigmoid_2_output);
1609   *logistic_3_op->add_input() = split_output + ":3";
1610   (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1611 
1612   const std::string mul_0_output = base + "mul";
1613   tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1614   mul_0_op->set_op("Mul");
1615   mul_0_op->set_name(mul_0_output);
1616   *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1617   *mul_0_op->add_input() = sigmoid_0_output;
1618   (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1619 
1620   const std::string add_1_output =
1621       src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1622   tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1623   add_1_op->set_op("Add");
1624   add_1_op->set_name(add_1_output);
1625   *add_1_op->add_input() = mul_0_output;
1626   *add_1_op->add_input() = mul_1_output;
1627   (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1628 
1629   const std::string tanh_1_output = base + "Tanh_1";
1630   tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1631   tanh_1_op->set_op("Tanh");
1632   tanh_1_op->set_name(tanh_1_output);
1633   *tanh_1_op->add_input() = add_1_output;
1634   (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1635 
1636   const std::string mul_2_output =
1637       src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1638   tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1639   mul_2_op->set_op("Mul");
1640   mul_2_op->set_name(mul_2_output);
1641   *mul_2_op->add_input() = tanh_1_output;
1642   *mul_2_op->add_input() = sigmoid_2_output;
1643   (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1644 }
1645 
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1646 void ConvertSpaceToBatchNDOperator(const Model& model,
1647                                    const SpaceToBatchNDOperator& src_op,
1648                                    GraphDef* tensorflow_graph) {
1649   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1650   new_op->set_op("SpaceToBatchND");
1651   new_op->set_name(src_op.outputs[0]);
1652   CHECK_EQ(src_op.inputs.size(), 3);
1653   *new_op->add_input() = src_op.inputs[0];
1654   *new_op->add_input() = src_op.inputs[1];
1655   *new_op->add_input() = src_op.inputs[2];
1656   const tensorflow::DataType params_type =
1657       GetTensorFlowDataType(model, src_op.inputs[0]);
1658   (*new_op->mutable_attr())["T"].set_type(params_type);
1659   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1660   (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1661 }
1662 
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1663 void ConvertBatchToSpaceNDOperator(const Model& model,
1664                                    const BatchToSpaceNDOperator& src_op,
1665                                    GraphDef* tensorflow_graph) {
1666   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1667   new_op->set_op("BatchToSpaceND");
1668   new_op->set_name(src_op.outputs[0]);
1669   CHECK_EQ(src_op.inputs.size(), 3);
1670   *new_op->add_input() = src_op.inputs[0];
1671   *new_op->add_input() = src_op.inputs[1];
1672   *new_op->add_input() = src_op.inputs[2];
1673   const tensorflow::DataType params_type =
1674       GetTensorFlowDataType(model, src_op.inputs[0]);
1675   (*new_op->mutable_attr())["T"].set_type(params_type);
1676   (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1677   (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1678 }
1679 
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1680 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1681                         GraphDef* tensorflow_graph) {
1682   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1683   new_op->set_op("Pad");
1684   new_op->set_name(src_op.outputs[0]);
1685   CHECK_EQ(src_op.inputs.size(), 2);
1686   *new_op->add_input() = src_op.inputs[0];
1687   *new_op->add_input() = src_op.inputs[1];
1688 
1689   const tensorflow::DataType params_type =
1690       GetTensorFlowDataType(model, src_op.inputs[0]);
1691   (*new_op->mutable_attr())["T"].set_type(params_type);
1692 
1693   // Create the params tensor.
1694   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1695   params_op->set_op("Const");
1696   params_op->set_name(src_op.inputs[1]);
1697   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1698   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1699   tensor->set_dtype(DT_INT32);
1700 
1701   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1702   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1703     tensor->add_int_val(src_op.left_padding[i]);
1704     tensor->add_int_val(src_op.right_padding[i]);
1705   }
1706   auto* shape = tensor->mutable_tensor_shape();
1707   shape->add_dim()->set_size(src_op.left_padding.size());
1708   shape->add_dim()->set_size(2);
1709 }
1710 
ConvertPadV2Operator(const Model & model,const PadV2Operator & src_op,GraphDef * tensorflow_graph)1711 void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1712                           GraphDef* tensorflow_graph) {
1713   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1714   new_op->set_op("PadV2");
1715   new_op->set_name(src_op.outputs[0]);
1716   CHECK_EQ(src_op.inputs.size(), 2);
1717   *new_op->add_input() = src_op.inputs[0];
1718   *new_op->add_input() = src_op.inputs[1];
1719   *new_op->add_input() = src_op.inputs[2];
1720 
1721   const tensorflow::DataType params_type =
1722       GetTensorFlowDataType(model, src_op.inputs[0]);
1723   (*new_op->mutable_attr())["T"].set_type(params_type);
1724 
1725   // Create the params tensor.
1726   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1727   params_op->set_op("Const");
1728   params_op->set_name(src_op.inputs[1]);
1729   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1730   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1731   tensor->set_dtype(DT_INT32);
1732 
1733   CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1734   for (int i = 0; i < src_op.left_padding.size(); ++i) {
1735     tensor->add_int_val(src_op.left_padding[i]);
1736     tensor->add_int_val(src_op.right_padding[i]);
1737   }
1738   auto* shape = tensor->mutable_tensor_shape();
1739   shape->add_dim()->set_size(src_op.left_padding.size());
1740   shape->add_dim()->set_size(2);
1741 }
1742 
CreateSliceInput(const std::string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1743 void CreateSliceInput(const std::string& input_name,
1744                       const std::vector<int>& values,
1745                       GraphDef* tensorflow_graph) {
1746   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1747   params_op->set_op("Const");
1748   params_op->set_name(input_name);
1749   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1750   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1751   tensor->set_dtype(DT_INT32);
1752 
1753   for (int i = 0; i < values.size(); ++i) {
1754     tensor->add_int_val(values[i]);
1755   }
1756   auto* shape = tensor->mutable_tensor_shape();
1757   shape->add_dim()->set_size(values.size());
1758 }
1759 
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1760 void ConvertStridedSliceOperator(const Model& model,
1761                                  const StridedSliceOperator& src_op,
1762                                  GraphDef* tensorflow_graph) {
1763   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1764   new_op->set_op("StridedSlice");
1765   new_op->set_name(src_op.outputs[0]);
1766   CHECK_EQ(src_op.inputs.size(), 4);
1767   *new_op->add_input() = src_op.inputs[0];
1768   *new_op->add_input() = src_op.inputs[1];
1769   *new_op->add_input() = src_op.inputs[2];
1770   *new_op->add_input() = src_op.inputs[3];
1771 
1772   const tensorflow::DataType params_type =
1773       GetTensorFlowDataType(model, src_op.inputs[0]);
1774   (*new_op->mutable_attr())["T"].set_type(params_type);
1775 
1776   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1777   (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1778   (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1779   (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1780   (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1781   (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1782 
1783   // Create tensors for start/stop indices and strides.
1784   CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1785   CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1786   CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1787 }
1788 
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1789 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1790                           GraphDef* tensorflow_graph) {
1791   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1792   new_op->set_op("Slice");
1793   new_op->set_name(src_op.outputs[0]);
1794   CHECK_EQ(src_op.inputs.size(), 3);
1795   *new_op->add_input() = src_op.inputs[0];
1796   *new_op->add_input() = src_op.inputs[1];
1797   *new_op->add_input() = src_op.inputs[2];
1798 
1799   const tensorflow::DataType params_type =
1800       GetTensorFlowDataType(model, src_op.inputs[0]);
1801   (*new_op->mutable_attr())["T"].set_type(params_type);
1802   (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1803 
1804   // Create tensors for begin and size inputs.
1805   CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1806   CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1807 }
1808 
1809 template <typename T>
ConvertReduceOperator(const Model & model,const T & src_op,GraphDef * tensorflow_graph,const std::string & op_name)1810 void ConvertReduceOperator(const Model& model, const T& src_op,
1811                            GraphDef* tensorflow_graph,
1812                            const std::string& op_name) {
1813   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1814   new_op->set_op(op_name);
1815   new_op->set_name(src_op.outputs[0]);
1816   CHECK_EQ(src_op.inputs.size(), 2);
1817   *new_op->add_input() = src_op.inputs[0];
1818   *new_op->add_input() = src_op.inputs[1];
1819 
1820   if (src_op.type != OperatorType::kAny) {
1821     const tensorflow::DataType params_type =
1822         GetTensorFlowDataType(model, src_op.inputs[0]);
1823     (*new_op->mutable_attr())["T"].set_type(params_type);
1824   }
1825   const tensorflow::DataType indices_type =
1826       GetTensorFlowDataType(model, src_op.inputs[1]);
1827   (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1828 
1829   if (src_op.keep_dims) {
1830     (*new_op->mutable_attr())["keep_dims"].set_b(true);
1831   }
1832 
1833   // Create the params tensor.
1834   tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1835   params_op->set_op("Const");
1836   params_op->set_name(src_op.inputs[1]);
1837   (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1838   auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1839   tensor->set_dtype(DT_INT32);
1840 
1841   for (int i = 0; i < src_op.axis.size(); ++i) {
1842     tensor->add_int_val(src_op.axis[i]);
1843   }
1844   auto* shape = tensor->mutable_tensor_shape();
1845   shape->add_dim()->set_size(src_op.axis.size());
1846 }
1847 
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1848 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1849                             GraphDef* tensorflow_graph) {
1850   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1851   new_op->set_op("Squeeze");
1852   new_op->set_name(src_op.outputs[0]);
1853   CHECK_EQ(src_op.inputs.size(), 1);
1854   *new_op->add_input() = src_op.inputs[0];
1855 
1856   const tensorflow::DataType params_type =
1857       GetTensorFlowDataType(model, src_op.inputs[0]);
1858   (*new_op->mutable_attr())["T"].set_type(params_type);
1859 
1860   if (!src_op.squeeze_dims.empty()) {
1861     auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1862     for (int i : src_op.squeeze_dims) {
1863       squeeze_dims.mutable_list()->add_i(i);
1864     }
1865   }
1866 }
1867 
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1868 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1869                         GraphDef* tensorflow_graph) {
1870   tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1871   sub_op->set_op("Sub");
1872   sub_op->set_name(src_op.outputs[0]);
1873   CHECK_EQ(src_op.inputs.size(), 2);
1874   *sub_op->add_input() = src_op.inputs[0];
1875   *sub_op->add_input() = src_op.inputs[1];
1876   const tensorflow::DataType data_type =
1877       GetTensorFlowDataType(model, src_op.inputs[0]);
1878   (*sub_op->mutable_attr())["T"].set_type(data_type);
1879 }
1880 
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1881 void ConvertTensorFlowMinimumOperator(const Model& model,
1882                                       const TensorFlowMinimumOperator& src_op,
1883                                       GraphDef* tensorflow_graph) {
1884   tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1885   min_op->set_op("Minimum");
1886   min_op->set_name(src_op.outputs[0]);
1887   CHECK_EQ(src_op.inputs.size(), 2);
1888   *min_op->add_input() = src_op.inputs[0];
1889   *min_op->add_input() = src_op.inputs[1];
1890   const tensorflow::DataType data_type =
1891       GetTensorFlowDataType(model, src_op.inputs[0]);
1892   (*min_op->mutable_attr())["T"].set_type(data_type);
1893 }
1894 
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1895 void ConvertTensorFlowMaximumOperator(const Model& model,
1896                                       const TensorFlowMaximumOperator& src_op,
1897                                       GraphDef* tensorflow_graph) {
1898   tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1899   max_op->set_op("Maximum");
1900   max_op->set_name(src_op.outputs[0]);
1901   CHECK_EQ(src_op.inputs.size(), 2);
1902   *max_op->add_input() = src_op.inputs[0];
1903   *max_op->add_input() = src_op.inputs[1];
1904   const tensorflow::DataType data_type =
1905       GetTensorFlowDataType(model, src_op.inputs[0]);
1906   (*max_op->mutable_attr())["T"].set_type(data_type);
1907 }
1908 
ConvertSelectOperator(const Model & model,const SelectOperator & src_op,GraphDef * tensorflow_graph)1909 void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1910                            GraphDef* tensorflow_graph) {
1911   tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1912   select_op->set_op("Select");
1913   select_op->set_name(src_op.outputs[0]);
1914   CHECK_EQ(src_op.inputs.size(), 3);
1915   *select_op->add_input() = src_op.inputs[0];
1916   *select_op->add_input() = src_op.inputs[1];
1917   *select_op->add_input() = src_op.inputs[2];
1918   const tensorflow::DataType data_type =
1919       GetTensorFlowDataType(model, src_op.inputs[1]);
1920   (*select_op->mutable_attr())["T"].set_type(data_type);
1921 }
1922 
ConvertTileOperator(const Model & model,const TensorFlowTileOperator & src_op,GraphDef * tensorflow_graph)1923 void ConvertTileOperator(const Model& model,
1924                          const TensorFlowTileOperator& src_op,
1925                          GraphDef* tensorflow_graph) {
1926   tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1927   tile_op->set_op("Tile");
1928   tile_op->set_name(src_op.outputs[0]);
1929   CHECK_EQ(src_op.inputs.size(), 2);
1930   *tile_op->add_input() = src_op.inputs[0];
1931   *tile_op->add_input() = src_op.inputs[1];
1932   const tensorflow::DataType data_type =
1933       GetTensorFlowDataType(model, src_op.inputs[0]);
1934   (*tile_op->mutable_attr())["T"].set_type(data_type);
1935   const tensorflow::DataType multiples_data_type =
1936       GetTensorFlowDataType(model, src_op.inputs[1]);
1937   (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1938 }
1939 
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1940 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1941                            GraphDef* tensorflow_graph) {
1942   tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1943   topk_op->set_op("TopKV2");
1944   topk_op->set_name(src_op.outputs[0]);
1945   CHECK_EQ(src_op.inputs.size(), 2);
1946   *topk_op->add_input() = src_op.inputs[0];
1947   *topk_op->add_input() = src_op.inputs[1];
1948   const tensorflow::DataType data_type =
1949       GetTensorFlowDataType(model, src_op.inputs[0]);
1950   (*topk_op->mutable_attr())["T"].set_type(data_type);
1951   (*topk_op->mutable_attr())["sorted"].set_b(true);
1952 }
1953 
ConvertRandomUniformOperator(const Model & model,const RandomUniformOperator & src_op,GraphDef * tensorflow_graph)1954 void ConvertRandomUniformOperator(const Model& model,
1955                                   const RandomUniformOperator& src_op,
1956                                   GraphDef* tensorflow_graph) {
1957   CHECK(tensorflow_graph != nullptr);
1958   tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1959   new_op->set_op("RandomUniform");
1960   CHECK_EQ(src_op.inputs.size(), 1);
1961   new_op->set_name(src_op.outputs[0]);
1962   *new_op->add_input() = src_op.inputs[0];
1963   const tensorflow::DataType shape_type =
1964       GetTensorFlowDataType(model, src_op.inputs[0]);
1965   (*new_op->mutable_attr())["T"].set_type(shape_type);
1966   (*new_op->mutable_attr())["dtype"].set_type(
1967       GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1968   (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1969   (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1970 }
1971 
ConvertComparisonOperator(const Model & model,const Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)1972 void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1973                                const char* op_name,
1974                                GraphDef* tensorflow_graph) {
1975   tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1976   comparison_op->set_op(op_name);
1977   comparison_op->set_name(src_op.outputs[0]);
1978   CHECK_EQ(src_op.inputs.size(), 2);
1979   *comparison_op->add_input() = src_op.inputs[0];
1980   *comparison_op->add_input() = src_op.inputs[1];
1981   const tensorflow::DataType data_type =
1982       GetTensorFlowDataType(model, src_op.inputs[0]);
1983   (*comparison_op->mutable_attr())["T"].set_type(data_type);
1984 }
1985 
ConvertSparseToDenseOperator(const Model & model,const SparseToDenseOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1986 void ConvertSparseToDenseOperator(const Model& model,
1987                                   const SparseToDenseOperator& src_op,
1988                                   const char* op_name,
1989                                   GraphDef* tensorflow_graph) {
1990   tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1991   sparse_to_dense_op->set_op(op_name);
1992   sparse_to_dense_op->set_name(src_op.outputs[0]);
1993   CHECK_EQ(src_op.inputs.size(), 4);
1994   for (int i = 0; i < 4; ++i) {
1995     *sparse_to_dense_op->add_input() = src_op.inputs[i];
1996   }
1997   const tensorflow::DataType data_type =
1998       GetTensorFlowDataType(model, src_op.inputs[3]);
1999   (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
2000   const tensorflow::DataType index_type =
2001       GetTensorFlowDataType(model, src_op.inputs[0]);
2002   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
2003   (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
2004       src_op.validate_indices);
2005 }
2006 
ConvertPowOperator(const Model & model,const PowOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2007 void ConvertPowOperator(const Model& model, const PowOperator& src_op,
2008                         const char* op_name, GraphDef* tensorflow_graph) {
2009   tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
2010   pow_op->set_op(op_name);
2011   pow_op->set_name(src_op.outputs[0]);
2012   CHECK_EQ(src_op.inputs.size(), 2);
2013   for (int i = 0; i < 2; ++i) {
2014     *pow_op->add_input() = src_op.inputs[i];
2015   }
2016   const tensorflow::DataType data_type =
2017       GetTensorFlowDataType(model, src_op.inputs[0]);
2018   (*pow_op->mutable_attr())["T"].set_type(data_type);
2019 }
2020 
ConvertLogicalAndOperator(const Model & model,const LogicalAndOperator & src_op,GraphDef * tensorflow_graph)2021 void ConvertLogicalAndOperator(const Model& model,
2022                                const LogicalAndOperator& src_op,
2023                                GraphDef* tensorflow_graph) {
2024   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2025   logical_op->set_op("LogicalAnd");
2026   logical_op->set_name(src_op.outputs[0]);
2027   CHECK_EQ(src_op.inputs.size(), 2);
2028   for (int i = 0; i < 2; ++i) {
2029     *logical_op->add_input() = src_op.inputs[i];
2030   }
2031 }
2032 
ConvertLogicalNotOperator(const Model & model,const LogicalNotOperator & src_op,GraphDef * tensorflow_graph)2033 void ConvertLogicalNotOperator(const Model& model,
2034                                const LogicalNotOperator& src_op,
2035                                GraphDef* tensorflow_graph) {
2036   tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2037   logical_op->set_op("LogicalNot");
2038   logical_op->set_name(src_op.outputs[0]);
2039   CHECK_EQ(src_op.inputs.size(), 1);
2040   *logical_op->add_input() = src_op.inputs[0];
2041 }
2042 
ConvertLogicalOrOperator(const Model & model,const LogicalOrOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2043 void ConvertLogicalOrOperator(const Model& model,
2044                               const LogicalOrOperator& src_op,
2045                               const char* op_name, GraphDef* tensorflow_graph) {
2046   tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2047   logical_or_op->set_op(op_name);
2048   logical_or_op->set_name(src_op.outputs[0]);
2049   CHECK_EQ(src_op.inputs.size(), 2);
2050   for (int i = 0; i < 2; ++i) {
2051     *logical_or_op->add_input() = src_op.inputs[i];
2052   }
2053   const tensorflow::DataType data_type =
2054       GetTensorFlowDataType(model, src_op.inputs[0]);
2055   (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2056 }
2057 
ConvertCTCBeamSearchDecoderOperator(const Model & model,const CTCBeamSearchDecoderOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2058 void ConvertCTCBeamSearchDecoderOperator(
2059     const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2060     const char* op_name, GraphDef* tensorflow_graph) {
2061   auto* op = tensorflow_graph->add_node();
2062   op->set_op(op_name);
2063   op->set_name(src_op.outputs[0]);
2064   CHECK_EQ(src_op.inputs.size(), 2);
2065   for (int i = 0; i < 2; ++i) {
2066     *op->add_input() = src_op.inputs[i];
2067   }
2068   (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2069   (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2070   (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2071 }
2072 
ConvertUnpackOperator(const Model & model,const UnpackOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2073 void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2074                            const char* op_name, GraphDef* tensorflow_graph) {
2075   tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2076   unpack_op->set_op(op_name);
2077   unpack_op->set_name(src_op.outputs[0]);
2078   CHECK_EQ(src_op.inputs.size(), 2);
2079   *unpack_op->add_input() = src_op.inputs[0];
2080   const tensorflow::DataType data_type =
2081       GetTensorFlowDataType(model, src_op.inputs[0]);
2082   (*unpack_op->mutable_attr())["T"].set_type(data_type);
2083   (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2084   (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2085 }
2086 
ConvertZerosLikeOperator(const Model & model,const TensorFlowZerosLikeOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2087 void ConvertZerosLikeOperator(const Model& model,
2088                               const TensorFlowZerosLikeOperator& src_op,
2089                               const char* op_name, GraphDef* tensorflow_graph) {
2090   tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2091   zeros_like_op->set_op(op_name);
2092   zeros_like_op->set_name(src_op.outputs[0]);
2093   DCHECK_EQ(src_op.inputs.size(), 1);
2094   *zeros_like_op->add_input() = src_op.inputs[0];
2095   const tensorflow::DataType data_type =
2096       GetTensorFlowDataType(model, src_op.inputs[0]);
2097   (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2098 }
2099 
ConvertReverseV2Operator(const Model & model,const ReverseV2Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)2100 void ConvertReverseV2Operator(const Model& model,
2101                               const ReverseV2Operator& src_op,
2102                               const char* op_name, GraphDef* tensorflow_graph) {
2103   tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2104   reverse_v2_op->set_op(op_name);
2105   reverse_v2_op->set_name(src_op.outputs[0]);
2106   DCHECK_EQ(src_op.inputs.size(), 2);
2107   *reverse_v2_op->add_input() = src_op.inputs[0];
2108   *reverse_v2_op->add_input() = src_op.inputs[1];
2109   const tensorflow::DataType data_type =
2110       GetTensorFlowDataType(model, src_op.inputs[0]);
2111   (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2112 }
2113 
ConvertReverseSequenceOperator(const Model & model,const ReverseSequenceOperator & src_op,GraphDef * tensorflow_graph)2114 void ConvertReverseSequenceOperator(const Model& model,
2115                                     const ReverseSequenceOperator& src_op,
2116                                     GraphDef* tensorflow_graph) {
2117   tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2118   reverse_seq_op->set_op("ReverseSequence");
2119   reverse_seq_op->set_name(src_op.outputs[0]);
2120   CHECK_EQ(src_op.inputs.size(), 2);
2121   *reverse_seq_op->add_input() = src_op.inputs[0];
2122   *reverse_seq_op->add_input() = src_op.inputs[1];
2123   (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2124   (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2125 }
2126 
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)2127 void ConvertOperator(const Model& model, const Operator& src_op,
2128                      GraphDef* tensorflow_graph) {
2129   if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2130     LOG(FATAL)
2131         << "Unsupported: the input model has a fused activation function";
2132   }
2133 
2134   if (src_op.type == OperatorType::kConv) {
2135     ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2136                         tensorflow_graph);
2137   } else if (src_op.type == OperatorType::kDepthwiseConv) {
2138     ConvertDepthwiseConvOperator(
2139         model, static_cast<const DepthwiseConvOperator&>(src_op),
2140         tensorflow_graph);
2141   } else if (src_op.type == OperatorType::kDepthToSpace) {
2142     ConvertDepthToSpaceOperator(
2143         model, static_cast<const DepthToSpaceOperator&>(src_op),
2144         tensorflow_graph);
2145   } else if (src_op.type == OperatorType::kSpaceToDepth) {
2146     ConvertSpaceToDepthOperator(
2147         model, static_cast<const SpaceToDepthOperator&>(src_op),
2148         tensorflow_graph);
2149   } else if (src_op.type == OperatorType::kFullyConnected) {
2150     ConvertFullyConnectedOperator(
2151         model, static_cast<const FullyConnectedOperator&>(src_op),
2152         tensorflow_graph);
2153   } else if (src_op.type == OperatorType::kAdd) {
2154     ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2155                        tensorflow_graph);
2156   } else if (src_op.type == OperatorType::kAddN) {
2157     ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2158                         tensorflow_graph);
2159   } else if (src_op.type == OperatorType::kMul) {
2160     ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2161                        tensorflow_graph);
2162   } else if (src_op.type == OperatorType::kDiv) {
2163     ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2164                        tensorflow_graph);
2165   } else if (src_op.type == OperatorType::kRelu) {
2166     ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2167                         tensorflow_graph);
2168   } else if (src_op.type == OperatorType::kRelu1) {
2169     ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2170                          tensorflow_graph);
2171   } else if (src_op.type == OperatorType::kRelu6) {
2172     ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2173                          tensorflow_graph);
2174   } else if (src_op.type == OperatorType::kLog) {
2175     ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2176                        tensorflow_graph);
2177   } else if (src_op.type == OperatorType::kLogistic) {
2178     ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2179                             tensorflow_graph);
2180   } else if (src_op.type == OperatorType::kTanh) {
2181     ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2182                         tensorflow_graph);
2183   } else if (src_op.type == OperatorType::kL2Normalization) {
2184     ConvertL2NormalizationOperator(
2185         static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2186   } else if (src_op.type == OperatorType::kSoftmax) {
2187     ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2188                            tensorflow_graph);
2189   } else if (src_op.type == OperatorType::kLogSoftmax) {
2190     ConvertLogSoftmaxOperator(model,
2191                               static_cast<const LogSoftmaxOperator&>(src_op),
2192                               tensorflow_graph);
2193   } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2194     ConvertLocalResponseNormalizationOperator(
2195         static_cast<const LocalResponseNormalizationOperator&>(src_op),
2196         tensorflow_graph);
2197   } else if (src_op.type == OperatorType::kLstmCell) {
2198     ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2199                             tensorflow_graph);
2200   } else if (src_op.type == OperatorType::kMaxPool) {
2201     ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2202                            tensorflow_graph);
2203   } else if (src_op.type == OperatorType::kAveragePool) {
2204     ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2205                                tensorflow_graph);
2206   } else if (src_op.type == OperatorType::kConcatenation) {
2207     ConvertConcatenationOperator(
2208         model, static_cast<const ConcatenationOperator&>(src_op),
2209         tensorflow_graph);
2210   } else if (src_op.type == OperatorType::kReshape) {
2211     ConvertTensorFlowReshapeOperator(
2212         model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2213         tensorflow_graph);
2214   } else if (src_op.type == OperatorType::kL2Pool) {
2215     ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2216                           tensorflow_graph);
2217   } else if (src_op.type == OperatorType::kSquare) {
2218     ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2219                           tensorflow_graph);
2220   } else if (src_op.type == OperatorType::kSqrt) {
2221     ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2222                         tensorflow_graph);
2223   } else if (src_op.type == OperatorType::kRsqrt) {
2224     ConvertRsqrtOperator(model,
2225                          static_cast<const TensorFlowRsqrtOperator&>(src_op),
2226                          tensorflow_graph);
2227   } else if (src_op.type == OperatorType::kSplit) {
2228     ConvertSplitOperator(model,
2229                          static_cast<const TensorFlowSplitOperator&>(src_op),
2230                          tensorflow_graph);
2231   } else if (src_op.type == OperatorType::kSplitV) {
2232     ConvertSplitVOperator(model,
2233                           static_cast<const TensorFlowSplitVOperator&>(src_op),
2234                           tensorflow_graph);
2235   } else if (src_op.type == OperatorType::kFakeQuant) {
2236     ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2237                              tensorflow_graph);
2238   } else if (src_op.type == OperatorType::kCast) {
2239     ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2240                         tensorflow_graph);
2241   } else if (src_op.type == OperatorType::kFloor) {
2242     ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2243                          tensorflow_graph);
2244   } else if (src_op.type == OperatorType::kCeil) {
2245     ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2246                         tensorflow_graph);
2247   } else if (src_op.type == OperatorType::kRound) {
2248     ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
2249                          tensorflow_graph);
2250   } else if (src_op.type == OperatorType::kGather) {
2251     ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2252                           tensorflow_graph);
2253   } else if (src_op.type == OperatorType::kResizeBilinear) {
2254     ConvertResizeBilinearOperator(
2255         model, static_cast<const ResizeBilinearOperator&>(src_op),
2256         tensorflow_graph);
2257   } else if (src_op.type == OperatorType::kResizeNearestNeighbor) {
2258     ConvertResizeNearestNeighborOperator(
2259         model, static_cast<const ResizeNearestNeighborOperator&>(src_op),
2260         tensorflow_graph);
2261   } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2262     ConvertSpaceToBatchNDOperator(
2263         model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2264         tensorflow_graph);
2265   } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2266     ConvertBatchToSpaceNDOperator(
2267         model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2268         tensorflow_graph);
2269   } else if (src_op.type == OperatorType::kPad) {
2270     ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2271                        tensorflow_graph);
2272   } else if (src_op.type == OperatorType::kPadV2) {
2273     ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2274                          tensorflow_graph);
2275   } else if (src_op.type == OperatorType::kStridedSlice) {
2276     ConvertStridedSliceOperator(
2277         model, static_cast<const StridedSliceOperator&>(src_op),
2278         tensorflow_graph);
2279   } else if (src_op.type == OperatorType::kMean) {
2280     ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2281                           tensorflow_graph, "Mean");
2282   } else if (src_op.type == OperatorType::kSum) {
2283     ConvertReduceOperator(model,
2284                           static_cast<const TensorFlowSumOperator&>(src_op),
2285                           tensorflow_graph, "Sum");
2286   } else if (src_op.type == OperatorType::kReduceProd) {
2287     ConvertReduceOperator(model,
2288                           static_cast<const TensorFlowProdOperator&>(src_op),
2289                           tensorflow_graph, "Prod");
2290   } else if (src_op.type == OperatorType::kReduceMin) {
2291     ConvertReduceOperator(model,
2292                           static_cast<const TensorFlowMinOperator&>(src_op),
2293                           tensorflow_graph, "Min");
2294   } else if (src_op.type == OperatorType::kReduceMax) {
2295     ConvertReduceOperator(model,
2296                           static_cast<const TensorFlowMaxOperator&>(src_op),
2297                           tensorflow_graph, "Max");
2298   } else if (src_op.type == OperatorType::kSub) {
2299     ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2300                        tensorflow_graph);
2301   } else if (src_op.type == OperatorType::kMinimum) {
2302     ConvertTensorFlowMinimumOperator(
2303         model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2304         tensorflow_graph);
2305   } else if (src_op.type == OperatorType::kMaximum) {
2306     ConvertTensorFlowMaximumOperator(
2307         model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2308         tensorflow_graph);
2309   } else if (src_op.type == OperatorType::kSqueeze) {
2310     ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2311                            tensorflow_graph);
2312   } else if (src_op.type == OperatorType::kSlice) {
2313     ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2314                          tensorflow_graph);
2315   } else if (src_op.type == OperatorType::kArgMax) {
2316     ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2317                           tensorflow_graph);
2318   } else if (src_op.type == OperatorType::kArgMin) {
2319     ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2320                           tensorflow_graph);
2321   } else if (src_op.type == OperatorType::kTopK_V2) {
2322     ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2323                           tensorflow_graph);
2324   } else if (src_op.type == OperatorType::kTranspose) {
2325     ConvertTransposeOperator(
2326         model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2327   } else if (src_op.type == OperatorType::kShape) {
2328     ConvertTensorFlowShapeOperator(
2329         model, static_cast<const TensorFlowShapeOperator&>(src_op),
2330         tensorflow_graph);
2331   } else if (src_op.type == OperatorType::kRank) {
2332     ConvertRankOperator(model,
2333                         static_cast<const TensorFlowRankOperator&>(src_op),
2334                         tensorflow_graph);
2335   } else if (src_op.type == OperatorType::kRange) {
2336     ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2337                          tensorflow_graph);
2338   } else if (src_op.type == OperatorType::kPack) {
2339     ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2340                         tensorflow_graph);
2341   } else if (src_op.type == OperatorType::kFill) {
2342     ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2343                         tensorflow_graph);
2344   } else if (src_op.type == OperatorType::kFloorDiv) {
2345     ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2346                             tensorflow_graph);
2347   } else if (src_op.type == OperatorType::kFloorMod) {
2348     ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2349                             tensorflow_graph);
2350   } else if (src_op.type == OperatorType::kExpandDims) {
2351     ConvertExpandDimsOperator(model,
2352                               static_cast<const ExpandDimsOperator&>(src_op),
2353                               tensorflow_graph);
2354   } else if (src_op.type == OperatorType::kTransposeConv) {
2355     ConvertTransposeConvOperator(
2356         model, static_cast<const TransposeConvOperator&>(src_op),
2357         tensorflow_graph);
2358   } else if (src_op.type == OperatorType::kRandomUniform) {
2359     ConvertRandomUniformOperator(
2360         model, static_cast<const RandomUniformOperator&>(src_op),
2361         tensorflow_graph);
2362   } else if (src_op.type == OperatorType::kEqual) {
2363     ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2364   } else if (src_op.type == OperatorType::kNotEqual) {
2365     ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2366   } else if (src_op.type == OperatorType::kGreater) {
2367     ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2368   } else if (src_op.type == OperatorType::kGreaterEqual) {
2369     ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2370   } else if (src_op.type == OperatorType::kLess) {
2371     ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2372   } else if (src_op.type == OperatorType::kLessEqual) {
2373     ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2374   } else if (src_op.type == OperatorType::kSelect) {
2375     ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2376                           tensorflow_graph);
2377   } else if (src_op.type == OperatorType::kTile) {
2378     ConvertTileOperator(model,
2379                         static_cast<const TensorFlowTileOperator&>(src_op),
2380                         tensorflow_graph);
2381   } else if (src_op.type == OperatorType::kPow) {
2382     ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2383                        tensorflow_graph);
2384   } else if (src_op.type == OperatorType::kAny) {
2385     ConvertReduceOperator(model,
2386                           static_cast<const TensorFlowAnyOperator&>(src_op),
2387                           tensorflow_graph, "Any");
2388   } else if (src_op.type == OperatorType::kLogicalAnd) {
2389     ConvertLogicalAndOperator(model,
2390                               static_cast<const LogicalAndOperator&>(src_op),
2391                               tensorflow_graph);
2392   } else if (src_op.type == OperatorType::kLogicalNot) {
2393     ConvertLogicalNotOperator(model,
2394                               static_cast<const LogicalNotOperator&>(src_op),
2395                               tensorflow_graph);
2396   } else if (src_op.type == OperatorType::kOneHot) {
2397     ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2398                           tensorflow_graph);
2399   } else if (src_op.type == OperatorType::kLogicalOr) {
2400     ConvertLogicalOrOperator(model,
2401                              static_cast<const LogicalOrOperator&>(src_op),
2402                              "LogicalOr", tensorflow_graph);
2403   } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2404     ConvertCTCBeamSearchDecoderOperator(
2405         model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2406         "CTCBeamSearchDecoder", tensorflow_graph);
2407   } else if (src_op.type == OperatorType::kUnpack) {
2408     ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2409                           "Unpack", tensorflow_graph);
2410   } else if (src_op.type == OperatorType::kZerosLike) {
2411     ConvertZerosLikeOperator(
2412         model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2413         "ZerosLike", tensorflow_graph);
2414   } else if (src_op.type == OperatorType::kReverseV2) {
2415     ConvertReverseV2Operator(model,
2416                              static_cast<const ReverseV2Operator&>(src_op),
2417                              "Reverse_V2", tensorflow_graph);
2418   } else if (src_op.type == OperatorType::kReverseSequence) {
2419     ConvertReverseSequenceOperator(
2420         model, static_cast<const ReverseSequenceOperator&>(src_op),
2421         tensorflow_graph);
2422   } else {
2423     LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2424   }
2425 }
2426 
AddPlaceholder(const std::string & name,ArrayDataType type,GraphDef * tensorflow_graph)2427 void AddPlaceholder(const std::string& name, ArrayDataType type,
2428                     GraphDef* tensorflow_graph) {
2429   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2430   placeholder->set_op("Placeholder");
2431   switch (type) {
2432     case ArrayDataType::kBool:
2433       (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2434       break;
2435     case ArrayDataType::kFloat:
2436       (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2437       break;
2438     case ArrayDataType::kUint8:
2439       (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2440       break;
2441     case ArrayDataType::kInt32:
2442       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2443       break;
2444     case ArrayDataType::kUint32:
2445       (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
2446       break;
2447     case ArrayDataType::kInt64:
2448       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2449       break;
2450     case ArrayDataType::kInt16:
2451       (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2452       break;
2453     case ArrayDataType::kComplex64:
2454       (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2455       break;
2456     default:
2457       LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2458   }
2459   placeholder->set_name(name);
2460 }
2461 
AddPlaceholderForRNNState(const Model & model,const std::string & name,int size,GraphDef * tensorflow_graph)2462 void AddPlaceholderForRNNState(const Model& model, const std::string& name,
2463                                int size, GraphDef* tensorflow_graph) {
2464   tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2465   placeholder->set_op("Placeholder");
2466   placeholder->set_name(name);
2467   (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2468 
2469   auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2470   const auto& state_array = model.GetArray(name);
2471   if (state_array.has_shape()) {
2472     const auto& state_shape = state_array.shape();
2473     const int kDims = state_shape.dimensions_count();
2474     for (int i = 0; i < kDims; ++i) {
2475       shape->add_dim()->set_size(state_shape.dims(i));
2476     }
2477   } else {
2478     shape->add_dim()->set_size(1);
2479     shape->add_dim()->set_size(size);
2480   }
2481 }
2482 
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)2483 void ExportTensorFlowGraphDefImplementation(const Model& model,
2484                                             GraphDef* tensorflow_graph) {
2485   for (const auto& input_array : model.flags.input_arrays()) {
2486     AddPlaceholder(input_array.name(),
2487                    model.GetArray(input_array.name()).data_type,
2488                    tensorflow_graph);
2489   }
2490   for (const auto& rnn_state : model.flags.rnn_states()) {
2491     AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2492                               tensorflow_graph);
2493   }
2494   for (const auto& op : model.operators) {
2495     ConvertOperator(model, *op, tensorflow_graph);
2496   }
2497   // Generically export arrays that haven't been exported already
2498   // by the above operators export. It's important that this comes
2499   // after, as some operators need to export arrays that they reference
2500   // in a specific way, rather than in the generic way done below.
2501   for (const auto& array_pair : model.GetArrayMap()) {
2502     const std::string& array_name = array_pair.first;
2503     const auto& array = *array_pair.second;
2504     if (array.buffer) {
2505       switch (array.data_type) {
2506         case ArrayDataType::kBool:
2507           ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2508           break;
2509         case ArrayDataType::kFloat:
2510           ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2511           break;
2512         case ArrayDataType::kInt32:
2513           ConvertIntTensorConst(model, array_name, tensorflow_graph);
2514           break;
2515         case ArrayDataType::kComplex64:
2516           ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2517           break;
2518         default:
2519           break;
2520       }
2521     }
2522   }
2523 }
2524 }  // namespace
2525 
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)2526 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2527   for (const auto& array_kv : model->GetArrayMap()) {
2528     const std::string& array_name = array_kv.first;
2529     Array& array = *array_kv.second;
2530     if (!array.buffer || !array.minmax) {
2531       continue;
2532     }
2533     const std::string& wrapped_array_name =
2534         AvailableArrayName(*model, array_name + "/data");
2535     Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2536     wrapped_array.data_type = array.data_type;
2537     wrapped_array.copy_shape(array.shape());
2538     wrapped_array.buffer = std::move(array.buffer);
2539     FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2540     fakequant_op->inputs = {wrapped_array_name};
2541     fakequant_op->outputs = {array_name};
2542     fakequant_op->minmax.reset(new MinMax);
2543     *fakequant_op->minmax = *array.minmax;
2544     const auto& it = FindOpWithInput(*model, array_name);
2545     model->operators.emplace(it, fakequant_op);
2546   }
2547   CheckInvariants(*model);
2548 }
2549 
ExportTensorFlowGraphDef(const Model & model,std::string * output_file_contents)2550 void ExportTensorFlowGraphDef(const Model& model,
2551                               std::string* output_file_contents) {
2552   CHECK(output_file_contents->empty());
2553   GraphDef tensorflow_graph;
2554   ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2555   LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2556   CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2557 }
2558 }  // namespace toco
2559