1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "google/protobuf/map.h"
22 #include "google/protobuf/text_format.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/lite/toco/model.h"
33 #include "tensorflow/lite/toco/model_flags.pb.h"
34 #include "tensorflow/lite/toco/runtime/types.h"
35 #include "tensorflow/lite/toco/tensorflow_util.h"
36 #include "tensorflow/lite/toco/tooling_util.h"
37
38 using tensorflow::DT_BOOL;
39 using tensorflow::DT_COMPLEX64;
40 using tensorflow::DT_FLOAT;
41 using tensorflow::DT_INT16;
42 using tensorflow::DT_INT32;
43 using tensorflow::DT_INT64;
44 using tensorflow::DT_UINT8;
45 using tensorflow::GraphDef;
46 using tensorflow::TensorProto;
47
48 namespace toco {
49 namespace {
50
GetTensorFlowDataType(ArrayDataType data_type,const string & error_location)51 tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
52 const string& error_location) {
53 switch (data_type) {
54 case ArrayDataType::kBool:
55 return tensorflow::DT_BOOL;
56 case ArrayDataType::kFloat:
57 return tensorflow::DT_FLOAT;
58 case ArrayDataType::kUint8:
59 return tensorflow::DT_UINT8;
60 case ArrayDataType::kInt32:
61 return tensorflow::DT_INT32;
62 case ArrayDataType::kInt64:
63 return tensorflow::DT_INT64;
64 case ArrayDataType::kString:
65 return tensorflow::DT_STRING;
66 case ArrayDataType::kComplex64:
67 return tensorflow::DT_COMPLEX64;
68 default:
69 case ArrayDataType::kNone:
70 LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
71 << "' in " << error_location;
72 return tensorflow::DT_INVALID;
73 }
74 }
75
GetTensorFlowDataTypeForOp(ArrayDataType data_type,const string & op_name)76 tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
77 const string& op_name) {
78 return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
79 }
80
GetTensorFlowDataType(const Model & model,const string & array_name)81 tensorflow::DataType GetTensorFlowDataType(const Model& model,
82 const string& array_name) {
83 return GetTensorFlowDataType(model.GetArray(array_name).data_type,
84 "array '" + array_name + "'");
85 }
86
87 // TensorFlow sometimes forbids what it calls "legacy scalars",
88 // which are 1-D shapes where the unique shape size is 1.
89 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
90 // For that reason, we generally avoid creating legacy scalars,
91 // by detecting the case where a 1-D shape would be of size 1 and
92 // replacing that by a 0-D shape.
93 // However, there is a special circumstance where we must not do that
94 // and must unconditionally create a 1-D shape even if it is going to
95 // be of size 1: that is the case of bias vectors, with BiasAdd nodes.
96 // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
97 // a depth of 1, that would be a legacy scalar, so in that case we
98 // must go ahead and keep the shape 1-D, letting it be a legacy scalar.
99 enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
100
ExportFloatArray(const Shape & input_shape,const float * input_data,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)101 void ExportFloatArray(const Shape& input_shape, const float* input_data,
102 TensorProto* output_tensor,
103 LegacyScalarPolicy legacy_scalar_policy) {
104 output_tensor->set_dtype(DT_FLOAT);
105 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
106 auto* shape = output_tensor->mutable_tensor_shape();
107
108 const int kDims = input_shape.dimensions_count();
109 if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
110 kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
111 for (int i = 0; i < kDims; ++i) {
112 shape->add_dim()->set_size(input_shape.dims(i));
113 }
114 }
115 output_tensor->set_tensor_content(
116 string(reinterpret_cast<const char*>(input_data),
117 sizeof(*input_data) * input_flat_size));
118 }
119
ExportFloatArray(AxesOrder input_axes_order,const Shape & input_shape,const float * input_data,AxesOrder output_axes_order,TensorProto * output_tensor,LegacyScalarPolicy legacy_scalar_policy)120 void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
121 const float* input_data, AxesOrder output_axes_order,
122 TensorProto* output_tensor,
123 LegacyScalarPolicy legacy_scalar_policy) {
124 CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
125 output_tensor->set_dtype(DT_FLOAT);
126 CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
127 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
128
129 Shape shuffled_shape;
130 ShuffleDims(input_shape, input_axes_order, output_axes_order,
131 &shuffled_shape);
132 std::vector<float> shuffled_data(input_flat_size);
133 ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
134 input_data, shuffled_data.data());
135
136 ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
137 legacy_scalar_policy);
138 }
139
HasAlreadyExportedConst(const string & name,const GraphDef & tensorflow_graph)140 bool HasAlreadyExportedConst(const string& name,
141 const GraphDef& tensorflow_graph) {
142 for (const auto& node : tensorflow_graph.node()) {
143 if (node.op() == "Const" && node.name() == name) {
144 return true;
145 }
146 }
147 return false;
148 }
149
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph,LegacyScalarPolicy legacy_scalar_policy)150 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
151 const float* input_data,
152 AxesOrder input_axes_order,
153 AxesOrder output_axes_order,
154 GraphDef* tensorflow_graph,
155 LegacyScalarPolicy legacy_scalar_policy) {
156 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
157 return;
158 }
159 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
160 const_op->set_op("Const");
161 const_op->set_name(name);
162 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
163 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
164 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
165 tensor, legacy_scalar_policy);
166 }
167
ConvertFloatTensorConst(const string & name,const Shape & input_shape,const float * input_data,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)168 void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
169 const float* input_data,
170 AxesOrder input_axes_order,
171 AxesOrder output_axes_order,
172 GraphDef* tensorflow_graph) {
173 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
174 output_axes_order, tensorflow_graph,
175 LegacyScalarPolicy::kAvoidLegacyScalars);
176 }
177
ConvertFloatTensorConst(const Model & model,const string & name,AxesOrder input_axes_order,AxesOrder output_axes_order,GraphDef * tensorflow_graph)178 void ConvertFloatTensorConst(const Model& model, const string& name,
179 AxesOrder input_axes_order,
180 AxesOrder output_axes_order,
181 GraphDef* tensorflow_graph) {
182 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
183 return;
184 }
185 CHECK(model.HasArray(name));
186 const auto& input_array = model.GetArray(name);
187 const auto& input_shape = input_array.shape();
188 CHECK(input_array.buffer);
189 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
190 const float* input_data =
191 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
192 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
193 output_axes_order, tensorflow_graph);
194 }
195
ConvertFloatTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)196 void ConvertFloatTensorConst(const Model& model, const string& name,
197 GraphDef* tensorflow_graph) {
198 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
199 return;
200 }
201 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
202 const_op->set_op("Const");
203 const_op->set_name(name);
204 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
205 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
206 CHECK(model.HasArray(name));
207 const auto& input_array = model.GetArray(name);
208 const auto& input_shape = input_array.shape();
209 CHECK(input_array.buffer);
210 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
211 const float* input_data =
212 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
213 ExportFloatArray(input_shape, input_data, tensor,
214 LegacyScalarPolicy::kAvoidLegacyScalars);
215 }
216
ConvertBoolTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)217 void ConvertBoolTensorConst(const Model& model, const string& name,
218 GraphDef* tensorflow_graph) {
219 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
220 return;
221 }
222 CHECK(model.HasArray(name));
223 const auto& array = model.GetArray(name);
224 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
225 const_op->set_op("Const");
226 const_op->set_name(name);
227 (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
228 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
229 tensor->set_dtype(DT_BOOL);
230 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
231 for (auto index : data) {
232 tensor->add_bool_val(index);
233 }
234 const auto& array_shape = array.shape();
235 auto* shape = tensor->mutable_tensor_shape();
236 for (int i = 0; i < array_shape.dimensions_count(); i++) {
237 shape->add_dim()->set_size(array_shape.dims(i));
238 }
239 }
240
ConvertIntTensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)241 void ConvertIntTensorConst(const Model& model, const string& name,
242 GraphDef* tensorflow_graph) {
243 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
244 return;
245 }
246 CHECK(model.HasArray(name));
247 const auto& array = model.GetArray(name);
248 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
249 const_op->set_op("Const");
250 const_op->set_name(name);
251 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
252 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
253 tensor->set_dtype(DT_INT32);
254 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
255 for (auto index : data) {
256 tensor->add_int_val(index);
257 }
258 const auto& array_shape = array.shape();
259 auto* shape = tensor->mutable_tensor_shape();
260 for (int i = 0; i < array_shape.dimensions_count(); i++) {
261 shape->add_dim()->set_size(array_shape.dims(i));
262 }
263 }
264
CreateIntTensorConst(const string & name,const std::vector<int32> & data,const std::vector<int32> & shape,GraphDef * tensorflow_graph)265 void CreateIntTensorConst(const string& name, const std::vector<int32>& data,
266 const std::vector<int32>& shape,
267 GraphDef* tensorflow_graph) {
268 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
269 return;
270 }
271 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
272 const_op->set_op("Const");
273 const_op->set_name(name);
274 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
275 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
276 tensor->set_dtype(DT_INT32);
277 for (auto index : data) {
278 tensor->add_int_val(index);
279 }
280 auto* tensor_shape = tensor->mutable_tensor_shape();
281 int num_elements = 1;
282 for (int size : shape) {
283 tensor_shape->add_dim()->set_size(size);
284 num_elements *= size;
285 }
286 CHECK_EQ(num_elements, data.size());
287 }
288
ConvertComplex64TensorConst(const Model & model,const string & name,GraphDef * tensorflow_graph)289 void ConvertComplex64TensorConst(const Model& model, const string& name,
290 GraphDef* tensorflow_graph) {
291 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
292 return;
293 }
294 CHECK(model.HasArray(name));
295 const auto& array = model.GetArray(name);
296 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
297 const_op->set_op("Const");
298 const_op->set_name(name);
299 (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
300 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
301 tensor->set_dtype(DT_COMPLEX64);
302 const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
303 for (auto index : data) {
304 tensor->add_scomplex_val(std::real(index));
305 tensor->add_scomplex_val(std::imag(index));
306 }
307 const auto& array_shape = array.shape();
308 auto* shape = tensor->mutable_tensor_shape();
309 for (int i = 0; i < array_shape.dimensions_count(); i++) {
310 shape->add_dim()->set_size(array_shape.dims(i));
311 }
312 }
313
CreateMatrixShapeTensorConst(const string & name,int rows,int cols,GraphDef * tensorflow_graph)314 void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
315 GraphDef* tensorflow_graph) {
316 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
317 return;
318 }
319 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
320 const_op->set_op("Const");
321 const_op->set_name(name);
322 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
323 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
324 tensor->set_dtype(DT_INT32);
325 const int32 data[2] = {cols, rows};
326 tensor->set_tensor_content(
327 string(reinterpret_cast<const char*>(data), sizeof(data)));
328 auto* shape = tensor->mutable_tensor_shape();
329 shape->add_dim()->set_size(2);
330 }
331
CreateDummyConcatDimTensorConst(const string & name,int dim,GraphDef * tensorflow_graph)332 void CreateDummyConcatDimTensorConst(const string& name, int dim,
333 GraphDef* tensorflow_graph) {
334 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
335 return;
336 }
337 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
338 const_op->set_op("Const");
339 const_op->set_name(name);
340 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
341 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
342 tensor->set_dtype(DT_INT32);
343 tensor->add_int_val(dim);
344 }
345
CreateReshapeShapeTensorConst(const string & name,const std::vector<int32> & shape,GraphDef * tensorflow_graph)346 void CreateReshapeShapeTensorConst(const string& name,
347 const std::vector<int32>& shape,
348 GraphDef* tensorflow_graph) {
349 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
350 return;
351 }
352 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
353 const_op->set_op("Const");
354 const_op->set_name(name);
355 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
356 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
357 tensor->set_dtype(DT_INT32);
358 for (auto s : shape) {
359 tensor->add_int_val(s);
360 }
361 // TensorFlow sometimes forbids what it calls "legacy scalars",
362 // which are shapes of size 1 where the unique shape size is 1.
363 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
364 if (shape.size() > 1) {
365 auto* tensor_shape = tensor->mutable_tensor_shape();
366 tensor_shape->add_dim()->set_size(shape.size());
367 }
368 }
369
WalkUpToConstantArray(const Model & model,const string & name)370 string WalkUpToConstantArray(const Model& model, const string& name) {
371 const Array& original_array = model.GetArray(name);
372 if (original_array.buffer) {
373 return name;
374 }
375 const auto* op = GetOpWithOutput(model, name);
376 CHECK(op);
377 CHECK(op->type == OperatorType::kFakeQuant);
378 const string& input_of_fakequant_name = op->inputs[0];
379 const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
380 CHECK(input_of_fakequant.buffer);
381 return input_of_fakequant_name;
382 }
383
ConvertConvOperator(const Model & model,const ConvOperator & src_op,GraphDef * tensorflow_graph)384 void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
385 GraphDef* tensorflow_graph) {
386 const bool has_bias = src_op.inputs.size() >= 3;
387 string conv_output = src_op.outputs[0];
388 if (has_bias) {
389 conv_output += "/conv";
390 }
391
392 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
393 conv2d_op->set_op("Conv2D");
394 conv2d_op->set_name(conv_output);
395 *conv2d_op->add_input() = src_op.inputs[0];
396 *conv2d_op->add_input() = src_op.inputs[1];
397 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
398 const string& weights_array_name =
399 WalkUpToConstantArray(model, src_op.inputs[1]);
400 const auto& weights_array = model.GetArray(weights_array_name);
401 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
402 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
403 AxesOrder::kHWIO, tensorflow_graph);
404 auto& strides = (*conv2d_op->mutable_attr())["strides"];
405 strides.mutable_list()->add_i(1);
406 strides.mutable_list()->add_i(src_op.stride_height);
407 strides.mutable_list()->add_i(src_op.stride_width);
408 strides.mutable_list()->add_i(1);
409 if ((src_op.dilation_width_factor != 1) ||
410 (src_op.dilation_height_factor != 1)) {
411 auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
412 dilations.mutable_list()->add_i(1);
413 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
414 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
415 dilations.mutable_list()->add_i(1);
416 }
417 string padding;
418 if (src_op.padding.type == PaddingType::kSame) {
419 padding = "SAME";
420 } else if (src_op.padding.type == PaddingType::kValid) {
421 padding = "VALID";
422 } else {
423 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
424 }
425 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
426
427 if (has_bias) {
428 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
429 biasadd_op->set_op("BiasAdd");
430 biasadd_op->set_name(src_op.outputs[0]);
431 biasadd_op->add_input(conv_output);
432 biasadd_op->add_input(src_op.inputs[2]);
433 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
434 CHECK(model.HasArray(src_op.inputs[2]));
435 const string& bias_array_name =
436 WalkUpToConstantArray(model, src_op.inputs[2]);
437 const auto& bias_array = model.GetArray(bias_array_name);
438 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
439 Shape bias_shape_1d = bias_array.shape();
440 UnextendShape(&bias_shape_1d, 1);
441 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
442 const float* bias_data =
443 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
444 ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
445 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
446 tensorflow_graph,
447 LegacyScalarPolicy::kDoCreateLegacyScalars);
448 }
449 }
450
ConvertDepthwiseConvOperator(const Model & model,const DepthwiseConvOperator & src_op,GraphDef * tensorflow_graph)451 void ConvertDepthwiseConvOperator(const Model& model,
452 const DepthwiseConvOperator& src_op,
453 GraphDef* tensorflow_graph) {
454 const bool has_bias = src_op.inputs.size() >= 3;
455 string conv_output = src_op.outputs[0];
456 if (has_bias) {
457 conv_output += "/conv";
458 }
459
460 tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
461 dc2d_op->set_op("DepthwiseConv2dNative");
462 dc2d_op->set_name(conv_output);
463 *dc2d_op->add_input() = src_op.inputs[0];
464 *dc2d_op->add_input() = src_op.inputs[1];
465 (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
466
467 // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
468 // We need to convert that to H x W x InputDepth x Multiplier.
469 // That's only a matter of constructing a Dims object; the actual
470 // array layout is the same.
471 CHECK(model.HasArray(src_op.inputs[1]));
472 const string& src_weights_name =
473 WalkUpToConstantArray(model, src_op.inputs[1]);
474 const auto& src_weights_array = model.GetArray(src_weights_name);
475 const auto& src_weights_shape = src_weights_array.shape();
476 CHECK_EQ(src_weights_shape.dimensions_count(), 4);
477 const Shape dst_weights_shape =
478 Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
479 src_weights_shape.dims(3) / src_op.depth_multiplier,
480 src_op.depth_multiplier});
481 CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
482 CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
483 src_weights_shape.dims(3));
484 CHECK_EQ(src_weights_shape.dims(0), 1);
485
486 CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
487 const float* src_weights_data =
488 src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
489 ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
490 AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
491
492 auto& strides = (*dc2d_op->mutable_attr())["strides"];
493 strides.mutable_list()->add_i(1);
494 strides.mutable_list()->add_i(src_op.stride_height);
495 strides.mutable_list()->add_i(src_op.stride_width);
496 strides.mutable_list()->add_i(1);
497 // TODO(b/116063589): To return a working TF GraphDef, we should be returning
498 // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
499 // the conv since TF doesn't support dilations.
500 if ((src_op.dilation_width_factor != 1) ||
501 (src_op.dilation_height_factor != 1)) {
502 auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
503 dilations.mutable_list()->add_i(1);
504 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
505 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
506 dilations.mutable_list()->add_i(1);
507 }
508 string padding;
509 if (src_op.padding.type == PaddingType::kSame) {
510 padding = "SAME";
511 } else if (src_op.padding.type == PaddingType::kValid) {
512 padding = "VALID";
513 } else {
514 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
515 }
516 (*dc2d_op->mutable_attr())["padding"].set_s(padding);
517
518 if (has_bias) {
519 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
520 biasadd_op->set_op("BiasAdd");
521 biasadd_op->set_name(src_op.outputs[0]);
522 biasadd_op->add_input(conv_output);
523 biasadd_op->add_input(src_op.inputs[2]);
524 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
525 CHECK(model.HasArray(src_op.inputs[2]));
526 const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
527 const auto& bias_array = model.GetArray(bias_name);
528 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
529 Shape bias_shape_1d = bias_array.shape();
530 UnextendShape(&bias_shape_1d, 1);
531 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
532 const float* bias_data =
533 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
534 ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
535 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
536 tensorflow_graph,
537 LegacyScalarPolicy::kDoCreateLegacyScalars);
538 }
539 }
540
ConvertTransposeConvOperator(const Model & model,const TransposeConvOperator & src_op,GraphDef * tensorflow_graph)541 void ConvertTransposeConvOperator(const Model& model,
542 const TransposeConvOperator& src_op,
543 GraphDef* tensorflow_graph) {
544 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
545 conv2d_op->set_op("Conv2DBackpropInput");
546 conv2d_op->set_name(src_op.outputs[0]);
547 *conv2d_op->add_input() = src_op.inputs[0];
548 *conv2d_op->add_input() = src_op.inputs[1];
549 *conv2d_op->add_input() = src_op.inputs[2];
550 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
551 const string& weights_array_name = WalkUpToConstantArray(
552 model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
553 const auto& weights_array = model.GetArray(weights_array_name);
554 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
555 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
556 AxesOrder::kHWOI, tensorflow_graph);
557 auto& strides = (*conv2d_op->mutable_attr())["strides"];
558 strides.mutable_list()->add_i(1);
559 strides.mutable_list()->add_i(src_op.stride_height);
560 strides.mutable_list()->add_i(src_op.stride_width);
561 strides.mutable_list()->add_i(1);
562 string padding;
563 if (src_op.padding.type == PaddingType::kSame) {
564 padding = "SAME";
565 } else if (src_op.padding.type == PaddingType::kValid) {
566 padding = "VALID";
567 } else {
568 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
569 }
570 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
571 }
572
ConvertDepthToSpaceOperator(const Model & model,const DepthToSpaceOperator & src_op,GraphDef * tensorflow_graph)573 void ConvertDepthToSpaceOperator(const Model& model,
574 const DepthToSpaceOperator& src_op,
575 GraphDef* tensorflow_graph) {
576 tensorflow::NodeDef* op = tensorflow_graph->add_node();
577 op->set_op("DepthToSpace");
578 op->set_name(src_op.outputs[0]);
579 *op->add_input() = src_op.inputs[0];
580 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
581 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
582 }
583
ConvertSpaceToDepthOperator(const Model & model,const SpaceToDepthOperator & src_op,GraphDef * tensorflow_graph)584 void ConvertSpaceToDepthOperator(const Model& model,
585 const SpaceToDepthOperator& src_op,
586 GraphDef* tensorflow_graph) {
587 tensorflow::NodeDef* op = tensorflow_graph->add_node();
588 op->set_op("SpaceToDepth");
589 op->set_name(src_op.outputs[0]);
590 *op->add_input() = src_op.inputs[0];
591 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
592 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
593 }
594
ConvertFullyConnectedOperator(const Model & model,const FullyConnectedOperator & src_op,GraphDef * tensorflow_graph)595 void ConvertFullyConnectedOperator(const Model& model,
596 const FullyConnectedOperator& src_op,
597 GraphDef* tensorflow_graph) {
598 // Reshape input activations to have the shape expected by the MatMul.
599 const string reshape_output =
600 AvailableArrayName(model, src_op.outputs[0] + "/reshape");
601 const string reshape_shape =
602 AvailableArrayName(model, reshape_output + "/shape");
603 const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
604 const auto& fc_weights_shape = fc_weights_array.shape();
605 CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
606 CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
607 tensorflow_graph);
608 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
609 reshape_op->set_op("Reshape");
610 reshape_op->set_name(reshape_output);
611 reshape_op->add_input(src_op.inputs[0]);
612 reshape_op->add_input(reshape_shape);
613 (*reshape_op->mutable_attr())["T"].set_type(
614 GetTensorFlowDataType(model, src_op.inputs[0]));
615
616 const bool has_bias = src_op.inputs.size() >= 3;
617 string matmul_output = src_op.outputs[0];
618 if (has_bias) {
619 matmul_output += "/matmul";
620 }
621
622 // Transpose the RHS input from column-major to row-major to match TensorFlow
623 // expectations. This is the inverse of the transpose we do during
624 // ResolveTensorFlowMatMul.
625 const string transpose_output =
626 AvailableArrayName(model, matmul_output + "/transpose_weights");
627 const string transpose_perm =
628 AvailableArrayName(model, transpose_output + "/perm");
629 CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
630 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
631 transpose_op->set_op("Transpose");
632 transpose_op->set_name(transpose_output);
633 *transpose_op->add_input() = src_op.inputs[1];
634 *transpose_op->add_input() = transpose_perm;
635 (*transpose_op->mutable_attr())["T"].set_type(
636 GetTensorFlowDataType(model, src_op.inputs[1]));
637 (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
638
639 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
640 matmul_op->set_op("MatMul");
641 matmul_op->set_name(matmul_output);
642 *matmul_op->add_input() = reshape_output;
643 *matmul_op->add_input() = transpose_op->name();
644 (*matmul_op->mutable_attr())["T"].set_type(
645 GetTensorFlowDataType(model, src_op.inputs[0]));
646 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
647 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
648 CHECK(model.HasArray(src_op.inputs[1]));
649
650 // Add the bias, if it exists.
651 if (has_bias) {
652 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
653 biasadd_op->set_op("BiasAdd");
654 biasadd_op->set_name(src_op.outputs[0]);
655 biasadd_op->add_input(matmul_output);
656 biasadd_op->add_input(src_op.inputs[2]);
657 (*biasadd_op->mutable_attr())["T"].set_type(
658 GetTensorFlowDataType(model, src_op.inputs[0]));
659 CHECK(model.HasArray(src_op.inputs[2]));
660 const auto& bias_array = model.GetArray(src_op.inputs[2]);
661 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
662 Shape bias_shape_1d = bias_array.shape();
663 UnextendShape(&bias_shape_1d, 1);
664 CHECK(bias_array.buffer);
665 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
666 const float* bias_data =
667 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
668 ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
669 bias_shape_1d, bias_data, AxesOrder::kOneAxis,
670 AxesOrder::kOneAxis, tensorflow_graph,
671 LegacyScalarPolicy::kDoCreateLegacyScalars);
672 }
673 }
674
ConvertAddOperator(const Model & model,const AddOperator & src_op,GraphDef * tensorflow_graph)675 void ConvertAddOperator(const Model& model, const AddOperator& src_op,
676 GraphDef* tensorflow_graph) {
677 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
678 add_op->set_op("Add");
679 add_op->set_name(src_op.outputs[0]);
680 CHECK_EQ(src_op.inputs.size(), 2);
681 *add_op->add_input() = src_op.inputs[0];
682 *add_op->add_input() = src_op.inputs[1];
683 (*add_op->mutable_attr())["T"].set_type(
684 GetTensorFlowDataType(model, src_op.outputs[0]));
685 }
686
ConvertAddNOperator(const Model & model,const AddNOperator & src_op,GraphDef * tensorflow_graph)687 void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
688 GraphDef* tensorflow_graph) {
689 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
690 add_op->set_op("AddN");
691 add_op->set_name(src_op.outputs[0]);
692 for (const auto& input : src_op.inputs) {
693 *add_op->add_input() = input;
694 }
695 (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
696 (*add_op->mutable_attr())["T"].set_type(
697 GetTensorFlowDataType(model, src_op.outputs[0]));
698 }
699
ConvertMulOperator(const Model & model,const MulOperator & src_op,GraphDef * tensorflow_graph)700 void ConvertMulOperator(const Model& model, const MulOperator& src_op,
701 GraphDef* tensorflow_graph) {
702 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
703 mul_op->set_op("Mul");
704 mul_op->set_name(src_op.outputs[0]);
705 CHECK_EQ(src_op.inputs.size(), 2);
706 *mul_op->add_input() = src_op.inputs[0];
707 *mul_op->add_input() = src_op.inputs[1];
708 (*mul_op->mutable_attr())["T"].set_type(
709 GetTensorFlowDataType(model, src_op.outputs[0]));
710 }
711
ConvertDivOperator(const Model & model,const DivOperator & src_op,GraphDef * tensorflow_graph)712 void ConvertDivOperator(const Model& model, const DivOperator& src_op,
713 GraphDef* tensorflow_graph) {
714 tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
715 div_op->set_op("Div");
716 div_op->set_name(src_op.outputs[0]);
717 CHECK_EQ(src_op.inputs.size(), 2);
718 *div_op->add_input() = src_op.inputs[0];
719 *div_op->add_input() = src_op.inputs[1];
720 (*div_op->mutable_attr())["T"].set_type(
721 GetTensorFlowDataType(model, src_op.outputs[0]));
722 }
723
ConvertReluOperator(const Model & model,const ReluOperator & src_op,GraphDef * tensorflow_graph)724 void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
725 GraphDef* tensorflow_graph) {
726 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
727 relu_op->set_op("Relu");
728 relu_op->set_name(src_op.outputs[0]);
729 *relu_op->add_input() = src_op.inputs[0];
730 (*relu_op->mutable_attr())["T"].set_type(
731 GetTensorFlowDataType(model, src_op.outputs[0]));
732 }
733
ConvertRelu1Operator(const Relu1Operator & src_op,GraphDef * tensorflow_graph)734 void ConvertRelu1Operator(const Relu1Operator& src_op,
735 GraphDef* tensorflow_graph) {
736 const string max_bounds = src_op.outputs[0] + "/max_bounds";
737 const string min_bounds = src_op.outputs[0] + "/min_bounds";
738 const string max_output = src_op.outputs[0] + "/max_output";
739
740 tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
741 max_bounds_const_op->set_op("Const");
742 max_bounds_const_op->set_name(max_bounds);
743 (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
744 auto* max_bounds_const_op_tensor =
745 (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
746 max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
747 max_bounds_const_op_tensor->add_float_val(-1.0f);
748
749 tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
750 min_bounds_const_op->set_op("Const");
751 min_bounds_const_op->set_name(min_bounds);
752 (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
753 auto* min_bounds_const_op_tensor =
754 (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
755 min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
756 min_bounds_const_op_tensor->add_float_val(1.0f);
757
758 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
759 max_op->set_op("Maximum");
760 max_op->set_name(max_output);
761 *max_op->add_input() = src_op.inputs[0];
762 *max_op->add_input() = max_bounds;
763 (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
764
765 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
766 min_op->set_op("Minimum");
767 min_op->set_name(src_op.outputs[0]);
768 *min_op->add_input() = max_output;
769 *min_op->add_input() = min_bounds;
770 (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
771 }
772
ConvertRelu6Operator(const Relu6Operator & src_op,GraphDef * tensorflow_graph)773 void ConvertRelu6Operator(const Relu6Operator& src_op,
774 GraphDef* tensorflow_graph) {
775 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
776 relu_op->set_op("Relu6");
777 relu_op->set_name(src_op.outputs[0]);
778 *relu_op->add_input() = src_op.inputs[0];
779 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
780 }
781
ConvertLogOperator(const LogOperator & src_op,GraphDef * tensorflow_graph)782 void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
783 tensorflow::NodeDef* op = tensorflow_graph->add_node();
784 op->set_op("Log");
785 op->set_name(src_op.outputs[0]);
786 CHECK_EQ(src_op.inputs.size(), 1);
787 *op->add_input() = src_op.inputs[0];
788 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
789 }
790
ConvertLogisticOperator(const LogisticOperator & src_op,GraphDef * tensorflow_graph)791 void ConvertLogisticOperator(const LogisticOperator& src_op,
792 GraphDef* tensorflow_graph) {
793 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
794 relu_op->set_op("Sigmoid");
795 relu_op->set_name(src_op.outputs[0]);
796 *relu_op->add_input() = src_op.inputs[0];
797 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
798 }
799
ConvertTanhOperator(const TanhOperator & src_op,GraphDef * tensorflow_graph)800 void ConvertTanhOperator(const TanhOperator& src_op,
801 GraphDef* tensorflow_graph) {
802 tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
803 tanh_op->set_op("Tanh");
804 tanh_op->set_name(src_op.outputs[0]);
805 *tanh_op->add_input() = src_op.inputs[0];
806 (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
807 }
808
ConvertSoftmaxOperator(const Model & model,const SoftmaxOperator & src_op,GraphDef * tensorflow_graph)809 void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
810 GraphDef* tensorflow_graph) {
811 string softmax_input;
812 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
813 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
814 softmax_input = src_op.inputs[0];
815 } else {
816 // Insert a reshape operator that reduces the dimensions down to the 2 that
817 // are required for TensorFlow Logits.
818 const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape";
819 const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
820 softmax_input = reshape_output;
821
822 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
823 reshape_op->set_op("Reshape");
824 reshape_op->set_name(reshape_output);
825 *reshape_op->add_input() = src_op.inputs[0];
826 *reshape_op->add_input() = softmax_size;
827 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
828
829 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
830 int32 flattened_size = 1;
831 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
832 flattened_size *= input_shape.dims(i);
833 }
834 const std::vector<int32> shape_data = {
835 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
836 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
837 }
838
839 tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
840 softmax_op->set_op("Softmax");
841 softmax_op->set_name(src_op.outputs[0]);
842 *softmax_op->add_input() = softmax_input;
843 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
844 CHECK_EQ(src_op.beta, 1.f);
845 (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
846 }
847
ConvertLogSoftmaxOperator(const Model & model,const LogSoftmaxOperator & src_op,GraphDef * tensorflow_graph)848 void ConvertLogSoftmaxOperator(const Model& model,
849 const LogSoftmaxOperator& src_op,
850 GraphDef* tensorflow_graph) {
851 string softmax_input;
852 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
853 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
854 softmax_input = src_op.inputs[0];
855 } else {
856 // Insert a reshape operator that reduces the dimensions down to the 2 that
857 // are required for TensorFlow Logits.
858 const string reshape_output =
859 src_op.outputs[0] + "/log_softmax_insert_reshape";
860 const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size";
861 softmax_input = reshape_output;
862
863 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
864 reshape_op->set_op("Reshape");
865 reshape_op->set_name(reshape_output);
866 *reshape_op->add_input() = src_op.inputs[0];
867 *reshape_op->add_input() = softmax_size;
868 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
869
870 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
871 int32 flattened_size = 1;
872 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
873 flattened_size *= input_shape.dims(i);
874 }
875 const std::vector<int32> shape_data = {
876 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
877 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
878 }
879
880 tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
881 log_softmax_op->set_op("LogSoftmax");
882 log_softmax_op->set_name(src_op.outputs[0]);
883 *log_softmax_op->add_input() = softmax_input;
884 (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
885 }
886
ConvertL2NormalizationOperator(const L2NormalizationOperator & src_op,GraphDef * tensorflow_graph)887 void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
888 GraphDef* tensorflow_graph) {
889 const string square_output = src_op.outputs[0] + "/square";
890 const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices";
891 const string sum_output = src_op.outputs[0] + "/sum";
892 const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
893 const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
894
895 tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
896 sum_reduction_indices_op->set_op("Const");
897 sum_reduction_indices_op->set_name(sum_reduction_indices);
898 (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
899 auto* sum_reduction_indices_tensor =
900 (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
901 sum_reduction_indices_tensor->set_dtype(DT_INT32);
902 auto* sum_reduction_indices_shape =
903 sum_reduction_indices_tensor->mutable_tensor_shape();
904 auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
905 sum_reduction_indices_dim->set_size(2);
906 sum_reduction_indices_tensor->add_int_val(0);
907 sum_reduction_indices_tensor->add_int_val(1);
908
909 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
910 square_op->set_op("Square");
911 square_op->set_name(square_output);
912 *square_op->add_input() = src_op.inputs[0];
913 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
914
915 tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
916 sum_op->set_op("Sum");
917 sum_op->set_name(sum_output);
918 *sum_op->add_input() = square_output;
919 *sum_op->add_input() = sum_reduction_indices;
920 (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
921
922 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
923 rsqrt_op->set_op("Rsqrt");
924 rsqrt_op->set_name(rsqrt_output);
925 *rsqrt_op->add_input() = sum_output;
926 (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
927
928 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
929 mul_op->set_op("Mul");
930 mul_op->set_name(src_op.outputs[0]);
931 *mul_op->add_input() = src_op.inputs[0];
932 *mul_op->add_input() = rsqrt_output;
933 (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
934 }
935
ConvertLocalResponseNormalizationOperator(const LocalResponseNormalizationOperator & src_op,GraphDef * tensorflow_graph)936 void ConvertLocalResponseNormalizationOperator(
937 const LocalResponseNormalizationOperator& src_op,
938 GraphDef* tensorflow_graph) {
939 tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
940 lrn_op->set_op("LRN");
941 lrn_op->set_name(src_op.outputs[0]);
942 *lrn_op->add_input() = src_op.inputs[0];
943 (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
944 (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
945 (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
946 (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
947 }
948
ConvertFakeQuantOperator(const FakeQuantOperator & src_op,GraphDef * tensorflow_graph)949 void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
950 GraphDef* tensorflow_graph) {
951 tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
952 fakequant_op->set_op("FakeQuantWithMinMaxArgs");
953 fakequant_op->set_name(src_op.outputs[0]);
954 CHECK_EQ(src_op.inputs.size(), 1);
955 *fakequant_op->add_input() = src_op.inputs[0];
956 CHECK(src_op.minmax);
957 (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
958 (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
959 if (src_op.num_bits) {
960 (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
961 }
962 if (src_op.narrow_range) {
963 (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
964 }
965 }
966
ConvertMaxPoolOperator(const MaxPoolOperator & src_op,GraphDef * tensorflow_graph)967 void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
968 GraphDef* tensorflow_graph) {
969 tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
970 maxpool_op->set_op("MaxPool");
971 maxpool_op->set_name(src_op.outputs[0]);
972 *maxpool_op->add_input() = src_op.inputs[0];
973 auto& strides = (*maxpool_op->mutable_attr())["strides"];
974 strides.mutable_list()->add_i(1);
975 strides.mutable_list()->add_i(src_op.stride_height);
976 strides.mutable_list()->add_i(src_op.stride_width);
977 strides.mutable_list()->add_i(1);
978 string padding;
979 if (src_op.padding.type == PaddingType::kSame) {
980 padding = "SAME";
981 } else if (src_op.padding.type == PaddingType::kValid) {
982 padding = "VALID";
983 } else {
984 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
985 }
986 (*maxpool_op->mutable_attr())["padding"].set_s(padding);
987 (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
988 auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
989 ksize.mutable_list()->add_i(1);
990 ksize.mutable_list()->add_i(src_op.kheight);
991 ksize.mutable_list()->add_i(src_op.kwidth);
992 ksize.mutable_list()->add_i(1);
993 }
994
ConvertAveragePoolOperator(const AveragePoolOperator & src_op,GraphDef * tensorflow_graph)995 void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
996 GraphDef* tensorflow_graph) {
997 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
998 avgpool_op->set_op("AvgPool");
999 avgpool_op->set_name(src_op.outputs[0]);
1000 *avgpool_op->add_input() = src_op.inputs[0];
1001 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1002 strides.mutable_list()->add_i(1);
1003 strides.mutable_list()->add_i(src_op.stride_height);
1004 strides.mutable_list()->add_i(src_op.stride_width);
1005 strides.mutable_list()->add_i(1);
1006 string padding;
1007 if (src_op.padding.type == PaddingType::kSame) {
1008 padding = "SAME";
1009 } else if (src_op.padding.type == PaddingType::kValid) {
1010 padding = "VALID";
1011 } else {
1012 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1013 }
1014 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1015 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1016 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1017 ksize.mutable_list()->add_i(1);
1018 ksize.mutable_list()->add_i(src_op.kheight);
1019 ksize.mutable_list()->add_i(src_op.kwidth);
1020 ksize.mutable_list()->add_i(1);
1021 }
1022
ConvertConcatenationOperator(const Model & model,const ConcatenationOperator & src_op,GraphDef * tensorflow_graph)1023 void ConvertConcatenationOperator(const Model& model,
1024 const ConcatenationOperator& src_op,
1025 GraphDef* tensorflow_graph) {
1026 tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1027 dc_op->set_op("ConcatV2");
1028 dc_op->set_name(src_op.outputs[0]);
1029 const string dummy_axis = src_op.outputs[0] + "/axis";
1030 CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1031 for (const auto& input : src_op.inputs) {
1032 *dc_op->add_input() = input;
1033 }
1034 *dc_op->add_input() = dummy_axis;
1035 (*dc_op->mutable_attr())["T"].set_type(
1036 GetTensorFlowDataType(model, src_op.inputs[0]));
1037 (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1038 (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1039 }
1040
ConvertTensorFlowReshapeOperator(const Model & model,const TensorFlowReshapeOperator & src_op,GraphDef * tensorflow_graph)1041 void ConvertTensorFlowReshapeOperator(const Model& model,
1042 const TensorFlowReshapeOperator& src_op,
1043 GraphDef* tensorflow_graph) {
1044 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1045 reshape_op->set_op("Reshape");
1046 reshape_op->set_name(src_op.outputs[0]);
1047 CHECK_EQ(src_op.inputs.size(), 2);
1048 *reshape_op->add_input() = src_op.inputs[0];
1049 *reshape_op->add_input() = src_op.inputs[1];
1050 (*reshape_op->mutable_attr())["T"].set_type(
1051 GetTensorFlowDataType(model, src_op.outputs[0]));
1052 const auto& shape_array = model.GetArray(src_op.inputs[1]);
1053 QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1054 << "Only int32 shape is supported.";
1055 QCHECK(shape_array.buffer != nullptr)
1056 << "Shape inferred at runtime is not supported.";
1057 const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1058 CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1059 }
1060
ConvertL2PoolOperator(const L2PoolOperator & src_op,GraphDef * tensorflow_graph)1061 void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1062 GraphDef* tensorflow_graph) {
1063 const string square_output = src_op.outputs[0] + "/square";
1064 const string avgpool_output = src_op.outputs[0] + "/avgpool";
1065
1066 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1067 square_op->set_op("Square");
1068 square_op->set_name(square_output);
1069 *square_op->add_input() = src_op.inputs[0];
1070 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1071
1072 string padding;
1073 if (src_op.padding.type == PaddingType::kSame) {
1074 padding = "SAME";
1075 } else if (src_op.padding.type == PaddingType::kValid) {
1076 padding = "VALID";
1077 } else {
1078 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1079 }
1080
1081 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1082 avgpool_op->set_op("AvgPool");
1083 avgpool_op->set_name(avgpool_output);
1084 *avgpool_op->add_input() = square_output;
1085 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1086 strides.mutable_list()->add_i(1);
1087 strides.mutable_list()->add_i(src_op.stride_height);
1088 strides.mutable_list()->add_i(src_op.stride_width);
1089 strides.mutable_list()->add_i(1);
1090
1091 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1092 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1093 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1094 ksize.mutable_list()->add_i(1);
1095 ksize.mutable_list()->add_i(src_op.kheight);
1096 ksize.mutable_list()->add_i(src_op.kwidth);
1097 ksize.mutable_list()->add_i(1);
1098
1099 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1100 sqrt_op->set_op("Sqrt");
1101 sqrt_op->set_name(src_op.outputs[0]);
1102 *sqrt_op->add_input() = avgpool_output;
1103 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1104 }
1105
ConvertSquareOperator(const TensorFlowSquareOperator & src_op,GraphDef * tensorflow_graph)1106 void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1107 GraphDef* tensorflow_graph) {
1108 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1109 square_op->set_op("Square");
1110 square_op->set_name(src_op.outputs[0]);
1111 CHECK_EQ(src_op.inputs.size(), 1);
1112 *square_op->add_input() = src_op.inputs[0];
1113 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1114 }
1115
ConvertSqrtOperator(const TensorFlowSqrtOperator & src_op,GraphDef * tensorflow_graph)1116 void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1117 GraphDef* tensorflow_graph) {
1118 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1119 sqrt_op->set_op("Sqrt");
1120 sqrt_op->set_name(src_op.outputs[0]);
1121 CHECK_EQ(src_op.inputs.size(), 1);
1122 *sqrt_op->add_input() = src_op.inputs[0];
1123 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1124 }
1125
ConvertRsqrtOperator(const Model & model,const TensorFlowRsqrtOperator & src_op,GraphDef * tensorflow_graph)1126 void ConvertRsqrtOperator(const Model& model,
1127 const TensorFlowRsqrtOperator& src_op,
1128 GraphDef* tensorflow_graph) {
1129 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1130 rsqrt_op->set_op("Rsqrt");
1131 rsqrt_op->set_name(src_op.outputs[0]);
1132 CHECK_EQ(src_op.inputs.size(), 1);
1133 *rsqrt_op->add_input() = src_op.inputs[0];
1134 const tensorflow::DataType data_type =
1135 GetTensorFlowDataType(model, src_op.inputs[0]);
1136 (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1137 }
1138
ConvertSplitOperator(const Model & model,const TensorFlowSplitOperator & src_op,GraphDef * tensorflow_graph)1139 void ConvertSplitOperator(const Model& model,
1140 const TensorFlowSplitOperator& src_op,
1141 GraphDef* tensorflow_graph) {
1142 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1143 split_op->set_op("Split");
1144 split_op->set_name(src_op.outputs[0]);
1145 for (const auto& input : src_op.inputs) {
1146 *split_op->add_input() = input;
1147 }
1148 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1149 (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1150 const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1151 CHECK(split_dim_array.buffer);
1152 CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1153 const auto& split_dim_data =
1154 split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1155 CHECK_EQ(split_dim_data.size(), 1);
1156 const int split_dim = split_dim_data[0];
1157 CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1158 tensorflow_graph);
1159 }
1160
ConvertSplitVOperator(const Model & model,const TensorFlowSplitVOperator & src_op,GraphDef * tensorflow_graph)1161 void ConvertSplitVOperator(const Model& model,
1162 const TensorFlowSplitVOperator& src_op,
1163 GraphDef* tensorflow_graph) {
1164 tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1165 split_v_op->set_op("SplitV");
1166 split_v_op->set_name(src_op.outputs[0]);
1167 for (const auto& input : src_op.inputs) {
1168 *split_v_op->add_input() = input;
1169 }
1170 (*split_v_op->mutable_attr())["T"].set_type(
1171 GetTensorFlowDataType(model, src_op.inputs[0]));
1172 (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1173 const auto& split_dim_array = model.GetArray(src_op.inputs[1]);
1174 CHECK(split_dim_array.buffer);
1175 CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1176 const auto& split_dim_data =
1177 split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1178 CHECK_EQ(split_dim_data.size(), 1);
1179 const int split_dim = split_dim_data[0];
1180 CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1181 tensorflow_graph);
1182 }
1183
ConvertCastOperator(const Model & model,const CastOperator & src_op,GraphDef * tensorflow_graph)1184 void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1185 GraphDef* tensorflow_graph) {
1186 tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1187 cast_op->set_op("Cast");
1188 cast_op->set_name(src_op.outputs[0]);
1189 CHECK_EQ(src_op.inputs.size(), 1);
1190 *cast_op->add_input() = src_op.inputs[0];
1191
1192 (*cast_op->mutable_attr())["DstT"].set_type(
1193 GetTensorFlowDataType(model, src_op.outputs[0]));
1194 (*cast_op->mutable_attr())["SrcT"].set_type(
1195 GetTensorFlowDataType(model, src_op.inputs[0]));
1196 }
1197
ConvertFloorOperator(const Model & model,const FloorOperator & src_op,GraphDef * tensorflow_graph)1198 void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1199 GraphDef* tensorflow_graph) {
1200 tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1201 floor_op->set_op("Floor");
1202 floor_op->set_name(src_op.outputs[0]);
1203 CHECK_EQ(src_op.inputs.size(), 1);
1204 *floor_op->add_input() = src_op.inputs[0];
1205 (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1206 }
1207
ConvertCeilOperator(const Model & model,const CeilOperator & src_op,GraphDef * tensorflow_graph)1208 void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1209 GraphDef* tensorflow_graph) {
1210 tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1211 ceil_op->set_op("Ceil");
1212 ceil_op->set_name(src_op.outputs[0]);
1213 CHECK_EQ(src_op.inputs.size(), 1);
1214 *ceil_op->add_input() = src_op.inputs[0];
1215 (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1216 }
1217
ConvertGatherOperator(const Model & model,const GatherOperator & src_op,GraphDef * tensorflow_graph)1218 void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1219 GraphDef* tensorflow_graph) {
1220 tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1221 gather_op->set_op("GatherV2");
1222 gather_op->set_name(src_op.outputs[0]);
1223 *gather_op->add_input() = src_op.inputs[0];
1224 *gather_op->add_input() = src_op.inputs[1];
1225
1226 if (!src_op.axis) {
1227 // Dynamic axis.
1228 CHECK_EQ(src_op.inputs.size(), 3);
1229 *gather_op->add_input() = src_op.inputs[2];
1230 } else {
1231 // Constant axis.
1232 CHECK_EQ(src_op.inputs.size(), 2);
1233 const string gather_axis =
1234 AvailableArrayName(model, gather_op->name() + "/axis");
1235 CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1236 tensorflow_graph);
1237 *gather_op->add_input() = gather_axis;
1238 }
1239
1240 (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1241 (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1242 const tensorflow::DataType params_type =
1243 GetTensorFlowDataType(model, src_op.inputs[0]);
1244 (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1245 }
1246
ConvertArgMaxOperator(const Model & model,const ArgMaxOperator & src_op,GraphDef * tensorflow_graph)1247 void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1248 GraphDef* tensorflow_graph) {
1249 tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1250 argmax_op->set_op("ArgMax");
1251 argmax_op->set_name(src_op.outputs[0]);
1252 CHECK_EQ(src_op.inputs.size(), 2);
1253 *argmax_op->add_input() = src_op.inputs[0];
1254 *argmax_op->add_input() = src_op.inputs[1];
1255 (*argmax_op->mutable_attr())["T"].set_type(
1256 GetTensorFlowDataType(model, src_op.inputs[0]));
1257 (*argmax_op->mutable_attr())["Tidx"].set_type(
1258 GetTensorFlowDataType(model, src_op.inputs[1]));
1259 (*argmax_op->mutable_attr())["output_type"].set_type(
1260 GetTensorFlowDataType(model, src_op.outputs[0]));
1261 }
1262
ConvertArgMinOperator(const Model & model,const ArgMinOperator & src_op,GraphDef * tensorflow_graph)1263 void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1264 GraphDef* tensorflow_graph) {
1265 tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1266 argmin_op->set_op("ArgMin");
1267 argmin_op->set_name(src_op.outputs[0]);
1268 CHECK_EQ(src_op.inputs.size(), 2);
1269 *argmin_op->add_input() = src_op.inputs[0];
1270 *argmin_op->add_input() = src_op.inputs[1];
1271 (*argmin_op->mutable_attr())["T"].set_type(
1272 GetTensorFlowDataType(model, src_op.inputs[0]));
1273 (*argmin_op->mutable_attr())["Tidx"].set_type(
1274 GetTensorFlowDataType(model, src_op.inputs[1]));
1275 (*argmin_op->mutable_attr())["output_type"].set_type(
1276 GetTensorFlowDataType(model, src_op.outputs[0]));
1277 }
1278
ConvertTransposeOperator(const Model & model,const TransposeOperator & src_op,GraphDef * tensorflow_graph)1279 void ConvertTransposeOperator(const Model& model,
1280 const TransposeOperator& src_op,
1281 GraphDef* tensorflow_graph) {
1282 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1283 transpose_op->set_op("Transpose");
1284 transpose_op->set_name(src_op.outputs[0]);
1285 CHECK_EQ(src_op.inputs.size(), 2);
1286 *transpose_op->add_input() = src_op.inputs[0];
1287 *transpose_op->add_input() = src_op.inputs[1];
1288 (*transpose_op->mutable_attr())["T"].set_type(
1289 GetTensorFlowDataType(model, src_op.inputs[0]));
1290 (*transpose_op->mutable_attr())["Tperm"].set_type(
1291 GetTensorFlowDataType(model, src_op.inputs[1]));
1292 }
1293
ConvertTensorFlowShapeOperator(const Model & model,const TensorFlowShapeOperator & src_op,GraphDef * tensorflow_graph)1294 void ConvertTensorFlowShapeOperator(const Model& model,
1295 const TensorFlowShapeOperator& src_op,
1296 GraphDef* tensorflow_graph) {
1297 tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1298 shape_op->set_op("Shape");
1299 shape_op->set_name(src_op.outputs[0]);
1300 CHECK_EQ(src_op.inputs.size(), 1);
1301 *shape_op->add_input() = src_op.inputs[0];
1302 (*shape_op->mutable_attr())["T"].set_type(
1303 GetTensorFlowDataType(model, src_op.inputs[0]));
1304 (*shape_op->mutable_attr())["out_type"].set_type(
1305 GetTensorFlowDataType(model, src_op.outputs[0]));
1306 }
1307
ConvertRankOperator(const Model & model,const TensorFlowRankOperator & src_op,GraphDef * tensorflow_graph)1308 void ConvertRankOperator(const Model& model,
1309 const TensorFlowRankOperator& src_op,
1310 GraphDef* tensorflow_graph) {
1311 tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1312 rank_op->set_op("Rank");
1313 rank_op->set_name(src_op.outputs[0]);
1314 CHECK_EQ(src_op.inputs.size(), 1);
1315 *rank_op->add_input() = src_op.inputs[0];
1316 (*rank_op->mutable_attr())["T"].set_type(
1317 GetTensorFlowDataType(model, src_op.inputs[0]));
1318 }
1319
ConvertRangeOperator(const Model & model,const RangeOperator & src_op,GraphDef * tensorflow_graph)1320 void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1321 GraphDef* tensorflow_graph) {
1322 tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1323 range_op->set_op("Range");
1324 range_op->set_name(src_op.outputs[0]);
1325 CHECK_EQ(src_op.inputs.size(), 3);
1326 *range_op->add_input() = src_op.inputs[0];
1327 *range_op->add_input() = src_op.inputs[1];
1328 *range_op->add_input() = src_op.inputs[2];
1329 (*range_op->mutable_attr())["Tidx"].set_type(
1330 GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1331 }
1332
ConvertPackOperator(const Model & model,const PackOperator & src_op,GraphDef * tensorflow_graph)1333 void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1334 GraphDef* tensorflow_graph) {
1335 tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1336 pack_op->set_op("Pack");
1337 pack_op->set_name(src_op.outputs[0]);
1338 for (const auto& input : src_op.inputs) {
1339 *pack_op->add_input() = input;
1340 }
1341 (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1342 (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1343 (*pack_op->mutable_attr())["T"].set_type(
1344 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1345 }
1346
ConvertFillOperator(const Model & model,const FillOperator & src_op,GraphDef * tensorflow_graph)1347 void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1348 GraphDef* tensorflow_graph) {
1349 tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1350 fill_op->set_op("Fill");
1351 fill_op->set_name(src_op.outputs[0]);
1352 CHECK_EQ(src_op.inputs.size(), 2);
1353 *fill_op->add_input() = src_op.inputs[0];
1354 *fill_op->add_input() = src_op.inputs[1];
1355 (*fill_op->mutable_attr())["index_type"].set_type(
1356 GetTensorFlowDataType(model, src_op.inputs[0]));
1357 (*fill_op->mutable_attr())["T"].set_type(
1358 GetTensorFlowDataType(model, src_op.inputs[1]));
1359 }
1360
ConvertFloorDivOperator(const Model & model,const FloorDivOperator & src_op,GraphDef * tensorflow_graph)1361 void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1362 GraphDef* tensorflow_graph) {
1363 tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1364 floor_div_op->set_op("FloorDiv");
1365 floor_div_op->set_name(src_op.outputs[0]);
1366 CHECK_EQ(src_op.inputs.size(), 2);
1367 *floor_div_op->add_input() = src_op.inputs[0];
1368 *floor_div_op->add_input() = src_op.inputs[1];
1369 (*floor_div_op->mutable_attr())["T"].set_type(
1370 GetTensorFlowDataType(model, src_op.inputs[0]));
1371 }
1372
ConvertFloorModOperator(const Model & model,const FloorModOperator & src_op,GraphDef * tensorflow_graph)1373 void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1374 GraphDef* tensorflow_graph) {
1375 tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1376 floor_mod_op->set_op("FloorMod");
1377 floor_mod_op->set_name(src_op.outputs[0]);
1378 DCHECK_EQ(src_op.inputs.size(), 2);
1379 *floor_mod_op->add_input() = src_op.inputs[0];
1380 *floor_mod_op->add_input() = src_op.inputs[1];
1381 (*floor_mod_op->mutable_attr())["T"].set_type(
1382 GetTensorFlowDataType(model, src_op.inputs[0]));
1383 }
1384
ConvertExpandDimsOperator(const Model & model,const ExpandDimsOperator & src_op,GraphDef * tensorflow_graph)1385 void ConvertExpandDimsOperator(const Model& model,
1386 const ExpandDimsOperator& src_op,
1387 GraphDef* tensorflow_graph) {
1388 tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1389 expand_dims_op->set_op("ExpandDims");
1390 expand_dims_op->set_name(src_op.outputs[0]);
1391 CHECK_EQ(src_op.inputs.size(), 2);
1392 *expand_dims_op->add_input() = src_op.inputs[0];
1393 *expand_dims_op->add_input() = src_op.inputs[1];
1394 (*expand_dims_op->mutable_attr())["T"].set_type(
1395 GetTensorFlowDataType(model, src_op.inputs[0]));
1396 (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1397 GetTensorFlowDataType(model, src_op.inputs[1]));
1398 }
1399
ConvertResizeBilinearOperator(const Model & model,const ResizeBilinearOperator & src_op,GraphDef * tensorflow_graph)1400 void ConvertResizeBilinearOperator(const Model& model,
1401 const ResizeBilinearOperator& src_op,
1402 GraphDef* tensorflow_graph) {
1403 tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1404 resize_op->set_op("ResizeBilinear");
1405 resize_op->set_name(src_op.outputs[0]);
1406 CHECK_EQ(src_op.inputs.size(), 2);
1407 *resize_op->add_input() = src_op.inputs[0];
1408 *resize_op->add_input() = src_op.inputs[1];
1409 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1410 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1411 }
1412
ConvertOneHotOperator(const Model & model,const OneHotOperator & src_op,GraphDef * tensorflow_graph)1413 void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1414 GraphDef* tensorflow_graph) {
1415 tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1416 onehot_op->set_op("OneHot");
1417 onehot_op->set_name(src_op.outputs[0]);
1418 CHECK_EQ(src_op.inputs.size(), 4);
1419 for (const auto& input : src_op.inputs) {
1420 *onehot_op->add_input() = input;
1421 }
1422 (*onehot_op->mutable_attr())["T"].set_type(
1423 GetTensorFlowDataType(model, src_op.outputs[0]));
1424 (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1425 }
1426
1427 namespace {
1428 // TODO(aselle): Remove when available in absl
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)1429 absl::string_view FindLongestCommonPrefix(absl::string_view a,
1430 absl::string_view b) {
1431 if (a.empty() || b.empty()) return absl::string_view();
1432
1433 const char* pa = a.data();
1434 const char* pb = b.data();
1435 string::difference_type count = 0;
1436 const string::difference_type limit = std::min(a.size(), b.size());
1437 while (count < limit && *pa == *pb) {
1438 ++pa;
1439 ++pb;
1440 ++count;
1441 }
1442
1443 return absl::string_view(a.data(), count);
1444 }
1445 } // namespace
1446
ConvertLstmCellOperator(const Model & model,const LstmCellOperator & src_op,GraphDef * tensorflow_graph)1447 void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1448 GraphDef* tensorflow_graph) {
1449 // Find the base name
1450 const string base(
1451 FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1452 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1453
1454 // Concatenate inputs
1455 const string concat_output = base + "basic_lstm_cell/concat";
1456 // Op names have been chosen to match the tf.slim LSTM naming
1457 // as closely as possible.
1458 const int axis =
1459 model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1460 .shape()
1461 .dimensions_count() -
1462 1;
1463 // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1464 // works the same since the tensor has the same underlying data layout.
1465 const string axis_output = concat_output + "/axis";
1466 CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1467 tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1468 concat_op->set_op("ConcatV2");
1469 concat_op->set_name(concat_output);
1470 *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1471 *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1472 *concat_op->add_input() = axis_output;
1473 (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1474 (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1475 (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
1476
1477 // Write weights
1478 const string weights_output = base + "weights";
1479 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1480 const string weights_name = WalkUpToConstantArray(
1481 model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1482 const auto& weights_array = model.GetArray(weights_name);
1483 // Convert 4D FullyConnected weights into 2D matrix
1484 const auto& weights_shape = weights_array.shape();
1485 CHECK_EQ(weights_shape.dimensions_count(), 2);
1486 CHECK(weights_array.buffer);
1487 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1488 const float* weights_data =
1489 weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1490 ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1491 AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1492
1493 // Fully connected matrix multiply
1494 const string matmul_output = base + "MatMul";
1495 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1496 matmul_op->set_op("MatMul");
1497 matmul_op->set_name(matmul_output);
1498 *matmul_op->add_input() = concat_output;
1499 *matmul_op->add_input() = weights_output;
1500 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1501 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1502 (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1503
1504 // Write biases
1505 const string biases_output = base + "biases";
1506 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1507 const string bias_name = WalkUpToConstantArray(
1508 model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1509 const auto& bias_array = model.GetArray(bias_name);
1510 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1511 Shape bias_shape_1d = bias_array.shape();
1512 UnextendShape(&bias_shape_1d, 1);
1513 CHECK(bias_array.buffer);
1514 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1515 const float* bias_data =
1516 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1517 ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1518 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1519 tensorflow_graph,
1520 LegacyScalarPolicy::kDoCreateLegacyScalars);
1521
1522 // Add biases
1523 string biasadd_output = base + "BiasAdd";
1524 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1525 biasadd_op->set_op("BiasAdd");
1526 biasadd_op->set_name(biasadd_output);
1527 biasadd_op->add_input(matmul_output);
1528 biasadd_op->add_input(biases_output);
1529 (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1530 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1531
1532 // Split
1533 string split_dim_output = base + "split/split_dim";
1534 // The dimension is the same as the concatenation dimension
1535 CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1536 string split_output = base + "split";
1537 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1538 split_op->set_op("Split");
1539 split_op->set_name(split_output);
1540 *split_op->add_input() = split_dim_output;
1541 *split_op->add_input() = biasadd_output;
1542 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1543 (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
1544
1545 // Activation functions and memory computations
1546 const string tanh_0_output = base + "Tanh";
1547 tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1548 tanh_0_op->set_op("Tanh");
1549 tanh_0_op->set_name(tanh_0_output);
1550 *tanh_0_op->add_input() = split_output + ":1";
1551 (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1552
1553 const string sigmoid_1_output = base + "Sigmoid_1";
1554 tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1555 logistic_1_op->set_op("Sigmoid");
1556 logistic_1_op->set_name(sigmoid_1_output);
1557 *logistic_1_op->add_input() = split_output;
1558 (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1559
1560 const string mul_1_output = base + "mul_1";
1561 tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1562 mul_1_op->set_op("Mul");
1563 mul_1_op->set_name(mul_1_output);
1564 *mul_1_op->add_input() = sigmoid_1_output;
1565 *mul_1_op->add_input() = tanh_0_output;
1566 (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1567
1568 const string sigmoid_0_output = base + "Sigmoid";
1569 tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1570 logistic_2_op->set_op("Sigmoid");
1571 logistic_2_op->set_name(sigmoid_0_output);
1572 *logistic_2_op->add_input() = split_output + ":2";
1573 (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1574
1575 const string sigmoid_2_output = base + "Sigmoid_2";
1576 tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1577 logistic_3_op->set_op("Sigmoid");
1578 logistic_3_op->set_name(sigmoid_2_output);
1579 *logistic_3_op->add_input() = split_output + ":3";
1580 (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1581
1582 const string mul_0_output = base + "mul";
1583 tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1584 mul_0_op->set_op("Mul");
1585 mul_0_op->set_name(mul_0_output);
1586 *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1587 *mul_0_op->add_input() = sigmoid_0_output;
1588 (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1589
1590 const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1591 tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1592 add_1_op->set_op("Add");
1593 add_1_op->set_name(add_1_output);
1594 *add_1_op->add_input() = mul_0_output;
1595 *add_1_op->add_input() = mul_1_output;
1596 (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1597
1598 const string tanh_1_output = base + "Tanh_1";
1599 tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1600 tanh_1_op->set_op("Tanh");
1601 tanh_1_op->set_name(tanh_1_output);
1602 *tanh_1_op->add_input() = add_1_output;
1603 (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1604
1605 const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1606 tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1607 mul_2_op->set_op("Mul");
1608 mul_2_op->set_name(mul_2_output);
1609 *mul_2_op->add_input() = tanh_1_output;
1610 *mul_2_op->add_input() = sigmoid_2_output;
1611 (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1612 }
1613
ConvertSpaceToBatchNDOperator(const Model & model,const SpaceToBatchNDOperator & src_op,GraphDef * tensorflow_graph)1614 void ConvertSpaceToBatchNDOperator(const Model& model,
1615 const SpaceToBatchNDOperator& src_op,
1616 GraphDef* tensorflow_graph) {
1617 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1618 new_op->set_op("SpaceToBatchND");
1619 new_op->set_name(src_op.outputs[0]);
1620 CHECK_EQ(src_op.inputs.size(), 3);
1621 *new_op->add_input() = src_op.inputs[0];
1622 *new_op->add_input() = src_op.inputs[1];
1623 *new_op->add_input() = src_op.inputs[2];
1624 const tensorflow::DataType params_type =
1625 GetTensorFlowDataType(model, src_op.inputs[0]);
1626 (*new_op->mutable_attr())["T"].set_type(params_type);
1627 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1628 (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1629 }
1630
ConvertBatchToSpaceNDOperator(const Model & model,const BatchToSpaceNDOperator & src_op,GraphDef * tensorflow_graph)1631 void ConvertBatchToSpaceNDOperator(const Model& model,
1632 const BatchToSpaceNDOperator& src_op,
1633 GraphDef* tensorflow_graph) {
1634 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1635 new_op->set_op("BatchToSpaceND");
1636 new_op->set_name(src_op.outputs[0]);
1637 CHECK_EQ(src_op.inputs.size(), 3);
1638 *new_op->add_input() = src_op.inputs[0];
1639 *new_op->add_input() = src_op.inputs[1];
1640 *new_op->add_input() = src_op.inputs[2];
1641 const tensorflow::DataType params_type =
1642 GetTensorFlowDataType(model, src_op.inputs[0]);
1643 (*new_op->mutable_attr())["T"].set_type(params_type);
1644 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1645 (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1646 }
1647
ConvertPadOperator(const Model & model,const PadOperator & src_op,GraphDef * tensorflow_graph)1648 void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1649 GraphDef* tensorflow_graph) {
1650 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1651 new_op->set_op("Pad");
1652 new_op->set_name(src_op.outputs[0]);
1653 CHECK_EQ(src_op.inputs.size(), 2);
1654 *new_op->add_input() = src_op.inputs[0];
1655 *new_op->add_input() = src_op.inputs[1];
1656
1657 const tensorflow::DataType params_type =
1658 GetTensorFlowDataType(model, src_op.inputs[0]);
1659 (*new_op->mutable_attr())["T"].set_type(params_type);
1660
1661 // Create the params tensor.
1662 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1663 params_op->set_op("Const");
1664 params_op->set_name(src_op.inputs[1]);
1665 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1666 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1667 tensor->set_dtype(DT_INT32);
1668
1669 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1670 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1671 tensor->add_int_val(src_op.left_padding[i]);
1672 tensor->add_int_val(src_op.right_padding[i]);
1673 }
1674 auto* shape = tensor->mutable_tensor_shape();
1675 shape->add_dim()->set_size(src_op.left_padding.size());
1676 shape->add_dim()->set_size(2);
1677 }
1678
ConvertPadV2Operator(const Model & model,const PadV2Operator & src_op,GraphDef * tensorflow_graph)1679 void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1680 GraphDef* tensorflow_graph) {
1681 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1682 new_op->set_op("PadV2");
1683 new_op->set_name(src_op.outputs[0]);
1684 CHECK_EQ(src_op.inputs.size(), 2);
1685 *new_op->add_input() = src_op.inputs[0];
1686 *new_op->add_input() = src_op.inputs[1];
1687 *new_op->add_input() = src_op.inputs[2];
1688
1689 const tensorflow::DataType params_type =
1690 GetTensorFlowDataType(model, src_op.inputs[0]);
1691 (*new_op->mutable_attr())["T"].set_type(params_type);
1692
1693 // Create the params tensor.
1694 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1695 params_op->set_op("Const");
1696 params_op->set_name(src_op.inputs[1]);
1697 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1698 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1699 tensor->set_dtype(DT_INT32);
1700
1701 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1702 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1703 tensor->add_int_val(src_op.left_padding[i]);
1704 tensor->add_int_val(src_op.right_padding[i]);
1705 }
1706 auto* shape = tensor->mutable_tensor_shape();
1707 shape->add_dim()->set_size(src_op.left_padding.size());
1708 shape->add_dim()->set_size(2);
1709 }
1710
CreateSliceInput(const string & input_name,const std::vector<int> & values,GraphDef * tensorflow_graph)1711 void CreateSliceInput(const string& input_name, const std::vector<int>& values,
1712 GraphDef* tensorflow_graph) {
1713 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1714 params_op->set_op("Const");
1715 params_op->set_name(input_name);
1716 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1717 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1718 tensor->set_dtype(DT_INT32);
1719
1720 for (int i = 0; i < values.size(); ++i) {
1721 tensor->add_int_val(values[i]);
1722 }
1723 auto* shape = tensor->mutable_tensor_shape();
1724 shape->add_dim()->set_size(values.size());
1725 }
1726
ConvertStridedSliceOperator(const Model & model,const StridedSliceOperator & src_op,GraphDef * tensorflow_graph)1727 void ConvertStridedSliceOperator(const Model& model,
1728 const StridedSliceOperator& src_op,
1729 GraphDef* tensorflow_graph) {
1730 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1731 new_op->set_op("StridedSlice");
1732 new_op->set_name(src_op.outputs[0]);
1733 CHECK_EQ(src_op.inputs.size(), 4);
1734 *new_op->add_input() = src_op.inputs[0];
1735 *new_op->add_input() = src_op.inputs[1];
1736 *new_op->add_input() = src_op.inputs[2];
1737 *new_op->add_input() = src_op.inputs[3];
1738
1739 const tensorflow::DataType params_type =
1740 GetTensorFlowDataType(model, src_op.inputs[0]);
1741 (*new_op->mutable_attr())["T"].set_type(params_type);
1742
1743 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1744 (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1745 (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1746 (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1747 (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1748 (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1749
1750 // Create tensors for start/stop indices and strides.
1751 CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1752 CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1753 CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1754 }
1755
ConvertSliceOperator(const Model & model,const SliceOperator & src_op,GraphDef * tensorflow_graph)1756 void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1757 GraphDef* tensorflow_graph) {
1758 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1759 new_op->set_op("Slice");
1760 new_op->set_name(src_op.outputs[0]);
1761 CHECK_EQ(src_op.inputs.size(), 3);
1762 *new_op->add_input() = src_op.inputs[0];
1763 *new_op->add_input() = src_op.inputs[1];
1764 *new_op->add_input() = src_op.inputs[2];
1765
1766 const tensorflow::DataType params_type =
1767 GetTensorFlowDataType(model, src_op.inputs[0]);
1768 (*new_op->mutable_attr())["T"].set_type(params_type);
1769 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1770
1771 // Create tensors for begin and size inputs.
1772 CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1773 CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1774 }
1775
1776 template <typename T>
ConvertReduceOperator(const Model & model,const T & src_op,GraphDef * tensorflow_graph,const string & op_name)1777 void ConvertReduceOperator(const Model& model, const T& src_op,
1778 GraphDef* tensorflow_graph, const string& op_name) {
1779 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1780 new_op->set_op(op_name);
1781 new_op->set_name(src_op.outputs[0]);
1782 CHECK_EQ(src_op.inputs.size(), 2);
1783 *new_op->add_input() = src_op.inputs[0];
1784 *new_op->add_input() = src_op.inputs[1];
1785
1786 if (src_op.type != OperatorType::kAny) {
1787 const tensorflow::DataType params_type =
1788 GetTensorFlowDataType(model, src_op.inputs[0]);
1789 (*new_op->mutable_attr())["T"].set_type(params_type);
1790 }
1791 const tensorflow::DataType indices_type =
1792 GetTensorFlowDataType(model, src_op.inputs[1]);
1793 (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1794
1795 if (src_op.keep_dims) {
1796 (*new_op->mutable_attr())["keep_dims"].set_b(true);
1797 }
1798
1799 // Create the params tensor.
1800 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1801 params_op->set_op("Const");
1802 params_op->set_name(src_op.inputs[1]);
1803 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1804 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1805 tensor->set_dtype(DT_INT32);
1806
1807 for (int i = 0; i < src_op.axis.size(); ++i) {
1808 tensor->add_int_val(src_op.axis[i]);
1809 }
1810 auto* shape = tensor->mutable_tensor_shape();
1811 shape->add_dim()->set_size(src_op.axis.size());
1812 }
1813
ConvertSqueezeOperator(const Model & model,const SqueezeOperator & src_op,GraphDef * tensorflow_graph)1814 void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1815 GraphDef* tensorflow_graph) {
1816 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1817 new_op->set_op("Squeeze");
1818 new_op->set_name(src_op.outputs[0]);
1819 CHECK_EQ(src_op.inputs.size(), 1);
1820 *new_op->add_input() = src_op.inputs[0];
1821
1822 const tensorflow::DataType params_type =
1823 GetTensorFlowDataType(model, src_op.inputs[0]);
1824 (*new_op->mutable_attr())["T"].set_type(params_type);
1825
1826 if (!src_op.squeeze_dims.empty()) {
1827 auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1828 for (int i : src_op.squeeze_dims) {
1829 squeeze_dims.mutable_list()->add_i(i);
1830 }
1831 }
1832 }
1833
ConvertSubOperator(const Model & model,const SubOperator & src_op,GraphDef * tensorflow_graph)1834 void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1835 GraphDef* tensorflow_graph) {
1836 tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1837 sub_op->set_op("Sub");
1838 sub_op->set_name(src_op.outputs[0]);
1839 CHECK_EQ(src_op.inputs.size(), 2);
1840 *sub_op->add_input() = src_op.inputs[0];
1841 *sub_op->add_input() = src_op.inputs[1];
1842 const tensorflow::DataType data_type =
1843 GetTensorFlowDataType(model, src_op.inputs[0]);
1844 (*sub_op->mutable_attr())["T"].set_type(data_type);
1845 }
1846
ConvertTensorFlowMinimumOperator(const Model & model,const TensorFlowMinimumOperator & src_op,GraphDef * tensorflow_graph)1847 void ConvertTensorFlowMinimumOperator(const Model& model,
1848 const TensorFlowMinimumOperator& src_op,
1849 GraphDef* tensorflow_graph) {
1850 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1851 min_op->set_op("Minimum");
1852 min_op->set_name(src_op.outputs[0]);
1853 CHECK_EQ(src_op.inputs.size(), 2);
1854 *min_op->add_input() = src_op.inputs[0];
1855 *min_op->add_input() = src_op.inputs[1];
1856 const tensorflow::DataType data_type =
1857 GetTensorFlowDataType(model, src_op.inputs[0]);
1858 (*min_op->mutable_attr())["T"].set_type(data_type);
1859 }
1860
ConvertTensorFlowMaximumOperator(const Model & model,const TensorFlowMaximumOperator & src_op,GraphDef * tensorflow_graph)1861 void ConvertTensorFlowMaximumOperator(const Model& model,
1862 const TensorFlowMaximumOperator& src_op,
1863 GraphDef* tensorflow_graph) {
1864 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1865 max_op->set_op("Maximum");
1866 max_op->set_name(src_op.outputs[0]);
1867 CHECK_EQ(src_op.inputs.size(), 2);
1868 *max_op->add_input() = src_op.inputs[0];
1869 *max_op->add_input() = src_op.inputs[1];
1870 const tensorflow::DataType data_type =
1871 GetTensorFlowDataType(model, src_op.inputs[0]);
1872 (*max_op->mutable_attr())["T"].set_type(data_type);
1873 }
1874
ConvertSelectOperator(const Model & model,const SelectOperator & src_op,GraphDef * tensorflow_graph)1875 void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1876 GraphDef* tensorflow_graph) {
1877 tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1878 select_op->set_op("Select");
1879 select_op->set_name(src_op.outputs[0]);
1880 CHECK_EQ(src_op.inputs.size(), 3);
1881 *select_op->add_input() = src_op.inputs[0];
1882 *select_op->add_input() = src_op.inputs[1];
1883 *select_op->add_input() = src_op.inputs[2];
1884 const tensorflow::DataType data_type =
1885 GetTensorFlowDataType(model, src_op.inputs[1]);
1886 (*select_op->mutable_attr())["T"].set_type(data_type);
1887 }
1888
ConvertTileOperator(const Model & model,const TensorFlowTileOperator & src_op,GraphDef * tensorflow_graph)1889 void ConvertTileOperator(const Model& model,
1890 const TensorFlowTileOperator& src_op,
1891 GraphDef* tensorflow_graph) {
1892 tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1893 tile_op->set_op("Tile");
1894 tile_op->set_name(src_op.outputs[0]);
1895 CHECK_EQ(src_op.inputs.size(), 2);
1896 *tile_op->add_input() = src_op.inputs[0];
1897 *tile_op->add_input() = src_op.inputs[1];
1898 const tensorflow::DataType data_type =
1899 GetTensorFlowDataType(model, src_op.inputs[0]);
1900 (*tile_op->mutable_attr())["T"].set_type(data_type);
1901 const tensorflow::DataType multiples_data_type =
1902 GetTensorFlowDataType(model, src_op.inputs[1]);
1903 (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1904 }
1905
ConvertTopKV2Operator(const Model & model,const TopKV2Operator & src_op,GraphDef * tensorflow_graph)1906 void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1907 GraphDef* tensorflow_graph) {
1908 tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1909 topk_op->set_op("TopKV2");
1910 topk_op->set_name(src_op.outputs[0]);
1911 CHECK_EQ(src_op.inputs.size(), 2);
1912 *topk_op->add_input() = src_op.inputs[0];
1913 *topk_op->add_input() = src_op.inputs[1];
1914 const tensorflow::DataType data_type =
1915 GetTensorFlowDataType(model, src_op.inputs[0]);
1916 (*topk_op->mutable_attr())["T"].set_type(data_type);
1917 (*topk_op->mutable_attr())["sorted"].set_b(true);
1918 }
1919
ConvertRandomUniformOperator(const Model & model,const RandomUniformOperator & src_op,GraphDef * tensorflow_graph)1920 void ConvertRandomUniformOperator(const Model& model,
1921 const RandomUniformOperator& src_op,
1922 GraphDef* tensorflow_graph) {
1923 CHECK(tensorflow_graph != nullptr);
1924 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1925 new_op->set_op("RandomUniform");
1926 CHECK_EQ(src_op.inputs.size(), 1);
1927 new_op->set_name(src_op.outputs[0]);
1928 *new_op->add_input() = src_op.inputs[0];
1929 const tensorflow::DataType shape_type =
1930 GetTensorFlowDataType(model, src_op.inputs[0]);
1931 (*new_op->mutable_attr())["T"].set_type(shape_type);
1932 (*new_op->mutable_attr())["dtype"].set_type(
1933 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1934 (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1935 (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1936 }
1937
ConvertComparisonOperator(const Model & model,const Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)1938 void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1939 const char* op_name,
1940 GraphDef* tensorflow_graph) {
1941 tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1942 comparison_op->set_op(op_name);
1943 comparison_op->set_name(src_op.outputs[0]);
1944 CHECK_EQ(src_op.inputs.size(), 2);
1945 *comparison_op->add_input() = src_op.inputs[0];
1946 *comparison_op->add_input() = src_op.inputs[1];
1947 const tensorflow::DataType data_type =
1948 GetTensorFlowDataType(model, src_op.inputs[0]);
1949 (*comparison_op->mutable_attr())["T"].set_type(data_type);
1950 }
1951
ConvertSparseToDenseOperator(const Model & model,const SparseToDenseOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1952 void ConvertSparseToDenseOperator(const Model& model,
1953 const SparseToDenseOperator& src_op,
1954 const char* op_name,
1955 GraphDef* tensorflow_graph) {
1956 tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1957 sparse_to_dense_op->set_op(op_name);
1958 sparse_to_dense_op->set_name(src_op.outputs[0]);
1959 CHECK_EQ(src_op.inputs.size(), 4);
1960 for (int i = 0; i < 4; ++i) {
1961 *sparse_to_dense_op->add_input() = src_op.inputs[i];
1962 }
1963 const tensorflow::DataType data_type =
1964 GetTensorFlowDataType(model, src_op.inputs[3]);
1965 (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
1966 const tensorflow::DataType index_type =
1967 GetTensorFlowDataType(model, src_op.inputs[0]);
1968 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
1969 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
1970 src_op.validate_indices);
1971 }
1972
ConvertPowOperator(const Model & model,const PowOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)1973 void ConvertPowOperator(const Model& model, const PowOperator& src_op,
1974 const char* op_name, GraphDef* tensorflow_graph) {
1975 tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
1976 pow_op->set_op(op_name);
1977 pow_op->set_name(src_op.outputs[0]);
1978 CHECK_EQ(src_op.inputs.size(), 2);
1979 for (int i = 0; i < 2; ++i) {
1980 *pow_op->add_input() = src_op.inputs[i];
1981 }
1982 const tensorflow::DataType data_type =
1983 GetTensorFlowDataType(model, src_op.inputs[0]);
1984 (*pow_op->mutable_attr())["T"].set_type(data_type);
1985 }
1986
ConvertLogicalAndOperator(const Model & model,const LogicalAndOperator & src_op,GraphDef * tensorflow_graph)1987 void ConvertLogicalAndOperator(const Model& model,
1988 const LogicalAndOperator& src_op,
1989 GraphDef* tensorflow_graph) {
1990 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
1991 logical_op->set_op("LogicalAnd");
1992 logical_op->set_name(src_op.outputs[0]);
1993 CHECK_EQ(src_op.inputs.size(), 2);
1994 for (int i = 0; i < 2; ++i) {
1995 *logical_op->add_input() = src_op.inputs[i];
1996 }
1997 }
1998
ConvertLogicalNotOperator(const Model & model,const LogicalNotOperator & src_op,GraphDef * tensorflow_graph)1999 void ConvertLogicalNotOperator(const Model& model,
2000 const LogicalNotOperator& src_op,
2001 GraphDef* tensorflow_graph) {
2002 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2003 logical_op->set_op("LogicalNot");
2004 logical_op->set_name(src_op.outputs[0]);
2005 CHECK_EQ(src_op.inputs.size(), 1);
2006 *logical_op->add_input() = src_op.inputs[0];
2007 }
2008
ConvertLogicalOrOperator(const Model & model,const LogicalOrOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2009 void ConvertLogicalOrOperator(const Model& model,
2010 const LogicalOrOperator& src_op,
2011 const char* op_name, GraphDef* tensorflow_graph) {
2012 tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2013 logical_or_op->set_op(op_name);
2014 logical_or_op->set_name(src_op.outputs[0]);
2015 CHECK_EQ(src_op.inputs.size(), 2);
2016 for (int i = 0; i < 2; ++i) {
2017 *logical_or_op->add_input() = src_op.inputs[i];
2018 }
2019 const tensorflow::DataType data_type =
2020 GetTensorFlowDataType(model, src_op.inputs[0]);
2021 (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2022 }
2023
ConvertCTCBeamSearchDecoderOperator(const Model & model,const CTCBeamSearchDecoderOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2024 void ConvertCTCBeamSearchDecoderOperator(
2025 const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2026 const char* op_name, GraphDef* tensorflow_graph) {
2027 auto* op = tensorflow_graph->add_node();
2028 op->set_op(op_name);
2029 op->set_name(src_op.outputs[0]);
2030 CHECK_EQ(src_op.inputs.size(), 2);
2031 for (int i = 0; i < 2; ++i) {
2032 *op->add_input() = src_op.inputs[i];
2033 }
2034 (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2035 (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2036 (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2037 }
2038
ConvertUnpackOperator(const Model & model,const UnpackOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2039 void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2040 const char* op_name, GraphDef* tensorflow_graph) {
2041 tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2042 unpack_op->set_op(op_name);
2043 unpack_op->set_name(src_op.outputs[0]);
2044 CHECK_EQ(src_op.inputs.size(), 2);
2045 *unpack_op->add_input() = src_op.inputs[0];
2046 const tensorflow::DataType data_type =
2047 GetTensorFlowDataType(model, src_op.inputs[0]);
2048 (*unpack_op->mutable_attr())["T"].set_type(data_type);
2049 (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2050 (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2051 }
2052
ConvertZerosLikeOperator(const Model & model,const TensorFlowZerosLikeOperator & src_op,const char * op_name,GraphDef * tensorflow_graph)2053 void ConvertZerosLikeOperator(const Model& model,
2054 const TensorFlowZerosLikeOperator& src_op,
2055 const char* op_name, GraphDef* tensorflow_graph) {
2056 tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2057 zeros_like_op->set_op(op_name);
2058 zeros_like_op->set_name(src_op.outputs[0]);
2059 DCHECK_EQ(src_op.inputs.size(), 1);
2060 *zeros_like_op->add_input() = src_op.inputs[0];
2061 const tensorflow::DataType data_type =
2062 GetTensorFlowDataType(model, src_op.inputs[0]);
2063 (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2064 }
2065
ConvertReverseV2Operator(const Model & model,const ReverseV2Operator & src_op,const char * op_name,GraphDef * tensorflow_graph)2066 void ConvertReverseV2Operator(const Model& model,
2067 const ReverseV2Operator& src_op,
2068 const char* op_name, GraphDef* tensorflow_graph) {
2069 tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2070 reverse_v2_op->set_op(op_name);
2071 reverse_v2_op->set_name(src_op.outputs[0]);
2072 DCHECK_EQ(src_op.inputs.size(), 2);
2073 *reverse_v2_op->add_input() = src_op.inputs[0];
2074 *reverse_v2_op->add_input() = src_op.inputs[1];
2075 const tensorflow::DataType data_type =
2076 GetTensorFlowDataType(model, src_op.inputs[0]);
2077 (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2078 }
2079
ConvertReverseSequenceOperator(const Model & model,const ReverseSequenceOperator & src_op,GraphDef * tensorflow_graph)2080 void ConvertReverseSequenceOperator(const Model& model,
2081 const ReverseSequenceOperator& src_op,
2082 GraphDef* tensorflow_graph) {
2083 tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2084 reverse_seq_op->set_op("ReverseSequence");
2085 reverse_seq_op->set_name(src_op.outputs[0]);
2086 CHECK_EQ(src_op.inputs.size(), 2);
2087 *reverse_seq_op->add_input() = src_op.inputs[0];
2088 *reverse_seq_op->add_input() = src_op.inputs[1];
2089 (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2090 (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2091 }
2092
ConvertOperator(const Model & model,const Operator & src_op,GraphDef * tensorflow_graph)2093 void ConvertOperator(const Model& model, const Operator& src_op,
2094 GraphDef* tensorflow_graph) {
2095 if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2096 LOG(FATAL)
2097 << "Unsupported: the input model has a fused activation function";
2098 }
2099
2100 if (src_op.type == OperatorType::kConv) {
2101 ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2102 tensorflow_graph);
2103 } else if (src_op.type == OperatorType::kDepthwiseConv) {
2104 ConvertDepthwiseConvOperator(
2105 model, static_cast<const DepthwiseConvOperator&>(src_op),
2106 tensorflow_graph);
2107 } else if (src_op.type == OperatorType::kDepthToSpace) {
2108 ConvertDepthToSpaceOperator(
2109 model, static_cast<const DepthToSpaceOperator&>(src_op),
2110 tensorflow_graph);
2111 } else if (src_op.type == OperatorType::kSpaceToDepth) {
2112 ConvertSpaceToDepthOperator(
2113 model, static_cast<const SpaceToDepthOperator&>(src_op),
2114 tensorflow_graph);
2115 } else if (src_op.type == OperatorType::kFullyConnected) {
2116 ConvertFullyConnectedOperator(
2117 model, static_cast<const FullyConnectedOperator&>(src_op),
2118 tensorflow_graph);
2119 } else if (src_op.type == OperatorType::kAdd) {
2120 ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2121 tensorflow_graph);
2122 } else if (src_op.type == OperatorType::kAddN) {
2123 ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2124 tensorflow_graph);
2125 } else if (src_op.type == OperatorType::kMul) {
2126 ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2127 tensorflow_graph);
2128 } else if (src_op.type == OperatorType::kDiv) {
2129 ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2130 tensorflow_graph);
2131 } else if (src_op.type == OperatorType::kRelu) {
2132 ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2133 tensorflow_graph);
2134 } else if (src_op.type == OperatorType::kRelu1) {
2135 ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2136 tensorflow_graph);
2137 } else if (src_op.type == OperatorType::kRelu6) {
2138 ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2139 tensorflow_graph);
2140 } else if (src_op.type == OperatorType::kLog) {
2141 ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2142 tensorflow_graph);
2143 } else if (src_op.type == OperatorType::kLogistic) {
2144 ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2145 tensorflow_graph);
2146 } else if (src_op.type == OperatorType::kTanh) {
2147 ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2148 tensorflow_graph);
2149 } else if (src_op.type == OperatorType::kL2Normalization) {
2150 ConvertL2NormalizationOperator(
2151 static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2152 } else if (src_op.type == OperatorType::kSoftmax) {
2153 ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2154 tensorflow_graph);
2155 } else if (src_op.type == OperatorType::kLogSoftmax) {
2156 ConvertLogSoftmaxOperator(model,
2157 static_cast<const LogSoftmaxOperator&>(src_op),
2158 tensorflow_graph);
2159 } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2160 ConvertLocalResponseNormalizationOperator(
2161 static_cast<const LocalResponseNormalizationOperator&>(src_op),
2162 tensorflow_graph);
2163 } else if (src_op.type == OperatorType::kLstmCell) {
2164 ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2165 tensorflow_graph);
2166 } else if (src_op.type == OperatorType::kMaxPool) {
2167 ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2168 tensorflow_graph);
2169 } else if (src_op.type == OperatorType::kAveragePool) {
2170 ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2171 tensorflow_graph);
2172 } else if (src_op.type == OperatorType::kConcatenation) {
2173 ConvertConcatenationOperator(
2174 model, static_cast<const ConcatenationOperator&>(src_op),
2175 tensorflow_graph);
2176 } else if (src_op.type == OperatorType::kReshape) {
2177 ConvertTensorFlowReshapeOperator(
2178 model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2179 tensorflow_graph);
2180 } else if (src_op.type == OperatorType::kL2Pool) {
2181 ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2182 tensorflow_graph);
2183 } else if (src_op.type == OperatorType::kSquare) {
2184 ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2185 tensorflow_graph);
2186 } else if (src_op.type == OperatorType::kSqrt) {
2187 ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2188 tensorflow_graph);
2189 } else if (src_op.type == OperatorType::kRsqrt) {
2190 ConvertRsqrtOperator(model,
2191 static_cast<const TensorFlowRsqrtOperator&>(src_op),
2192 tensorflow_graph);
2193 } else if (src_op.type == OperatorType::kSplit) {
2194 ConvertSplitOperator(model,
2195 static_cast<const TensorFlowSplitOperator&>(src_op),
2196 tensorflow_graph);
2197 } else if (src_op.type == OperatorType::kSplitV) {
2198 ConvertSplitVOperator(model,
2199 static_cast<const TensorFlowSplitVOperator&>(src_op),
2200 tensorflow_graph);
2201 } else if (src_op.type == OperatorType::kFakeQuant) {
2202 ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2203 tensorflow_graph);
2204 } else if (src_op.type == OperatorType::kCast) {
2205 ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2206 tensorflow_graph);
2207 } else if (src_op.type == OperatorType::kFloor) {
2208 ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2209 tensorflow_graph);
2210 } else if (src_op.type == OperatorType::kCeil) {
2211 ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2212 tensorflow_graph);
2213 } else if (src_op.type == OperatorType::kGather) {
2214 ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2215 tensorflow_graph);
2216 } else if (src_op.type == OperatorType::kResizeBilinear) {
2217 ConvertResizeBilinearOperator(
2218 model, static_cast<const ResizeBilinearOperator&>(src_op),
2219 tensorflow_graph);
2220 } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2221 ConvertSpaceToBatchNDOperator(
2222 model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2223 tensorflow_graph);
2224 } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2225 ConvertBatchToSpaceNDOperator(
2226 model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2227 tensorflow_graph);
2228 } else if (src_op.type == OperatorType::kPad) {
2229 ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2230 tensorflow_graph);
2231 } else if (src_op.type == OperatorType::kPadV2) {
2232 ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2233 tensorflow_graph);
2234 } else if (src_op.type == OperatorType::kStridedSlice) {
2235 ConvertStridedSliceOperator(
2236 model, static_cast<const StridedSliceOperator&>(src_op),
2237 tensorflow_graph);
2238 } else if (src_op.type == OperatorType::kMean) {
2239 ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2240 tensorflow_graph, "Mean");
2241 } else if (src_op.type == OperatorType::kSum) {
2242 ConvertReduceOperator(model,
2243 static_cast<const TensorFlowSumOperator&>(src_op),
2244 tensorflow_graph, "Sum");
2245 } else if (src_op.type == OperatorType::kReduceProd) {
2246 ConvertReduceOperator(model,
2247 static_cast<const TensorFlowProdOperator&>(src_op),
2248 tensorflow_graph, "Prod");
2249 } else if (src_op.type == OperatorType::kReduceMin) {
2250 ConvertReduceOperator(model,
2251 static_cast<const TensorFlowMinOperator&>(src_op),
2252 tensorflow_graph, "Min");
2253 } else if (src_op.type == OperatorType::kReduceMax) {
2254 ConvertReduceOperator(model,
2255 static_cast<const TensorFlowMaxOperator&>(src_op),
2256 tensorflow_graph, "Max");
2257 } else if (src_op.type == OperatorType::kSub) {
2258 ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2259 tensorflow_graph);
2260 } else if (src_op.type == OperatorType::kMinimum) {
2261 ConvertTensorFlowMinimumOperator(
2262 model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2263 tensorflow_graph);
2264 } else if (src_op.type == OperatorType::kMaximum) {
2265 ConvertTensorFlowMaximumOperator(
2266 model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2267 tensorflow_graph);
2268 } else if (src_op.type == OperatorType::kSqueeze) {
2269 ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2270 tensorflow_graph);
2271 } else if (src_op.type == OperatorType::kSlice) {
2272 ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2273 tensorflow_graph);
2274 } else if (src_op.type == OperatorType::kArgMax) {
2275 ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2276 tensorflow_graph);
2277 } else if (src_op.type == OperatorType::kArgMin) {
2278 ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2279 tensorflow_graph);
2280 } else if (src_op.type == OperatorType::kTopK_V2) {
2281 ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2282 tensorflow_graph);
2283 } else if (src_op.type == OperatorType::kTranspose) {
2284 ConvertTransposeOperator(
2285 model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2286 } else if (src_op.type == OperatorType::kShape) {
2287 ConvertTensorFlowShapeOperator(
2288 model, static_cast<const TensorFlowShapeOperator&>(src_op),
2289 tensorflow_graph);
2290 } else if (src_op.type == OperatorType::kRank) {
2291 ConvertRankOperator(model,
2292 static_cast<const TensorFlowRankOperator&>(src_op),
2293 tensorflow_graph);
2294 } else if (src_op.type == OperatorType::kRange) {
2295 ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2296 tensorflow_graph);
2297 } else if (src_op.type == OperatorType::kPack) {
2298 ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2299 tensorflow_graph);
2300 } else if (src_op.type == OperatorType::kFill) {
2301 ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2302 tensorflow_graph);
2303 } else if (src_op.type == OperatorType::kFloorDiv) {
2304 ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2305 tensorflow_graph);
2306 } else if (src_op.type == OperatorType::kFloorMod) {
2307 ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2308 tensorflow_graph);
2309 } else if (src_op.type == OperatorType::kExpandDims) {
2310 ConvertExpandDimsOperator(model,
2311 static_cast<const ExpandDimsOperator&>(src_op),
2312 tensorflow_graph);
2313 } else if (src_op.type == OperatorType::kTransposeConv) {
2314 ConvertTransposeConvOperator(
2315 model, static_cast<const TransposeConvOperator&>(src_op),
2316 tensorflow_graph);
2317 } else if (src_op.type == OperatorType::kRandomUniform) {
2318 ConvertRandomUniformOperator(
2319 model, static_cast<const RandomUniformOperator&>(src_op),
2320 tensorflow_graph);
2321 } else if (src_op.type == OperatorType::kEqual) {
2322 ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2323 } else if (src_op.type == OperatorType::kNotEqual) {
2324 ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2325 } else if (src_op.type == OperatorType::kGreater) {
2326 ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2327 } else if (src_op.type == OperatorType::kGreaterEqual) {
2328 ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2329 } else if (src_op.type == OperatorType::kLess) {
2330 ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2331 } else if (src_op.type == OperatorType::kLessEqual) {
2332 ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2333 } else if (src_op.type == OperatorType::kSelect) {
2334 ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2335 tensorflow_graph);
2336 } else if (src_op.type == OperatorType::kTile) {
2337 ConvertTileOperator(model,
2338 static_cast<const TensorFlowTileOperator&>(src_op),
2339 tensorflow_graph);
2340 } else if (src_op.type == OperatorType::kPow) {
2341 ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2342 tensorflow_graph);
2343 } else if (src_op.type == OperatorType::kAny) {
2344 ConvertReduceOperator(model,
2345 static_cast<const TensorFlowAnyOperator&>(src_op),
2346 tensorflow_graph, "Any");
2347 } else if (src_op.type == OperatorType::kLogicalAnd) {
2348 ConvertLogicalAndOperator(model,
2349 static_cast<const LogicalAndOperator&>(src_op),
2350 tensorflow_graph);
2351 } else if (src_op.type == OperatorType::kLogicalNot) {
2352 ConvertLogicalNotOperator(model,
2353 static_cast<const LogicalNotOperator&>(src_op),
2354 tensorflow_graph);
2355 } else if (src_op.type == OperatorType::kOneHot) {
2356 ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2357 tensorflow_graph);
2358 } else if (src_op.type == OperatorType::kLogicalOr) {
2359 ConvertLogicalOrOperator(model,
2360 static_cast<const LogicalOrOperator&>(src_op),
2361 "LogicalOr", tensorflow_graph);
2362 } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2363 ConvertCTCBeamSearchDecoderOperator(
2364 model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2365 "CTCBeamSearchDecoder", tensorflow_graph);
2366 } else if (src_op.type == OperatorType::kUnpack) {
2367 ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2368 "Unpack", tensorflow_graph);
2369 } else if (src_op.type == OperatorType::kZerosLike) {
2370 ConvertZerosLikeOperator(
2371 model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2372 "ZerosLike", tensorflow_graph);
2373 } else if (src_op.type == OperatorType::kReverseV2) {
2374 ConvertReverseV2Operator(model,
2375 static_cast<const ReverseV2Operator&>(src_op),
2376 "Reverse_V2", tensorflow_graph);
2377 } else if (src_op.type == OperatorType::kReverseSequence) {
2378 ConvertReverseSequenceOperator(
2379 model, static_cast<const ReverseSequenceOperator&>(src_op),
2380 tensorflow_graph);
2381 } else {
2382 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2383 }
2384 }
2385
AddPlaceholder(const string & name,ArrayDataType type,GraphDef * tensorflow_graph)2386 void AddPlaceholder(const string& name, ArrayDataType type,
2387 GraphDef* tensorflow_graph) {
2388 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2389 placeholder->set_op("Placeholder");
2390 switch (type) {
2391 case ArrayDataType::kBool:
2392 (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2393 break;
2394 case ArrayDataType::kFloat:
2395 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2396 break;
2397 case ArrayDataType::kUint8:
2398 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2399 break;
2400 case ArrayDataType::kInt32:
2401 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2402 break;
2403 case ArrayDataType::kInt64:
2404 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2405 break;
2406 case ArrayDataType::kInt16:
2407 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2408 break;
2409 case ArrayDataType::kComplex64:
2410 (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2411 break;
2412 default:
2413 LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2414 }
2415 placeholder->set_name(name);
2416 }
2417
AddPlaceholderForRNNState(const Model & model,const string & name,int size,GraphDef * tensorflow_graph)2418 void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
2419 GraphDef* tensorflow_graph) {
2420 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2421 placeholder->set_op("Placeholder");
2422 placeholder->set_name(name);
2423 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2424
2425 auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2426 const auto& state_array = model.GetArray(name);
2427 if (state_array.has_shape()) {
2428 const auto& state_shape = state_array.shape();
2429 const int kDims = state_shape.dimensions_count();
2430 for (int i = 0; i < kDims; ++i) {
2431 shape->add_dim()->set_size(state_shape.dims(i));
2432 }
2433 } else {
2434 shape->add_dim()->set_size(1);
2435 shape->add_dim()->set_size(size);
2436 }
2437 }
2438
ExportTensorFlowGraphDefImplementation(const Model & model,GraphDef * tensorflow_graph)2439 void ExportTensorFlowGraphDefImplementation(const Model& model,
2440 GraphDef* tensorflow_graph) {
2441 for (const auto& input_array : model.flags.input_arrays()) {
2442 AddPlaceholder(input_array.name(),
2443 model.GetArray(input_array.name()).data_type,
2444 tensorflow_graph);
2445 }
2446 for (const auto& rnn_state : model.flags.rnn_states()) {
2447 AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2448 tensorflow_graph);
2449 }
2450 for (const auto& op : model.operators) {
2451 ConvertOperator(model, *op, tensorflow_graph);
2452 }
2453 // Generically export arrays that haven't been exported already
2454 // by the above operators export. It's important that this comes
2455 // after, as some operators need to export arrays that they reference
2456 // in a specific way, rather than in the generic way done below.
2457 for (const auto& array_pair : model.GetArrayMap()) {
2458 const string& array_name = array_pair.first;
2459 const auto& array = *array_pair.second;
2460 if (array.buffer) {
2461 switch (array.data_type) {
2462 case ArrayDataType::kBool:
2463 ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2464 break;
2465 case ArrayDataType::kFloat:
2466 ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2467 break;
2468 case ArrayDataType::kInt32:
2469 ConvertIntTensorConst(model, array_name, tensorflow_graph);
2470 break;
2471 case ArrayDataType::kComplex64:
2472 ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2473 break;
2474 default:
2475 break;
2476 }
2477 }
2478 }
2479 }
2480 } // namespace
2481
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model * model)2482 void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2483 for (const auto& array_kv : model->GetArrayMap()) {
2484 const string& array_name = array_kv.first;
2485 Array& array = *array_kv.second;
2486 if (!array.buffer || !array.minmax) {
2487 continue;
2488 }
2489 const string& wrapped_array_name =
2490 AvailableArrayName(*model, array_name + "/data");
2491 Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2492 wrapped_array.data_type = array.data_type;
2493 wrapped_array.copy_shape(array.shape());
2494 wrapped_array.buffer = std::move(array.buffer);
2495 FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2496 fakequant_op->inputs = {wrapped_array_name};
2497 fakequant_op->outputs = {array_name};
2498 fakequant_op->minmax.reset(new MinMax);
2499 *fakequant_op->minmax = *array.minmax;
2500 const auto& it = FindOpWithInput(*model, array_name);
2501 model->operators.emplace(it, fakequant_op);
2502 }
2503 CheckInvariants(*model);
2504 }
2505
ExportTensorFlowGraphDef(const Model & model,string * output_file_contents)2506 void ExportTensorFlowGraphDef(const Model& model,
2507 string* output_file_contents) {
2508 CHECK(output_file_contents->empty());
2509 GraphDef tensorflow_graph;
2510 ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2511 LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2512 CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2513 }
2514 } // namespace toco
2515