1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "google/protobuf/map.h"
22 #include "google/protobuf/text_format.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/model.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/runtime/types.h"
35 #include "tensorflow/lite/toco/tensorflow_util.h"
36 #include "tensorflow/lite/toco/tooling_util.h"
37
38 using tensorflow::DT_BOOL;
39 using tensorflow::DT_COMPLEX64;
40 using tensorflow::DT_FLOAT;
41 using tensorflow::DT_INT16;
42 using tensorflow::DT_INT32;
43 using tensorflow::DT_INT64;
44 using tensorflow::DT_UINT32;
45 using tensorflow::DT_UINT8;
46 using tensorflow::GraphDef;
47 using tensorflow::TensorProto;
48
49 namespace toco {
50 namespace {
51
GetTensorFlowDataType(ArrayDataType data_type,const std::string & error_location)52 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
53 const std::string& error_location) {
54 switch (data_type) {
55 case ArrayDataType::kBool:
56 return tensorflow::DT_BOOL;
57 case ArrayDataType::kFloat:
58 return tensorflow::DT_FLOAT;
59 case ArrayDataType::kUint8:
60 return tensorflow::DT_UINT8;
61 case ArrayDataType::kInt32:
62 return tensorflow::DT_INT32;
63 case ArrayDataType::kUint32:
64 return tensorflow::DT_UINT32;
65 case ArrayDataType::kInt64:
66 return tensorflow::DT_INT64;
67 case ArrayDataType::kString:
68 return tensorflow::DT_STRING;
69 case ArrayDataType::kComplex64:
70 return tensorflow::DT_COMPLEX64;
71 default:
72 case ArrayDataType::kNone:
73 LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
74 << "' in " << error_location;
75 return tensorflow::DT_INVALID;
76 }
77 }
78
GetTensorFlowDataTypeForOp(ArrayDataType data_type,const std::string & op_name)79 tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
80 const std::string& op_name) {
81 return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
82 }
83
GetTensorFlowDataType(const Model & model,const std::string & array_name)84 tensorflow::DataType GetTensorFlowDataType(const Model& model,
85 const std::string& array_name) {
86 return GetTensorFlowDataType(model.GetArray(array_name).data_type,
87 "array '" + array_name + "'");
88 }
89
90 // TensorFlow sometimes forbids what it calls "legacy scalars",
91 // which are 1-D shapes where the unique shape size is 1.
92 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
93 // For that reason, we generally avoid creating legacy scalars,
94 // by detecting the case where a 1-D shape would be of size 1 and
95 // replacing that by a 0-D shape.
96 // However, there is a special circumstance where we must not do that
97 // and must unconditionally create a 1-D shape even if it is going to
98 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
99 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
100 // a depth of 1, that would be a legacy scalar, so in that case we
101 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
102 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
103
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)104 void ExportFloatArray(const Shape& input_shape, const float* input_data,
105 TensorProto* output_tensor,
106 LegacyScalarPolicy legacy_scalar_policy) {
107 output_tensor->set_dtype(DT_FLOAT);
108 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
109 auto* shape = output_tensor->mutable_tensor_shape();
110
111 const int kDims = input_shape.dimensions_count();
112 if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
113 kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
114 for (int i = 0; i < kDims; ++i) {
115 shape->add_dim()->set_size(input_shape.dims(i));
116 }
117 }
118 output_tensor->set_tensor_content(
119 std::string(reinterpret_cast<const char*>(input_data),
120 sizeof(*input_data) * input_flat_size));
121 }
122
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)123 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
124 const float* input_data, AxesOrder output_axes_order,
125 TensorProto* output_tensor,
126 LegacyScalarPolicy legacy_scalar_policy) {
127 CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
128 output_tensor->set_dtype(DT_FLOAT);
129 CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
130 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
131
132 Shape shuffled_shape;
133 ShuffleDims(input_shape, input_axes_order, output_axes_order,
134 &shuffled_shape);
135 std::vector<float> shuffled_data(input_flat_size);
136 ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
137 input_data, shuffled_data.data());
138
139 ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
140 legacy_scalar_policy);
141 }
142
HasAlreadyExportedConst(const std::string & name,const GraphDef & tensorflow_graph)143 bool HasAlreadyExportedConst(const std::string& name,
144 const GraphDef& tensorflow_graph) {
145 for (const auto& node : tensorflow_graph.node()) {
146 if (node.op() == "Const" && node.name() == name) {
147 return true;
148 }
149 }
150 return false;
151 }
152
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)153 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
154 const float* input_data,
155 AxesOrder input_axes_order,
156 AxesOrder output_axes_order,
157 GraphDef* tensorflow_graph,
158 LegacyScalarPolicy legacy_scalar_policy) {
159 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
160 return;
161 }
162 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
163 const_op->set_op("Const");
164 const_op->set_name(name);
165 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
166 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
167 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
168 tensor, legacy_scalar_policy);
169 }
170
ConvertFloatTensorConst(const std::string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)171 void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
172 const float* input_data,
173 AxesOrder input_axes_order,
174 AxesOrder output_axes_order,
175 GraphDef* tensorflow_graph) {
176 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
177 output_axes_order, tensorflow_graph,
178 LegacyScalarPolicy::kAvoidLegacyScalars);
179 }
180
ConvertFloatTensorConst(const Model & model,const std::string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)181 void ConvertFloatTensorConst(const Model& model, const std::string& name,
182 AxesOrder input_axes_order,
183 AxesOrder output_axes_order,
184 GraphDef* tensorflow_graph) {
185 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
186 return;
187 }
188 CHECK(model.HasArray(name));
189 const auto& input_array = model.GetArray(name);
190 const auto& input_shape = input_array.shape();
191 CHECK(input_array.buffer);
192 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
193 const float* input_data =
194 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
195 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
196 output_axes_order, tensorflow_graph);
197 }
198
ConvertFloatTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)199 void ConvertFloatTensorConst(const Model& model, const std::string& name,
200 GraphDef* tensorflow_graph) {
201 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
202 return;
203 }
204 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
205 const_op->set_op("Const");
206 const_op->set_name(name);
207 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
208 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
209 CHECK(model.HasArray(name));
210 const auto& input_array = model.GetArray(name);
211 const auto& input_shape = input_array.shape();
212 CHECK(input_array.buffer);
213 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
214 const float* input_data =
215 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
216 ExportFloatArray(input_shape, input_data, tensor,
217 LegacyScalarPolicy::kAvoidLegacyScalars);
218 }
219
ConvertBoolTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)220 void ConvertBoolTensorConst(const Model& model, const std::string& name,
221 GraphDef* tensorflow_graph) {
222 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
223 return;
224 }
225 CHECK(model.HasArray(name));
226 const auto& array = model.GetArray(name);
227 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
228 const_op->set_op("Const");
229 const_op->set_name(name);
230 (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
231 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
232 tensor->set_dtype(DT_BOOL);
233 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
234 for (auto index : data) {
235 tensor->add_bool_val(index);
236 }
237 const auto& array_shape = array.shape();
238 auto* shape = tensor->mutable_tensor_shape();
239 for (int i = 0; i < array_shape.dimensions_count(); i++) {
240 shape->add_dim()->set_size(array_shape.dims(i));
241 }
242 }
243
ConvertIntTensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)244 void ConvertIntTensorConst(const Model& model, const std::string& name,
245 GraphDef* tensorflow_graph) {
246 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
247 return;
248 }
249 CHECK(model.HasArray(name));
250 const auto& array = model.GetArray(name);
251 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
252 const_op->set_op("Const");
253 const_op->set_name(name);
254 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
255 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
256 tensor->set_dtype(DT_INT32);
257 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
258 for (auto index : data) {
259 tensor->add_int_val(index);
260 }
261 const auto& array_shape = array.shape();
262 auto* shape = tensor->mutable_tensor_shape();
263 for (int i = 0; i < array_shape.dimensions_count(); i++) {
264 shape->add_dim()->set_size(array_shape.dims(i));
265 }
266 }
267
CreateIntTensorConst(const std::string & name,const std::vector<int32> & data,const std::vector<int32> & shape,GraphDef * tensorflow_graph)268 void CreateIntTensorConst(const std::string& name,
269 const std::vector<int32>& data,
270 const std::vector<int32>& shape,
271 GraphDef* tensorflow_graph) {
272 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
273 return;
274 }
275 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
276 const_op->set_op("Const");
277 const_op->set_name(name);
278 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
279 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
280 tensor->set_dtype(DT_INT32);
281 for (auto index : data) {
282 tensor->add_int_val(index);
283 }
284 auto* tensor_shape = tensor->mutable_tensor_shape();
285 int num_elements = 1;
286 for (int size : shape) {
287 tensor_shape->add_dim()->set_size(size);
288 num_elements *= size;
289 }
290 CHECK_EQ(num_elements, data.size());
291 }
292
ConvertComplex64TensorConst(const Model & model,const std::string & name,GraphDef * tensorflow_graph)293 void ConvertComplex64TensorConst(const Model& model, const std::string& name,
294 GraphDef* tensorflow_graph) {
295 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
296 return;
297 }
298 CHECK(model.HasArray(name));
299 const auto& array = model.GetArray(name);
300 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
301 const_op->set_op("Const");
302 const_op->set_name(name);
303 (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
304 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
305 tensor->set_dtype(DT_COMPLEX64);
306 const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
307 for (auto index : data) {
308 tensor->add_scomplex_val(std::real(index));
309 tensor->add_scomplex_val(std::imag(index));
310 }
311 const auto& array_shape = array.shape();
312 auto* shape = tensor->mutable_tensor_shape();
313 for (int i = 0; i < array_shape.dimensions_count(); i++) {
314 shape->add_dim()->set_size(array_shape.dims(i));
315 }
316 }
317
CreateMatrixShapeTensorConst(const std::string & name,int rows,int cols,GraphDef * tensorflow_graph)318 void CreateMatrixShapeTensorConst(const std::string& name, int rows, int cols,
319 GraphDef* tensorflow_graph) {
320 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
321 return;
322 }
323 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
324 const_op->set_op("Const");
325 const_op->set_name(name);
326 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
327 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
328 tensor->set_dtype(DT_INT32);
329 const int32 data[2] = {cols, rows};
330 tensor->set_tensor_content(
331 std::string(reinterpret_cast<const char*>(data), sizeof(data)));
332 auto* shape = tensor->mutable_tensor_shape();
333 shape->add_dim()->set_size(2);
334 }
335
CreateDummyConcatDimTensorConst(const std::string & name,int dim,GraphDef * tensorflow_graph)336 void CreateDummyConcatDimTensorConst(const std::string& name, int dim,
337 GraphDef* tensorflow_graph) {
338 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
339 return;
340 }
341 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
342 const_op->set_op("Const");
343 const_op->set_name(name);
344 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
345 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
346 tensor->set_dtype(DT_INT32);
347 tensor->add_int_val(dim);
348 }
349
CreateReshapeShapeTensorConst(const std::string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)350 void CreateReshapeShapeTensorConst(const std::string& name,
351 const std::vector<int32>& shape,
352 GraphDef* tensorflow_graph) {
353 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
354 return;
355 }
356 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
357 const_op->set_op("Const");
358 const_op->set_name(name);
359 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
360 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
361 tensor->set_dtype(DT_INT32);
362 for (auto s : shape) {
363 tensor->add_int_val(s);
364 }
365 // TensorFlow sometimes forbids what it calls "legacy scalars",
366 // which are shapes of size 1 where the unique shape size is 1.
367 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
368 if (shape.size() > 1) {
369 auto* tensor_shape = tensor->mutable_tensor_shape();
370 tensor_shape->add_dim()->set_size(shape.size());
371 }
372 }
373
WalkUpToConstantArray(const Model & model,const std::string & name)374 std::string WalkUpToConstantArray(const Model& model, const std::string& name) {
375 const Array& original_array = model.GetArray(name);
376 if (original_array.buffer) {
377 return name;
378 }
379 const auto* op = GetOpWithOutput(model, name);
380 CHECK(op);
381 CHECK(op->type == OperatorType::kFakeQuant);
382 const std::string& input_of_fakequant_name = op->inputs[0];
383 const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
384 CHECK(input_of_fakequant.buffer);
385 return input_of_fakequant_name;
386 }
387
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)388 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
389 GraphDef* tensorflow_graph) {
390 const bool has_bias = src_op.inputs.size() >= 3;
391 std::string conv_output = src_op.outputs[0];
392 if (has_bias) {
393 conv_output += "/conv";
394 }
395
396 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
397 conv2d_op->set_op("Conv2D");
398 conv2d_op->set_name(conv_output);
399 *conv2d_op->add_input() = src_op.inputs[0];
400 *conv2d_op->add_input() = src_op.inputs[1];
401 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
402 const std::string& weights_array_name =
403 WalkUpToConstantArray(model, src_op.inputs[1]);
404 const auto& weights_array = model.GetArray(weights_array_name);
405 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
406 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
407 AxesOrder::kHWIO, tensorflow_graph);
408 auto& strides = (*conv2d_op->mutable_attr())["strides"];
409 strides.mutable_list()->add_i(1);
410 strides.mutable_list()->add_i(src_op.stride_height);
411 strides.mutable_list()->add_i(src_op.stride_width);
412 strides.mutable_list()->add_i(1);
413 if ((src_op.dilation_width_factor != 1) ||
414 (src_op.dilation_height_factor != 1)) {
415 auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
416 dilations.mutable_list()->add_i(1);
417 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
418 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
419 dilations.mutable_list()->add_i(1);
420 }
421 std::string padding;
422 if (src_op.padding.type == PaddingType::kSame) {
423 padding = "SAME";
424 } else if (src_op.padding.type == PaddingType::kValid) {
425 padding = "VALID";
426 } else {
427 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
428 }
429 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
430
431 if (has_bias) {
432 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
433 biasadd_op->set_op("BiasAdd");
434 biasadd_op->set_name(src_op.outputs[0]);
435 biasadd_op->add_input(conv_output);
436 biasadd_op->add_input(src_op.inputs[2]);
437 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
438 CHECK(model.HasArray(src_op.inputs[2]));
439 const std::string& bias_array_name =
440 WalkUpToConstantArray(model, src_op.inputs[2]);
441 const auto& bias_array = model.GetArray(bias_array_name);
442 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
443 Shape bias_shape_1d = bias_array.shape();
444 UnextendShape(&bias_shape_1d, 1);
445 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
446 const float* bias_data =
447 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
448 ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
449 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
450 tensorflow_graph,
451 LegacyScalarPolicy::kDoCreateLegacyScalars);
452 }
453 }
454
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)455 void ConvertDepthwiseConvOperator(const Model& model,
456 const DepthwiseConvOperator& src_op,
457 GraphDef* tensorflow_graph) {
458 const bool has_bias = src_op.inputs.size() >= 3;
459 std::string conv_output = src_op.outputs[0];
460 if (has_bias) {
461 conv_output += "/conv";
462 }
463
464 tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
465 dc2d_op->set_op("DepthwiseConv2dNative");
466 dc2d_op->set_name(conv_output);
467 *dc2d_op->add_input() = src_op.inputs[0];
468 *dc2d_op->add_input() = src_op.inputs[1];
469 (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
470
471 // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
472 // We need to convert that to H x W x InputDepth x Multiplier.
473 // That's only a matter of constructing a Dims object; the actual
474 // array layout is the same.
475 CHECK(model.HasArray(src_op.inputs[1]));
476 const std::string& src_weights_name =
477 WalkUpToConstantArray(model, src_op.inputs[1]);
478 const auto& src_weights_array = model.GetArray(src_weights_name);
479 const auto& src_weights_shape = src_weights_array.shape();
480 CHECK_EQ(src_weights_shape.dimensions_count(), 4);
481 const Shape dst_weights_shape =
482 Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
483 src_weights_shape.dims(3) / src_op.depth_multiplier,
484 src_op.depth_multiplier});
485 CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
486 CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
487 src_weights_shape.dims(3));
488 CHECK_EQ(src_weights_shape.dims(0), 1);
489
490 CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
491 const float* src_weights_data =
492 src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
493 ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
494 AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
495
496 auto& strides = (*dc2d_op->mutable_attr())["strides"];
497 strides.mutable_list()->add_i(1);
498 strides.mutable_list()->add_i(src_op.stride_height);
499 strides.mutable_list()->add_i(src_op.stride_width);
500 strides.mutable_list()->add_i(1);
501 // TODO(b/116063589): To return a working TF GraphDef, we should be returning
502 // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
503 // the conv since TF doesn't support dilations.
504 if ((src_op.dilation_width_factor != 1) ||
505 (src_op.dilation_height_factor != 1)) {
506 auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
507 dilations.mutable_list()->add_i(1);
508 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
509 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
510 dilations.mutable_list()->add_i(1);
511 }
512 std::string padding;
513 if (src_op.padding.type == PaddingType::kSame) {
514 padding = "SAME";
515 } else if (src_op.padding.type == PaddingType::kValid) {
516 padding = "VALID";
517 } else {
518 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
519 }
520 (*dc2d_op->mutable_attr())["padding"].set_s(padding);
521
522 if (has_bias) {
523 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
524 biasadd_op->set_op("BiasAdd");
525 biasadd_op->set_name(src_op.outputs[0]);
526 biasadd_op->add_input(conv_output);
527 biasadd_op->add_input(src_op.inputs[2]);
528 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
529 CHECK(model.HasArray(src_op.inputs[2]));
530 const std::string& bias_name =
531 WalkUpToConstantArray(model, src_op.inputs[2]);
532 const auto& bias_array = model.GetArray(bias_name);
533 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
534 Shape bias_shape_1d = bias_array.shape();
535 UnextendShape(&bias_shape_1d, 1);
536 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
537 const float* bias_data =
538 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
539 ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
540 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
541 tensorflow_graph,
542 LegacyScalarPolicy::kDoCreateLegacyScalars);
543 }
544 }
545
ConvertTransposeConvOperator(const Model & model,const TransposeConvOperator & src_op,GraphDef * tensorflow_graph)546 void ConvertTransposeConvOperator(const Model& model,
547 const TransposeConvOperator& src_op,
548 GraphDef* tensorflow_graph) {
549 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
550 conv2d_op->set_op("Conv2DBackpropInput");
551 conv2d_op->set_name(src_op.outputs[0]);
552 *conv2d_op->add_input() = src_op.inputs[0];
553 *conv2d_op->add_input() = src_op.inputs[1];
554 *conv2d_op->add_input() = src_op.inputs[2];
555 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
556 const std::string& weights_array_name = WalkUpToConstantArray(
557 model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
558 const auto& weights_array = model.GetArray(weights_array_name);
559 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
560 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
561 AxesOrder::kHWOI, tensorflow_graph);
562 auto& strides = (*conv2d_op->mutable_attr())["strides"];
563 strides.mutable_list()->add_i(1);
564 strides.mutable_list()->add_i(src_op.stride_height);
565 strides.mutable_list()->add_i(src_op.stride_width);
566 strides.mutable_list()->add_i(1);
567 std::string padding;
568 if (src_op.padding.type == PaddingType::kSame) {
569 padding = "SAME";
570 } else if (src_op.padding.type == PaddingType::kValid) {
571 padding = "VALID";
572 } else {
573 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
574 }
575 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
576 }
577
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)578 void ConvertDepthToSpaceOperator(const Model& model,
579 const DepthToSpaceOperator& src_op,
580 GraphDef* tensorflow_graph) {
581 tensorflow::NodeDef* op = tensorflow_graph->add_node();
582 op->set_op("DepthToSpace");
583 op->set_name(src_op.outputs[0]);
584 *op->add_input() = src_op.inputs[0];
585 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
586 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
587 }
588
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)589 void ConvertSpaceToDepthOperator(const Model& model,
590 const SpaceToDepthOperator& src_op,
591 GraphDef* tensorflow_graph) {
592 tensorflow::NodeDef* op = tensorflow_graph->add_node();
593 op->set_op("SpaceToDepth");
594 op->set_name(src_op.outputs[0]);
595 *op->add_input() = src_op.inputs[0];
596 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
597 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
598 }
599
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)600 void ConvertFullyConnectedOperator(const Model& model,
601 const FullyConnectedOperator& src_op,
602 GraphDef* tensorflow_graph) {
603 // Reshape input activations to have the shape expected by the MatMul.
604 const std::string reshape_output =
605 AvailableArrayName(model, src_op.outputs[0] + "/reshape");
606 const std::string reshape_shape =
607 AvailableArrayName(model, reshape_output + "/shape");
608 const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
609 const auto& fc_weights_shape = fc_weights_array.shape();
610 CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
611 CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
612 tensorflow_graph);
613 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
614 reshape_op->set_op("Reshape");
615 reshape_op->set_name(reshape_output);
616 reshape_op->add_input(src_op.inputs[0]);
617 reshape_op->add_input(reshape_shape);
618 (*reshape_op->mutable_attr())["T"].set_type(
619 GetTensorFlowDataType(model, src_op.inputs[0]));
620
621 const bool has_bias = src_op.inputs.size() >= 3;
622 std::string matmul_output = src_op.outputs[0];
623 if (has_bias) {
624 matmul_output += "/matmul";
625 }
626
627 // Transpose the RHS input from column-major to row-major to match TensorFlow
628 // expectations. This is the inverse of the transpose we do during
629 // ResolveTensorFlowMatMul.
630 const std::string transpose_output =
631 AvailableArrayName(model, matmul_output + "/transpose_weights");
632 const std::string transpose_perm =
633 AvailableArrayName(model, transpose_output + "/perm");
634 CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
635 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
636 transpose_op->set_op("Transpose");
637 transpose_op->set_name(transpose_output);
638 *transpose_op->add_input() = src_op.inputs[1];
639 *transpose_op->add_input() = transpose_perm;
640 (*transpose_op->mutable_attr())["T"].set_type(
641 GetTensorFlowDataType(model, src_op.inputs[1]));
642 (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
643
644 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
645 matmul_op->set_op("MatMul");
646 matmul_op->set_name(matmul_output);
647 *matmul_op->add_input() = reshape_output;
648 *matmul_op->add_input() = transpose_op->name();
649 (*matmul_op->mutable_attr())["T"].set_type(
650 GetTensorFlowDataType(model, src_op.inputs[0]));
651 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
652 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
653 CHECK(model.HasArray(src_op.inputs[1]));
654
655 // Add the bias, if it exists.
656 if (has_bias) {
657 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
658 biasadd_op->set_op("BiasAdd");
659 biasadd_op->set_name(src_op.outputs[0]);
660 biasadd_op->add_input(matmul_output);
661 biasadd_op->add_input(src_op.inputs[2]);
662 (*biasadd_op->mutable_attr())["T"].set_type(
663 GetTensorFlowDataType(model, src_op.inputs[0]));
664 CHECK(model.HasArray(src_op.inputs[2]));
665 const auto& bias_array = model.GetArray(src_op.inputs[2]);
666 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
667 Shape bias_shape_1d = bias_array.shape();
668 UnextendShape(&bias_shape_1d, 1);
669 CHECK(bias_array.buffer);
670 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
671 const float* bias_data =
672 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
673 ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
674 bias_shape_1d, bias_data, AxesOrder::kOneAxis,
675 AxesOrder::kOneAxis, tensorflow_graph,
676 LegacyScalarPolicy::kDoCreateLegacyScalars);
677 }
678 }
679
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)680 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
681 GraphDef* tensorflow_graph) {
682 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
683 add_op->set_op("Add");
684 add_op->set_name(src_op.outputs[0]);
685 CHECK_EQ(src_op.inputs.size(), 2);
686 *add_op->add_input() = src_op.inputs[0];
687 *add_op->add_input() = src_op.inputs[1];
688 (*add_op->mutable_attr())["T"].set_type(
689 GetTensorFlowDataType(model, src_op.outputs[0]));
690 }
691
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)692 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
693 GraphDef* tensorflow_graph) {
694 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
695 add_op->set_op("AddN");
696 add_op->set_name(src_op.outputs[0]);
697 for (const auto& input : src_op.inputs) {
698 *add_op->add_input() = input;
699 }
700 (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
701 (*add_op->mutable_attr())["T"].set_type(
702 GetTensorFlowDataType(model, src_op.outputs[0]));
703 }
704
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)705 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
706 GraphDef* tensorflow_graph) {
707 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
708 mul_op->set_op("Mul");
709 mul_op->set_name(src_op.outputs[0]);
710 CHECK_EQ(src_op.inputs.size(), 2);
711 *mul_op->add_input() = src_op.inputs[0];
712 *mul_op->add_input() = src_op.inputs[1];
713 (*mul_op->mutable_attr())["T"].set_type(
714 GetTensorFlowDataType(model, src_op.outputs[0]));
715 }
716
ConvertDivOperator(const Model & model,const DivOperator & src_op,GraphDef * tensorflow_graph)717 void ConvertDivOperator(const Model& model, const DivOperator& src_op,
718 GraphDef* tensorflow_graph) {
719 tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
720 div_op->set_op("Div");
721 div_op->set_name(src_op.outputs[0]);
722 CHECK_EQ(src_op.inputs.size(), 2);
723 *div_op->add_input() = src_op.inputs[0];
724 *div_op->add_input() = src_op.inputs[1];
725 (*div_op->mutable_attr())["T"].set_type(
726 GetTensorFlowDataType(model, src_op.outputs[0]));
727 }
728
ConvertReluOperator(const Model & model,const ReluOperator & src_op,GraphDef * tensorflow_graph)729 void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
730 GraphDef* tensorflow_graph) {
731 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
732 relu_op->set_op("Relu");
733 relu_op->set_name(src_op.outputs[0]);
734 *relu_op->add_input() = src_op.inputs[0];
735 (*relu_op->mutable_attr())["T"].set_type(
736 GetTensorFlowDataType(model, src_op.outputs[0]));
737 }
738
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)739 void ConvertRelu1Operator(const Relu1Operator& src_op,
740 GraphDef* tensorflow_graph) {
741 const std::string max_bounds = src_op.outputs[0] + "/max_bounds";
742 const std::string min_bounds = src_op.outputs[0] + "/min_bounds";
743 const std::string max_output = src_op.outputs[0] + "/max_output";
744
745 tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
746 max_bounds_const_op->set_op("Const");
747 max_bounds_const_op->set_name(max_bounds);
748 (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
749 auto* max_bounds_const_op_tensor =
750 (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
751 max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
752 max_bounds_const_op_tensor->add_float_val(-1.0f);
753
754 tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
755 min_bounds_const_op->set_op("Const");
756 min_bounds_const_op->set_name(min_bounds);
757 (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
758 auto* min_bounds_const_op_tensor =
759 (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
760 min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
761 min_bounds_const_op_tensor->add_float_val(1.0f);
762
763 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
764 max_op->set_op("Maximum");
765 max_op->set_name(max_output);
766 *max_op->add_input() = src_op.inputs[0];
767 *max_op->add_input() = max_bounds;
768 (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
769
770 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
771 min_op->set_op("Minimum");
772 min_op->set_name(src_op.outputs[0]);
773 *min_op->add_input() = max_output;
774 *min_op->add_input() = min_bounds;
775 (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
776 }
777
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)778 void ConvertRelu6Operator(const Relu6Operator& src_op,
779 GraphDef* tensorflow_graph) {
780 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
781 relu_op->set_op("Relu6");
782 relu_op->set_name(src_op.outputs[0]);
783 *relu_op->add_input() = src_op.inputs[0];
784 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
785 }
786
ConvertLogOperator(const LogOperator & src_op,GraphDef * tensorflow_graph)787 void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
788 tensorflow::NodeDef* op = tensorflow_graph->add_node();
789 op->set_op("Log");
790 op->set_name(src_op.outputs[0]);
791 CHECK_EQ(src_op.inputs.size(), 1);
792 *op->add_input() = src_op.inputs[0];
793 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
794 }
795
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)796 void ConvertLogisticOperator(const LogisticOperator& src_op,
797 GraphDef* tensorflow_graph) {
798 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
799 relu_op->set_op("Sigmoid");
800 relu_op->set_name(src_op.outputs[0]);
801 *relu_op->add_input() = src_op.inputs[0];
802 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
803 }
804
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)805 void ConvertTanhOperator(const TanhOperator& src_op,
806 GraphDef* tensorflow_graph) {
807 tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
808 tanh_op->set_op("Tanh");
809 tanh_op->set_name(src_op.outputs[0]);
810 *tanh_op->add_input() = src_op.inputs[0];
811 (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
812 }
813
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)814 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
815 GraphDef* tensorflow_graph) {
816 std::string softmax_input;
817 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
818 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
819 softmax_input = src_op.inputs[0];
820 } else {
821 // Insert a reshape operator that reduces the dimensions down to the 2 that
822 // are required for TensorFlow Logits.
823 const std::string reshape_output =
824 src_op.outputs[0] + "/softmax_insert_reshape";
825 const std::string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
826 softmax_input = reshape_output;
827
828 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
829 reshape_op->set_op("Reshape");
830 reshape_op->set_name(reshape_output);
831 *reshape_op->add_input() = src_op.inputs[0];
832 *reshape_op->add_input() = softmax_size;
833 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
834
835 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
836 int32 flattened_size = 1;
837 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
838 flattened_size *= input_shape.dims(i);
839 }
840 const std::vector<int32> shape_data = {
841 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
842 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
843 }
844
845 tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
846 softmax_op->set_op("Softmax");
847 softmax_op->set_name(src_op.outputs[0]);
848 *softmax_op->add_input() = softmax_input;
849 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
850 CHECK_EQ(src_op.beta, 1.f);
851 (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
852 }
853
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)854 void ConvertLogSoftmaxOperator(const Model& model,
855 const LogSoftmaxOperator& src_op,
856 GraphDef* tensorflow_graph) {
857 std::string softmax_input;
858 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
859 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
860 softmax_input = src_op.inputs[0];
861 } else {
862 // Insert a reshape operator that reduces the dimensions down to the 2 that
863 // are required for TensorFlow Logits.
864 const std::string reshape_output =
865 src_op.outputs[0] + "/log_softmax_insert_reshape";
866 const std::string softmax_size =
867 src_op.outputs[0] + "/log_softmax_insert_size";
868 softmax_input = reshape_output;
869
870 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
871 reshape_op->set_op("Reshape");
872 reshape_op->set_name(reshape_output);
873 *reshape_op->add_input() = src_op.inputs[0];
874 *reshape_op->add_input() = softmax_size;
875 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
876
877 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
878 int32 flattened_size = 1;
879 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
880 flattened_size *= input_shape.dims(i);
881 }
882 const std::vector<int32> shape_data = {
883 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
884 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
885 }
886
887 tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
888 log_softmax_op->set_op("LogSoftmax");
889 log_softmax_op->set_name(src_op.outputs[0]);
890 *log_softmax_op->add_input() = softmax_input;
891 (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
892 }
893
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)894 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
895 GraphDef* tensorflow_graph) {
896 const std::string square_output = src_op.outputs[0] + "/square";
897 const std::string sum_reduction_indices =
898 src_op.outputs[0] + "/reduction_indices";
899 const std::string sum_output = src_op.outputs[0] + "/sum";
900 const std::string rsqrt_output = src_op.outputs[0] + "/rsqrt";
901 const std::string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
902
903 tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
904 sum_reduction_indices_op->set_op("Const");
905 sum_reduction_indices_op->set_name(sum_reduction_indices);
906 (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
907 auto* sum_reduction_indices_tensor =
908 (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
909 sum_reduction_indices_tensor->set_dtype(DT_INT32);
910 auto* sum_reduction_indices_shape =
911 sum_reduction_indices_tensor->mutable_tensor_shape();
912 auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
913 sum_reduction_indices_dim->set_size(2);
914 sum_reduction_indices_tensor->add_int_val(0);
915 sum_reduction_indices_tensor->add_int_val(1);
916
917 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
918 square_op->set_op("Square");
919 square_op->set_name(square_output);
920 *square_op->add_input() = src_op.inputs[0];
921 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
922
923 tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
924 sum_op->set_op("Sum");
925 sum_op->set_name(sum_output);
926 *sum_op->add_input() = square_output;
927 *sum_op->add_input() = sum_reduction_indices;
928 (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
929
930 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
931 rsqrt_op->set_op("Rsqrt");
932 rsqrt_op->set_name(rsqrt_output);
933 *rsqrt_op->add_input() = sum_output;
934 (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
935
936 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
937 mul_op->set_op("Mul");
938 mul_op->set_name(src_op.outputs[0]);
939 *mul_op->add_input() = src_op.inputs[0];
940 *mul_op->add_input() = rsqrt_output;
941 (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
942 }
943
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)944 void ConvertLocalResponseNormalizationOperator(
945 const LocalResponseNormalizationOperator& src_op,
946 GraphDef* tensorflow_graph) {
947 tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
948 lrn_op->set_op("LRN");
949 lrn_op->set_name(src_op.outputs[0]);
950 *lrn_op->add_input() = src_op.inputs[0];
951 (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
952 (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
953 (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
954 (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
955 }
956
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)957 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
958 GraphDef* tensorflow_graph) {
959 tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
960 fakequant_op->set_op("FakeQuantWithMinMaxArgs");
961 fakequant_op->set_name(src_op.outputs[0]);
962 CHECK_EQ(src_op.inputs.size(), 1);
963 *fakequant_op->add_input() = src_op.inputs[0];
964 CHECK(src_op.minmax);
965 (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
966 (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
967 if (src_op.num_bits) {
968 (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
969 }
970 if (src_op.narrow_range) {
971 (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
972 }
973 }
974
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)975 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
976 GraphDef* tensorflow_graph) {
977 tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
978 maxpool_op->set_op("MaxPool");
979 maxpool_op->set_name(src_op.outputs[0]);
980 *maxpool_op->add_input() = src_op.inputs[0];
981 auto& strides = (*maxpool_op->mutable_attr())["strides"];
982 strides.mutable_list()->add_i(1);
983 strides.mutable_list()->add_i(src_op.stride_height);
984 strides.mutable_list()->add_i(src_op.stride_width);
985 strides.mutable_list()->add_i(1);
986 std::string padding;
987 if (src_op.padding.type == PaddingType::kSame) {
988 padding = "SAME";
989 } else if (src_op.padding.type == PaddingType::kValid) {
990 padding = "VALID";
991 } else {
992 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
993 }
994 (*maxpool_op->mutable_attr())["padding"].set_s(padding);
995 (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
996 auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
997 ksize.mutable_list()->add_i(1);
998 ksize.mutable_list()->add_i(src_op.kheight);
999 ksize.mutable_list()->add_i(src_op.kwidth);
1000 ksize.mutable_list()->add_i(1);
1001 }
1002
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)1003 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
1004 GraphDef* tensorflow_graph) {
1005 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1006 avgpool_op->set_op("AvgPool");
1007 avgpool_op->set_name(src_op.outputs[0]);
1008 *avgpool_op->add_input() = src_op.inputs[0];
1009 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1010 strides.mutable_list()->add_i(1);
1011 strides.mutable_list()->add_i(src_op.stride_height);
1012 strides.mutable_list()->add_i(src_op.stride_width);
1013 strides.mutable_list()->add_i(1);
1014 std::string padding;
1015 if (src_op.padding.type == PaddingType::kSame) {
1016 padding = "SAME";
1017 } else if (src_op.padding.type == PaddingType::kValid) {
1018 padding = "VALID";
1019 } else {
1020 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1021 }
1022 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1023 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1024 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1025 ksize.mutable_list()->add_i(1);
1026 ksize.mutable_list()->add_i(src_op.kheight);
1027 ksize.mutable_list()->add_i(src_op.kwidth);
1028 ksize.mutable_list()->add_i(1);
1029 }
1030
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)1031 void ConvertConcatenationOperator(const Model& model,
1032 const ConcatenationOperator& src_op,
1033 GraphDef* tensorflow_graph) {
1034 tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1035 dc_op->set_op("ConcatV2");
1036 dc_op->set_name(src_op.outputs[0]);
1037 const std::string dummy_axis = src_op.outputs[0] + "/axis";
1038 CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1039 for (const auto& input : src_op.inputs) {
1040 *dc_op->add_input() = input;
1041 }
1042 *dc_op->add_input() = dummy_axis;
1043 (*dc_op->mutable_attr())["T"].set_type(
1044 GetTensorFlowDataType(model, src_op.inputs[0]));
1045 (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1046 (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1047 }
1048
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)1049 void ConvertTensorFlowReshapeOperator(const Model& model,
1050 const TensorFlowReshapeOperator& src_op,
1051 GraphDef* tensorflow_graph) {
1052 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1053 reshape_op->set_op("Reshape");
1054 reshape_op->set_name(src_op.outputs[0]);
1055 CHECK_EQ(src_op.inputs.size(), 2);
1056 *reshape_op->add_input() = src_op.inputs[0];
1057 *reshape_op->add_input() = src_op.inputs[1];
1058 (*reshape_op->mutable_attr())["T"].set_type(
1059 GetTensorFlowDataType(model, src_op.outputs[0]));
1060 const auto& shape_array = model.GetArray(src_op.inputs[1]);
1061 QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1062 << "Only int32 shape is supported.";
1063 QCHECK(shape_array.buffer != nullptr)
1064 << "Shape inferred at runtime is not supported.";
1065 const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1066 CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1067 }
1068
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)1069 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1070 GraphDef* tensorflow_graph) {
1071 const std::string square_output = src_op.outputs[0] + "/square";
1072 const std::string avgpool_output = src_op.outputs[0] + "/avgpool";
1073
1074 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1075 square_op->set_op("Square");
1076 square_op->set_name(square_output);
1077 *square_op->add_input() = src_op.inputs[0];
1078 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1079
1080 std::string padding;
1081 if (src_op.padding.type == PaddingType::kSame) {
1082 padding = "SAME";
1083 } else if (src_op.padding.type == PaddingType::kValid) {
1084 padding = "VALID";
1085 } else {
1086 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1087 }
1088
1089 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1090 avgpool_op->set_op("AvgPool");
1091 avgpool_op->set_name(avgpool_output);
1092 *avgpool_op->add_input() = square_output;
1093 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1094 strides.mutable_list()->add_i(1);
1095 strides.mutable_list()->add_i(src_op.stride_height);
1096 strides.mutable_list()->add_i(src_op.stride_width);
1097 strides.mutable_list()->add_i(1);
1098
1099 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1100 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1101 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1102 ksize.mutable_list()->add_i(1);
1103 ksize.mutable_list()->add_i(src_op.kheight);
1104 ksize.mutable_list()->add_i(src_op.kwidth);
1105 ksize.mutable_list()->add_i(1);
1106
1107 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1108 sqrt_op->set_op("Sqrt");
1109 sqrt_op->set_name(src_op.outputs[0]);
1110 *sqrt_op->add_input() = avgpool_output;
1111 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1112 }
1113
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)1114 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1115 GraphDef* tensorflow_graph) {
1116 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1117 square_op->set_op("Square");
1118 square_op->set_name(src_op.outputs[0]);
1119 CHECK_EQ(src_op.inputs.size(), 1);
1120 *square_op->add_input() = src_op.inputs[0];
1121 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1122 }
1123
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)1124 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1125 GraphDef* tensorflow_graph) {
1126 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1127 sqrt_op->set_op("Sqrt");
1128 sqrt_op->set_name(src_op.outputs[0]);
1129 CHECK_EQ(src_op.inputs.size(), 1);
1130 *sqrt_op->add_input() = src_op.inputs[0];
1131 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1132 }
1133
ConvertRsqrtOperator(const Model & model,const TensorFlowRsqrtOperator & src_op,GraphDef * tensorflow_graph)1134 void ConvertRsqrtOperator(const Model& model,
1135 const TensorFlowRsqrtOperator& src_op,
1136 GraphDef* tensorflow_graph) {
1137 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1138 rsqrt_op->set_op("Rsqrt");
1139 rsqrt_op->set_name(src_op.outputs[0]);
1140 CHECK_EQ(src_op.inputs.size(), 1);
1141 *rsqrt_op->add_input() = src_op.inputs[0];
1142 const tensorflow::DataType data_type =
1143 GetTensorFlowDataType(model, src_op.inputs[0]);
1144 (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1145 }
1146
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)1147 void ConvertSplitOperator(const Model& model,
1148 const TensorFlowSplitOperator& src_op,
1149 GraphDef* tensorflow_graph) {
1150 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1151 split_op->set_op("Split");
1152 split_op->set_name(src_op.outputs[0]);
1153 for (const auto& input : src_op.inputs) {
1154 *split_op->add_input() = input;
1155 }
1156 (*split_op->mutable_attr())["T"].set_type(
1157 GetTensorFlowDataType(model, src_op.outputs[0]));
1158 (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1159 const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1160 CHECK(split_dim_array.buffer);
1161 CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1162 const auto& split_dim_data =
1163 split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1164 CHECK_EQ(split_dim_data.size(), 1);
1165 const int split_dim = split_dim_data[0];
1166 CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1167 tensorflow_graph);
1168 }
1169
ConvertSplitVOperator(const Model & model,const TensorFlowSplitVOperator & src_op,GraphDef * tensorflow_graph)1170 void ConvertSplitVOperator(const Model& model,
1171 const TensorFlowSplitVOperator& src_op,
1172 GraphDef* tensorflow_graph) {
1173 tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1174 split_v_op->set_op("SplitV");
1175 split_v_op->set_name(src_op.outputs[0]);
1176 for (const auto& input : src_op.inputs) {
1177 *split_v_op->add_input() = input;
1178 }
1179 (*split_v_op->mutable_attr())["T"].set_type(
1180 GetTensorFlowDataType(model, src_op.outputs[0]));
1181 (*split_v_op->mutable_attr())["Tlen"].set_type(
1182 GetTensorFlowDataType(model, src_op.inputs[1]));
1183 (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1184 ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
1185 }
1186
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1187 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1188 GraphDef* tensorflow_graph) {
1189 tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1190 cast_op->set_op("Cast");
1191 cast_op->set_name(src_op.outputs[0]);
1192 CHECK_EQ(src_op.inputs.size(), 1);
1193 *cast_op->add_input() = src_op.inputs[0];
1194
1195 (*cast_op->mutable_attr())["DstT"].set_type(
1196 GetTensorFlowDataType(model, src_op.outputs[0]));
1197 (*cast_op->mutable_attr())["SrcT"].set_type(
1198 GetTensorFlowDataType(model, src_op.inputs[0]));
1199 }
1200
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1201 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1202 GraphDef* tensorflow_graph) {
1203 tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1204 floor_op->set_op("Floor");
1205 floor_op->set_name(src_op.outputs[0]);
1206 CHECK_EQ(src_op.inputs.size(), 1);
1207 *floor_op->add_input() = src_op.inputs[0];
1208 (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1209 }
1210
ConvertCeilOperator(const Model & model,const CeilOperator & src_op,GraphDef * tensorflow_graph)1211 void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1212 GraphDef* tensorflow_graph) {
1213 tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1214 ceil_op->set_op("Ceil");
1215 ceil_op->set_name(src_op.outputs[0]);
1216 CHECK_EQ(src_op.inputs.size(), 1);
1217 *ceil_op->add_input() = src_op.inputs[0];
1218 (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1219 }
1220
ConvertRoundOperator(const Model & model,const RoundOperator & src_op,GraphDef * tensorflow_graph)1221 void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
1222 GraphDef* tensorflow_graph) {
1223 tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
1224 round_op->set_op("Round");
1225 round_op->set_name(src_op.outputs[0]);
1226 CHECK_EQ(src_op.inputs.size(), 1);
1227 *round_op->add_input() = src_op.inputs[0];
1228 (*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
1229 }
1230
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1231 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1232 GraphDef* tensorflow_graph) {
1233 tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1234 gather_op->set_op("GatherV2");
1235 gather_op->set_name(src_op.outputs[0]);
1236 *gather_op->add_input() = src_op.inputs[0];
1237 *gather_op->add_input() = src_op.inputs[1];
1238
1239 if (!src_op.axis) {
1240 // Dynamic axis.
1241 CHECK_EQ(src_op.inputs.size(), 3);
1242 *gather_op->add_input() = src_op.inputs[2];
1243 } else {
1244 // Constant axis.
1245 CHECK_EQ(src_op.inputs.size(), 2);
1246 const std::string gather_axis =
1247 AvailableArrayName(model, gather_op->name() + "/axis");
1248 CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1249 tensorflow_graph);
1250 *gather_op->add_input() = gather_axis;
1251 }
1252
1253 (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1254 (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1255 const tensorflow::DataType params_type =
1256 GetTensorFlowDataType(model, src_op.inputs[0]);
1257 (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1258 }
1259
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1260 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1261 GraphDef* tensorflow_graph) {
1262 tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1263 argmax_op->set_op("ArgMax");
1264 argmax_op->set_name(src_op.outputs[0]);
1265 CHECK_EQ(src_op.inputs.size(), 2);
1266 *argmax_op->add_input() = src_op.inputs[0];
1267 *argmax_op->add_input() = src_op.inputs[1];
1268 (*argmax_op->mutable_attr())["T"].set_type(
1269 GetTensorFlowDataType(model, src_op.inputs[0]));
1270 (*argmax_op->mutable_attr())["Tidx"].set_type(
1271 GetTensorFlowDataType(model, src_op.inputs[1]));
1272 (*argmax_op->mutable_attr())["output_type"].set_type(
1273 GetTensorFlowDataType(model, src_op.outputs[0]));
1274 }
1275
ConvertArgMinOperator(const Model & model,const ArgMinOperator & src_op,GraphDef * tensorflow_graph)1276 void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1277 GraphDef* tensorflow_graph) {
1278 tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1279 argmin_op->set_op("ArgMin");
1280 argmin_op->set_name(src_op.outputs[0]);
1281 CHECK_EQ(src_op.inputs.size(), 2);
1282 *argmin_op->add_input() = src_op.inputs[0];
1283 *argmin_op->add_input() = src_op.inputs[1];
1284 (*argmin_op->mutable_attr())["T"].set_type(
1285 GetTensorFlowDataType(model, src_op.inputs[0]));
1286 (*argmin_op->mutable_attr())["Tidx"].set_type(
1287 GetTensorFlowDataType(model, src_op.inputs[1]));
1288 (*argmin_op->mutable_attr())["output_type"].set_type(
1289 GetTensorFlowDataType(model, src_op.outputs[0]));
1290 }
1291
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1292 void ConvertTransposeOperator(const Model& model,
1293 const TransposeOperator& src_op,
1294 GraphDef* tensorflow_graph) {
1295 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1296 transpose_op->set_op("Transpose");
1297 transpose_op->set_name(src_op.outputs[0]);
1298 CHECK_EQ(src_op.inputs.size(), 2);
1299 *transpose_op->add_input() = src_op.inputs[0];
1300 *transpose_op->add_input() = src_op.inputs[1];
1301 (*transpose_op->mutable_attr())["T"].set_type(
1302 GetTensorFlowDataType(model, src_op.inputs[0]));
1303 (*transpose_op->mutable_attr())["Tperm"].set_type(
1304 GetTensorFlowDataType(model, src_op.inputs[1]));
1305 }
1306
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1307 void ConvertTensorFlowShapeOperator(const Model& model,
1308 const TensorFlowShapeOperator& src_op,
1309 GraphDef* tensorflow_graph) {
1310 tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1311 shape_op->set_op("Shape");
1312 shape_op->set_name(src_op.outputs[0]);
1313 CHECK_EQ(src_op.inputs.size(), 1);
1314 *shape_op->add_input() = src_op.inputs[0];
1315 (*shape_op->mutable_attr())["T"].set_type(
1316 GetTensorFlowDataType(model, src_op.inputs[0]));
1317 (*shape_op->mutable_attr())["out_type"].set_type(
1318 GetTensorFlowDataType(model, src_op.outputs[0]));
1319 }
1320
ConvertRankOperator(const Model & model,const TensorFlowRankOperator & src_op,GraphDef * tensorflow_graph)1321 void ConvertRankOperator(const Model& model,
1322 const TensorFlowRankOperator& src_op,
1323 GraphDef* tensorflow_graph) {
1324 tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1325 rank_op->set_op("Rank");
1326 rank_op->set_name(src_op.outputs[0]);
1327 CHECK_EQ(src_op.inputs.size(), 1);
1328 *rank_op->add_input() = src_op.inputs[0];
1329 (*rank_op->mutable_attr())["T"].set_type(
1330 GetTensorFlowDataType(model, src_op.inputs[0]));
1331 }
1332
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1333 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1334 GraphDef* tensorflow_graph) {
1335 tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1336 range_op->set_op("Range");
1337 range_op->set_name(src_op.outputs[0]);
1338 CHECK_EQ(src_op.inputs.size(), 3);
1339 *range_op->add_input() = src_op.inputs[0];
1340 *range_op->add_input() = src_op.inputs[1];
1341 *range_op->add_input() = src_op.inputs[2];
1342 (*range_op->mutable_attr())["Tidx"].set_type(
1343 GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1344 }
1345
ConvertPackOperator(const Model & model,const PackOperator & src_op,GraphDef * tensorflow_graph)1346 void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1347 GraphDef* tensorflow_graph) {
1348 tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1349 pack_op->set_op("Pack");
1350 pack_op->set_name(src_op.outputs[0]);
1351 for (const auto& input : src_op.inputs) {
1352 *pack_op->add_input() = input;
1353 }
1354 (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1355 (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1356 (*pack_op->mutable_attr())["T"].set_type(
1357 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1358 }
1359
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1360 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1361 GraphDef* tensorflow_graph) {
1362 tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1363 fill_op->set_op("Fill");
1364 fill_op->set_name(src_op.outputs[0]);
1365 CHECK_EQ(src_op.inputs.size(), 2);
1366 *fill_op->add_input() = src_op.inputs[0];
1367 *fill_op->add_input() = src_op.inputs[1];
1368 (*fill_op->mutable_attr())["index_type"].set_type(
1369 GetTensorFlowDataType(model, src_op.inputs[0]));
1370 (*fill_op->mutable_attr())["T"].set_type(
1371 GetTensorFlowDataType(model, src_op.inputs[1]));
1372 }
1373
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1374 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1375 GraphDef* tensorflow_graph) {
1376 tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1377 floor_div_op->set_op("FloorDiv");
1378 floor_div_op->set_name(src_op.outputs[0]);
1379 CHECK_EQ(src_op.inputs.size(), 2);
1380 *floor_div_op->add_input() = src_op.inputs[0];
1381 *floor_div_op->add_input() = src_op.inputs[1];
1382 (*floor_div_op->mutable_attr())["T"].set_type(
1383 GetTensorFlowDataType(model, src_op.inputs[0]));
1384 }
1385
ConvertFloorModOperator(const Model & model,const FloorModOperator & src_op,GraphDef * tensorflow_graph)1386 void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1387 GraphDef* tensorflow_graph) {
1388 tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1389 floor_mod_op->set_op("FloorMod");
1390 floor_mod_op->set_name(src_op.outputs[0]);
1391 DCHECK_EQ(src_op.inputs.size(), 2);
1392 *floor_mod_op->add_input() = src_op.inputs[0];
1393 *floor_mod_op->add_input() = src_op.inputs[1];
1394 (*floor_mod_op->mutable_attr())["T"].set_type(
1395 GetTensorFlowDataType(model, src_op.inputs[0]));
1396 }
1397
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1398 void ConvertExpandDimsOperator(const Model& model,
1399 const ExpandDimsOperator& src_op,
1400 GraphDef* tensorflow_graph) {
1401 tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1402 expand_dims_op->set_op("ExpandDims");
1403 expand_dims_op->set_name(src_op.outputs[0]);
1404 CHECK_EQ(src_op.inputs.size(), 2);
1405 *expand_dims_op->add_input() = src_op.inputs[0];
1406 *expand_dims_op->add_input() = src_op.inputs[1];
1407 (*expand_dims_op->mutable_attr())["T"].set_type(
1408 GetTensorFlowDataType(model, src_op.inputs[0]));
1409 (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1410 GetTensorFlowDataType(model, src_op.inputs[1]));
1411 }
1412
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1413 void ConvertResizeBilinearOperator(const Model& model,
1414 const ResizeBilinearOperator& src_op,
1415 GraphDef* tensorflow_graph) {
1416 tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1417 resize_op->set_op("ResizeBilinear");
1418 resize_op->set_name(src_op.outputs[0]);
1419 CHECK_EQ(src_op.inputs.size(), 2);
1420 *resize_op->add_input() = src_op.inputs[0];
1421 *resize_op->add_input() = src_op.inputs[1];
1422 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1423 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1424 (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1425 src_op.half_pixel_centers);
1426 }
1427
ConvertResizeNearestNeighborOperator(const Model & model,const ResizeNearestNeighborOperator & src_op,GraphDef * tensorflow_graph)1428 void ConvertResizeNearestNeighborOperator(
1429 const Model& model, const ResizeNearestNeighborOperator& src_op,
1430 GraphDef* tensorflow_graph) {
1431 tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1432 resize_op->set_op("ResizeNearestNeighbor");
1433 resize_op->set_name(src_op.outputs[0]);
1434 CHECK_EQ(src_op.inputs.size(), 2);
1435 *resize_op->add_input() = src_op.inputs[0];
1436 *resize_op->add_input() = src_op.inputs[1];
1437 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1438 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1439 (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1440 src_op.half_pixel_centers);
1441 }
1442
ConvertOneHotOperator(const Model & model,const OneHotOperator & src_op,GraphDef * tensorflow_graph)1443 void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1444 GraphDef* tensorflow_graph) {
1445 tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1446 onehot_op->set_op("OneHot");
1447 onehot_op->set_name(src_op.outputs[0]);
1448 CHECK_EQ(src_op.inputs.size(), 4);
1449 for (const auto& input : src_op.inputs) {
1450 *onehot_op->add_input() = input;
1451 }
1452 (*onehot_op->mutable_attr())["T"].set_type(
1453 GetTensorFlowDataType(model, src_op.outputs[0]));
1454 (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1455 }
1456
1457 namespace {
1458 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1459 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1460 absl::string_view b) {
1461 if (a.empty() || b.empty()) return absl::string_view();
1462
1463 const char* pa = a.data();
1464 const char* pb = b.data();
1465 std::string::difference_type count = 0;
1466 const std::string::difference_type limit = std::min(a.size(), b.size());
1467 while (count < limit && *pa == *pb) {
1468 ++pa;
1469 ++pb;
1470 ++count;
1471 }
1472
1473 return absl::string_view(a.data(), count);
1474 }
1475 } // namespace
1476
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1477 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1478 GraphDef* tensorflow_graph) {
1479 // Find the base name
1480 const std::string base(
1481 FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1482 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1483
1484 // Concatenate inputs
1485 const std::string concat_output = base + "basic_lstm_cell/concat";
1486 // Op names have been chosen to match the tf.slim LSTM naming
1487 // as closely as possible.
1488 const int axis =
1489 model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1490 .shape()
1491 .dimensions_count() -
1492 1;
1493 // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1494 // works the same since the tensor has the same underlying data layout.
1495 const std::string axis_output = concat_output + "/axis";
1496 CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1497 tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1498 concat_op->set_op("ConcatV2");
1499 concat_op->set_name(concat_output);
1500 *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1501 *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1502 *concat_op->add_input() = axis_output;
1503 (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1504 (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1505 (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
1506
1507 // Write weights
1508 const std::string weights_output = base + "weights";
1509 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1510 const std::string weights_name = WalkUpToConstantArray(
1511 model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1512 const auto& weights_array = model.GetArray(weights_name);
1513 // Convert 4D FullyConnected weights into 2D matrix
1514 const auto& weights_shape = weights_array.shape();
1515 CHECK_EQ(weights_shape.dimensions_count(), 2);
1516 CHECK(weights_array.buffer);
1517 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1518 const float* weights_data =
1519 weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1520 ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1521 AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1522
1523 // Fully connected matrix multiply
1524 const std::string matmul_output = base + "MatMul";
1525 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1526 matmul_op->set_op("MatMul");
1527 matmul_op->set_name(matmul_output);
1528 *matmul_op->add_input() = concat_output;
1529 *matmul_op->add_input() = weights_output;
1530 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1531 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1532 (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1533
1534 // Write biases
1535 const std::string biases_output = base + "biases";
1536 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1537 const std::string bias_name = WalkUpToConstantArray(
1538 model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1539 const auto& bias_array = model.GetArray(bias_name);
1540 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1541 Shape bias_shape_1d = bias_array.shape();
1542 UnextendShape(&bias_shape_1d, 1);
1543 CHECK(bias_array.buffer);
1544 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1545 const float* bias_data =
1546 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1547 ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1548 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1549 tensorflow_graph,
1550 LegacyScalarPolicy::kDoCreateLegacyScalars);
1551
1552 // Add biases
1553 std::string biasadd_output = base + "BiasAdd";
1554 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1555 biasadd_op->set_op("BiasAdd");
1556 biasadd_op->set_name(biasadd_output);
1557 biasadd_op->add_input(matmul_output);
1558 biasadd_op->add_input(biases_output);
1559 (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1560 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1561
1562 // Split
1563 std::string split_dim_output = base + "split/split_dim";
1564 // The dimension is the same as the concatenation dimension
1565 CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1566 std::string split_output = base + "split";
1567 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1568 split_op->set_op("Split");
1569 split_op->set_name(split_output);
1570 *split_op->add_input() = split_dim_output;
1571 *split_op->add_input() = biasadd_output;
1572 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1573 (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
1574
1575 // Activation functions and memory computations
1576 const std::string tanh_0_output = base + "Tanh";
1577 tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1578 tanh_0_op->set_op("Tanh");
1579 tanh_0_op->set_name(tanh_0_output);
1580 *tanh_0_op->add_input() = split_output + ":1";
1581 (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1582
1583 const std::string sigmoid_1_output = base + "Sigmoid_1";
1584 tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1585 logistic_1_op->set_op("Sigmoid");
1586 logistic_1_op->set_name(sigmoid_1_output);
1587 *logistic_1_op->add_input() = split_output;
1588 (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1589
1590 const std::string mul_1_output = base + "mul_1";
1591 tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1592 mul_1_op->set_op("Mul");
1593 mul_1_op->set_name(mul_1_output);
1594 *mul_1_op->add_input() = sigmoid_1_output;
1595 *mul_1_op->add_input() = tanh_0_output;
1596 (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1597
1598 const std::string sigmoid_0_output = base + "Sigmoid";
1599 tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1600 logistic_2_op->set_op("Sigmoid");
1601 logistic_2_op->set_name(sigmoid_0_output);
1602 *logistic_2_op->add_input() = split_output + ":2";
1603 (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1604
1605 const std::string sigmoid_2_output = base + "Sigmoid_2";
1606 tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1607 logistic_3_op->set_op("Sigmoid");
1608 logistic_3_op->set_name(sigmoid_2_output);
1609 *logistic_3_op->add_input() = split_output + ":3";
1610 (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1611
1612 const std::string mul_0_output = base + "mul";
1613 tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1614 mul_0_op->set_op("Mul");
1615 mul_0_op->set_name(mul_0_output);
1616 *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1617 *mul_0_op->add_input() = sigmoid_0_output;
1618 (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1619
1620 const std::string add_1_output =
1621 src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1622 tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1623 add_1_op->set_op("Add");
1624 add_1_op->set_name(add_1_output);
1625 *add_1_op->add_input() = mul_0_output;
1626 *add_1_op->add_input() = mul_1_output;
1627 (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1628
1629 const std::string tanh_1_output = base + "Tanh_1";
1630 tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1631 tanh_1_op->set_op("Tanh");
1632 tanh_1_op->set_name(tanh_1_output);
1633 *tanh_1_op->add_input() = add_1_output;
1634 (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1635
1636 const std::string mul_2_output =
1637 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1638 tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1639 mul_2_op->set_op("Mul");
1640 mul_2_op->set_name(mul_2_output);
1641 *mul_2_op->add_input() = tanh_1_output;
1642 *mul_2_op->add_input() = sigmoid_2_output;
1643 (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1644 }
1645
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1646 void ConvertSpaceToBatchNDOperator(const Model& model,
1647 const SpaceToBatchNDOperator& src_op,
1648 GraphDef* tensorflow_graph) {
1649 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1650 new_op->set_op("SpaceToBatchND");
1651 new_op->set_name(src_op.outputs[0]);
1652 CHECK_EQ(src_op.inputs.size(), 3);
1653 *new_op->add_input() = src_op.inputs[0];
1654 *new_op->add_input() = src_op.inputs[1];
1655 *new_op->add_input() = src_op.inputs[2];
1656 const tensorflow::DataType params_type =
1657 GetTensorFlowDataType(model, src_op.inputs[0]);
1658 (*new_op->mutable_attr())["T"].set_type(params_type);
1659 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1660 (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1661 }
1662
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1663 void ConvertBatchToSpaceNDOperator(const Model& model,
1664 const BatchToSpaceNDOperator& src_op,
1665 GraphDef* tensorflow_graph) {
1666 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1667 new_op->set_op("BatchToSpaceND");
1668 new_op->set_name(src_op.outputs[0]);
1669 CHECK_EQ(src_op.inputs.size(), 3);
1670 *new_op->add_input() = src_op.inputs[0];
1671 *new_op->add_input() = src_op.inputs[1];
1672 *new_op->add_input() = src_op.inputs[2];
1673 const tensorflow::DataType params_type =
1674 GetTensorFlowDataType(model, src_op.inputs[0]);
1675 (*new_op->mutable_attr())["T"].set_type(params_type);
1676 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1677 (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1678 }
1679
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1680 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1681 GraphDef* tensorflow_graph) {
1682 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1683 new_op->set_op("Pad");
1684 new_op->set_name(src_op.outputs[0]);
1685 CHECK_EQ(src_op.inputs.size(), 2);
1686 *new_op->add_input() = src_op.inputs[0];
1687 *new_op->add_input() = src_op.inputs[1];
1688
1689 const tensorflow::DataType params_type =
1690 GetTensorFlowDataType(model, src_op.inputs[0]);
1691 (*new_op->mutable_attr())["T"].set_type(params_type);
1692
1693 // Create the params tensor.
1694 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1695 params_op->set_op("Const");
1696 params_op->set_name(src_op.inputs[1]);
1697 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1698 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1699 tensor->set_dtype(DT_INT32);
1700
1701 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1702 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1703 tensor->add_int_val(src_op.left_padding[i]);
1704 tensor->add_int_val(src_op.right_padding[i]);
1705 }
1706 auto* shape = tensor->mutable_tensor_shape();
1707 shape->add_dim()->set_size(src_op.left_padding.size());
1708 shape->add_dim()->set_size(2);
1709 }
1710
ConvertPadV2Operator(const Model & model,const PadV2Operator & src_op,GraphDef * tensorflow_graph)1711 void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1712 GraphDef* tensorflow_graph) {
1713 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1714 new_op->set_op("PadV2");
1715 new_op->set_name(src_op.outputs[0]);
1716 CHECK_EQ(src_op.inputs.size(), 2);
1717 *new_op->add_input() = src_op.inputs[0];
1718 *new_op->add_input() = src_op.inputs[1];
1719 *new_op->add_input() = src_op.inputs[2];
1720
1721 const tensorflow::DataType params_type =
1722 GetTensorFlowDataType(model, src_op.inputs[0]);
1723 (*new_op->mutable_attr())["T"].set_type(params_type);
1724
1725 // Create the params tensor.
1726 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1727 params_op->set_op("Const");
1728 params_op->set_name(src_op.inputs[1]);
1729 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1730 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1731 tensor->set_dtype(DT_INT32);
1732
1733 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1734 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1735 tensor->add_int_val(src_op.left_padding[i]);
1736 tensor->add_int_val(src_op.right_padding[i]);
1737 }
1738 auto* shape = tensor->mutable_tensor_shape();
1739 shape->add_dim()->set_size(src_op.left_padding.size());
1740 shape->add_dim()->set_size(2);
1741 }
1742
CreateSliceInput(const std::string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1743 void CreateSliceInput(const std::string& input_name,
1744 const std::vector<int>& values,
1745 GraphDef* tensorflow_graph) {
1746 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1747 params_op->set_op("Const");
1748 params_op->set_name(input_name);
1749 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1750 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1751 tensor->set_dtype(DT_INT32);
1752
1753 for (int i = 0; i < values.size(); ++i) {
1754 tensor->add_int_val(values[i]);
1755 }
1756 auto* shape = tensor->mutable_tensor_shape();
1757 shape->add_dim()->set_size(values.size());
1758 }
1759
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1760 void ConvertStridedSliceOperator(const Model& model,
1761 const StridedSliceOperator& src_op,
1762 GraphDef* tensorflow_graph) {
1763 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1764 new_op->set_op("StridedSlice");
1765 new_op->set_name(src_op.outputs[0]);
1766 CHECK_EQ(src_op.inputs.size(), 4);
1767 *new_op->add_input() = src_op.inputs[0];
1768 *new_op->add_input() = src_op.inputs[1];
1769 *new_op->add_input() = src_op.inputs[2];
1770 *new_op->add_input() = src_op.inputs[3];
1771
1772 const tensorflow::DataType params_type =
1773 GetTensorFlowDataType(model, src_op.inputs[0]);
1774 (*new_op->mutable_attr())["T"].set_type(params_type);
1775
1776 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1777 (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1778 (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1779 (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1780 (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1781 (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1782
1783 // Create tensors for start/stop indices and strides.
1784 CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1785 CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1786 CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1787 }
1788
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1789 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1790 GraphDef* tensorflow_graph) {
1791 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1792 new_op->set_op("Slice");
1793 new_op->set_name(src_op.outputs[0]);
1794 CHECK_EQ(src_op.inputs.size(), 3);
1795 *new_op->add_input() = src_op.inputs[0];
1796 *new_op->add_input() = src_op.inputs[1];
1797 *new_op->add_input() = src_op.inputs[2];
1798
1799 const tensorflow::DataType params_type =
1800 GetTensorFlowDataType(model, src_op.inputs[0]);
1801 (*new_op->mutable_attr())["T"].set_type(params_type);
1802 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1803
1804 // Create tensors for begin and size inputs.
1805 CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1806 CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1807 }
1808
1809 template <typename T>
ConvertReduceOperator(const Model & model,const T & src_op,GraphDef * tensorflow_graph,const std::string & op_name)1810 void ConvertReduceOperator(const Model& model, const T& src_op,
1811 GraphDef* tensorflow_graph,
1812 const std::string& op_name) {
1813 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1814 new_op->set_op(op_name);
1815 new_op->set_name(src_op.outputs[0]);
1816 CHECK_EQ(src_op.inputs.size(), 2);
1817 *new_op->add_input() = src_op.inputs[0];
1818 *new_op->add_input() = src_op.inputs[1];
1819
1820 if (src_op.type != OperatorType::kAny) {
1821 const tensorflow::DataType params_type =
1822 GetTensorFlowDataType(model, src_op.inputs[0]);
1823 (*new_op->mutable_attr())["T"].set_type(params_type);
1824 }
1825 const tensorflow::DataType indices_type =
1826 GetTensorFlowDataType(model, src_op.inputs[1]);
1827 (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1828
1829 if (src_op.keep_dims) {
1830 (*new_op->mutable_attr())["keep_dims"].set_b(true);
1831 }
1832
1833 // Create the params tensor.
1834 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1835 params_op->set_op("Const");
1836 params_op->set_name(src_op.inputs[1]);
1837 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1838 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1839 tensor->set_dtype(DT_INT32);
1840
1841 for (int i = 0; i < src_op.axis.size(); ++i) {
1842 tensor->add_int_val(src_op.axis[i]);
1843 }
1844 auto* shape = tensor->mutable_tensor_shape();
1845 shape->add_dim()->set_size(src_op.axis.size());
1846 }
1847
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1848 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1849 GraphDef* tensorflow_graph) {
1850 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1851 new_op->set_op("Squeeze");
1852 new_op->set_name(src_op.outputs[0]);
1853 CHECK_EQ(src_op.inputs.size(), 1);
1854 *new_op->add_input() = src_op.inputs[0];
1855
1856 const tensorflow::DataType params_type =
1857 GetTensorFlowDataType(model, src_op.inputs[0]);
1858 (*new_op->mutable_attr())["T"].set_type(params_type);
1859
1860 if (!src_op.squeeze_dims.empty()) {
1861 auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1862 for (int i : src_op.squeeze_dims) {
1863 squeeze_dims.mutable_list()->add_i(i);
1864 }
1865 }
1866 }
1867
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1868 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1869 GraphDef* tensorflow_graph) {
1870 tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1871 sub_op->set_op("Sub");
1872 sub_op->set_name(src_op.outputs[0]);
1873 CHECK_EQ(src_op.inputs.size(), 2);
1874 *sub_op->add_input() = src_op.inputs[0];
1875 *sub_op->add_input() = src_op.inputs[1];
1876 const tensorflow::DataType data_type =
1877 GetTensorFlowDataType(model, src_op.inputs[0]);
1878 (*sub_op->mutable_attr())["T"].set_type(data_type);
1879 }
1880
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1881 void ConvertTensorFlowMinimumOperator(const Model& model,
1882 const TensorFlowMinimumOperator& src_op,
1883 GraphDef* tensorflow_graph) {
1884 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1885 min_op->set_op("Minimum");
1886 min_op->set_name(src_op.outputs[0]);
1887 CHECK_EQ(src_op.inputs.size(), 2);
1888 *min_op->add_input() = src_op.inputs[0];
1889 *min_op->add_input() = src_op.inputs[1];
1890 const tensorflow::DataType data_type =
1891 GetTensorFlowDataType(model, src_op.inputs[0]);
1892 (*min_op->mutable_attr())["T"].set_type(data_type);
1893 }
1894
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1895 void ConvertTensorFlowMaximumOperator(const Model& model,
1896 const TensorFlowMaximumOperator& src_op,
1897 GraphDef* tensorflow_graph) {
1898 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1899 max_op->set_op("Maximum");
1900 max_op->set_name(src_op.outputs[0]);
1901 CHECK_EQ(src_op.inputs.size(), 2);
1902 *max_op->add_input() = src_op.inputs[0];
1903 *max_op->add_input() = src_op.inputs[1];
1904 const tensorflow::DataType data_type =
1905 GetTensorFlowDataType(model, src_op.inputs[0]);
1906 (*max_op->mutable_attr())["T"].set_type(data_type);
1907 }
1908
ConvertSelectOperator(const Model & model,const SelectOperator & src_op,GraphDef * tensorflow_graph)1909 void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1910 GraphDef* tensorflow_graph) {
1911 tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1912 select_op->set_op("Select");
1913 select_op->set_name(src_op.outputs[0]);
1914 CHECK_EQ(src_op.inputs.size(), 3);
1915 *select_op->add_input() = src_op.inputs[0];
1916 *select_op->add_input() = src_op.inputs[1];
1917 *select_op->add_input() = src_op.inputs[2];
1918 const tensorflow::DataType data_type =
1919 GetTensorFlowDataType(model, src_op.inputs[1]);
1920 (*select_op->mutable_attr())["T"].set_type(data_type);
1921 }
1922
ConvertTileOperator(const Model & model,const TensorFlowTileOperator & src_op,GraphDef * tensorflow_graph)1923 void ConvertTileOperator(const Model& model,
1924 const TensorFlowTileOperator& src_op,
1925 GraphDef* tensorflow_graph) {
1926 tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1927 tile_op->set_op("Tile");
1928 tile_op->set_name(src_op.outputs[0]);
1929 CHECK_EQ(src_op.inputs.size(), 2);
1930 *tile_op->add_input() = src_op.inputs[0];
1931 *tile_op->add_input() = src_op.inputs[1];
1932 const tensorflow::DataType data_type =
1933 GetTensorFlowDataType(model, src_op.inputs[0]);
1934 (*tile_op->mutable_attr())["T"].set_type(data_type);
1935 const tensorflow::DataType multiples_data_type =
1936 GetTensorFlowDataType(model, src_op.inputs[1]);
1937 (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1938 }
1939
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1940 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1941 GraphDef* tensorflow_graph) {
1942 tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1943 topk_op->set_op("TopKV2");
1944 topk_op->set_name(src_op.outputs[0]);
1945 CHECK_EQ(src_op.inputs.size(), 2);
1946 *topk_op->add_input() = src_op.inputs[0];
1947 *topk_op->add_input() = src_op.inputs[1];
1948 const tensorflow::DataType data_type =
1949 GetTensorFlowDataType(model, src_op.inputs[0]);
1950 (*topk_op->mutable_attr())["T"].set_type(data_type);
1951 (*topk_op->mutable_attr())["sorted"].set_b(true);
1952 }
1953
ConvertRandomUniformOperator(const Model & model,const RandomUniformOperator & src_op,GraphDef * tensorflow_graph)1954 void ConvertRandomUniformOperator(const Model& model,
1955 const RandomUniformOperator& src_op,
1956 GraphDef* tensorflow_graph) {
1957 CHECK(tensorflow_graph != nullptr);
1958 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1959 new_op->set_op("RandomUniform");
1960 CHECK_EQ(src_op.inputs.size(), 1);
1961 new_op->set_name(src_op.outputs[0]);
1962 *new_op->add_input() = src_op.inputs[0];
1963 const tensorflow::DataType shape_type =
1964 GetTensorFlowDataType(model, src_op.inputs[0]);
1965 (*new_op->mutable_attr())["T"].set_type(shape_type);
1966 (*new_op->mutable_attr())["dtype"].set_type(
1967 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1968 (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1969 (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1970 }
1971
ConvertComparisonOperator(const Model & model,const Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)1972 void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1973 const char* op_name,
1974 GraphDef* tensorflow_graph) {
1975 tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1976 comparison_op->set_op(op_name);
1977 comparison_op->set_name(src_op.outputs[0]);
1978 CHECK_EQ(src_op.inputs.size(), 2);
1979 *comparison_op->add_input() = src_op.inputs[0];
1980 *comparison_op->add_input() = src_op.inputs[1];
1981 const tensorflow::DataType data_type =
1982 GetTensorFlowDataType(model, src_op.inputs[0]);
1983 (*comparison_op->mutable_attr())["T"].set_type(data_type);
1984 }
1985
ConvertSparseToDenseOperator(const Model & model,const SparseToDenseOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1986 void ConvertSparseToDenseOperator(const Model& model,
1987 const SparseToDenseOperator& src_op,
1988 const char* op_name,
1989 GraphDef* tensorflow_graph) {
1990 tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1991 sparse_to_dense_op->set_op(op_name);
1992 sparse_to_dense_op->set_name(src_op.outputs[0]);
1993 CHECK_EQ(src_op.inputs.size(), 4);
1994 for (int i = 0; i < 4; ++i) {
1995 *sparse_to_dense_op->add_input() = src_op.inputs[i];
1996 }
1997 const tensorflow::DataType data_type =
1998 GetTensorFlowDataType(model, src_op.inputs[3]);
1999 (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
2000 const tensorflow::DataType index_type =
2001 GetTensorFlowDataType(model, src_op.inputs[0]);
2002 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
2003 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
2004 src_op.validate_indices);
2005 }
2006
ConvertPowOperator(const Model & model,const PowOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2007 void ConvertPowOperator(const Model& model, const PowOperator& src_op,
2008 const char* op_name, GraphDef* tensorflow_graph) {
2009 tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
2010 pow_op->set_op(op_name);
2011 pow_op->set_name(src_op.outputs[0]);
2012 CHECK_EQ(src_op.inputs.size(), 2);
2013 for (int i = 0; i < 2; ++i) {
2014 *pow_op->add_input() = src_op.inputs[i];
2015 }
2016 const tensorflow::DataType data_type =
2017 GetTensorFlowDataType(model, src_op.inputs[0]);
2018 (*pow_op->mutable_attr())["T"].set_type(data_type);
2019 }
2020
ConvertLogicalAndOperator(const Model & model,const LogicalAndOperator & src_op,GraphDef * tensorflow_graph)2021 void ConvertLogicalAndOperator(const Model& model,
2022 const LogicalAndOperator& src_op,
2023 GraphDef* tensorflow_graph) {
2024 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2025 logical_op->set_op("LogicalAnd");
2026 logical_op->set_name(src_op.outputs[0]);
2027 CHECK_EQ(src_op.inputs.size(), 2);
2028 for (int i = 0; i < 2; ++i) {
2029 *logical_op->add_input() = src_op.inputs[i];
2030 }
2031 }
2032
ConvertLogicalNotOperator(const Model & model,const LogicalNotOperator & src_op,GraphDef * tensorflow_graph)2033 void ConvertLogicalNotOperator(const Model& model,
2034 const LogicalNotOperator& src_op,
2035 GraphDef* tensorflow_graph) {
2036 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2037 logical_op->set_op("LogicalNot");
2038 logical_op->set_name(src_op.outputs[0]);
2039 CHECK_EQ(src_op.inputs.size(), 1);
2040 *logical_op->add_input() = src_op.inputs[0];
2041 }
2042
ConvertLogicalOrOperator(const Model & model,const LogicalOrOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2043 void ConvertLogicalOrOperator(const Model& model,
2044 const LogicalOrOperator& src_op,
2045 const char* op_name, GraphDef* tensorflow_graph) {
2046 tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2047 logical_or_op->set_op(op_name);
2048 logical_or_op->set_name(src_op.outputs[0]);
2049 CHECK_EQ(src_op.inputs.size(), 2);
2050 for (int i = 0; i < 2; ++i) {
2051 *logical_or_op->add_input() = src_op.inputs[i];
2052 }
2053 const tensorflow::DataType data_type =
2054 GetTensorFlowDataType(model, src_op.inputs[0]);
2055 (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2056 }
2057
ConvertCTCBeamSearchDecoderOperator(const Model & model,const CTCBeamSearchDecoderOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2058 void ConvertCTCBeamSearchDecoderOperator(
2059 const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2060 const char* op_name, GraphDef* tensorflow_graph) {
2061 auto* op = tensorflow_graph->add_node();
2062 op->set_op(op_name);
2063 op->set_name(src_op.outputs[0]);
2064 CHECK_EQ(src_op.inputs.size(), 2);
2065 for (int i = 0; i < 2; ++i) {
2066 *op->add_input() = src_op.inputs[i];
2067 }
2068 (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2069 (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2070 (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2071 }
2072
ConvertUnpackOperator(const Model & model,const UnpackOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2073 void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2074 const char* op_name, GraphDef* tensorflow_graph) {
2075 tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2076 unpack_op->set_op(op_name);
2077 unpack_op->set_name(src_op.outputs[0]);
2078 CHECK_EQ(src_op.inputs.size(), 2);
2079 *unpack_op->add_input() = src_op.inputs[0];
2080 const tensorflow::DataType data_type =
2081 GetTensorFlowDataType(model, src_op.inputs[0]);
2082 (*unpack_op->mutable_attr())["T"].set_type(data_type);
2083 (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2084 (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2085 }
2086
ConvertZerosLikeOperator(const Model & model,const TensorFlowZerosLikeOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2087 void ConvertZerosLikeOperator(const Model& model,
2088 const TensorFlowZerosLikeOperator& src_op,
2089 const char* op_name, GraphDef* tensorflow_graph) {
2090 tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2091 zeros_like_op->set_op(op_name);
2092 zeros_like_op->set_name(src_op.outputs[0]);
2093 DCHECK_EQ(src_op.inputs.size(), 1);
2094 *zeros_like_op->add_input() = src_op.inputs[0];
2095 const tensorflow::DataType data_type =
2096 GetTensorFlowDataType(model, src_op.inputs[0]);
2097 (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2098 }
2099
ConvertReverseV2Operator(const Model & model,const ReverseV2Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)2100 void ConvertReverseV2Operator(const Model& model,
2101 const ReverseV2Operator& src_op,
2102 const char* op_name, GraphDef* tensorflow_graph) {
2103 tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2104 reverse_v2_op->set_op(op_name);
2105 reverse_v2_op->set_name(src_op.outputs[0]);
2106 DCHECK_EQ(src_op.inputs.size(), 2);
2107 *reverse_v2_op->add_input() = src_op.inputs[0];
2108 *reverse_v2_op->add_input() = src_op.inputs[1];
2109 const tensorflow::DataType data_type =
2110 GetTensorFlowDataType(model, src_op.inputs[0]);
2111 (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2112 }
2113
ConvertReverseSequenceOperator(const Model & model,const ReverseSequenceOperator & src_op,GraphDef * tensorflow_graph)2114 void ConvertReverseSequenceOperator(const Model& model,
2115 const ReverseSequenceOperator& src_op,
2116 GraphDef* tensorflow_graph) {
2117 tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2118 reverse_seq_op->set_op("ReverseSequence");
2119 reverse_seq_op->set_name(src_op.outputs[0]);
2120 CHECK_EQ(src_op.inputs.size(), 2);
2121 *reverse_seq_op->add_input() = src_op.inputs[0];
2122 *reverse_seq_op->add_input() = src_op.inputs[1];
2123 (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2124 (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2125 }
2126
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)2127 void ConvertOperator(const Model& model, const Operator& src_op,
2128 GraphDef* tensorflow_graph) {
2129 if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2130 LOG(FATAL)
2131 << "Unsupported: the input model has a fused activation function";
2132 }
2133
2134 if (src_op.type == OperatorType::kConv) {
2135 ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2136 tensorflow_graph);
2137 } else if (src_op.type == OperatorType::kDepthwiseConv) {
2138 ConvertDepthwiseConvOperator(
2139 model, static_cast<const DepthwiseConvOperator&>(src_op),
2140 tensorflow_graph);
2141 } else if (src_op.type == OperatorType::kDepthToSpace) {
2142 ConvertDepthToSpaceOperator(
2143 model, static_cast<const DepthToSpaceOperator&>(src_op),
2144 tensorflow_graph);
2145 } else if (src_op.type == OperatorType::kSpaceToDepth) {
2146 ConvertSpaceToDepthOperator(
2147 model, static_cast<const SpaceToDepthOperator&>(src_op),
2148 tensorflow_graph);
2149 } else if (src_op.type == OperatorType::kFullyConnected) {
2150 ConvertFullyConnectedOperator(
2151 model, static_cast<const FullyConnectedOperator&>(src_op),
2152 tensorflow_graph);
2153 } else if (src_op.type == OperatorType::kAdd) {
2154 ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2155 tensorflow_graph);
2156 } else if (src_op.type == OperatorType::kAddN) {
2157 ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2158 tensorflow_graph);
2159 } else if (src_op.type == OperatorType::kMul) {
2160 ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2161 tensorflow_graph);
2162 } else if (src_op.type == OperatorType::kDiv) {
2163 ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2164 tensorflow_graph);
2165 } else if (src_op.type == OperatorType::kRelu) {
2166 ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2167 tensorflow_graph);
2168 } else if (src_op.type == OperatorType::kRelu1) {
2169 ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2170 tensorflow_graph);
2171 } else if (src_op.type == OperatorType::kRelu6) {
2172 ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2173 tensorflow_graph);
2174 } else if (src_op.type == OperatorType::kLog) {
2175 ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2176 tensorflow_graph);
2177 } else if (src_op.type == OperatorType::kLogistic) {
2178 ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2179 tensorflow_graph);
2180 } else if (src_op.type == OperatorType::kTanh) {
2181 ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2182 tensorflow_graph);
2183 } else if (src_op.type == OperatorType::kL2Normalization) {
2184 ConvertL2NormalizationOperator(
2185 static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2186 } else if (src_op.type == OperatorType::kSoftmax) {
2187 ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2188 tensorflow_graph);
2189 } else if (src_op.type == OperatorType::kLogSoftmax) {
2190 ConvertLogSoftmaxOperator(model,
2191 static_cast<const LogSoftmaxOperator&>(src_op),
2192 tensorflow_graph);
2193 } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2194 ConvertLocalResponseNormalizationOperator(
2195 static_cast<const LocalResponseNormalizationOperator&>(src_op),
2196 tensorflow_graph);
2197 } else if (src_op.type == OperatorType::kLstmCell) {
2198 ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2199 tensorflow_graph);
2200 } else if (src_op.type == OperatorType::kMaxPool) {
2201 ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2202 tensorflow_graph);
2203 } else if (src_op.type == OperatorType::kAveragePool) {
2204 ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2205 tensorflow_graph);
2206 } else if (src_op.type == OperatorType::kConcatenation) {
2207 ConvertConcatenationOperator(
2208 model, static_cast<const ConcatenationOperator&>(src_op),
2209 tensorflow_graph);
2210 } else if (src_op.type == OperatorType::kReshape) {
2211 ConvertTensorFlowReshapeOperator(
2212 model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2213 tensorflow_graph);
2214 } else if (src_op.type == OperatorType::kL2Pool) {
2215 ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2216 tensorflow_graph);
2217 } else if (src_op.type == OperatorType::kSquare) {
2218 ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2219 tensorflow_graph);
2220 } else if (src_op.type == OperatorType::kSqrt) {
2221 ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2222 tensorflow_graph);
2223 } else if (src_op.type == OperatorType::kRsqrt) {
2224 ConvertRsqrtOperator(model,
2225 static_cast<const TensorFlowRsqrtOperator&>(src_op),
2226 tensorflow_graph);
2227 } else if (src_op.type == OperatorType::kSplit) {
2228 ConvertSplitOperator(model,
2229 static_cast<const TensorFlowSplitOperator&>(src_op),
2230 tensorflow_graph);
2231 } else if (src_op.type == OperatorType::kSplitV) {
2232 ConvertSplitVOperator(model,
2233 static_cast<const TensorFlowSplitVOperator&>(src_op),
2234 tensorflow_graph);
2235 } else if (src_op.type == OperatorType::kFakeQuant) {
2236 ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2237 tensorflow_graph);
2238 } else if (src_op.type == OperatorType::kCast) {
2239 ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2240 tensorflow_graph);
2241 } else if (src_op.type == OperatorType::kFloor) {
2242 ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2243 tensorflow_graph);
2244 } else if (src_op.type == OperatorType::kCeil) {
2245 ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2246 tensorflow_graph);
2247 } else if (src_op.type == OperatorType::kRound) {
2248 ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
2249 tensorflow_graph);
2250 } else if (src_op.type == OperatorType::kGather) {
2251 ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2252 tensorflow_graph);
2253 } else if (src_op.type == OperatorType::kResizeBilinear) {
2254 ConvertResizeBilinearOperator(
2255 model, static_cast<const ResizeBilinearOperator&>(src_op),
2256 tensorflow_graph);
2257 } else if (src_op.type == OperatorType::kResizeNearestNeighbor) {
2258 ConvertResizeNearestNeighborOperator(
2259 model, static_cast<const ResizeNearestNeighborOperator&>(src_op),
2260 tensorflow_graph);
2261 } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2262 ConvertSpaceToBatchNDOperator(
2263 model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2264 tensorflow_graph);
2265 } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2266 ConvertBatchToSpaceNDOperator(
2267 model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2268 tensorflow_graph);
2269 } else if (src_op.type == OperatorType::kPad) {
2270 ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2271 tensorflow_graph);
2272 } else if (src_op.type == OperatorType::kPadV2) {
2273 ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2274 tensorflow_graph);
2275 } else if (src_op.type == OperatorType::kStridedSlice) {
2276 ConvertStridedSliceOperator(
2277 model, static_cast<const StridedSliceOperator&>(src_op),
2278 tensorflow_graph);
2279 } else if (src_op.type == OperatorType::kMean) {
2280 ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2281 tensorflow_graph, "Mean");
2282 } else if (src_op.type == OperatorType::kSum) {
2283 ConvertReduceOperator(model,
2284 static_cast<const TensorFlowSumOperator&>(src_op),
2285 tensorflow_graph, "Sum");
2286 } else if (src_op.type == OperatorType::kReduceProd) {
2287 ConvertReduceOperator(model,
2288 static_cast<const TensorFlowProdOperator&>(src_op),
2289 tensorflow_graph, "Prod");
2290 } else if (src_op.type == OperatorType::kReduceMin) {
2291 ConvertReduceOperator(model,
2292 static_cast<const TensorFlowMinOperator&>(src_op),
2293 tensorflow_graph, "Min");
2294 } else if (src_op.type == OperatorType::kReduceMax) {
2295 ConvertReduceOperator(model,
2296 static_cast<const TensorFlowMaxOperator&>(src_op),
2297 tensorflow_graph, "Max");
2298 } else if (src_op.type == OperatorType::kSub) {
2299 ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2300 tensorflow_graph);
2301 } else if (src_op.type == OperatorType::kMinimum) {
2302 ConvertTensorFlowMinimumOperator(
2303 model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2304 tensorflow_graph);
2305 } else if (src_op.type == OperatorType::kMaximum) {
2306 ConvertTensorFlowMaximumOperator(
2307 model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2308 tensorflow_graph);
2309 } else if (src_op.type == OperatorType::kSqueeze) {
2310 ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2311 tensorflow_graph);
2312 } else if (src_op.type == OperatorType::kSlice) {
2313 ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2314 tensorflow_graph);
2315 } else if (src_op.type == OperatorType::kArgMax) {
2316 ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2317 tensorflow_graph);
2318 } else if (src_op.type == OperatorType::kArgMin) {
2319 ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2320 tensorflow_graph);
2321 } else if (src_op.type == OperatorType::kTopK_V2) {
2322 ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2323 tensorflow_graph);
2324 } else if (src_op.type == OperatorType::kTranspose) {
2325 ConvertTransposeOperator(
2326 model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2327 } else if (src_op.type == OperatorType::kShape) {
2328 ConvertTensorFlowShapeOperator(
2329 model, static_cast<const TensorFlowShapeOperator&>(src_op),
2330 tensorflow_graph);
2331 } else if (src_op.type == OperatorType::kRank) {
2332 ConvertRankOperator(model,
2333 static_cast<const TensorFlowRankOperator&>(src_op),
2334 tensorflow_graph);
2335 } else if (src_op.type == OperatorType::kRange) {
2336 ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2337 tensorflow_graph);
2338 } else if (src_op.type == OperatorType::kPack) {
2339 ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2340 tensorflow_graph);
2341 } else if (src_op.type == OperatorType::kFill) {
2342 ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2343 tensorflow_graph);
2344 } else if (src_op.type == OperatorType::kFloorDiv) {
2345 ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2346 tensorflow_graph);
2347 } else if (src_op.type == OperatorType::kFloorMod) {
2348 ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2349 tensorflow_graph);
2350 } else if (src_op.type == OperatorType::kExpandDims) {
2351 ConvertExpandDimsOperator(model,
2352 static_cast<const ExpandDimsOperator&>(src_op),
2353 tensorflow_graph);
2354 } else if (src_op.type == OperatorType::kTransposeConv) {
2355 ConvertTransposeConvOperator(
2356 model, static_cast<const TransposeConvOperator&>(src_op),
2357 tensorflow_graph);
2358 } else if (src_op.type == OperatorType::kRandomUniform) {
2359 ConvertRandomUniformOperator(
2360 model, static_cast<const RandomUniformOperator&>(src_op),
2361 tensorflow_graph);
2362 } else if (src_op.type == OperatorType::kEqual) {
2363 ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2364 } else if (src_op.type == OperatorType::kNotEqual) {
2365 ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2366 } else if (src_op.type == OperatorType::kGreater) {
2367 ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2368 } else if (src_op.type == OperatorType::kGreaterEqual) {
2369 ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2370 } else if (src_op.type == OperatorType::kLess) {
2371 ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2372 } else if (src_op.type == OperatorType::kLessEqual) {
2373 ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2374 } else if (src_op.type == OperatorType::kSelect) {
2375 ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2376 tensorflow_graph);
2377 } else if (src_op.type == OperatorType::kTile) {
2378 ConvertTileOperator(model,
2379 static_cast<const TensorFlowTileOperator&>(src_op),
2380 tensorflow_graph);
2381 } else if (src_op.type == OperatorType::kPow) {
2382 ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2383 tensorflow_graph);
2384 } else if (src_op.type == OperatorType::kAny) {
2385 ConvertReduceOperator(model,
2386 static_cast<const TensorFlowAnyOperator&>(src_op),
2387 tensorflow_graph, "Any");
2388 } else if (src_op.type == OperatorType::kLogicalAnd) {
2389 ConvertLogicalAndOperator(model,
2390 static_cast<const LogicalAndOperator&>(src_op),
2391 tensorflow_graph);
2392 } else if (src_op.type == OperatorType::kLogicalNot) {
2393 ConvertLogicalNotOperator(model,
2394 static_cast<const LogicalNotOperator&>(src_op),
2395 tensorflow_graph);
2396 } else if (src_op.type == OperatorType::kOneHot) {
2397 ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2398 tensorflow_graph);
2399 } else if (src_op.type == OperatorType::kLogicalOr) {
2400 ConvertLogicalOrOperator(model,
2401 static_cast<const LogicalOrOperator&>(src_op),
2402 "LogicalOr", tensorflow_graph);
2403 } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2404 ConvertCTCBeamSearchDecoderOperator(
2405 model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2406 "CTCBeamSearchDecoder", tensorflow_graph);
2407 } else if (src_op.type == OperatorType::kUnpack) {
2408 ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2409 "Unpack", tensorflow_graph);
2410 } else if (src_op.type == OperatorType::kZerosLike) {
2411 ConvertZerosLikeOperator(
2412 model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2413 "ZerosLike", tensorflow_graph);
2414 } else if (src_op.type == OperatorType::kReverseV2) {
2415 ConvertReverseV2Operator(model,
2416 static_cast<const ReverseV2Operator&>(src_op),
2417 "Reverse_V2", tensorflow_graph);
2418 } else if (src_op.type == OperatorType::kReverseSequence) {
2419 ConvertReverseSequenceOperator(
2420 model, static_cast<const ReverseSequenceOperator&>(src_op),
2421 tensorflow_graph);
2422 } else {
2423 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2424 }
2425 }
2426
AddPlaceholder(const std::string & name,ArrayDataType type,GraphDef * tensorflow_graph)2427 void AddPlaceholder(const std::string& name, ArrayDataType type,
2428 GraphDef* tensorflow_graph) {
2429 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2430 placeholder->set_op("Placeholder");
2431 switch (type) {
2432 case ArrayDataType::kBool:
2433 (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2434 break;
2435 case ArrayDataType::kFloat:
2436 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2437 break;
2438 case ArrayDataType::kUint8:
2439 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2440 break;
2441 case ArrayDataType::kInt32:
2442 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2443 break;
2444 case ArrayDataType::kUint32:
2445 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
2446 break;
2447 case ArrayDataType::kInt64:
2448 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2449 break;
2450 case ArrayDataType::kInt16:
2451 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2452 break;
2453 case ArrayDataType::kComplex64:
2454 (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2455 break;
2456 default:
2457 LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2458 }
2459 placeholder->set_name(name);
2460 }
2461
AddPlaceholderForRNNState(const Model & model,const std::string & name,int size,GraphDef * tensorflow_graph)2462 void AddPlaceholderForRNNState(const Model& model, const std::string& name,
2463 int size, GraphDef* tensorflow_graph) {
2464 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2465 placeholder->set_op("Placeholder");
2466 placeholder->set_name(name);
2467 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2468
2469 auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2470 const auto& state_array = model.GetArray(name);
2471 if (state_array.has_shape()) {
2472 const auto& state_shape = state_array.shape();
2473 const int kDims = state_shape.dimensions_count();
2474 for (int i = 0; i < kDims; ++i) {
2475 shape->add_dim()->set_size(state_shape.dims(i));
2476 }
2477 } else {
2478 shape->add_dim()->set_size(1);
2479 shape->add_dim()->set_size(size);
2480 }
2481 }
2482
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)2483 void ExportTensorFlowGraphDefImplementation(const Model& model,
2484 GraphDef* tensorflow_graph) {
2485 for (const auto& input_array : model.flags.input_arrays()) {
2486 AddPlaceholder(input_array.name(),
2487 model.GetArray(input_array.name()).data_type,
2488 tensorflow_graph);
2489 }
2490 for (const auto& rnn_state : model.flags.rnn_states()) {
2491 AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2492 tensorflow_graph);
2493 }
2494 for (const auto& op : model.operators) {
2495 ConvertOperator(model, *op, tensorflow_graph);
2496 }
2497 // Generically export arrays that haven't been exported already
2498 // by the above operators export. It's important that this comes
2499 // after, as some operators need to export arrays that they reference
2500 // in a specific way, rather than in the generic way done below.
2501 for (const auto& array_pair : model.GetArrayMap()) {
2502 const std::string& array_name = array_pair.first;
2503 const auto& array = *array_pair.second;
2504 if (array.buffer) {
2505 switch (array.data_type) {
2506 case ArrayDataType::kBool:
2507 ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2508 break;
2509 case ArrayDataType::kFloat:
2510 ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2511 break;
2512 case ArrayDataType::kInt32:
2513 ConvertIntTensorConst(model, array_name, tensorflow_graph);
2514 break;
2515 case ArrayDataType::kComplex64:
2516 ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2517 break;
2518 default:
2519 break;
2520 }
2521 }
2522 }
2523 }
2524 } // namespace
2525
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)2526 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2527 for (const auto& array_kv : model->GetArrayMap()) {
2528 const std::string& array_name = array_kv.first;
2529 Array& array = *array_kv.second;
2530 if (!array.buffer || !array.minmax) {
2531 continue;
2532 }
2533 const std::string& wrapped_array_name =
2534 AvailableArrayName(*model, array_name + "/data");
2535 Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2536 wrapped_array.data_type = array.data_type;
2537 wrapped_array.copy_shape(array.shape());
2538 wrapped_array.buffer = std::move(array.buffer);
2539 FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2540 fakequant_op->inputs = {wrapped_array_name};
2541 fakequant_op->outputs = {array_name};
2542 fakequant_op->minmax.reset(new MinMax);
2543 *fakequant_op->minmax = *array.minmax;
2544 const auto& it = FindOpWithInput(*model, array_name);
2545 model->operators.emplace(it, fakequant_op);
2546 }
2547 CheckInvariants(*model);
2548 }
2549
ExportTensorFlowGraphDef(const Model & model,std::string * output_file_contents)2550 void ExportTensorFlowGraphDef(const Model& model,
2551 std::string* output_file_contents) {
2552 CHECK(output_file_contents->empty());
2553 GraphDef tensorflow_graph;
2554 ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2555 LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2556 CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2557 }
2558 } // namespace toco
2559