1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/import_tensorflow.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_split.h"
29 #include "absl/strings/strip.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/model_flags.pb.h"
32 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
33 #include "tensorflow/lite/toco/tensorflow_util.h"
34 #include "tensorflow/lite/toco/tooling_util.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph.pb.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/tensor_shape.pb.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/public/session_options.h"
50 #include "tensorflow/core/public/version.h"
51
52 using tensorflow::AttrValue;
53 using tensorflow::DT_BOOL;
54 using tensorflow::DT_COMPLEX64;
55 using tensorflow::DT_FLOAT;
56 using tensorflow::DT_INT32;
57 using tensorflow::DT_INT64;
58 using tensorflow::DT_QUINT8;
59 using tensorflow::DT_STRING;
60 using tensorflow::DT_UINT8;
61 using tensorflow::GraphDef;
62 using tensorflow::NodeDef;
63 using tensorflow::OpRegistry;
64 using tensorflow::TensorProto;
65 using tensorflow::TensorShapeProto;
66
67 namespace toco {
68
69 namespace {
HasAttr(const NodeDef & node,const string & attr_name)70 bool HasAttr(const NodeDef& node, const string& attr_name) {
71 return node.attr().count(attr_name) > 0;
72 }
73
HasWildcardDimension(const TensorShapeProto & shape)74 bool HasWildcardDimension(const TensorShapeProto& shape) {
75 for (const auto& dim : shape.dim()) {
76 if (dim.size() == -1) return true;
77 }
78 return false;
79 }
80
GetStringAttr(const NodeDef & node,const string & attr_name)81 const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
82 CHECK(HasAttr(node, attr_name));
83 const auto& attr = node.attr().at(attr_name);
84 CHECK_EQ(attr.value_case(), AttrValue::kS);
85 return attr.s();
86 }
87
GetIntAttr(const NodeDef & node,const string & attr_name)88 int64 GetIntAttr(const NodeDef& node, const string& attr_name) {
89 CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
90 << node.DebugString();
91 const auto& attr = node.attr().at(attr_name);
92 CHECK_EQ(attr.value_case(), AttrValue::kI);
93 return attr.i();
94 }
95
GetFloatAttr(const NodeDef & node,const string & attr_name)96 float GetFloatAttr(const NodeDef& node, const string& attr_name) {
97 CHECK(HasAttr(node, attr_name));
98 const auto& attr = node.attr().at(attr_name);
99 CHECK_EQ(attr.value_case(), AttrValue::kF);
100 return attr.f();
101 }
102
GetBoolAttr(const NodeDef & node,const string & attr_name)103 bool GetBoolAttr(const NodeDef& node, const string& attr_name) {
104 CHECK(HasAttr(node, attr_name));
105 const auto& attr = node.attr().at(attr_name);
106 CHECK_EQ(attr.value_case(), AttrValue::kB);
107 return attr.b();
108 }
109
GetDataTypeAttr(const NodeDef & node,const string & attr_name)110 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
111 const string& attr_name) {
112 CHECK(HasAttr(node, attr_name));
113 const auto& attr = node.attr().at(attr_name);
114 CHECK_EQ(attr.value_case(), AttrValue::kType);
115 return attr.type();
116 }
117
GetShapeAttr(const NodeDef & node,const string & attr_name)118 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
119 const string& attr_name) {
120 CHECK(HasAttr(node, attr_name));
121 const auto& attr = node.attr().at(attr_name);
122 CHECK_EQ(attr.value_case(), AttrValue::kShape);
123 return attr.shape();
124 }
125
GetTensorAttr(const NodeDef & node,const string & attr_name)126 const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
127 CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
128 const auto& attr = node.attr().at(attr_name);
129 CHECK_EQ(attr.value_case(), AttrValue::kTensor);
130 return attr.tensor();
131 }
132
GetListAttr(const NodeDef & node,const string & attr_name)133 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
134 const string& attr_name) {
135 CHECK(HasAttr(node, attr_name));
136 const auto& attr = node.attr().at(attr_name);
137 CHECK_EQ(attr.value_case(), AttrValue::kList);
138 return attr.list();
139 }
140
CheckOptionalAttr(const NodeDef & node,const string & attr_name,const string & expected_value)141 tensorflow::Status CheckOptionalAttr(const NodeDef& node,
142 const string& attr_name,
143 const string& expected_value) {
144 if (HasAttr(node, attr_name)) {
145 const string& value = GetStringAttr(node, attr_name);
146 if (value != expected_value) {
147 return tensorflow::errors::InvalidArgument(
148 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
149 expected_value + "'");
150 }
151 }
152 return tensorflow::Status::OK();
153 }
154
CheckOptionalAttr(const NodeDef & node,const string & attr_name,const tensorflow::DataType & expected_value)155 tensorflow::Status CheckOptionalAttr(
156 const NodeDef& node, const string& attr_name,
157 const tensorflow::DataType& expected_value) {
158 if (HasAttr(node, attr_name)) {
159 const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
160 if (value != expected_value) {
161 return tensorflow::errors::InvalidArgument(
162 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
163 tensorflow::DataType_Name(expected_value) + "'");
164 }
165 }
166 return tensorflow::Status::OK();
167 }
168
169 template <typename T1, typename T2>
ExpectValue(const T1 & v1,const T2 & v2,const string & description)170 tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
171 const string& description) {
172 if (v1 == v2) return tensorflow::Status::OK();
173 return tensorflow::errors::InvalidArgument(absl::StrCat(
174 "Unexpected ", description, ": got ", v1, ", expected ", v2));
175 }
176
ConvertDataType(tensorflow::DataType dtype)177 ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
178 if (dtype == DT_UINT8)
179 return ArrayDataType::kUint8;
180 else if (dtype == DT_FLOAT)
181 return ArrayDataType::kFloat;
182 else if (dtype == DT_BOOL)
183 return ArrayDataType::kBool;
184 else if (dtype == DT_INT32)
185 return ArrayDataType::kInt32;
186 else if (dtype == DT_INT64)
187 return ArrayDataType::kInt64;
188 else if (dtype == DT_STRING)
189 return ArrayDataType::kString;
190 else if (dtype == DT_COMPLEX64)
191 return ArrayDataType::kComplex64;
192 else
193 LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
194 return ArrayDataType::kNone;
195 }
196
ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim> & input_dims,int * input_flat_size,Shape * shape)197 tensorflow::Status ImportShape(
198 const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
199 input_dims,
200 int* input_flat_size, Shape* shape) {
201 std::vector<int> input_dims_only_sizes;
202 bool zero_sized_shape = false;
203 for (auto& d : input_dims) {
204 // TensorFlow's shapes use int64s, while TOCO uses ints.
205 if (d.size() > std::numeric_limits<int>::max()) {
206 return tensorflow::errors::InvalidArgument("Shape element overflows");
207 }
208 if (d.size() == 0) {
209 zero_sized_shape = true;
210 }
211 input_dims_only_sizes.push_back(d.size());
212 }
213
214 // Note that up to this point we were OK with the input shape containing
215 // elements valued -1 or 0, which are perfectly legal in tensorflow. However
216 // our CheckValidShapeDimensions() insists on them being >= 1, with the
217 // exception of the "scalar" shape [0]. The main issue with zero-values shape
218 // elements is that the corresponding arrays don't contain any data and the
219 // allocation code gets a bit confused. It seems that the code expects an
220 // empty shape for zero-sized shapes, so we will do just that, except for the
221 // [0] case.
222 // TODO(b/119325030): In order to correctly import the "scalar" shapes the
223 // following test must include "&& input_dims_only_sizes.size() > 1", but
224 // that seems to slow everything down a lot.
225 if (zero_sized_shape) {
226 shape->mutable_dims()->clear();
227 if (input_flat_size != nullptr) *input_flat_size = 0;
228 return tensorflow::Status::OK();
229 }
230
231 *shape->mutable_dims() = input_dims_only_sizes;
232
233 if (input_flat_size == nullptr) return tensorflow::Status::OK();
234
235 return NumElements(input_dims_only_sizes, input_flat_size);
236 }
237
238 // Define ways to retrieve data from tensors of different types.
239 // TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
240 template <typename T>
241 struct TensorTraits;
242
243 template <>
244 struct TensorTraits<float> {
sizetoco::__anonb974ebe00111::TensorTraits245 static int size(const TensorProto& p) { return p.float_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits246 static float get(const TensorProto& p, int i) { return p.float_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits247 static string accessor_name() { return "float_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits248 static string type_name() { return "float"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits249 static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
250 toco::port::CopyToBuffer(p.tensor_content(),
251 reinterpret_cast<char*>(data->data()));
252 }
253 };
254
255 template <>
256 struct TensorTraits<uint8_t> {
sizetoco::__anonb974ebe00111::TensorTraits257 static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits258 static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits259 static string accessor_name() { return "int_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits260 static string type_name() { return "uint8"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits261 static void CopyFromContent(const TensorProto& p,
262 std::vector<uint8_t>* data) {
263 toco::port::CopyToBuffer(p.tensor_content(),
264 reinterpret_cast<char*>(data->data()));
265 }
266 };
267
268 template <>
269 struct TensorTraits<std::complex<float>> {
sizetoco::__anonb974ebe00111::TensorTraits270 static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
gettoco::__anonb974ebe00111::TensorTraits271 static std::complex<float> get(const TensorProto& p, int i) {
272 return std::complex<float>(p.scomplex_val(2 * i),
273 p.scomplex_val(2 * i + 1));
274 }
accessor_nametoco::__anonb974ebe00111::TensorTraits275 static string accessor_name() { return "scomplex_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits276 static string type_name() { return "complex64"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits277 static void CopyFromContent(const TensorProto& p,
278 std::vector<std::complex<float>>* data) {
279 toco::port::CopyToBuffer(p.tensor_content(),
280 reinterpret_cast<char*>(data->data()));
281 }
282 };
283
284 template <>
285 struct TensorTraits<int32> {
sizetoco::__anonb974ebe00111::TensorTraits286 static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits287 static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits288 static string accessor_name() { return "int_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits289 static string type_name() { return "int32"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits290 static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
291 toco::port::CopyToBuffer(p.tensor_content(),
292 reinterpret_cast<char*>(data->data()));
293 }
294 };
295
296 template <>
297 struct TensorTraits<int64> {
sizetoco::__anonb974ebe00111::TensorTraits298 static int size(const TensorProto& p) { return p.int64_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits299 static int64 get(const TensorProto& p, int i) { return p.int64_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits300 static string accessor_name() { return "int64_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits301 static string type_name() { return "int64"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits302 static void CopyFromContent(const TensorProto& p, std::vector<int64>* data) {
303 toco::port::CopyToBuffer(p.tensor_content(),
304 reinterpret_cast<char*>(data->data()));
305 }
306 };
307
308 template <>
309 struct TensorTraits<bool> {
sizetoco::__anonb974ebe00111::TensorTraits310 static int size(const TensorProto& p) { return p.bool_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits311 static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits312 static string accessor_name() { return "bool_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits313 static string type_name() { return "bool"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits314 static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
315 std::vector<char> buf(p.tensor_content().size());
316 toco::port::CopyToBuffer(p.tensor_content(), buf.data());
317 for (int i = 0; i < p.tensor_content().size(); i++) {
318 (*data)[i] = static_cast<bool>(buf[i]);
319 }
320 }
321 };
322
323 template <typename T>
ImportTensorData(const TensorProto & input_tensor,int input_flat_size,std::vector<T> * output_data)324 tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
325 int input_flat_size,
326 std::vector<T>* output_data) {
327 CHECK_GE(output_data->size(), input_flat_size);
328 int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
329 if (num_elements_in_tensor == input_flat_size) {
330 for (int i = 0; i < num_elements_in_tensor; i++) {
331 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
332 }
333 } else if (input_tensor.tensor_content().size() ==
334 input_flat_size * sizeof(T)) {
335 TensorTraits<T>::CopyFromContent(input_tensor, output_data);
336 } else if (num_elements_in_tensor > 0 &&
337 num_elements_in_tensor < input_flat_size) {
338 // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
339 // official way to import tensor data. This particular else-if handles a
340 // grappler optimization where the last few elements in a tensor are
341 // omitted if they are repeated.
342 int i = 0;
343 for (; i < num_elements_in_tensor; ++i) {
344 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
345 }
346 auto last = (*output_data)[i - 1];
347 for (; i < input_flat_size; ++i) {
348 (*output_data)[i] = last;
349 }
350 } else {
351 string accessor_name = TensorTraits<T>::accessor_name();
352 string type_name = TensorTraits<T>::type_name();
353 return tensorflow::errors::InvalidArgument(
354 absl::StrCat("Neither input_content (",
355 input_tensor.tensor_content().size() / sizeof(T), ") nor ",
356 accessor_name, " (", num_elements_in_tensor,
357 ") have the right dimensions (", input_flat_size,
358 ") for this ", type_name, " tensor"));
359 }
360 return tensorflow::Status::OK();
361 }
362
ImportFloatArray(const TensorProto & input_tensor,Array * output_array)363 tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
364 Array* output_array) {
365 CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
366 const auto& input_shape = input_tensor.tensor_shape();
367 CHECK_LE(input_shape.dim_size(), 6);
368 int input_flat_size;
369 auto status = ImportShape(input_shape.dim(), &input_flat_size,
370 output_array->mutable_shape());
371 if (!status.ok()) return status;
372
373 auto& output_float_data =
374 output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
375 output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
376 0.f);
377 return ImportTensorData<float>(input_tensor, input_flat_size,
378 &output_float_data);
379 }
380
ImportComplex64Array(const TensorProto & input_tensor,Array * output_array)381 tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
382 Array* output_array) {
383 CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
384 const auto& input_shape = input_tensor.tensor_shape();
385 CHECK_LE(input_shape.dim_size(), 4);
386 int input_flat_size;
387 auto status = ImportShape(input_shape.dim(), &input_flat_size,
388 output_array->mutable_shape());
389 if (!status.ok()) return status;
390
391 auto& output_complex_data =
392 output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
393 output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
394 std::complex<float>(0.f, 0.f));
395 return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
396 &output_complex_data);
397 }
398
ImportQuint8Array(const TensorProto & input_tensor,Array * output_array)399 tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
400 Array* output_array) {
401 CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
402 const auto& input_shape = input_tensor.tensor_shape();
403 CHECK_LE(input_shape.dim_size(), 6);
404 int input_flat_size;
405 auto status = ImportShape(input_shape.dim(), &input_flat_size,
406 output_array->mutable_shape());
407 if (!status.ok()) return status;
408
409 auto& output_int_data =
410 output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
411 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
412 return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
413 &output_int_data);
414 }
415
ImportInt32Array(const TensorProto & input_tensor,Array * output_array)416 tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
417 Array* output_array) {
418 CHECK_EQ(input_tensor.dtype(), DT_INT32);
419 const auto& input_shape = input_tensor.tensor_shape();
420 CHECK_LE(input_shape.dim_size(), 6);
421 int input_flat_size;
422 auto status = ImportShape(input_shape.dim(), &input_flat_size,
423 output_array->mutable_shape());
424 if (!status.ok()) return status;
425
426 auto& output_int_data =
427 output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
428 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
429 return ImportTensorData<int32>(input_tensor, input_flat_size,
430 &output_int_data);
431 }
432
ImportInt64Array(const TensorProto & input_tensor,Array * output_array)433 tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
434 Array* output_array) {
435 CHECK_EQ(input_tensor.dtype(), DT_INT64);
436 const auto& input_shape = input_tensor.tensor_shape();
437 CHECK_LE(input_shape.dim_size(), 6);
438 int input_flat_size;
439 auto status = ImportShape(input_shape.dim(), &input_flat_size,
440 output_array->mutable_shape());
441 if (!status.ok()) return status;
442
443 auto& output_int_data =
444 output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
445 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
446 return ImportTensorData<int64>(input_tensor, input_flat_size,
447 &output_int_data);
448 }
449
ImportBoolArray(const TensorProto & input_tensor,Array * output_array)450 tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
451 Array* output_array) {
452 CHECK_EQ(input_tensor.dtype(), DT_BOOL);
453 const auto& input_shape = input_tensor.tensor_shape();
454 CHECK_LE(input_shape.dim_size(), 6);
455 int input_flat_size;
456 auto status = ImportShape(input_shape.dim(), &input_flat_size,
457 output_array->mutable_shape());
458 if (!status.ok()) return status;
459
460 auto& output_bool_data =
461 output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
462 output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
463 false);
464 status =
465 ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
466 if (!status.ok() && output_bool_data.size() == 1) {
467 // Some graphs have bool const nodes without actual value...
468 // assuming that 'false' is implied.
469 // So far only encountered that in an array with 1 entry, let's
470 // require that until we encounter a graph where that's not the case.
471 output_bool_data[0] = false;
472 return tensorflow::Status::OK();
473 }
474 return status;
475 }
476
ImportStringArray(const TensorProto & input_tensor,Array * output_array)477 tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
478 Array* output_array) {
479 CHECK_EQ(input_tensor.dtype(), DT_STRING);
480 const auto& input_shape = input_tensor.tensor_shape();
481 CHECK_LE(input_shape.dim_size(), 6);
482 int input_flat_size;
483 auto status = ImportShape(input_shape.dim(), &input_flat_size,
484 output_array->mutable_shape());
485 if (!status.ok()) return status;
486
487 if (input_flat_size != input_tensor.string_val_size()) {
488 return tensorflow::errors::InvalidArgument(
489 "Input_content string_val doesn't have the right dimensions "
490 "for this string tensor");
491 }
492
493 auto& output_string_data =
494 output_array->GetMutableBuffer<ArrayDataType::kString>().data;
495 output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
496 CHECK_GE(output_string_data.size(), input_flat_size);
497 for (int i = 0; i < input_flat_size; ++i) {
498 output_string_data[i] = input_tensor.string_val(i);
499 }
500 return tensorflow::Status::OK();
501 }
502
503 // Count the number of inputs of a given node. If
504 // `tf_import_flags.drop_control_dependency` is true, count the number of
505 // non-control-dependency inputs.
GetInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags)506 int GetInputsCount(const NodeDef& node,
507 const TensorFlowImportFlags& tf_import_flags) {
508 if (tf_import_flags.drop_control_dependency) {
509 for (size_t i = 0; i < node.input_size(); ++i) {
510 if (node.input(i)[0] == '^') {
511 return i;
512 }
513 }
514 }
515 return node.input_size();
516 }
517
CheckInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,int expected_input_count)518 tensorflow::Status CheckInputsCount(
519 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
520 int expected_input_count) {
521 if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
522 return tensorflow::errors::FailedPrecondition(
523 node.op(), " node expects ", expected_input_count,
524 " input(s) other than control dependencies: ", node.DebugString());
525 }
526 return tensorflow::Status::OK();
527 }
528
529 template <ArrayDataType T>
CreateConstArray(Model * model,string const & name,std::vector<typename toco::DataType<T>> const & data)530 string CreateConstArray(Model* model, string const& name,
531 std::vector<typename toco::DataType<T> > const& data) {
532 // Utility function to create a const 1D array, useful for input parameters.
533 string array_name = toco::AvailableArrayName(*model, name);
534 auto& array = model->GetOrCreateArray(array_name);
535 array.data_type = T;
536 array.mutable_shape()->mutable_dims()->emplace_back(data.size());
537 array.GetMutableBuffer<T>().data = data;
538 return array_name;
539 }
540
541 // Retain TensorFlow NodeDef in Toco Operator.
542 //
543 // If an op is supported by Toco but not supported by TFLite, TFLite exporter
544 // will use the retained NodeDef to populate a Flex op when Flex mode is
545 // enabled.
546 //
547 // This can't be easily applied to all operations, because a TensorFlow node
548 // may become multiple Toco operators. Thus we need to call this function in
549 // operator conversion functions one by one whenever feasible.
550 //
551 // This may cause problems if a graph transformation rule changes parameters
552 // of the node. When calling this function, please check if any existing
553 // graph transformation rule will change an existing operator with the same
554 // type.
555 //
556 // This provides a route to handle Toco-supported & TFLite-unsupported ops
557 // in Flex mode. However it's not a solid solution. Eventually we should
558 // get rid of this.
559 // TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
560 // this function.
RetainTensorFlowNodeDef(const NodeDef & node,Operator * op)561 void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
562 node.SerializeToString(&op->tensorflow_node_def);
563 }
564
ConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)565 tensorflow::Status ConvertConstOperator(
566 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
567 Model* model) {
568 CHECK_EQ(node.op(), "Const");
569 const auto& tensor = GetTensorAttr(node, "value");
570 const auto dtype = GetDataTypeAttr(node, "dtype");
571
572 tensorflow::Status status = tensorflow::Status::OK();
573
574 auto& array = model->GetOrCreateArray(node.name());
575 switch (dtype) {
576 case DT_FLOAT:
577 array.data_type = ArrayDataType::kFloat;
578 status = ImportFloatArray(tensor, &array);
579 break;
580 case DT_INT32:
581 array.data_type = ArrayDataType::kInt32;
582 status = ImportInt32Array(tensor, &array);
583 break;
584 case DT_QUINT8:
585 array.data_type = ArrayDataType::kUint8;
586 status = ImportQuint8Array(tensor, &array);
587 break;
588 case DT_INT64:
589 array.data_type = ArrayDataType::kInt64;
590 status = ImportInt64Array(tensor, &array);
591 break;
592 case DT_STRING:
593 array.data_type = ArrayDataType::kString;
594 status = ImportStringArray(tensor, &array);
595 break;
596 case DT_BOOL:
597 array.data_type = ArrayDataType::kBool;
598 status = ImportBoolArray(tensor, &array);
599 break;
600 case DT_COMPLEX64:
601 array.data_type = ArrayDataType::kComplex64;
602 status = ImportComplex64Array(tensor, &array);
603 break;
604 default:
605 array.data_type = ArrayDataType::kNone;
606 // do nothing, silently ignore the Const data.
607 // We just make a dummy buffer to indicate that
608 // this array does not rely on external input.
609 array.GetMutableBuffer<ArrayDataType::kNone>();
610 break;
611 }
612 TF_RETURN_WITH_CONTEXT_IF_ERROR(
613 status, " (while processing node '" + node.name() + "')");
614 return tensorflow::Status::OK();
615 }
616
ConvertConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)617 tensorflow::Status ConvertConvOperator(
618 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
619 Model* model) {
620 CHECK_EQ(node.op(), "Conv2D");
621 TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
622
623 // We only support NHWC, which is the default data_format.
624 // So if data_format is not defined, we're all good.
625 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
626 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
627
628 const auto& input_name = node.input(0);
629 const auto& weights_name = node.input(1);
630 const auto& reordered_weights_name =
631 AvailableArrayName(*model, weights_name + "_reordered");
632 // Check if a ReorderAxesOperator was already created for these weights
633 // (that happens when multiple layers share the same weights).
634 const Operator* existing_reorder =
635 GetOpWithOutput(*model, reordered_weights_name);
636 if (existing_reorder) {
637 // Check that it is safe to rely on the _reordered naming of the output
638 // array!
639 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
640 } else {
641 // Create a new ReorderAxesOperator
642 auto* reorder = new ReorderAxesOperator;
643 reorder->inputs = {weights_name};
644 reorder->outputs = {reordered_weights_name};
645 reorder->input_axes_order = AxesOrder::kHWIO;
646 reorder->output_axes_order = AxesOrder::kOHWI;
647 model->operators.emplace_back(reorder);
648 }
649 auto* conv = new ConvOperator;
650 conv->inputs = {input_name, reordered_weights_name};
651 conv->outputs = {node.name()};
652 if (!HasAttr(node, "strides")) {
653 return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
654 }
655 const auto& strides = GetListAttr(node, "strides");
656 TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
657 TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
658 TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
659 conv->stride_height = strides.i(1);
660 conv->stride_width = strides.i(2);
661 if (HasAttr(node, "dilations")) {
662 const auto& dilations = GetListAttr(node, "dilations");
663 TF_RETURN_IF_ERROR(
664 ExpectValue(dilations.i_size(), 4, "number of dilations"));
665 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
666 return tensorflow::errors::InvalidArgument(absl::StrCat(
667 "Can only import Conv ops with dilation along the height "
668 "(1st) or width (2nd) axis. TensorFlow op \"",
669 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
670 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
671 }
672 conv->dilation_height_factor = dilations.i(1);
673 conv->dilation_width_factor = dilations.i(2);
674 } else {
675 conv->dilation_height_factor = 1;
676 conv->dilation_width_factor = 1;
677 }
678 const auto& padding = GetStringAttr(node, "padding");
679 if (padding == "SAME") {
680 conv->padding.type = PaddingType::kSame;
681 } else if (padding == "VALID") {
682 conv->padding.type = PaddingType::kValid;
683 } else {
684 return tensorflow::errors::InvalidArgument(
685 "Bad padding (only SAME and VALID are supported)");
686 }
687 model->operators.emplace_back(conv);
688
689 return tensorflow::Status::OK();
690 }
691
ConvertDepthwiseConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)692 tensorflow::Status ConvertDepthwiseConvOperator(
693 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
694 Model* model) {
695 CHECK_EQ(node.op(), "DepthwiseConv2dNative");
696 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
697
698 // We only support NHWC, which is the default data_format.
699 // So if data_format is not defined, we're all good.
700 if (HasAttr(node, "data_format")) {
701 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
702 }
703 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
704
705 const auto& input_name = node.input(0);
706 const auto& weights_name = node.input(1);
707 const auto& reordered_weights_name = weights_name + "_reordered";
708 // Check if a ReorderAxesOperator was already created for these weights
709 // (that happens when multiple layers share the same weights).
710 const Operator* existing_reorder =
711 GetOpWithOutput(*model, reordered_weights_name);
712 if (existing_reorder) {
713 // Check that it is safe to rely on the _reordered naming of the output
714 // array!
715 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
716 } else {
717 // Create a new ReorderAxesOperator
718 auto* reorder = new ReorderAxesOperator;
719 reorder->inputs = {weights_name};
720 reorder->outputs = {reordered_weights_name};
721 reorder->input_axes_order = AxesOrder::kHWIM;
722 reorder->output_axes_order = AxesOrder::k1HWO;
723 model->operators.emplace_back(reorder);
724 }
725 auto* conv = new DepthwiseConvOperator;
726 conv->inputs = {input_name, reordered_weights_name};
727 conv->outputs = {node.name()};
728 const auto& strides = GetListAttr(node, "strides");
729 CHECK_EQ(strides.i_size(), 4);
730 CHECK_EQ(strides.i(0), 1);
731 CHECK_EQ(strides.i(3), 1);
732 conv->stride_height = strides.i(1);
733 conv->stride_width = strides.i(2);
734 if (HasAttr(node, "dilations")) {
735 const auto& dilations = GetListAttr(node, "dilations");
736 TF_RETURN_IF_ERROR(
737 ExpectValue(dilations.i_size(), 4, "number of dilations"));
738 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
739 return tensorflow::errors::InvalidArgument(absl::StrCat(
740 "Can only import Conv ops with dilation along the height "
741 "(1st) or width (2nd) axis. TensorFlow op \"",
742 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
743 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
744 }
745 conv->dilation_height_factor = dilations.i(1);
746 conv->dilation_width_factor = dilations.i(2);
747 } else {
748 conv->dilation_height_factor = 1;
749 conv->dilation_width_factor = 1;
750 }
751 const auto& padding = GetStringAttr(node, "padding");
752 if (padding == "SAME") {
753 conv->padding.type = PaddingType::kSame;
754 } else if (padding == "VALID") {
755 conv->padding.type = PaddingType::kValid;
756 } else {
757 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
758 }
759 model->operators.emplace_back(conv);
760 return tensorflow::Status::OK();
761 }
762
ConvertDepthToSpaceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)763 tensorflow::Status ConvertDepthToSpaceOperator(
764 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
765 Model* model) {
766 CHECK_EQ(node.op(), "DepthToSpace");
767 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
768
769 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
770 auto* op = new DepthToSpaceOperator;
771 op->inputs.push_back(node.input(0));
772 op->outputs.push_back(node.name());
773 op->block_size = GetIntAttr(node, "block_size");
774 QCHECK_GE(op->block_size, 2);
775 model->operators.emplace_back(op);
776 return tensorflow::Status::OK();
777 }
778
ConvertSpaceToDepthOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)779 tensorflow::Status ConvertSpaceToDepthOperator(
780 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
781 Model* model) {
782 CHECK_EQ(node.op(), "SpaceToDepth");
783 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
784
785 tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
786 if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
787 dtype != DT_INT64) {
788 const auto* enum_descriptor = tensorflow::DataType_descriptor();
789 LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
790 << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
791 << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}.";
792 }
793 auto* op = new SpaceToDepthOperator;
794 op->inputs.push_back(node.input(0));
795 op->outputs.push_back(node.name());
796 op->block_size = GetIntAttr(node, "block_size");
797 QCHECK_GE(op->block_size, 2);
798 model->operators.emplace_back(op);
799 return tensorflow::Status::OK();
800 }
801
ConvertBiasAddOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)802 tensorflow::Status ConvertBiasAddOperator(
803 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
804 Model* model) {
805 CHECK_EQ(node.op(), "BiasAdd");
806 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
807
808 const auto& input_name = node.input(0);
809 const auto& bias_name = node.input(1);
810 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
811 auto* biasadd = new AddOperator;
812 biasadd->inputs.push_back(input_name);
813 biasadd->inputs.push_back(bias_name);
814 biasadd->outputs.push_back(node.name());
815 model->operators.emplace_back(biasadd);
816 return tensorflow::Status::OK();
817 }
818
ConvertRandomUniform(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)819 tensorflow::Status ConvertRandomUniform(
820 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
821 Model* model) {
822 CHECK_EQ(node.op(), "RandomUniform");
823 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
824
825 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
826 auto op = absl::make_unique<RandomUniformOperator>();
827 op->inputs.push_back(node.input(0));
828 op->outputs.push_back(node.name());
829 op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
830 op->seed = GetIntAttr(node, "seed");
831 op->seed2 = GetIntAttr(node, "seed2");
832 CHECK(model != nullptr);
833 model->operators.emplace_back(std::move(op));
834 return tensorflow::Status::OK();
835 }
836
ConvertIdentityOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)837 tensorflow::Status ConvertIdentityOperator(
838 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
839 Model* model) {
840 CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
841 node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
842 auto* op = new TensorFlowIdentityOperator;
843 // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
844 // identity nodes with multiple inputs, but the other inputs seem
845 // to be gratuitous (in the case of rajeev_lstm.pb, these are
846 // enumerating the LSTM state arrays). We will just ignore extra
847 // inputs beyond the first input.
848 QCHECK_GE(node.input_size(), 1)
849 << node.op()
850 << " node expects at least 1 input other than control dependencies: "
851 << node.DebugString();
852 const auto& input_name = node.input(0);
853 op->inputs.push_back(input_name);
854 op->outputs.push_back(node.name());
855 model->operators.emplace_back(op);
856 return tensorflow::Status::OK();
857 }
858
ConvertFakeQuantWithMinMaxArgs(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)859 tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
860 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
861 Model* model) {
862 CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
863 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
864 auto* op = new FakeQuantOperator;
865 op->inputs.push_back(node.input(0));
866 op->minmax.reset(new MinMax);
867 auto& minmax = *op->minmax;
868 minmax.min = GetFloatAttr(node, "min");
869 minmax.max = GetFloatAttr(node, "max");
870 op->outputs.push_back(node.name());
871 // tf.fake_quant_with_min_max_args num_bits defaults to 8.
872 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
873 if (HasAttr(node, "narrow_range")) {
874 op->narrow_range = GetBoolAttr(node, "narrow_range");
875 }
876 model->operators.emplace_back(op);
877 return tensorflow::Status::OK();
878 }
879
ConvertFakeQuantWithMinMaxVars(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)880 tensorflow::Status ConvertFakeQuantWithMinMaxVars(
881 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
882 Model* model) {
883 CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
884 const int num_inputs = GetInputsCount(node, tf_import_flags);
885 QCHECK(num_inputs == 3 || num_inputs == 4)
886 << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
887 "control dependencies: "
888 << node.DebugString();
889 auto* op = new FakeQuantOperator;
890 for (int i = 0; i < 3; i++) {
891 op->inputs.push_back(node.input(i));
892 }
893 op->outputs.push_back(node.name());
894 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
895 if (HasAttr(node, "narrow_range")) {
896 op->narrow_range = GetBoolAttr(node, "narrow_range");
897 }
898 model->operators.emplace_back(op);
899 return tensorflow::Status::OK();
900 }
901
ConvertSqueezeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)902 tensorflow::Status ConvertSqueezeOperator(
903 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
904 Model* model) {
905 CHECK_EQ(node.op(), "Squeeze");
906 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
907 auto* op = new SqueezeOperator;
908 op->inputs.push_back(node.input(0));
909 op->outputs.push_back(node.name());
910
911 // When omitted we are to squeeze all dimensions == 1.
912 if (HasAttr(node, "squeeze_dims")) {
913 const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
914 for (int i = 0; i < squeeze_dims.i_size(); ++i) {
915 op->squeeze_dims.push_back(squeeze_dims.i(i));
916 }
917 }
918
919 model->operators.emplace_back(op);
920 return tensorflow::Status::OK();
921 }
922
ConvertSplitOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)923 tensorflow::Status ConvertSplitOperator(
924 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
925 Model* model) {
926 CHECK_EQ(node.op(), "Split");
927 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
928 auto* op = new TensorFlowSplitOperator;
929 op->inputs.push_back(node.input(0));
930 op->inputs.push_back(node.input(1));
931 const int num_split = GetIntAttr(node, "num_split");
932 op->outputs.push_back(node.name());
933 for (int i = 1; i < num_split; i++) {
934 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
935 }
936 op->num_split = num_split;
937 model->operators.emplace_back(op);
938 return tensorflow::Status::OK();
939 }
940
ConvertSplitVOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)941 tensorflow::Status ConvertSplitVOperator(
942 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
943 Model* model) {
944 CHECK_EQ(node.op(), "SplitV");
945 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
946 auto* op = new TensorFlowSplitVOperator;
947 op->inputs.push_back(node.input(0));
948 op->inputs.push_back(node.input(1));
949 op->inputs.push_back(node.input(2));
950 const int num_split = GetIntAttr(node, "num_split");
951 op->outputs.push_back(node.name());
952 for (int i = 1; i < num_split; i++) {
953 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
954 }
955 op->num_split = num_split;
956 model->operators.emplace_back(op);
957 return tensorflow::Status::OK();
958 }
959
ConvertSwitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)960 tensorflow::Status ConvertSwitchOperator(
961 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
962 Model* model) {
963 CHECK_EQ(node.op(), "Switch");
964 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
965 auto* op = new TensorFlowSwitchOperator;
966 op->inputs.push_back(node.input(0));
967 op->inputs.push_back(node.input(1));
968 op->outputs.push_back(node.name());
969 // Switch operators have two outputs: "name" and "name:1".
970 op->outputs.push_back(node.name() + ":1");
971 model->operators.emplace_back(op);
972 return tensorflow::Status::OK();
973 }
974
ConvertSoftmaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)975 tensorflow::Status ConvertSoftmaxOperator(
976 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
977 Model* model) {
978 CHECK_EQ(node.op(), "Softmax");
979 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
980 const auto& input_name = node.input(0);
981 auto* softmax = new SoftmaxOperator;
982 softmax->inputs.push_back(input_name);
983 softmax->outputs.push_back(node.name());
984 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
985 CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
986 softmax->beta = 1.f;
987 model->operators.emplace_back(softmax);
988 return tensorflow::Status::OK();
989 }
990
ConvertLRNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)991 tensorflow::Status ConvertLRNOperator(
992 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
993 Model* model) {
994 CHECK_EQ(node.op(), "LRN");
995 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
996 const auto& input_name = node.input(0);
997 auto* lrn = new LocalResponseNormalizationOperator;
998 lrn->inputs.push_back(input_name);
999 lrn->outputs.push_back(node.name());
1000 lrn->range = GetIntAttr(node, "depth_radius");
1001 lrn->bias = GetFloatAttr(node, "bias");
1002 lrn->alpha = GetFloatAttr(node, "alpha");
1003 lrn->beta = GetFloatAttr(node, "beta");
1004 model->operators.emplace_back(lrn);
1005 return tensorflow::Status::OK();
1006 }
1007
ConvertMaxPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1008 tensorflow::Status ConvertMaxPoolOperator(
1009 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1010 Model* model) {
1011 CHECK_EQ(node.op(), "MaxPool");
1012 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1013 const auto& input_name = node.input(0);
1014 // We only support NHWC, which is the default data_format.
1015 // So if data_format is not defined, we're all good.
1016 if (node.attr().count("data_format")) {
1017 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1018 }
1019 if (HasAttr(node, "T")) {
1020 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1021 } else {
1022 LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1023 }
1024 auto* maxpool = new MaxPoolOperator;
1025 maxpool->inputs.push_back(input_name);
1026 maxpool->outputs.push_back(node.name());
1027 const auto& strides = GetListAttr(node, "strides");
1028 CHECK_EQ(strides.i_size(), 4);
1029 CHECK_EQ(strides.i(0), 1);
1030 CHECK_EQ(strides.i(3), 1);
1031 maxpool->stride_height = strides.i(1);
1032 maxpool->stride_width = strides.i(2);
1033 const auto& ksize = GetListAttr(node, "ksize");
1034 CHECK_EQ(ksize.i_size(), 4);
1035 CHECK_EQ(ksize.i(0), 1);
1036 CHECK_EQ(ksize.i(3), 1);
1037 maxpool->kheight = ksize.i(1);
1038 maxpool->kwidth = ksize.i(2);
1039 const auto& padding = GetStringAttr(node, "padding");
1040 if (padding == "SAME") {
1041 maxpool->padding.type = PaddingType::kSame;
1042 } else if (padding == "VALID") {
1043 maxpool->padding.type = PaddingType::kValid;
1044 } else {
1045 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1046 }
1047 model->operators.emplace_back(maxpool);
1048 return tensorflow::Status::OK();
1049 }
1050
ConvertAvgPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1051 tensorflow::Status ConvertAvgPoolOperator(
1052 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1053 Model* model) {
1054 CHECK_EQ(node.op(), "AvgPool");
1055 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1056 const auto& input_name = node.input(0);
1057 // We only support NHWC, which is the default data_format.
1058 // So if data_format is not defined, we're all good.
1059 if (node.attr().count("data_format")) {
1060 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1061 }
1062 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1063 auto* avgpool = new AveragePoolOperator;
1064 avgpool->inputs.push_back(input_name);
1065 avgpool->outputs.push_back(node.name());
1066 const auto& strides = GetListAttr(node, "strides");
1067 CHECK_EQ(strides.i_size(), 4);
1068 CHECK_EQ(strides.i(0), 1);
1069 CHECK_EQ(strides.i(3), 1);
1070 avgpool->stride_height = strides.i(1);
1071 avgpool->stride_width = strides.i(2);
1072 const auto& ksize = GetListAttr(node, "ksize");
1073 CHECK_EQ(ksize.i_size(), 4);
1074 CHECK_EQ(ksize.i(0), 1);
1075 CHECK_EQ(ksize.i(3), 1);
1076 avgpool->kheight = ksize.i(1);
1077 avgpool->kwidth = ksize.i(2);
1078 const auto& padding = GetStringAttr(node, "padding");
1079 if (padding == "SAME") {
1080 avgpool->padding.type = PaddingType::kSame;
1081 } else if (padding == "VALID") {
1082 avgpool->padding.type = PaddingType::kValid;
1083 } else {
1084 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1085 }
1086 model->operators.emplace_back(avgpool);
1087 return tensorflow::Status::OK();
1088 }
1089
ConvertBatchMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1090 tensorflow::Status ConvertBatchMatMulOperator(
1091 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1092 Model* model) {
1093 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1094
1095 auto* batch_matmul = new BatchMatMulOperator;
1096 // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1097 if (HasAttr(node, "adj_x")) {
1098 batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1099 }
1100 if (HasAttr(node, "adj_y")) {
1101 batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1102 }
1103 batch_matmul->inputs = {node.input(0), node.input(1)};
1104 batch_matmul->outputs = {node.name()};
1105
1106 // For Flex mode. Please read the comments of the function.
1107 RetainTensorFlowNodeDef(node, batch_matmul);
1108
1109 model->operators.emplace_back(batch_matmul);
1110 return tensorflow::Status::OK();
1111 }
1112
ConvertMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1113 tensorflow::Status ConvertMatMulOperator(
1114 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1115 Model* model) {
1116 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1117
1118 CHECK(!HasAttr(node, "adjoint_a") ||
1119 (GetBoolAttr(node, "adjoint_a") == false));
1120 CHECK(!HasAttr(node, "adjoint_b") ||
1121 (GetBoolAttr(node, "adjoint_b") == false));
1122
1123 auto* matmul = new TensorFlowMatMulOperator;
1124 if (HasAttr(node, "transpose_a")) {
1125 matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1126 }
1127 if (HasAttr(node, "transpose_b")) {
1128 matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1129 }
1130
1131 matmul->inputs = {node.input(0), node.input(1)};
1132 matmul->outputs = {node.name()};
1133 model->operators.emplace_back(matmul);
1134 return tensorflow::Status::OK();
1135 }
1136
ConvertConcatOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1137 tensorflow::Status ConvertConcatOperator(
1138 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1139 Model* model) {
1140 Operator* op = nullptr;
1141 if (node.op() == "Concat") {
1142 op = new TensorFlowConcatOperator;
1143 } else if (node.op() == "ConcatV2") {
1144 op = new TensorFlowConcatV2Operator;
1145 } else {
1146 LOG(FATAL) << "Expected Concat or ConcatV2";
1147 }
1148 const int num_inputs = GetInputsCount(node, tf_import_flags);
1149 QCHECK_GE(num_inputs, 2)
1150 << node.op()
1151 << " node expects at least 2 inputs other than control dependencies: "
1152 << node.DebugString();
1153 CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1154 for (int i = 0; i < num_inputs; ++i) {
1155 op->inputs.push_back(node.input(i));
1156 }
1157 op->outputs.push_back(node.name());
1158 model->operators.emplace_back(op);
1159 return tensorflow::Status::OK();
1160 }
1161
ConvertMirrorPadOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1162 tensorflow::Status ConvertMirrorPadOperator(
1163 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1164 Model* model) {
1165 if (node.op() != "MirrorPad") {
1166 LOG(FATAL) << "Expected MirrorPad.";
1167 }
1168 const int num_inputs = GetInputsCount(node, tf_import_flags);
1169 CHECK_EQ(num_inputs, 2);
1170 auto* op = new MirrorPadOperator;
1171 for (int i = 0; i < num_inputs; ++i) {
1172 op->inputs.push_back(node.input(i));
1173 }
1174 op->outputs.push_back(node.name());
1175 const auto mode = GetStringAttr(node, "mode");
1176 if (mode == "REFLECT") {
1177 op->mode = toco::MirrorPadMode::kReflect;
1178 } else if (mode == "SYMMETRIC") {
1179 op->mode = toco::MirrorPadMode::kSymmetric;
1180 }
1181
1182 model->operators.emplace_back(op);
1183
1184 return tensorflow::Status::OK();
1185 }
1186
1187 static constexpr int kAnyNumInputs = -1;
1188
1189 enum FlexSupport { kFlexOk, kFlexNotOk };
1190
1191 // This method supports simple operators without additional attributes.
1192 // Converts a simple operator that takes no attributes. The list of inputs is
1193 // taken from the given NodeDef, and its number must match NumInputs, unless
1194 // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1195 // will be eligible for being exported as a flex op.
1196 template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
ConvertSimpleOperatorGeneric(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1197 tensorflow::Status ConvertSimpleOperatorGeneric(
1198 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1199 Model* model) {
1200 if (NumInputs != kAnyNumInputs) {
1201 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1202 }
1203 auto* op = new Op;
1204 const int num_inputs = GetInputsCount(node, tf_import_flags);
1205 for (int i = 0; i < num_inputs; ++i) {
1206 op->inputs.push_back(node.input(i));
1207 }
1208 op->outputs.push_back(node.name());
1209 if (NumOutputs > 1) {
1210 for (int i = 1; i < NumOutputs; ++i) {
1211 op->outputs.push_back(node.name() + ":" + std::to_string(i));
1212 }
1213 }
1214
1215 if (flex == kFlexOk) {
1216 RetainTensorFlowNodeDef(node, op);
1217 }
1218
1219 model->operators.emplace_back(op);
1220 return tensorflow::Status::OK();
1221 }
1222
1223 // Convert a simple operator which is not valid as a flex op.
1224 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1225 tensorflow::Status ConvertSimpleOperator(
1226 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1227 Model* model) {
1228 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1229 node, tf_import_flags, model);
1230 }
1231
1232 // Convert a simple operator which is valid as a flex op.
1233 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperatorFlexOk(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1234 tensorflow::Status ConvertSimpleOperatorFlexOk(
1235 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1236 Model* model) {
1237 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1238 node, tf_import_flags, model);
1239 }
1240
GetOutputNamesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)1241 void GetOutputNamesFromNodeDef(const NodeDef& node,
1242 const tensorflow::OpDef& op_def,
1243 TensorFlowUnsupportedOperator* op) {
1244 int next_output = 0;
1245 auto add_output = [&node, &next_output, op]() {
1246 if (next_output == 0) {
1247 op->outputs.push_back(node.name()); // Implicit :0.
1248 } else {
1249 op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
1250 }
1251 ++next_output;
1252 };
1253 for (int i = 0; i < op_def.output_arg_size(); ++i) {
1254 string multiples = op_def.output_arg(i).number_attr();
1255 if (!multiples.empty()) {
1256 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
1257 int num_outputs = GetIntAttr(node, multiples);
1258 for (int j = 0; j < num_outputs; ++j) {
1259 add_output();
1260 }
1261 } else {
1262 string list = op_def.output_arg(i).type_list_attr();
1263 if (!list.empty()) {
1264 CHECK(HasAttr(node, list)) << "No attr named " << list;
1265 const AttrValue::ListValue& list_value = GetListAttr(node, list);
1266 for (int j = 0; j < list_value.type_size(); ++j) {
1267 add_output();
1268 }
1269 } else {
1270 add_output();
1271 }
1272 }
1273 }
1274 }
1275
GetOutputTypesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)1276 void GetOutputTypesFromNodeDef(const NodeDef& node,
1277 const tensorflow::OpDef& op_def,
1278 TensorFlowUnsupportedOperator* op) {
1279 // The given type to the op, or clear the types if invalid.
1280 auto add_type = [&node, op](tensorflow::DataType type) {
1281 if (type == tensorflow::DT_INVALID) {
1282 LOG(WARNING) << "Op node missing output type attribute: " << node.name();
1283 op->output_data_types.clear();
1284 } else {
1285 op->output_data_types.push_back(ConvertDataType(type));
1286 }
1287 };
1288
1289 // Retrieve the data type according to the OpDef definition: either the
1290 // "type" or "type_attr" field will be set.
1291 auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
1292 if (a.type() != tensorflow::DT_INVALID) {
1293 return a.type();
1294 } else if (HasAttr(node, a.type_attr())) {
1295 return GetDataTypeAttr(node, a.type_attr());
1296 } else {
1297 return tensorflow::DT_INVALID;
1298 }
1299 };
1300
1301 for (int i = 0; i < op_def.output_arg_size(); ++i) {
1302 string multiples = op_def.output_arg(i).number_attr();
1303 if (!multiples.empty()) {
1304 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
1305 int num_outputs = GetIntAttr(node, multiples);
1306 auto type = get_type(op_def.output_arg(i));
1307 for (int j = 0; j < num_outputs; ++j) {
1308 add_type(type);
1309 }
1310 } else {
1311 string list = op_def.output_arg(i).type_list_attr();
1312 if (!list.empty()) {
1313 CHECK(HasAttr(node, list)) << "No attr named " << list;
1314 const AttrValue::ListValue& list_value = GetListAttr(node, list);
1315 for (int j = 0; j < list_value.type_size(); ++j) {
1316 add_type(list_value.type(j));
1317 }
1318 } else {
1319 add_type(get_type(op_def.output_arg(i)));
1320 }
1321 }
1322 }
1323 }
1324
ConvertUnsupportedOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1325 tensorflow::Status ConvertUnsupportedOperator(
1326 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1327 Model* model) {
1328 // Names of special attributes in TF graph that are used by Toco.
1329 static constexpr char kAttrOutputQuantized[] = "_output_quantized";
1330 static constexpr char kAttrOutputTypes[] = "_output_types";
1331 static constexpr char kAttrOutputShapes[] = "_output_shapes";
1332 static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
1333 "_support_output_type_float_in_quantized_op";
1334
1335 LOG(INFO) << "Converting unsupported operation: " << node.op();
1336
1337 auto* op = new TensorFlowUnsupportedOperator;
1338 op->tensorflow_op = node.op();
1339
1340 // For Flex mode. Please read the comments of the function.
1341 RetainTensorFlowNodeDef(node, op);
1342
1343 model->operators.emplace_back(op);
1344
1345 // Parse inputs.
1346 const int num_inputs = GetInputsCount(node, tf_import_flags);
1347 for (int i = 0; i < num_inputs; ++i) {
1348 op->inputs.push_back(node.input(i));
1349 }
1350
1351 // Parse outputs. Name them after the node's name, plus an ordinal suffix.
1352 // Note that some outputs are to be multiplied by a named attribute.
1353 const tensorflow::OpDef* op_def = nullptr;
1354 if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
1355 GetOutputNamesFromNodeDef(node, *op_def, op);
1356 } else {
1357 op->outputs.push_back(node.name()); // Implicit :0.
1358 }
1359
1360 // Parse if the op supports quantization
1361 if (HasAttr(node, kAttrOutputQuantized)) {
1362 op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
1363 }
1364 // Parse if the quantized op allows output arrays of type float
1365 if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
1366 op->support_output_type_float_in_quantized_op =
1367 GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
1368 }
1369
1370 // Parse output type(s).
1371 if (HasAttr(node, kAttrOutputTypes)) {
1372 const auto& output_types = GetListAttr(node, kAttrOutputTypes);
1373 for (int i = 0; i < output_types.type_size(); ++i) {
1374 op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
1375 }
1376 } else if (HasAttr(node, "Tout")) {
1377 const auto& output_type = GetDataTypeAttr(node, "Tout");
1378 op->output_data_types.push_back(ConvertDataType(output_type));
1379 } else if (op_def != nullptr) {
1380 GetOutputTypesFromNodeDef(node, *op_def, op);
1381 } else {
1382 // TODO(b/113613439): Figure out how to propagate types for custom ops
1383 // that have no OpDef.
1384 LOG(INFO) << "Unable to determine output type for op: " << node.op();
1385 }
1386
1387 // Parse output shape(s).
1388 if (HasAttr(node, kAttrOutputShapes)) {
1389 const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
1390 Shape output_shape;
1391 for (int i = 0; i < output_shapes.shape_size(); ++i) {
1392 const auto& shape = output_shapes.shape(i);
1393 // TOCO doesn't yet properly handle shapes with wildcard dimensions.
1394 // TODO(b/113613439): Handle shape inference for unsupported ops that have
1395 // shapes with wildcard dimensions.
1396 if (HasWildcardDimension(shape)) {
1397 LOG(INFO) << "Skipping wildcard output shape(s) for node: "
1398 << node.name();
1399 op->output_shapes.clear();
1400 break;
1401 }
1402 const auto status =
1403 ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
1404 if (!status.ok()) {
1405 return status;
1406 }
1407 op->output_shapes.push_back(output_shape);
1408 }
1409 }
1410 return tensorflow::Status::OK();
1411 }
1412
1413 // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1414 // the types are not supported. Converting Const operators here avoids
1415 // expensive copies of the protocol buffers downstream in the flex delegate.
ConditionallyConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1416 tensorflow::Status ConditionallyConvertConstOperator(
1417 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1418 Model* model) {
1419 // We avoid incomplete and zero shapes because the resulting arrays
1420 // are not completely compatible with Eager/TensorFlow.
1421 const auto& tensor = GetTensorAttr(node, "value");
1422 const auto& shape = tensor.tensor_shape();
1423 for (const auto& dim : shape.dim()) {
1424 if (dim.size() <= 0) {
1425 return ConvertUnsupportedOperator(node, tf_import_flags, model);
1426 }
1427 }
1428
1429 switch (GetDataTypeAttr(node, "dtype")) {
1430 case DT_FLOAT:
1431 case DT_INT32:
1432 case DT_QUINT8:
1433 case DT_INT64:
1434 case DT_STRING:
1435 case DT_BOOL:
1436 case DT_COMPLEX64:
1437 return ConvertConstOperator(node, tf_import_flags, model);
1438 default:
1439 return ConvertUnsupportedOperator(node, tf_import_flags, model);
1440 }
1441 }
1442
ConvertStridedSliceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1443 tensorflow::Status ConvertStridedSliceOperator(
1444 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1445 Model* model) {
1446 CHECK_EQ(node.op(), "StridedSlice");
1447 // TODO(soroosh): The 4th input (strides) should be e optional, to be
1448 // consistent with TF.
1449 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1450
1451 auto* op = new StridedSliceOperator;
1452 for (const auto& input : node.input()) {
1453 op->inputs.push_back(input);
1454 }
1455 op->outputs.push_back(node.name());
1456
1457 op->begin_mask =
1458 HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1459 op->ellipsis_mask =
1460 HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1461 op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1462 op->new_axis_mask =
1463 HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1464 op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1465 ? GetIntAttr(node, "shrink_axis_mask")
1466 : 0;
1467
1468 model->operators.emplace_back(op);
1469 return tensorflow::Status::OK();
1470 }
1471
ConvertPlaceholderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1472 tensorflow::Status ConvertPlaceholderOperator(
1473 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1474 Model* model) {
1475 CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1476 if (node.op() == "Placeholder") {
1477 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1478 }
1479 auto& array = model->GetOrCreateArray(node.name());
1480 if (node.attr().count("dtype")) {
1481 array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1482 }
1483 if (node.attr().count("shape")) {
1484 const auto& shape = GetShapeAttr(node, "shape");
1485 auto num_dims = shape.dim_size();
1486 // TODO(b/62716978): This logic needs to be revisited. During dims
1487 // refactoring it is an interim fix.
1488 if (num_dims > 0 && !HasWildcardDimension(shape)) {
1489 auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1490 dst_array_dims.resize(num_dims);
1491 for (std::size_t i = 0; i < num_dims; i++) {
1492 dst_array_dims[i] = shape.dim(i).size();
1493 }
1494 }
1495 }
1496 return tensorflow::Status::OK();
1497 }
1498
ConvertNoOpOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1499 tensorflow::Status ConvertNoOpOperator(
1500 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1501 Model* model) {
1502 return tensorflow::Status::OK();
1503 }
1504
ConvertCastOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1505 tensorflow::Status ConvertCastOperator(
1506 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1507 Model* model) {
1508 CHECK_EQ(node.op(), "Cast");
1509 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1510 const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1511 const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1512 auto* op = new CastOperator;
1513 op->src_data_type = ConvertDataType(tf_src_dtype);
1514 op->dst_data_type = ConvertDataType(tf_dst_dtype);
1515 op->inputs.push_back(node.input(0));
1516 op->outputs.push_back(node.name());
1517 model->operators.emplace_back(op);
1518 return tensorflow::Status::OK();
1519 }
1520
ConvertFloorOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1521 tensorflow::Status ConvertFloorOperator(
1522 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1523 Model* model) {
1524 CHECK_EQ(node.op(), "Floor");
1525 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1526 const auto data_type = GetDataTypeAttr(node, "T");
1527 CHECK(data_type == DT_FLOAT);
1528 auto* op = new FloorOperator;
1529 op->inputs.push_back(node.input(0));
1530 op->outputs.push_back(node.name());
1531 model->operators.emplace_back(op);
1532 return tensorflow::Status::OK();
1533 }
1534
ConvertCeilOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1535 tensorflow::Status ConvertCeilOperator(
1536 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1537 Model* model) {
1538 CHECK_EQ(node.op(), "Ceil");
1539 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1540 const auto data_type = GetDataTypeAttr(node, "T");
1541 CHECK(data_type == DT_FLOAT);
1542 auto* op = new CeilOperator;
1543 op->inputs.push_back(node.input(0));
1544 op->outputs.push_back(node.name());
1545 model->operators.emplace_back(op);
1546 return tensorflow::Status::OK();
1547 }
1548
ConvertGatherOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1549 tensorflow::Status ConvertGatherOperator(
1550 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1551 Model* model) {
1552 CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1553 if (node.op() == "Gather")
1554 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1555 if (node.op() == "GatherV2")
1556 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1557 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1558 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1559 auto* op = new GatherOperator;
1560 op->inputs.push_back(node.input(0));
1561 op->inputs.push_back(node.input(1));
1562 if (node.input_size() >= 3) {
1563 // GatherV2 form where we are provided an axis. It may be either a constant
1564 // or runtime defined value, so we just wire up the array and let
1565 // ResolveGatherAttributes take care of it later on.
1566 const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1567 CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1568 op->inputs.push_back(node.input(2));
1569 } else {
1570 // Gather form that assumes axis=0.
1571 op->axis = {0};
1572 }
1573 op->outputs.push_back(node.name());
1574 model->operators.emplace_back(op);
1575 return tensorflow::Status::OK();
1576 }
1577
ConvertGatherNdOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1578 tensorflow::Status ConvertGatherNdOperator(
1579 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1580 Model* model) {
1581 CHECK_EQ(node.op(), "GatherNd");
1582 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1583 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1584 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1585 auto* op = new GatherNdOperator;
1586 op->inputs.push_back(node.input(0));
1587 op->inputs.push_back(node.input(1));
1588 op->outputs.push_back(node.name());
1589 model->operators.emplace_back(op);
1590 return tensorflow::Status::OK();
1591 }
1592
1593 template <typename Op>
ConvertArgMinMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1594 tensorflow::Status ConvertArgMinMaxOperator(
1595 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1596 Model* model) {
1597 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1598 const auto axis_data_type =
1599 HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1600 const auto output_type = HasAttr(node, "output_type")
1601 ? GetDataTypeAttr(node, "output_type")
1602 : DT_INT64;
1603 CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1604 CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1605 auto* op = new Op;
1606 op->output_data_type = ConvertDataType(output_type);
1607 op->inputs.push_back(node.input(0));
1608 op->inputs.push_back(node.input(1));
1609 op->outputs.push_back(node.name());
1610 model->operators.emplace_back(op);
1611 return tensorflow::Status::OK();
1612 }
1613
ConvertArgMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1614 tensorflow::Status ConvertArgMaxOperator(
1615 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1616 Model* model) {
1617 CHECK_EQ(node.op(), "ArgMax");
1618 return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
1619 }
1620
ConvertArgMinOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1621 tensorflow::Status ConvertArgMinOperator(
1622 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1623 Model* model) {
1624 CHECK_EQ(node.op(), "ArgMin");
1625 return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
1626 }
1627
ConvertResizeBilinearOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1628 tensorflow::Status ConvertResizeBilinearOperator(
1629 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1630 Model* model) {
1631 CHECK_EQ(node.op(), "ResizeBilinear");
1632 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1633 auto* op = new ResizeBilinearOperator;
1634
1635 op->align_corners = false;
1636 if (HasAttr(node, "align_corners")) {
1637 op->align_corners = GetBoolAttr(node, "align_corners");
1638 }
1639
1640 op->inputs.push_back(node.input(0));
1641 op->inputs.push_back(node.input(1));
1642 op->outputs.push_back(node.name());
1643 model->operators.emplace_back(op);
1644 return tensorflow::Status::OK();
1645 }
1646
ConvertResizeNearestNeighborOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1647 tensorflow::Status ConvertResizeNearestNeighborOperator(
1648 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1649 Model* model) {
1650 CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1651 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1652 auto* op = new ResizeNearestNeighborOperator;
1653
1654 op->align_corners = false;
1655 if (HasAttr(node, "align_corners")) {
1656 op->align_corners = GetBoolAttr(node, "align_corners");
1657 }
1658
1659 op->inputs.push_back(node.input(0));
1660 op->inputs.push_back(node.input(1));
1661 op->outputs.push_back(node.name());
1662 model->operators.emplace_back(op);
1663 return tensorflow::Status::OK();
1664 }
1665
ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1666 tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1667 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1668 Model* model) {
1669 CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1670 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1671
1672 // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1673 // to the input, before feeding it into TensorFlowRsqrtOperator.
1674 // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1675
1676 string multiplier = node.name() + "_mul";
1677 if (GetBoolAttr(node, "scale_after_normalization")) {
1678 // Create graph:
1679 // v -> RSQRT ->
1680 // MUL -> multiplier
1681 // gamma ----->
1682 string rsqrt = node.name() + "_rsqrt";
1683
1684 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1685 rsqrt_op->inputs.push_back(node.input(2));
1686 rsqrt_op->outputs.push_back(rsqrt);
1687 model->operators.emplace_back(rsqrt_op);
1688
1689 auto* mul_op = new MulOperator;
1690 mul_op->inputs.push_back(rsqrt);
1691 mul_op->inputs.push_back(node.input(4));
1692 mul_op->outputs.push_back(multiplier);
1693 model->operators.emplace_back(mul_op);
1694 } else {
1695 // Create graph:
1696 // v -> RSQRT -> multiplier
1697 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1698 rsqrt_op->inputs.push_back(node.input(2));
1699 rsqrt_op->outputs.push_back(multiplier);
1700 model->operators.emplace_back(rsqrt_op);
1701 }
1702
1703 auto* op = new BatchNormalizationOperator;
1704 op->global_normalization = true;
1705
1706 op->inputs.push_back(node.input(0));
1707 op->inputs.push_back(node.input(1));
1708 op->inputs.push_back(multiplier);
1709 op->inputs.push_back(node.input(3));
1710 op->outputs.push_back(node.name());
1711
1712 model->operators.emplace_back(op);
1713 return tensorflow::Status::OK();
1714 }
1715
ConvertFusedBatchNormOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1716 tensorflow::Status ConvertFusedBatchNormOperator(
1717 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1718 Model* model) {
1719 CHECK_EQ(node.op(), "FusedBatchNorm");
1720 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1721
1722 // Declare shortcuts for the inputs.
1723 const string& gamma_input = node.input(1);
1724 const string& beta_input = node.input(2);
1725 const string& moving_mean_input = node.input(3);
1726 const string& moving_variance_input = node.input(4);
1727
1728 // Create an array holding the epsilon value (typically, 0.001).
1729 const string epsilon_array_name = CreateConstArray<ArrayDataType::kFloat>(
1730 model, node.name() + "_epsilon_array", {GetFloatAttr(node, "epsilon")});
1731
1732 // Add epsilon to the moving variance.
1733 const string epsilon_add_op_name = node.name() + "_epsilon";
1734 auto* epsilon_add_op = new AddOperator;
1735 epsilon_add_op->inputs.push_back(moving_variance_input);
1736 epsilon_add_op->inputs.push_back(epsilon_array_name);
1737 epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1738 model->operators.emplace_back(epsilon_add_op);
1739
1740 // Take the inverse square root of the (variance + epsilon).
1741 const string rsqrt_op_name = node.name() + "_rsqrt";
1742 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1743 rsqrt_op->inputs.push_back(epsilon_add_op_name);
1744 rsqrt_op->outputs.push_back(rsqrt_op_name);
1745 model->operators.emplace_back(rsqrt_op);
1746
1747 // Multiply the result by gamma.
1748 const string multiplier = node.name() + "_mul";
1749 auto* mul_op = new MulOperator;
1750 mul_op->inputs.push_back(rsqrt_op_name);
1751 mul_op->inputs.push_back(gamma_input);
1752 mul_op->outputs.push_back(multiplier);
1753 model->operators.emplace_back(mul_op);
1754
1755 // Now we have all required inputs for the BatchNormalizationOperator.
1756 auto* op = new BatchNormalizationOperator;
1757 op->global_normalization = true;
1758
1759 op->inputs.push_back(node.input(0));
1760 op->inputs.push_back(moving_mean_input);
1761 op->inputs.push_back(multiplier);
1762 op->inputs.push_back(beta_input);
1763 op->outputs.push_back(node.name());
1764
1765 model->operators.emplace_back(op);
1766 return tensorflow::Status::OK();
1767 }
1768
ConvertSpaceToBatchNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1769 tensorflow::Status ConvertSpaceToBatchNDOperator(
1770 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1771 Model* model) {
1772 CHECK_EQ(node.op(), "SpaceToBatchND");
1773 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1774 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1775 CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1776 auto* op = new SpaceToBatchNDOperator;
1777 op->inputs.push_back(node.input(0));
1778 op->inputs.push_back(node.input(1));
1779 op->inputs.push_back(node.input(2));
1780 op->outputs.push_back(node.name());
1781 model->operators.emplace_back(op);
1782 return tensorflow::Status::OK();
1783 }
1784
ConvertBatchToSpaceNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1785 tensorflow::Status ConvertBatchToSpaceNDOperator(
1786 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1787 Model* model) {
1788 CHECK_EQ(node.op(), "BatchToSpaceND");
1789 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1790 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1791 CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1792 auto* op = new BatchToSpaceNDOperator;
1793 op->inputs.push_back(node.input(0));
1794 op->inputs.push_back(node.input(1));
1795 op->inputs.push_back(node.input(2));
1796 op->outputs.push_back(node.name());
1797 model->operators.emplace_back(op);
1798 return tensorflow::Status::OK();
1799 }
1800
1801 template <typename T>
ConvertReduceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1802 tensorflow::Status ConvertReduceOperator(
1803 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1804 Model* model) {
1805 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1806 auto* op = new T;
1807 op->inputs.push_back(node.input(0));
1808 op->inputs.push_back(node.input(1));
1809 op->outputs.push_back(node.name());
1810 model->operators.emplace_back(op);
1811 if (HasAttr(node, "keepdims")) {
1812 op->keep_dims = GetBoolAttr(node, "keepdims");
1813 } else if (HasAttr(node, "keep_dims")) {
1814 op->keep_dims = GetBoolAttr(node, "keep_dims");
1815 }
1816 return tensorflow::Status::OK();
1817 }
1818
ConvertSvdfOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1819 tensorflow::Status ConvertSvdfOperator(
1820 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1821 Model* model) {
1822 CHECK_EQ(node.op(), "Svdf");
1823 const int input_size = GetInputsCount(node, tf_import_flags);
1824 QCHECK(input_size == 3 || input_size == 4)
1825 << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1826 << node.DebugString();
1827 bool has_bias = (input_size == 4);
1828 auto* op = new SvdfOperator;
1829 op->inputs.push_back(node.input(0));
1830 op->inputs.push_back(node.input(1));
1831 op->inputs.push_back(node.input(2));
1832 if (has_bias) {
1833 op->inputs.push_back(node.input(3));
1834 }
1835 op->outputs.push_back(node.name() + "_state");
1836 op->outputs.push_back(node.name());
1837 if (node.attr().at("ActivationFunction").s() == "Relu") {
1838 op->fused_activation_function = FusedActivationFunctionType::kRelu;
1839 } else {
1840 op->fused_activation_function = FusedActivationFunctionType::kNone;
1841 }
1842 op->rank = node.attr().at("Rank").i();
1843 model->operators.emplace_back(op);
1844 return tensorflow::Status::OK();
1845 }
1846
1847 // This is just bare bones support to get the shapes to propagate.
ConvertTransposeConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1848 tensorflow::Status ConvertTransposeConvOperator(
1849 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1850 Model* model) {
1851 CHECK_EQ(node.op(), "Conv2DBackpropInput");
1852 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1853 auto* op = new TransposeConvOperator;
1854 op->inputs.push_back(node.input(0));
1855 op->inputs.push_back(node.input(1));
1856 op->inputs.push_back(node.input(2));
1857 op->outputs.push_back(node.name());
1858 const auto& strides = GetListAttr(node, "strides");
1859 op->stride_height = strides.i(1);
1860 op->stride_width = strides.i(2);
1861 CHECK_EQ(strides.i_size(), 4)
1862 << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1863 << node.name() << "\" has " << strides.i_size() << "D strides.";
1864 CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1865 << "Can only import TransposeConv ops with striding along the height "
1866 "(1st) or width (2nd) axis. TensorFlow op \""
1867 << node.name() << "\" had strides:[ " << strides.i(0) << ", "
1868 << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
1869 op->stride_height = strides.i(1);
1870 op->stride_width = strides.i(2);
1871 if (HasAttr(node, "dilations")) {
1872 const auto& dilations = GetListAttr(node, "dilations");
1873 CHECK_EQ(dilations.i_size(), 4)
1874 << "Dilation unsupported in TransposeConv. TensorFlow op \""
1875 << node.name() << "\" had dilations";
1876 CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
1877 (dilations.i(1) == 1) && (dilations.i(3) == 1))
1878 << "Dilation unsupported in TransposeConv. TensorFlow op \""
1879 << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
1880 << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
1881 << "].";
1882 }
1883
1884 const string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
1885 const string& transposed_weights_name = weights_name + "_transposed";
1886 // Check if a TransposeOperator was already created for these weights
1887 // (can happen when multiple layers share the same weights).
1888 const Operator* existing_transpose =
1889 GetOpWithOutput(*model, transposed_weights_name);
1890 if (existing_transpose) {
1891 CHECK(existing_transpose->type == OperatorType::kTranspose);
1892 } else {
1893 // Transpose weights from HWOI order to OHWI order, which is more efficient
1894 // for computation. (Note that TensorFlow considers the order as HWIO
1895 // because they consider this a backward conv, inverting the sense of
1896 // input/output.)
1897 TransposeOperator* transpose = new TransposeOperator;
1898 string perm_array = CreateConstArray<ArrayDataType::kInt32>(
1899 model, node.name() + "_transpose_perm", {2, 0, 1, 3});
1900 transpose->inputs = {weights_name, perm_array};
1901 transpose->outputs = {transposed_weights_name};
1902 model->operators.emplace_back(transpose);
1903 }
1904 op->inputs[1] = transposed_weights_name;
1905
1906 auto const& padding = GetStringAttr(node, "padding");
1907 if (padding == "SAME") {
1908 op->padding.type = PaddingType::kSame;
1909 } else if (padding == "VALID") {
1910 op->padding.type = PaddingType::kValid;
1911 } else {
1912 LOG(FATAL) << "Only SAME and VALID padding supported on "
1913 "Conv2DBackpropInput nodes.";
1914 }
1915 model->operators.emplace_back(op);
1916 return tensorflow::Status::OK();
1917 }
1918
ConvertRangeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1919 tensorflow::Status ConvertRangeOperator(
1920 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1921 Model* model) {
1922 CHECK_EQ(node.op(), "Range");
1923 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1924 auto* op = new RangeOperator;
1925 if (HasAttr(node, "Tidx")) {
1926 const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
1927 CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
1928 dtype == DT_FLOAT);
1929 op->dtype = ConvertDataType(dtype);
1930 }
1931 op->inputs.push_back(node.input(0));
1932 op->inputs.push_back(node.input(1));
1933 op->inputs.push_back(node.input(2));
1934 op->outputs.push_back(node.name());
1935
1936 model->operators.emplace_back(op);
1937 return tensorflow::Status::OK();
1938 }
1939
1940 // Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
1941 // they aren't the same thing. tf.stack results in a "Pack" operator. "Stack"
1942 // operators also exist, but involve manipulating the TF runtime stack, and are
1943 // not directly related to tf.stack() usage.
ConvertPackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1944 tensorflow::Status ConvertPackOperator(
1945 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1946 Model* model) {
1947 CHECK_EQ(node.op(), "Pack");
1948 auto op = absl::make_unique<PackOperator>();
1949 const int num_inputs = GetInputsCount(node, tf_import_flags);
1950 QCHECK_GE(num_inputs, 1)
1951 << node.op()
1952 << " node expects at least 1 input other than control dependencies: "
1953 << node.DebugString();
1954 CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
1955 for (int i = 0; i < num_inputs; ++i) {
1956 op->inputs.push_back(node.input(i));
1957 }
1958 op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
1959 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
1960 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
1961 op->outputs.push_back(node.name());
1962 model->operators.emplace_back(std::move(op));
1963 return tensorflow::Status::OK();
1964 }
1965
ConvertUnpackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1966 tensorflow::Status ConvertUnpackOperator(
1967 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1968 Model* model) {
1969 CHECK_EQ(node.op(), "Unpack");
1970 auto op = absl::make_unique<UnpackOperator>();
1971 const int num_inputs = GetInputsCount(node, tf_import_flags);
1972 QCHECK_EQ(num_inputs, 1);
1973 op->inputs.push_back(node.input(0));
1974 op->num = GetIntAttr(node, "num");
1975 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
1976 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
1977
1978 op->outputs.push_back(node.name()); // Implicit :0.
1979 for (int i = 1; i < op->num; ++i) {
1980 op->outputs.push_back(node.name() + ":" + std::to_string(i));
1981 }
1982 model->operators.emplace_back(std::move(op));
1983 return tensorflow::Status::OK();
1984 }
1985
1986 // Some TensorFlow ops only occur in graph cycles, representing
1987 // control flow. We do not currently support control flow, so we wouldn't
1988 // be able to fully support such graphs, including performing inference,
1989 // anyway. However, rather than erroring out early on graphs being cyclic,
1990 // it helps to at least support these just enough to allow getting a
1991 // graph visualization. This is not trivial, as we require graphs to be
1992 // acyclic aside from RNN back-edges. The solution is to special-case
1993 // such ops as RNN back-edges, which is technically incorrect (does not
1994 // allow representing the op's semantics) but good enough to get a
1995 // graph visualization.
ConvertOperatorSpecialCasedAsRNNBackEdge(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1996 tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
1997 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1998 Model* model) {
1999 // At the moment, the only type of operator special-cased in this way is
2000 // NextIteration, occurring only in control-flow cycles.
2001 CHECK_EQ(node.op(), "NextIteration");
2002 CHECK_EQ(node.input_size(), 1);
2003 auto* rnn_state = model->flags.add_rnn_states();
2004 // This RNN state is not explicitly created by the user, so it's
2005 // OK for some later graph transformation to discard it.
2006 rnn_state->set_discardable(true);
2007 rnn_state->set_state_array(node.name());
2008 rnn_state->set_back_edge_source_array(node.input(0));
2009 return tensorflow::Status::OK();
2010 }
2011
ConvertShapeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2012 tensorflow::Status ConvertShapeOperator(
2013 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2014 Model* model) {
2015 CHECK_EQ(node.op(), "Shape");
2016 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2017 const auto out_type =
2018 HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2019 CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2020 auto op = absl::make_unique<TensorFlowShapeOperator>();
2021 op->output_data_type = ConvertDataType(out_type);
2022 op->inputs.push_back(node.input(0));
2023 op->outputs.push_back(node.name());
2024 model->operators.push_back(std::move(op));
2025 return tensorflow::Status::OK();
2026 }
2027
ConvertReverseSequenceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2028 tensorflow::Status ConvertReverseSequenceOperator(
2029 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2030 Model* model) {
2031 CHECK_EQ(node.op(), "ReverseSequence");
2032 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2033 auto op = absl::make_unique<ReverseSequenceOperator>();
2034 if (HasAttr(node, "seq_dim")) {
2035 op->seq_dim = GetIntAttr(node, "seq_dim");
2036 }
2037 // In tf.reverse_sequence, batch_dim defaults to 0.
2038 op->batch_dim =
2039 HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2040 const int num_inputs = GetInputsCount(node, tf_import_flags);
2041 for (int i = 0; i < num_inputs; ++i) {
2042 op->inputs.push_back(node.input(i));
2043 }
2044 op->outputs.push_back(node.name());
2045 model->operators.push_back(std::move(op));
2046 return tensorflow::Status::OK();
2047 }
2048
StripCaretFromArrayNames(Model * model)2049 void StripCaretFromArrayNames(Model* model) {
2050 for (auto& op : model->operators) {
2051 for (auto& input : op->inputs) {
2052 input = string(absl::StripPrefix(input, "^"));
2053 }
2054 for (auto& output : op->outputs) {
2055 output = string(absl::StripPrefix(output, "^"));
2056 }
2057 }
2058 for (auto& array : model->GetArrayMap()) {
2059 if (absl::StartsWith(array.first, "^")) {
2060 LOG(FATAL) << "What?";
2061 }
2062 }
2063 }
2064
StripZeroOutputIndexFromInputs(NodeDef * node)2065 void StripZeroOutputIndexFromInputs(NodeDef* node) {
2066 for (auto& input : *node->mutable_input()) {
2067 input = string(absl::StripSuffix(input, ":0"));
2068 }
2069 }
2070
2071 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
2072 // name:0, name:1, ...
2073 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2074 // form for name:0.
2075 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
2076 // (unlike inputs), it being implied by the node's name and operator type
2077 // (the latter implies the number of outputs).
2078 // This makes it non-trivial for us to reconstruct the list of all arrays
2079 // present in the graph and, for each operator, the list of its outputs.
2080 // We do that by taking advantage of the fact that
2081 // at least each node lists explicitly its inputs, so after we've loaded
2082 // all nodes, we can use that information.
AddExtraOutputs(Model * model)2083 void AddExtraOutputs(Model* model) {
2084 // Construct the list of all arrays consumed by anything in the graph.
2085 std::vector<string> consumed_arrays;
2086 // Add arrays consumed by an op.
2087 for (const auto& consumer_op : model->operators) {
2088 for (const string& input : consumer_op->inputs) {
2089 consumed_arrays.push_back(input);
2090 }
2091 }
2092 // Add global outputs of the model.
2093 for (const string& output_array : model->flags.output_arrays()) {
2094 consumed_arrays.push_back(output_array);
2095 }
2096 // Add arrays consumed by a RNN back-edge.
2097 for (const auto& rnn_state : model->flags.rnn_states()) {
2098 consumed_arrays.push_back(rnn_state.back_edge_source_array());
2099 }
2100 // Now add operator outputs so that all arrays that are consumed,
2101 // are produced.
2102 for (const string& consumed_array : consumed_arrays) {
2103 // Split the consumed array name into the form name:output_index.
2104 const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
2105 // If not of the form name:output_index, then this is not an additional
2106 // output of a node with multiple outputs, so nothing to do here.
2107 if (split.size() != 2) {
2108 continue;
2109 }
2110 int output_index = 0;
2111 if (!absl::SimpleAtoi(split[1], &output_index)) {
2112 continue;
2113 }
2114 // Each op is initially recorded as producing at least the array that
2115 // has its name. We use that to identify the producer node.
2116 auto* producer_op = GetOpWithOutput(*model, split[0]);
2117 if (!producer_op) {
2118 continue;
2119 }
2120 // Add extra outputs to that producer node, all the way to the
2121 // output_index.
2122 while (producer_op->outputs.size() <= output_index) {
2123 using toco::port::StringF;
2124 producer_op->outputs.push_back(
2125 StringF("%s:%d", split[0], producer_op->outputs.size()));
2126 }
2127 }
2128 }
2129
InlineAllFunctions(GraphDef * graphdef)2130 bool InlineAllFunctions(GraphDef* graphdef) {
2131 if (graphdef->library().function().empty()) {
2132 VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2133 return false;
2134 }
2135
2136 // Override "_noinline" attribute on all functions
2137 GraphDef graphdef_copy(*graphdef);
2138 for (auto& function :
2139 (*graphdef_copy.mutable_library()->mutable_function())) {
2140 auto* attributes = function.mutable_attr();
2141 if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2142 (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2143 }
2144 }
2145
2146 // Construct minimum resources needed to use ExpandInlineFunctions().
2147 tensorflow::SessionOptions options;
2148 auto* device_count = options.config.mutable_device_count();
2149 device_count->insert({"CPU", 1});
2150 std::vector<std::unique_ptr<tensorflow::Device>> devices;
2151 TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2152 options, "/job:localhost/replica:0/task:0", &devices));
2153
2154 tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2155 graphdef_copy.library());
2156 tensorflow::DeviceMgr device_mgr(std::move(devices));
2157 tensorflow::OptimizerOptions o_opts;
2158 tensorflow::ProcessFunctionLibraryRuntime pflr(
2159 &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
2160 o_opts, nullptr);
2161 tensorflow::FunctionLibraryRuntime* flr;
2162 flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2163
2164 tensorflow::Graph graph(fld);
2165 tensorflow::ImportGraphDefOptions gc_opts;
2166 gc_opts.validate_shape = false;
2167 const auto& tf_convert_status = tensorflow::ImportGraphDef(
2168 gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2169 if (!tf_convert_status.ok()) {
2170 LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2171 << tf_convert_status.ToString();
2172 return false;
2173 }
2174
2175 // Iterate over the graph until there are no more nodes to be inlined.
2176 bool graph_modified = false;
2177 while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2178 graph_modified = true;
2179 }
2180
2181 // Output inlined graph
2182 if (graph_modified) {
2183 LOG(INFO) << "Found and inlined TensorFlow functions.";
2184 graph.ToGraphDef(graphdef);
2185 }
2186 return graph_modified;
2187 }
2188
ConvertTopKV2Operator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2189 tensorflow::Status ConvertTopKV2Operator(
2190 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2191 Model* model) {
2192 CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2193 auto op = absl::make_unique<TopKV2Operator>();
2194 op->inputs.push_back(node.input(0));
2195 // K can be encoded as attr (TopK) convert it to a const.
2196 if (HasAttr(node, "k")) {
2197 string k_array = CreateConstArray<ArrayDataType::kInt32>(
2198 model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2199 op->inputs.push_back(k_array);
2200 } else {
2201 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2202 op->inputs.push_back(node.input(1));
2203 }
2204 // The op has two outputs.
2205 op->outputs.push_back(node.name());
2206 op->outputs.push_back(node.name() + ":1");
2207 model->operators.emplace_back(op.release());
2208 return tensorflow::Status::OK();
2209 }
2210
ConvertDynamicPartitionOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2211 tensorflow::Status ConvertDynamicPartitionOperator(
2212 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2213 Model* model) {
2214 auto op = absl::make_unique<DynamicPartitionOperator>();
2215 CHECK(HasAttr(node, "num_partitions"));
2216 op->num_partitions = GetIntAttr(node, "num_partitions");
2217 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2218 op->inputs.push_back(node.input(0));
2219 op->inputs.push_back(node.input(1));
2220 CHECK_GT(op->num_partitions, 1);
2221 op->outputs.push_back(node.name()); // Implicit :0.
2222 for (int i = 1; i < op->num_partitions; ++i) {
2223 op->outputs.push_back(node.name() + ":" + std::to_string(i));
2224 }
2225 model->operators.emplace_back(op.release());
2226 return tensorflow::Status::OK();
2227 }
2228
ConvertDynamicStitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2229 tensorflow::Status ConvertDynamicStitchOperator(
2230 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2231 Model* model) {
2232 // The parallel and non-parallel variants are the same besides whether they
2233 // have a parallel loop; there are no behavioral differences.
2234 CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2235 auto op = absl::make_unique<DynamicStitchOperator>();
2236 CHECK(HasAttr(node, "N"));
2237 op->num_partitions = GetIntAttr(node, "N");
2238 // Expect all ID partitions + all value partitions.
2239 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2240 for (int i = 0; i < op->num_partitions * 2; ++i) {
2241 op->inputs.push_back(node.input(i));
2242 }
2243 op->outputs.push_back(node.name());
2244 model->operators.emplace_back(op.release());
2245 return tensorflow::Status::OK();
2246 }
2247
ConvertSparseToDenseOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2248 tensorflow::Status ConvertSparseToDenseOperator(
2249 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2250 Model* model) {
2251 CHECK_EQ(node.op(), "SparseToDense");
2252 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2253
2254 auto* op = new SparseToDenseOperator;
2255 for (const string& input : node.input()) {
2256 op->inputs.push_back(input);
2257 }
2258 op->outputs.push_back(node.name());
2259
2260 op->validate_indices = HasAttr(node, "validate_indices")
2261 ? GetBoolAttr(node, "validate_indices")
2262 : true;
2263 model->operators.emplace_back(op);
2264 return tensorflow::Status::OK();
2265 }
2266
ConvertOneHotOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2267 tensorflow::Status ConvertOneHotOperator(
2268 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2269 Model* model) {
2270 CHECK_EQ(node.op(), "OneHot");
2271 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2272
2273 const auto dtype = GetDataTypeAttr(node, "T");
2274 // TODO(b/111744875): Support DT_UINT8 and quantization.
2275 CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2276 dtype == DT_BOOL);
2277
2278 auto op = absl::make_unique<OneHotOperator>();
2279 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2280 for (const string& input : node.input()) {
2281 op->inputs.push_back(input);
2282 }
2283 op->outputs.push_back(node.name());
2284 model->operators.emplace_back(op.release());
2285 return tensorflow::Status::OK();
2286 }
2287
ConvertCTCBeamSearchDecoderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2288 tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2289 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2290 Model* model) {
2291 CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2292 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2293
2294 auto* op = new CTCBeamSearchDecoderOperator;
2295 for (const string& input : node.input()) {
2296 op->inputs.push_back(input);
2297 }
2298
2299 op->beam_width =
2300 HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2301 op->top_paths =
2302 HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2303 op->merge_repeated = HasAttr(node, "merge_repeated")
2304 ? GetBoolAttr(node, "merge_repeated")
2305 : true;
2306
2307 // There are top_paths + 1 outputs.
2308 op->outputs.push_back(node.name()); // Implicit :0.
2309 for (int i = 0; i < op->top_paths; ++i) {
2310 op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2311 }
2312 model->operators.emplace_back(op);
2313 return tensorflow::Status::OK();
2314 }
2315
2316 // This isn't a TensorFlow builtin op. Currently this node can only be generated
2317 // with TfLite OpHint API.
ConvertUnidirectionalSequenceLstm(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2318 tensorflow::Status ConvertUnidirectionalSequenceLstm(
2319 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2320 Model* model) {
2321 DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2322
2323 auto* op = new UnidirectionalSequenceLstmOperator();
2324 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2325 if (indices.i_size() != node.input().size()) {
2326 return tensorflow::errors::InvalidArgument("Input size does not match.");
2327 }
2328
2329 // The input size needs to be the same as the TfLite UniDirectionalSequence
2330 // Lstm implementation.
2331 const int kInputsSize = 20;
2332
2333 op->inputs.resize(kInputsSize);
2334 std::vector<bool> done(kInputsSize);
2335 int idx = 0;
2336 for (const string& input : node.input()) {
2337 int real_index = indices.i(idx);
2338 op->inputs[real_index] = (input);
2339 done[real_index] = true;
2340 idx++;
2341 }
2342
2343 for (int idx = 0; idx < done.size(); idx++) {
2344 if (!done[idx]) {
2345 string optional_name = node.name() + "_" + std::to_string(idx);
2346 model->CreateOptionalArray(optional_name);
2347 op->inputs[idx] = optional_name;
2348 }
2349 }
2350
2351 // There're three outputs, only the last one is required.
2352 op->outputs.push_back(node.name() + ":2");
2353 model->operators.emplace_back(op);
2354
2355 return tensorflow::Status::OK();
2356 }
2357
ConvertLeakyReluOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2358 tensorflow::Status ConvertLeakyReluOperator(
2359 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2360 Model* model) {
2361 CHECK_EQ(node.op(), "LeakyRelu");
2362 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2363 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2364 const auto& input_name = node.input(0);
2365 auto* op = new LeakyReluOperator;
2366 op->inputs.push_back(input_name);
2367 op->outputs.push_back(node.name());
2368 op->alpha = GetFloatAttr(node, "alpha");
2369 model->operators.emplace_back(op);
2370 return tensorflow::Status::OK();
2371 }
2372
ConvertUnidirectionalSequenceRnn(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2373 tensorflow::Status ConvertUnidirectionalSequenceRnn(
2374 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2375 Model* model) {
2376 DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2377
2378 auto* op = new UnidirectionalSequenceRnnOperator();
2379 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2380 if (indices.i_size() != node.input().size()) {
2381 return tensorflow::errors::InvalidArgument("Input size does not match.");
2382 }
2383
2384 for (const string& input : node.input()) {
2385 op->inputs.push_back(input);
2386 }
2387 // Only use the last one as input.
2388 op->outputs.push_back(node.name() + ":1");
2389 model->operators.emplace_back(op);
2390
2391 return tensorflow::Status::OK();
2392 }
2393
2394 } // namespace
2395
2396 namespace internal {
2397
2398 using ConverterType = tensorflow::Status (*)(
2399 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2400 Model* model);
2401 using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2402
GetTensorFlowNodeConverterMapForFlex()2403 ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2404 return std::unordered_map<std::string, ConverterType>({
2405 // We need to let TOCO convert Placeholder information into
2406 // array data, so that the data types are correct.
2407 {"LegacyFedInput", ConvertPlaceholderOperator},
2408 {"Placeholder", ConvertPlaceholderOperator},
2409 {"Const", ConditionallyConvertConstOperator},
2410 });
2411 }
2412
GetTensorFlowNodeConverterMap()2413 ConverterMapType GetTensorFlowNodeConverterMap() {
2414 return std::unordered_map<std::string, ConverterType>({
2415 {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2416 {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2417 {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2418 {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2419 {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2420 {"ArgMax", ConvertArgMaxOperator},
2421 {"ArgMin", ConvertArgMinOperator},
2422 {"Assert",
2423 ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2424 {"AvgPool", ConvertAvgPoolOperator},
2425 {"BatchMatMul", ConvertBatchMatMulOperator},
2426 {"BatchNormWithGlobalNormalization",
2427 ConvertBatchNormWithGlobalNormalizationOperator},
2428 {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2429 {"BiasAdd", ConvertBiasAddOperator},
2430 {"Cast", ConvertCastOperator},
2431 {"Ceil", ConvertCeilOperator},
2432 {"CheckNumerics", ConvertIdentityOperator},
2433 {"Concat", ConvertConcatOperator},
2434 {"ConcatV2", ConvertConcatOperator},
2435 {"Const", ConvertConstOperator},
2436 {"Conv2D", ConvertConvOperator},
2437 {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2438 {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2439 {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2440 {"DepthToSpace", ConvertDepthToSpaceOperator},
2441 {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2442 {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2443 {"DynamicPartition", ConvertDynamicPartitionOperator},
2444 {"DynamicStitch", ConvertDynamicStitchOperator},
2445 {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2446 {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2447 {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2448 {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2449 {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2450 {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2451 {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2452 {"Floor", ConvertFloorOperator},
2453 {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2454 {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2455 {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2456 {"Gather", ConvertGatherOperator},
2457 {"GatherV2", ConvertGatherOperator},
2458 {"GatherNd", ConvertGatherNdOperator},
2459 {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2460 {"GreaterEqual",
2461 ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2462 {"Identity", ConvertIdentityOperator},
2463 {"LRN", ConvertLRNOperator},
2464 {"LeakyRelu", ConvertLeakyReluOperator},
2465 {"LegacyFedInput", ConvertPlaceholderOperator},
2466 {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2467 {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2468 {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2469 {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2470 {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2471 {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2472 {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2473 {"MatMul", ConvertMatMulOperator},
2474 {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2475 {"MaxPool", ConvertMaxPoolOperator},
2476 {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2477 {"Mean", ConvertReduceOperator<MeanOperator>},
2478 {"Merge",
2479 ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2480 {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2481 {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2482 {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2483 {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2484 {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2485 {"NoOp", ConvertNoOpOperator},
2486 {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2487 {"OneHot", ConvertOneHotOperator},
2488 {"Pack", ConvertPackOperator},
2489 {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2490 {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2491 {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2492 {"Placeholder", ConvertPlaceholderOperator},
2493 {"PlaceholderWithDefault", ConvertIdentityOperator},
2494 {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2495 {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2496 {"RandomUniform", ConvertRandomUniform},
2497 {"Range", ConvertRangeOperator},
2498 {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2499 {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2500 {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2501 {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2502 {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2503 {"ResizeBilinear", ConvertResizeBilinearOperator},
2504 {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2505 {"ReverseSequence", ConvertReverseSequenceOperator},
2506 {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2507 {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2508 {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2509 {"Shape", ConvertShapeOperator},
2510 {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2511 {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2512 {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2513 {"Softmax", ConvertSoftmaxOperator},
2514 {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2515 {"SpaceToDepth", ConvertSpaceToDepthOperator},
2516 {"SparseToDense", ConvertSparseToDenseOperator},
2517 {"Split", ConvertSplitOperator},
2518 {"SplitV", ConvertSplitVOperator},
2519 {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2520 {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2521 {"SquaredDifference",
2522 ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2523 {"Squeeze", ConvertSqueezeOperator},
2524 {"StopGradient", ConvertIdentityOperator},
2525 {"StridedSlice", ConvertStridedSliceOperator},
2526 {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2527 {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2528 {"Svdf", ConvertSvdfOperator},
2529 {"Switch", ConvertSwitchOperator},
2530 {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2531 {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2532 {"TopK", ConvertTopKV2Operator},
2533 {"TopKV2", ConvertTopKV2Operator},
2534 {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2535 {"Unpack", ConvertUnpackOperator},
2536 {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2537 {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2538 {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2539 {"MirrorPad", ConvertMirrorPadOperator},
2540 {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2541 {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2542 });
2543 }
2544
ImportTensorFlowNode(const tensorflow::NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model,const ConverterMapType & converter_map)2545 tensorflow::Status ImportTensorFlowNode(
2546 const tensorflow::NodeDef& node,
2547 const TensorFlowImportFlags& tf_import_flags, Model* model,
2548 const ConverterMapType& converter_map) {
2549 auto converter = converter_map.find(node.op());
2550 if (converter == converter_map.end()) {
2551 return ConvertUnsupportedOperator(node, tf_import_flags, model);
2552 } else {
2553 return converter->second(node, tf_import_flags, model);
2554 }
2555 }
2556 } // namespace internal
2557
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const GraphDef & tf_graph)2558 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2559 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2560 const GraphDef& tf_graph) {
2561 LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2562
2563 GraphDef inlined_graph(tf_graph);
2564 if (InlineAllFunctions(&inlined_graph)) {
2565 LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2566 }
2567
2568 // Check input and output specification.
2569 for (const auto& specified_input_array : model_flags.input_arrays()) {
2570 CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2571 << "Unsupported explicit zero output index: "
2572 << specified_input_array.name();
2573 }
2574 for (const string& specified_output_array : model_flags.output_arrays()) {
2575 CHECK(!absl::EndsWith(specified_output_array, ":0"))
2576 << "Unsupported explicit zero output index: " << specified_output_array;
2577 }
2578
2579 Model* model = new Model;
2580 internal::ConverterMapType converter_map;
2581
2582 // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2583 // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2584 // converted to TFLite Flex ops.
2585 if (!tf_import_flags.import_all_ops_as_unsupported) {
2586 converter_map = internal::GetTensorFlowNodeConverterMap();
2587 } else {
2588 converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2589 }
2590
2591 for (auto node : inlined_graph.node()) {
2592 StripZeroOutputIndexFromInputs(&node);
2593 auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
2594 converter_map);
2595 CHECK(status.ok()) << status.error_message();
2596 }
2597
2598 ResolveModelFlags(model_flags, model);
2599
2600 StripCaretFromArrayNames(model);
2601 AddExtraOutputs(model);
2602 FixNoMissingArray(model);
2603 FixNoOrphanedArray(model);
2604 FixOperatorOrdering(model);
2605 CheckInvariants(*model);
2606
2607 // if rnn state arrays are constant, make them transient
2608 for (const auto& rnn_state : model->flags.rnn_states()) {
2609 model->GetArray(rnn_state.state_array()).buffer = nullptr;
2610 }
2611
2612 return std::unique_ptr<Model>(model);
2613 }
2614
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const string & input_file_contents)2615 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2616 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2617 const string& input_file_contents) {
2618 std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2619 CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2620
2621 std::unique_ptr<GraphDef> pruned_graph =
2622 MaybeReplaceCompositeSubgraph(*tf_graph);
2623 if (pruned_graph) {
2624 tf_graph = std::move(pruned_graph);
2625 }
2626 return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2627 }
2628 } // namespace toco
2629