• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/import_tensorflow.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/map.h"
23 #include "google/protobuf/text_format.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_split.h"
29 #include "absl/strings/strip.h"
30 #include "tensorflow/lite/toco/model.h"
31 #include "tensorflow/lite/toco/model_flags.pb.h"
32 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
33 #include "tensorflow/lite/toco/tensorflow_util.h"
34 #include "tensorflow/lite/toco/tooling_util.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph.pb.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/tensor_shape.pb.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/public/session_options.h"
50 #include "tensorflow/core/public/version.h"
51 
52 using tensorflow::AttrValue;
53 using tensorflow::DT_BOOL;
54 using tensorflow::DT_COMPLEX64;
55 using tensorflow::DT_FLOAT;
56 using tensorflow::DT_INT32;
57 using tensorflow::DT_INT64;
58 using tensorflow::DT_QUINT8;
59 using tensorflow::DT_STRING;
60 using tensorflow::DT_UINT8;
61 using tensorflow::GraphDef;
62 using tensorflow::NodeDef;
63 using tensorflow::OpRegistry;
64 using tensorflow::TensorProto;
65 using tensorflow::TensorShapeProto;
66 
67 namespace toco {
68 
69 namespace {
HasAttr(const NodeDef & node,const string & attr_name)70 bool HasAttr(const NodeDef& node, const string& attr_name) {
71   return node.attr().count(attr_name) > 0;
72 }
73 
HasWildcardDimension(const TensorShapeProto & shape)74 bool HasWildcardDimension(const TensorShapeProto& shape) {
75   for (const auto& dim : shape.dim()) {
76     if (dim.size() == -1) return true;
77   }
78   return false;
79 }
80 
GetStringAttr(const NodeDef & node,const string & attr_name)81 const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
82   CHECK(HasAttr(node, attr_name));
83   const auto& attr = node.attr().at(attr_name);
84   CHECK_EQ(attr.value_case(), AttrValue::kS);
85   return attr.s();
86 }
87 
GetIntAttr(const NodeDef & node,const string & attr_name)88 int64 GetIntAttr(const NodeDef& node, const string& attr_name) {
89   CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
90                                   << node.DebugString();
91   const auto& attr = node.attr().at(attr_name);
92   CHECK_EQ(attr.value_case(), AttrValue::kI);
93   return attr.i();
94 }
95 
GetFloatAttr(const NodeDef & node,const string & attr_name)96 float GetFloatAttr(const NodeDef& node, const string& attr_name) {
97   CHECK(HasAttr(node, attr_name));
98   const auto& attr = node.attr().at(attr_name);
99   CHECK_EQ(attr.value_case(), AttrValue::kF);
100   return attr.f();
101 }
102 
GetBoolAttr(const NodeDef & node,const string & attr_name)103 bool GetBoolAttr(const NodeDef& node, const string& attr_name) {
104   CHECK(HasAttr(node, attr_name));
105   const auto& attr = node.attr().at(attr_name);
106   CHECK_EQ(attr.value_case(), AttrValue::kB);
107   return attr.b();
108 }
109 
GetDataTypeAttr(const NodeDef & node,const string & attr_name)110 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
111                                      const string& attr_name) {
112   CHECK(HasAttr(node, attr_name));
113   const auto& attr = node.attr().at(attr_name);
114   CHECK_EQ(attr.value_case(), AttrValue::kType);
115   return attr.type();
116 }
117 
GetShapeAttr(const NodeDef & node,const string & attr_name)118 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
119                                      const string& attr_name) {
120   CHECK(HasAttr(node, attr_name));
121   const auto& attr = node.attr().at(attr_name);
122   CHECK_EQ(attr.value_case(), AttrValue::kShape);
123   return attr.shape();
124 }
125 
GetTensorAttr(const NodeDef & node,const string & attr_name)126 const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
127   CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
128   const auto& attr = node.attr().at(attr_name);
129   CHECK_EQ(attr.value_case(), AttrValue::kTensor);
130   return attr.tensor();
131 }
132 
GetListAttr(const NodeDef & node,const string & attr_name)133 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
134                                         const string& attr_name) {
135   CHECK(HasAttr(node, attr_name));
136   const auto& attr = node.attr().at(attr_name);
137   CHECK_EQ(attr.value_case(), AttrValue::kList);
138   return attr.list();
139 }
140 
CheckOptionalAttr(const NodeDef & node,const string & attr_name,const string & expected_value)141 tensorflow::Status CheckOptionalAttr(const NodeDef& node,
142                                      const string& attr_name,
143                                      const string& expected_value) {
144   if (HasAttr(node, attr_name)) {
145     const string& value = GetStringAttr(node, attr_name);
146     if (value != expected_value) {
147       return tensorflow::errors::InvalidArgument(
148           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
149           expected_value + "'");
150     }
151   }
152   return tensorflow::Status::OK();
153 }
154 
CheckOptionalAttr(const NodeDef & node,const string & attr_name,const tensorflow::DataType & expected_value)155 tensorflow::Status CheckOptionalAttr(
156     const NodeDef& node, const string& attr_name,
157     const tensorflow::DataType& expected_value) {
158   if (HasAttr(node, attr_name)) {
159     const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
160     if (value != expected_value) {
161       return tensorflow::errors::InvalidArgument(
162           "Unexpected value for attribute '" + attr_name + "'. Expected '" +
163           tensorflow::DataType_Name(expected_value) + "'");
164     }
165   }
166   return tensorflow::Status::OK();
167 }
168 
169 template <typename T1, typename T2>
ExpectValue(const T1 & v1,const T2 & v2,const string & description)170 tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
171                                const string& description) {
172   if (v1 == v2) return tensorflow::Status::OK();
173   return tensorflow::errors::InvalidArgument(absl::StrCat(
174       "Unexpected ", description, ": got ", v1, ", expected ", v2));
175 }
176 
ConvertDataType(tensorflow::DataType dtype)177 ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
178   if (dtype == DT_UINT8)
179     return ArrayDataType::kUint8;
180   else if (dtype == DT_FLOAT)
181     return ArrayDataType::kFloat;
182   else if (dtype == DT_BOOL)
183     return ArrayDataType::kBool;
184   else if (dtype == DT_INT32)
185     return ArrayDataType::kInt32;
186   else if (dtype == DT_INT64)
187     return ArrayDataType::kInt64;
188   else if (dtype == DT_STRING)
189     return ArrayDataType::kString;
190   else if (dtype == DT_COMPLEX64)
191     return ArrayDataType::kComplex64;
192   else
193     LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
194   return ArrayDataType::kNone;
195 }
196 
ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim> & input_dims,int * input_flat_size,Shape * shape)197 tensorflow::Status ImportShape(
198     const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
199         input_dims,
200     int* input_flat_size, Shape* shape) {
201   std::vector<int> input_dims_only_sizes;
202   bool zero_sized_shape = false;
203   for (auto& d : input_dims) {
204     // TensorFlow's shapes use int64s, while TOCO uses ints.
205     if (d.size() > std::numeric_limits<int>::max()) {
206       return tensorflow::errors::InvalidArgument("Shape element overflows");
207     }
208     if (d.size() == 0) {
209       zero_sized_shape = true;
210     }
211     input_dims_only_sizes.push_back(d.size());
212   }
213 
214   // Note that up to this point we were OK with the input shape containing
215   // elements valued -1 or 0, which are perfectly legal in tensorflow. However
216   // our CheckValidShapeDimensions() insists on them being >= 1, with the
217   // exception of the "scalar" shape [0]. The main issue with zero-values shape
218   // elements is that the corresponding arrays don't contain any data and the
219   // allocation code gets a bit confused. It seems that the code expects an
220   // empty shape for zero-sized shapes, so we will do just that, except for the
221   // [0] case.
222   // TODO(b/119325030): In order to correctly import the "scalar" shapes the
223   // following test must include "&& input_dims_only_sizes.size() > 1", but
224   // that seems to slow everything down a lot.
225   if (zero_sized_shape) {
226     shape->mutable_dims()->clear();
227     if (input_flat_size != nullptr) *input_flat_size = 0;
228     return tensorflow::Status::OK();
229   }
230 
231   *shape->mutable_dims() = input_dims_only_sizes;
232 
233   if (input_flat_size == nullptr) return tensorflow::Status::OK();
234 
235   return NumElements(input_dims_only_sizes, input_flat_size);
236 }
237 
238 // Define ways to retrieve data from tensors of different types.
239 // TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
240 template <typename T>
241 struct TensorTraits;
242 
243 template <>
244 struct TensorTraits<float> {
sizetoco::__anonb974ebe00111::TensorTraits245   static int size(const TensorProto& p) { return p.float_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits246   static float get(const TensorProto& p, int i) { return p.float_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits247   static string accessor_name() { return "float_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits248   static string type_name() { return "float"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits249   static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
250     toco::port::CopyToBuffer(p.tensor_content(),
251                              reinterpret_cast<char*>(data->data()));
252   }
253 };
254 
255 template <>
256 struct TensorTraits<uint8_t> {
sizetoco::__anonb974ebe00111::TensorTraits257   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits258   static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits259   static string accessor_name() { return "int_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits260   static string type_name() { return "uint8"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits261   static void CopyFromContent(const TensorProto& p,
262                               std::vector<uint8_t>* data) {
263     toco::port::CopyToBuffer(p.tensor_content(),
264                              reinterpret_cast<char*>(data->data()));
265   }
266 };
267 
268 template <>
269 struct TensorTraits<std::complex<float>> {
sizetoco::__anonb974ebe00111::TensorTraits270   static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
gettoco::__anonb974ebe00111::TensorTraits271   static std::complex<float> get(const TensorProto& p, int i) {
272     return std::complex<float>(p.scomplex_val(2 * i),
273                                p.scomplex_val(2 * i + 1));
274   }
accessor_nametoco::__anonb974ebe00111::TensorTraits275   static string accessor_name() { return "scomplex_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits276   static string type_name() { return "complex64"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits277   static void CopyFromContent(const TensorProto& p,
278                               std::vector<std::complex<float>>* data) {
279     toco::port::CopyToBuffer(p.tensor_content(),
280                              reinterpret_cast<char*>(data->data()));
281   }
282 };
283 
284 template <>
285 struct TensorTraits<int32> {
sizetoco::__anonb974ebe00111::TensorTraits286   static int size(const TensorProto& p) { return p.int_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits287   static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits288   static string accessor_name() { return "int_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits289   static string type_name() { return "int32"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits290   static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
291     toco::port::CopyToBuffer(p.tensor_content(),
292                              reinterpret_cast<char*>(data->data()));
293   }
294 };
295 
296 template <>
297 struct TensorTraits<int64> {
sizetoco::__anonb974ebe00111::TensorTraits298   static int size(const TensorProto& p) { return p.int64_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits299   static int64 get(const TensorProto& p, int i) { return p.int64_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits300   static string accessor_name() { return "int64_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits301   static string type_name() { return "int64"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits302   static void CopyFromContent(const TensorProto& p, std::vector<int64>* data) {
303     toco::port::CopyToBuffer(p.tensor_content(),
304                              reinterpret_cast<char*>(data->data()));
305   }
306 };
307 
308 template <>
309 struct TensorTraits<bool> {
sizetoco::__anonb974ebe00111::TensorTraits310   static int size(const TensorProto& p) { return p.bool_val_size(); }
gettoco::__anonb974ebe00111::TensorTraits311   static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
accessor_nametoco::__anonb974ebe00111::TensorTraits312   static string accessor_name() { return "bool_val"; }
type_nametoco::__anonb974ebe00111::TensorTraits313   static string type_name() { return "bool"; }
CopyFromContenttoco::__anonb974ebe00111::TensorTraits314   static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
315     std::vector<char> buf(p.tensor_content().size());
316     toco::port::CopyToBuffer(p.tensor_content(), buf.data());
317     for (int i = 0; i < p.tensor_content().size(); i++) {
318       (*data)[i] = static_cast<bool>(buf[i]);
319     }
320   }
321 };
322 
323 template <typename T>
ImportTensorData(const TensorProto & input_tensor,int input_flat_size,std::vector<T> * output_data)324 tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
325                                     int input_flat_size,
326                                     std::vector<T>* output_data) {
327   CHECK_GE(output_data->size(), input_flat_size);
328   int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
329   if (num_elements_in_tensor == input_flat_size) {
330     for (int i = 0; i < num_elements_in_tensor; i++) {
331       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
332     }
333   } else if (input_tensor.tensor_content().size() ==
334              input_flat_size * sizeof(T)) {
335     TensorTraits<T>::CopyFromContent(input_tensor, output_data);
336   } else if (num_elements_in_tensor > 0 &&
337              num_elements_in_tensor < input_flat_size) {
338     // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
339     // official way to import tensor data. This particular else-if handles a
340     // grappler optimization where the last few elements in a tensor are
341     // omitted if they are repeated.
342     int i = 0;
343     for (; i < num_elements_in_tensor; ++i) {
344       (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
345     }
346     auto last = (*output_data)[i - 1];
347     for (; i < input_flat_size; ++i) {
348       (*output_data)[i] = last;
349     }
350   } else {
351     string accessor_name = TensorTraits<T>::accessor_name();
352     string type_name = TensorTraits<T>::type_name();
353     return tensorflow::errors::InvalidArgument(
354         absl::StrCat("Neither input_content (",
355                      input_tensor.tensor_content().size() / sizeof(T), ") nor ",
356                      accessor_name, " (", num_elements_in_tensor,
357                      ") have the right dimensions (", input_flat_size,
358                      ") for this ", type_name, " tensor"));
359   }
360   return tensorflow::Status::OK();
361 }
362 
ImportFloatArray(const TensorProto & input_tensor,Array * output_array)363 tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
364                                     Array* output_array) {
365   CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
366   const auto& input_shape = input_tensor.tensor_shape();
367   CHECK_LE(input_shape.dim_size(), 6);
368   int input_flat_size;
369   auto status = ImportShape(input_shape.dim(), &input_flat_size,
370                             output_array->mutable_shape());
371   if (!status.ok()) return status;
372 
373   auto& output_float_data =
374       output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
375   output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
376                            0.f);
377   return ImportTensorData<float>(input_tensor, input_flat_size,
378                                  &output_float_data);
379 }
380 
ImportComplex64Array(const TensorProto & input_tensor,Array * output_array)381 tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
382                                         Array* output_array) {
383   CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
384   const auto& input_shape = input_tensor.tensor_shape();
385   CHECK_LE(input_shape.dim_size(), 4);
386   int input_flat_size;
387   auto status = ImportShape(input_shape.dim(), &input_flat_size,
388                             output_array->mutable_shape());
389   if (!status.ok()) return status;
390 
391   auto& output_complex_data =
392       output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
393   output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
394                              std::complex<float>(0.f, 0.f));
395   return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
396                                                &output_complex_data);
397 }
398 
ImportQuint8Array(const TensorProto & input_tensor,Array * output_array)399 tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
400                                      Array* output_array) {
401   CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
402   const auto& input_shape = input_tensor.tensor_shape();
403   CHECK_LE(input_shape.dim_size(), 6);
404   int input_flat_size;
405   auto status = ImportShape(input_shape.dim(), &input_flat_size,
406                             output_array->mutable_shape());
407   if (!status.ok()) return status;
408 
409   auto& output_int_data =
410       output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
411   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
412   return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
413                                    &output_int_data);
414 }
415 
ImportInt32Array(const TensorProto & input_tensor,Array * output_array)416 tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
417                                     Array* output_array) {
418   CHECK_EQ(input_tensor.dtype(), DT_INT32);
419   const auto& input_shape = input_tensor.tensor_shape();
420   CHECK_LE(input_shape.dim_size(), 6);
421   int input_flat_size;
422   auto status = ImportShape(input_shape.dim(), &input_flat_size,
423                             output_array->mutable_shape());
424   if (!status.ok()) return status;
425 
426   auto& output_int_data =
427       output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
428   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
429   return ImportTensorData<int32>(input_tensor, input_flat_size,
430                                  &output_int_data);
431 }
432 
ImportInt64Array(const TensorProto & input_tensor,Array * output_array)433 tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
434                                     Array* output_array) {
435   CHECK_EQ(input_tensor.dtype(), DT_INT64);
436   const auto& input_shape = input_tensor.tensor_shape();
437   CHECK_LE(input_shape.dim_size(), 6);
438   int input_flat_size;
439   auto status = ImportShape(input_shape.dim(), &input_flat_size,
440                             output_array->mutable_shape());
441   if (!status.ok()) return status;
442 
443   auto& output_int_data =
444       output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
445   output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
446   return ImportTensorData<int64>(input_tensor, input_flat_size,
447                                  &output_int_data);
448 }
449 
ImportBoolArray(const TensorProto & input_tensor,Array * output_array)450 tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
451                                    Array* output_array) {
452   CHECK_EQ(input_tensor.dtype(), DT_BOOL);
453   const auto& input_shape = input_tensor.tensor_shape();
454   CHECK_LE(input_shape.dim_size(), 6);
455   int input_flat_size;
456   auto status = ImportShape(input_shape.dim(), &input_flat_size,
457                             output_array->mutable_shape());
458   if (!status.ok()) return status;
459 
460   auto& output_bool_data =
461       output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
462   output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
463                           false);
464   status =
465       ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
466   if (!status.ok() && output_bool_data.size() == 1) {
467     // Some graphs have bool const nodes without actual value...
468     // assuming that 'false' is implied.
469     // So far only encountered that in an array with 1 entry, let's
470     // require that until we encounter a graph where that's not the case.
471     output_bool_data[0] = false;
472     return tensorflow::Status::OK();
473   }
474   return status;
475 }
476 
ImportStringArray(const TensorProto & input_tensor,Array * output_array)477 tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
478                                      Array* output_array) {
479   CHECK_EQ(input_tensor.dtype(), DT_STRING);
480   const auto& input_shape = input_tensor.tensor_shape();
481   CHECK_LE(input_shape.dim_size(), 6);
482   int input_flat_size;
483   auto status = ImportShape(input_shape.dim(), &input_flat_size,
484                             output_array->mutable_shape());
485   if (!status.ok()) return status;
486 
487   if (input_flat_size != input_tensor.string_val_size()) {
488     return tensorflow::errors::InvalidArgument(
489         "Input_content string_val doesn't have the right dimensions "
490         "for this string tensor");
491   }
492 
493   auto& output_string_data =
494       output_array->GetMutableBuffer<ArrayDataType::kString>().data;
495   output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
496   CHECK_GE(output_string_data.size(), input_flat_size);
497   for (int i = 0; i < input_flat_size; ++i) {
498     output_string_data[i] = input_tensor.string_val(i);
499   }
500   return tensorflow::Status::OK();
501 }
502 
503 // Count the number of inputs of a given node. If
504 // `tf_import_flags.drop_control_dependency` is true, count the number of
505 // non-control-dependency inputs.
GetInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags)506 int GetInputsCount(const NodeDef& node,
507                    const TensorFlowImportFlags& tf_import_flags) {
508   if (tf_import_flags.drop_control_dependency) {
509     for (size_t i = 0; i < node.input_size(); ++i) {
510       if (node.input(i)[0] == '^') {
511         return i;
512       }
513     }
514   }
515   return node.input_size();
516 }
517 
CheckInputsCount(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,int expected_input_count)518 tensorflow::Status CheckInputsCount(
519     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
520     int expected_input_count) {
521   if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
522     return tensorflow::errors::FailedPrecondition(
523         node.op(), " node expects ", expected_input_count,
524         " input(s) other than control dependencies: ", node.DebugString());
525   }
526   return tensorflow::Status::OK();
527 }
528 
529 template <ArrayDataType T>
CreateConstArray(Model * model,string const & name,std::vector<typename toco::DataType<T>> const & data)530 string CreateConstArray(Model* model, string const& name,
531                         std::vector<typename toco::DataType<T> > const& data) {
532   // Utility function to create a const 1D array, useful for input parameters.
533   string array_name = toco::AvailableArrayName(*model, name);
534   auto& array = model->GetOrCreateArray(array_name);
535   array.data_type = T;
536   array.mutable_shape()->mutable_dims()->emplace_back(data.size());
537   array.GetMutableBuffer<T>().data = data;
538   return array_name;
539 }
540 
541 // Retain TensorFlow NodeDef in Toco Operator.
542 //
543 // If an op is supported by Toco but not supported by TFLite, TFLite exporter
544 // will use the retained NodeDef to populate a Flex op when Flex mode is
545 // enabled.
546 //
547 // This can't be easily applied to all operations, because a TensorFlow node
548 // may become multiple Toco operators. Thus we need to call this function in
549 // operator conversion functions one by one whenever feasible.
550 //
551 // This may cause problems if a graph transformation rule changes parameters
552 // of the node. When calling this function, please check if any existing
553 // graph transformation rule will change an existing operator with the same
554 // type.
555 //
556 // This provides a route to handle Toco-supported & TFLite-unsupported ops
557 // in Flex mode. However it's not a solid solution. Eventually we should
558 // get rid of this.
559 // TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
560 // this function.
RetainTensorFlowNodeDef(const NodeDef & node,Operator * op)561 void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
562   node.SerializeToString(&op->tensorflow_node_def);
563 }
564 
ConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)565 tensorflow::Status ConvertConstOperator(
566     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
567     Model* model) {
568   CHECK_EQ(node.op(), "Const");
569   const auto& tensor = GetTensorAttr(node, "value");
570   const auto dtype = GetDataTypeAttr(node, "dtype");
571 
572   tensorflow::Status status = tensorflow::Status::OK();
573 
574   auto& array = model->GetOrCreateArray(node.name());
575   switch (dtype) {
576     case DT_FLOAT:
577       array.data_type = ArrayDataType::kFloat;
578       status = ImportFloatArray(tensor, &array);
579       break;
580     case DT_INT32:
581       array.data_type = ArrayDataType::kInt32;
582       status = ImportInt32Array(tensor, &array);
583       break;
584     case DT_QUINT8:
585       array.data_type = ArrayDataType::kUint8;
586       status = ImportQuint8Array(tensor, &array);
587       break;
588     case DT_INT64:
589       array.data_type = ArrayDataType::kInt64;
590       status = ImportInt64Array(tensor, &array);
591       break;
592     case DT_STRING:
593       array.data_type = ArrayDataType::kString;
594       status = ImportStringArray(tensor, &array);
595       break;
596     case DT_BOOL:
597       array.data_type = ArrayDataType::kBool;
598       status = ImportBoolArray(tensor, &array);
599       break;
600     case DT_COMPLEX64:
601       array.data_type = ArrayDataType::kComplex64;
602       status = ImportComplex64Array(tensor, &array);
603       break;
604     default:
605       array.data_type = ArrayDataType::kNone;
606       // do nothing, silently ignore the Const data.
607       // We just make a dummy buffer to indicate that
608       // this array does not rely on external input.
609       array.GetMutableBuffer<ArrayDataType::kNone>();
610       break;
611   }
612   TF_RETURN_WITH_CONTEXT_IF_ERROR(
613       status, " (while processing node '" + node.name() + "')");
614   return tensorflow::Status::OK();
615 }
616 
ConvertConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)617 tensorflow::Status ConvertConvOperator(
618     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
619     Model* model) {
620   CHECK_EQ(node.op(), "Conv2D");
621   TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
622 
623   // We only support NHWC, which is the default data_format.
624   // So if data_format is not defined, we're all good.
625   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
626   TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
627 
628   const auto& input_name = node.input(0);
629   const auto& weights_name = node.input(1);
630   const auto& reordered_weights_name =
631       AvailableArrayName(*model, weights_name + "_reordered");
632   // Check if a ReorderAxesOperator was already created for these weights
633   // (that happens when multiple layers share the same weights).
634   const Operator* existing_reorder =
635       GetOpWithOutput(*model, reordered_weights_name);
636   if (existing_reorder) {
637     // Check that it is safe to rely on the _reordered naming of the output
638     // array!
639     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
640   } else {
641     // Create a new ReorderAxesOperator
642     auto* reorder = new ReorderAxesOperator;
643     reorder->inputs = {weights_name};
644     reorder->outputs = {reordered_weights_name};
645     reorder->input_axes_order = AxesOrder::kHWIO;
646     reorder->output_axes_order = AxesOrder::kOHWI;
647     model->operators.emplace_back(reorder);
648   }
649   auto* conv = new ConvOperator;
650   conv->inputs = {input_name, reordered_weights_name};
651   conv->outputs = {node.name()};
652   if (!HasAttr(node, "strides")) {
653     return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
654   }
655   const auto& strides = GetListAttr(node, "strides");
656   TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
657   TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
658   TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
659   conv->stride_height = strides.i(1);
660   conv->stride_width = strides.i(2);
661   if (HasAttr(node, "dilations")) {
662     const auto& dilations = GetListAttr(node, "dilations");
663     TF_RETURN_IF_ERROR(
664         ExpectValue(dilations.i_size(), 4, "number of dilations"));
665     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
666       return tensorflow::errors::InvalidArgument(absl::StrCat(
667           "Can only import Conv ops with dilation along the height "
668           "(1st) or width (2nd) axis. TensorFlow op \"",
669           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
670           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
671     }
672     conv->dilation_height_factor = dilations.i(1);
673     conv->dilation_width_factor = dilations.i(2);
674   } else {
675     conv->dilation_height_factor = 1;
676     conv->dilation_width_factor = 1;
677   }
678   const auto& padding = GetStringAttr(node, "padding");
679   if (padding == "SAME") {
680     conv->padding.type = PaddingType::kSame;
681   } else if (padding == "VALID") {
682     conv->padding.type = PaddingType::kValid;
683   } else {
684     return tensorflow::errors::InvalidArgument(
685         "Bad padding (only SAME and VALID are supported)");
686   }
687   model->operators.emplace_back(conv);
688 
689   return tensorflow::Status::OK();
690 }
691 
ConvertDepthwiseConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)692 tensorflow::Status ConvertDepthwiseConvOperator(
693     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
694     Model* model) {
695   CHECK_EQ(node.op(), "DepthwiseConv2dNative");
696   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
697 
698   // We only support NHWC, which is the default data_format.
699   // So if data_format is not defined, we're all good.
700   if (HasAttr(node, "data_format")) {
701     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
702   }
703   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
704 
705   const auto& input_name = node.input(0);
706   const auto& weights_name = node.input(1);
707   const auto& reordered_weights_name = weights_name + "_reordered";
708   // Check if a ReorderAxesOperator was already created for these weights
709   // (that happens when multiple layers share the same weights).
710   const Operator* existing_reorder =
711       GetOpWithOutput(*model, reordered_weights_name);
712   if (existing_reorder) {
713     // Check that it is safe to rely on the _reordered naming of the output
714     // array!
715     CHECK(existing_reorder->type == OperatorType::kReorderAxes);
716   } else {
717     // Create a new ReorderAxesOperator
718     auto* reorder = new ReorderAxesOperator;
719     reorder->inputs = {weights_name};
720     reorder->outputs = {reordered_weights_name};
721     reorder->input_axes_order = AxesOrder::kHWIM;
722     reorder->output_axes_order = AxesOrder::k1HWO;
723     model->operators.emplace_back(reorder);
724   }
725   auto* conv = new DepthwiseConvOperator;
726   conv->inputs = {input_name, reordered_weights_name};
727   conv->outputs = {node.name()};
728   const auto& strides = GetListAttr(node, "strides");
729   CHECK_EQ(strides.i_size(), 4);
730   CHECK_EQ(strides.i(0), 1);
731   CHECK_EQ(strides.i(3), 1);
732   conv->stride_height = strides.i(1);
733   conv->stride_width = strides.i(2);
734   if (HasAttr(node, "dilations")) {
735     const auto& dilations = GetListAttr(node, "dilations");
736     TF_RETURN_IF_ERROR(
737         ExpectValue(dilations.i_size(), 4, "number of dilations"));
738     if (dilations.i(0) != 1 || dilations.i(3) != 1) {
739       return tensorflow::errors::InvalidArgument(absl::StrCat(
740           "Can only import Conv ops with dilation along the height "
741           "(1st) or width (2nd) axis. TensorFlow op \"",
742           node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
743           dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
744     }
745     conv->dilation_height_factor = dilations.i(1);
746     conv->dilation_width_factor = dilations.i(2);
747   } else {
748     conv->dilation_height_factor = 1;
749     conv->dilation_width_factor = 1;
750   }
751   const auto& padding = GetStringAttr(node, "padding");
752   if (padding == "SAME") {
753     conv->padding.type = PaddingType::kSame;
754   } else if (padding == "VALID") {
755     conv->padding.type = PaddingType::kValid;
756   } else {
757     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
758   }
759   model->operators.emplace_back(conv);
760   return tensorflow::Status::OK();
761 }
762 
ConvertDepthToSpaceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)763 tensorflow::Status ConvertDepthToSpaceOperator(
764     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
765     Model* model) {
766   CHECK_EQ(node.op(), "DepthToSpace");
767   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
768 
769   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
770   auto* op = new DepthToSpaceOperator;
771   op->inputs.push_back(node.input(0));
772   op->outputs.push_back(node.name());
773   op->block_size = GetIntAttr(node, "block_size");
774   QCHECK_GE(op->block_size, 2);
775   model->operators.emplace_back(op);
776   return tensorflow::Status::OK();
777 }
778 
ConvertSpaceToDepthOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)779 tensorflow::Status ConvertSpaceToDepthOperator(
780     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
781     Model* model) {
782   CHECK_EQ(node.op(), "SpaceToDepth");
783   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
784 
785   tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
786   if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
787       dtype != DT_INT64) {
788     const auto* enum_descriptor = tensorflow::DataType_descriptor();
789     LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
790                << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
791                << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}.";
792   }
793   auto* op = new SpaceToDepthOperator;
794   op->inputs.push_back(node.input(0));
795   op->outputs.push_back(node.name());
796   op->block_size = GetIntAttr(node, "block_size");
797   QCHECK_GE(op->block_size, 2);
798   model->operators.emplace_back(op);
799   return tensorflow::Status::OK();
800 }
801 
ConvertBiasAddOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)802 tensorflow::Status ConvertBiasAddOperator(
803     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
804     Model* model) {
805   CHECK_EQ(node.op(), "BiasAdd");
806   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
807 
808   const auto& input_name = node.input(0);
809   const auto& bias_name = node.input(1);
810   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
811   auto* biasadd = new AddOperator;
812   biasadd->inputs.push_back(input_name);
813   biasadd->inputs.push_back(bias_name);
814   biasadd->outputs.push_back(node.name());
815   model->operators.emplace_back(biasadd);
816   return tensorflow::Status::OK();
817 }
818 
ConvertRandomUniform(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)819 tensorflow::Status ConvertRandomUniform(
820     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
821     Model* model) {
822   CHECK_EQ(node.op(), "RandomUniform");
823   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
824 
825   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
826   auto op = absl::make_unique<RandomUniformOperator>();
827   op->inputs.push_back(node.input(0));
828   op->outputs.push_back(node.name());
829   op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
830   op->seed = GetIntAttr(node, "seed");
831   op->seed2 = GetIntAttr(node, "seed2");
832   CHECK(model != nullptr);
833   model->operators.emplace_back(std::move(op));
834   return tensorflow::Status::OK();
835 }
836 
ConvertIdentityOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)837 tensorflow::Status ConvertIdentityOperator(
838     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
839     Model* model) {
840   CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
841         node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
842   auto* op = new TensorFlowIdentityOperator;
843   // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
844   // identity nodes with multiple inputs, but the other inputs seem
845   // to be gratuitous (in the case of rajeev_lstm.pb, these are
846   // enumerating the LSTM state arrays). We will just ignore extra
847   // inputs beyond the first input.
848   QCHECK_GE(node.input_size(), 1)
849       << node.op()
850       << " node expects at least 1 input other than control dependencies: "
851       << node.DebugString();
852   const auto& input_name = node.input(0);
853   op->inputs.push_back(input_name);
854   op->outputs.push_back(node.name());
855   model->operators.emplace_back(op);
856   return tensorflow::Status::OK();
857 }
858 
ConvertFakeQuantWithMinMaxArgs(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)859 tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
860     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
861     Model* model) {
862   CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
863   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
864   auto* op = new FakeQuantOperator;
865   op->inputs.push_back(node.input(0));
866   op->minmax.reset(new MinMax);
867   auto& minmax = *op->minmax;
868   minmax.min = GetFloatAttr(node, "min");
869   minmax.max = GetFloatAttr(node, "max");
870   op->outputs.push_back(node.name());
871   // tf.fake_quant_with_min_max_args num_bits defaults to 8.
872   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
873   if (HasAttr(node, "narrow_range")) {
874     op->narrow_range = GetBoolAttr(node, "narrow_range");
875   }
876   model->operators.emplace_back(op);
877   return tensorflow::Status::OK();
878 }
879 
ConvertFakeQuantWithMinMaxVars(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)880 tensorflow::Status ConvertFakeQuantWithMinMaxVars(
881     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
882     Model* model) {
883   CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
884   const int num_inputs = GetInputsCount(node, tf_import_flags);
885   QCHECK(num_inputs == 3 || num_inputs == 4)
886       << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
887          "control dependencies: "
888       << node.DebugString();
889   auto* op = new FakeQuantOperator;
890   for (int i = 0; i < 3; i++) {
891     op->inputs.push_back(node.input(i));
892   }
893   op->outputs.push_back(node.name());
894   op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
895   if (HasAttr(node, "narrow_range")) {
896     op->narrow_range = GetBoolAttr(node, "narrow_range");
897   }
898   model->operators.emplace_back(op);
899   return tensorflow::Status::OK();
900 }
901 
ConvertSqueezeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)902 tensorflow::Status ConvertSqueezeOperator(
903     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
904     Model* model) {
905   CHECK_EQ(node.op(), "Squeeze");
906   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
907   auto* op = new SqueezeOperator;
908   op->inputs.push_back(node.input(0));
909   op->outputs.push_back(node.name());
910 
911   // When omitted we are to squeeze all dimensions == 1.
912   if (HasAttr(node, "squeeze_dims")) {
913     const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
914     for (int i = 0; i < squeeze_dims.i_size(); ++i) {
915       op->squeeze_dims.push_back(squeeze_dims.i(i));
916     }
917   }
918 
919   model->operators.emplace_back(op);
920   return tensorflow::Status::OK();
921 }
922 
ConvertSplitOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)923 tensorflow::Status ConvertSplitOperator(
924     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
925     Model* model) {
926   CHECK_EQ(node.op(), "Split");
927   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
928   auto* op = new TensorFlowSplitOperator;
929   op->inputs.push_back(node.input(0));
930   op->inputs.push_back(node.input(1));
931   const int num_split = GetIntAttr(node, "num_split");
932   op->outputs.push_back(node.name());
933   for (int i = 1; i < num_split; i++) {
934     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
935   }
936   op->num_split = num_split;
937   model->operators.emplace_back(op);
938   return tensorflow::Status::OK();
939 }
940 
ConvertSplitVOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)941 tensorflow::Status ConvertSplitVOperator(
942     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
943     Model* model) {
944   CHECK_EQ(node.op(), "SplitV");
945   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
946   auto* op = new TensorFlowSplitVOperator;
947   op->inputs.push_back(node.input(0));
948   op->inputs.push_back(node.input(1));
949   op->inputs.push_back(node.input(2));
950   const int num_split = GetIntAttr(node, "num_split");
951   op->outputs.push_back(node.name());
952   for (int i = 1; i < num_split; i++) {
953     op->outputs.push_back(absl::StrCat(node.name(), ":", i));
954   }
955   op->num_split = num_split;
956   model->operators.emplace_back(op);
957   return tensorflow::Status::OK();
958 }
959 
ConvertSwitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)960 tensorflow::Status ConvertSwitchOperator(
961     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
962     Model* model) {
963   CHECK_EQ(node.op(), "Switch");
964   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
965   auto* op = new TensorFlowSwitchOperator;
966   op->inputs.push_back(node.input(0));
967   op->inputs.push_back(node.input(1));
968   op->outputs.push_back(node.name());
969   // Switch operators have two outputs: "name" and "name:1".
970   op->outputs.push_back(node.name() + ":1");
971   model->operators.emplace_back(op);
972   return tensorflow::Status::OK();
973 }
974 
ConvertSoftmaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)975 tensorflow::Status ConvertSoftmaxOperator(
976     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
977     Model* model) {
978   CHECK_EQ(node.op(), "Softmax");
979   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
980   const auto& input_name = node.input(0);
981   auto* softmax = new SoftmaxOperator;
982   softmax->inputs.push_back(input_name);
983   softmax->outputs.push_back(node.name());
984   // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
985   CHECK(!node.attr().count("beta"));  // Stab in the dark, just in case.
986   softmax->beta = 1.f;
987   model->operators.emplace_back(softmax);
988   return tensorflow::Status::OK();
989 }
990 
ConvertLRNOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)991 tensorflow::Status ConvertLRNOperator(
992     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
993     Model* model) {
994   CHECK_EQ(node.op(), "LRN");
995   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
996   const auto& input_name = node.input(0);
997   auto* lrn = new LocalResponseNormalizationOperator;
998   lrn->inputs.push_back(input_name);
999   lrn->outputs.push_back(node.name());
1000   lrn->range = GetIntAttr(node, "depth_radius");
1001   lrn->bias = GetFloatAttr(node, "bias");
1002   lrn->alpha = GetFloatAttr(node, "alpha");
1003   lrn->beta = GetFloatAttr(node, "beta");
1004   model->operators.emplace_back(lrn);
1005   return tensorflow::Status::OK();
1006 }
1007 
ConvertMaxPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1008 tensorflow::Status ConvertMaxPoolOperator(
1009     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1010     Model* model) {
1011   CHECK_EQ(node.op(), "MaxPool");
1012   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1013   const auto& input_name = node.input(0);
1014   // We only support NHWC, which is the default data_format.
1015   // So if data_format is not defined, we're all good.
1016   if (node.attr().count("data_format")) {
1017     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1018   }
1019   if (HasAttr(node, "T")) {
1020     CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1021   } else {
1022     LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1023   }
1024   auto* maxpool = new MaxPoolOperator;
1025   maxpool->inputs.push_back(input_name);
1026   maxpool->outputs.push_back(node.name());
1027   const auto& strides = GetListAttr(node, "strides");
1028   CHECK_EQ(strides.i_size(), 4);
1029   CHECK_EQ(strides.i(0), 1);
1030   CHECK_EQ(strides.i(3), 1);
1031   maxpool->stride_height = strides.i(1);
1032   maxpool->stride_width = strides.i(2);
1033   const auto& ksize = GetListAttr(node, "ksize");
1034   CHECK_EQ(ksize.i_size(), 4);
1035   CHECK_EQ(ksize.i(0), 1);
1036   CHECK_EQ(ksize.i(3), 1);
1037   maxpool->kheight = ksize.i(1);
1038   maxpool->kwidth = ksize.i(2);
1039   const auto& padding = GetStringAttr(node, "padding");
1040   if (padding == "SAME") {
1041     maxpool->padding.type = PaddingType::kSame;
1042   } else if (padding == "VALID") {
1043     maxpool->padding.type = PaddingType::kValid;
1044   } else {
1045     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1046   }
1047   model->operators.emplace_back(maxpool);
1048   return tensorflow::Status::OK();
1049 }
1050 
ConvertAvgPoolOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1051 tensorflow::Status ConvertAvgPoolOperator(
1052     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1053     Model* model) {
1054   CHECK_EQ(node.op(), "AvgPool");
1055   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1056   const auto& input_name = node.input(0);
1057   // We only support NHWC, which is the default data_format.
1058   // So if data_format is not defined, we're all good.
1059   if (node.attr().count("data_format")) {
1060     CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1061   }
1062   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1063   auto* avgpool = new AveragePoolOperator;
1064   avgpool->inputs.push_back(input_name);
1065   avgpool->outputs.push_back(node.name());
1066   const auto& strides = GetListAttr(node, "strides");
1067   CHECK_EQ(strides.i_size(), 4);
1068   CHECK_EQ(strides.i(0), 1);
1069   CHECK_EQ(strides.i(3), 1);
1070   avgpool->stride_height = strides.i(1);
1071   avgpool->stride_width = strides.i(2);
1072   const auto& ksize = GetListAttr(node, "ksize");
1073   CHECK_EQ(ksize.i_size(), 4);
1074   CHECK_EQ(ksize.i(0), 1);
1075   CHECK_EQ(ksize.i(3), 1);
1076   avgpool->kheight = ksize.i(1);
1077   avgpool->kwidth = ksize.i(2);
1078   const auto& padding = GetStringAttr(node, "padding");
1079   if (padding == "SAME") {
1080     avgpool->padding.type = PaddingType::kSame;
1081   } else if (padding == "VALID") {
1082     avgpool->padding.type = PaddingType::kValid;
1083   } else {
1084     LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1085   }
1086   model->operators.emplace_back(avgpool);
1087   return tensorflow::Status::OK();
1088 }
1089 
ConvertBatchMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1090 tensorflow::Status ConvertBatchMatMulOperator(
1091     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1092     Model* model) {
1093   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1094 
1095   auto* batch_matmul = new BatchMatMulOperator;
1096   // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1097   if (HasAttr(node, "adj_x")) {
1098     batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1099   }
1100   if (HasAttr(node, "adj_y")) {
1101     batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1102   }
1103   batch_matmul->inputs = {node.input(0), node.input(1)};
1104   batch_matmul->outputs = {node.name()};
1105 
1106   // For Flex mode. Please read the comments of the function.
1107   RetainTensorFlowNodeDef(node, batch_matmul);
1108 
1109   model->operators.emplace_back(batch_matmul);
1110   return tensorflow::Status::OK();
1111 }
1112 
ConvertMatMulOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1113 tensorflow::Status ConvertMatMulOperator(
1114     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1115     Model* model) {
1116   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1117 
1118   CHECK(!HasAttr(node, "adjoint_a") ||
1119         (GetBoolAttr(node, "adjoint_a") == false));
1120   CHECK(!HasAttr(node, "adjoint_b") ||
1121         (GetBoolAttr(node, "adjoint_b") == false));
1122 
1123   auto* matmul = new TensorFlowMatMulOperator;
1124   if (HasAttr(node, "transpose_a")) {
1125     matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1126   }
1127   if (HasAttr(node, "transpose_b")) {
1128     matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1129   }
1130 
1131   matmul->inputs = {node.input(0), node.input(1)};
1132   matmul->outputs = {node.name()};
1133   model->operators.emplace_back(matmul);
1134   return tensorflow::Status::OK();
1135 }
1136 
ConvertConcatOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1137 tensorflow::Status ConvertConcatOperator(
1138     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1139     Model* model) {
1140   Operator* op = nullptr;
1141   if (node.op() == "Concat") {
1142     op = new TensorFlowConcatOperator;
1143   } else if (node.op() == "ConcatV2") {
1144     op = new TensorFlowConcatV2Operator;
1145   } else {
1146     LOG(FATAL) << "Expected Concat or ConcatV2";
1147   }
1148   const int num_inputs = GetInputsCount(node, tf_import_flags);
1149   QCHECK_GE(num_inputs, 2)
1150       << node.op()
1151       << " node expects at least 2 inputs other than control dependencies: "
1152       << node.DebugString();
1153   CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1154   for (int i = 0; i < num_inputs; ++i) {
1155     op->inputs.push_back(node.input(i));
1156   }
1157   op->outputs.push_back(node.name());
1158   model->operators.emplace_back(op);
1159   return tensorflow::Status::OK();
1160 }
1161 
ConvertMirrorPadOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1162 tensorflow::Status ConvertMirrorPadOperator(
1163     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1164     Model* model) {
1165   if (node.op() != "MirrorPad") {
1166     LOG(FATAL) << "Expected MirrorPad.";
1167   }
1168   const int num_inputs = GetInputsCount(node, tf_import_flags);
1169   CHECK_EQ(num_inputs, 2);
1170   auto* op = new MirrorPadOperator;
1171   for (int i = 0; i < num_inputs; ++i) {
1172     op->inputs.push_back(node.input(i));
1173   }
1174   op->outputs.push_back(node.name());
1175   const auto mode = GetStringAttr(node, "mode");
1176   if (mode == "REFLECT") {
1177     op->mode = toco::MirrorPadMode::kReflect;
1178   } else if (mode == "SYMMETRIC") {
1179     op->mode = toco::MirrorPadMode::kSymmetric;
1180   }
1181 
1182   model->operators.emplace_back(op);
1183 
1184   return tensorflow::Status::OK();
1185 }
1186 
1187 static constexpr int kAnyNumInputs = -1;
1188 
1189 enum FlexSupport { kFlexOk, kFlexNotOk };
1190 
1191 // This method supports simple operators without additional attributes.
1192 // Converts a simple operator that takes no attributes. The list of inputs is
1193 // taken from the given NodeDef, and its number must match NumInputs, unless
1194 // kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1195 // will be eligible for being exported as a flex op.
1196 template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
ConvertSimpleOperatorGeneric(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1197 tensorflow::Status ConvertSimpleOperatorGeneric(
1198     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1199     Model* model) {
1200   if (NumInputs != kAnyNumInputs) {
1201     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1202   }
1203   auto* op = new Op;
1204   const int num_inputs = GetInputsCount(node, tf_import_flags);
1205   for (int i = 0; i < num_inputs; ++i) {
1206     op->inputs.push_back(node.input(i));
1207   }
1208   op->outputs.push_back(node.name());
1209   if (NumOutputs > 1) {
1210     for (int i = 1; i < NumOutputs; ++i) {
1211       op->outputs.push_back(node.name() + ":" + std::to_string(i));
1212     }
1213   }
1214 
1215   if (flex == kFlexOk) {
1216     RetainTensorFlowNodeDef(node, op);
1217   }
1218 
1219   model->operators.emplace_back(op);
1220   return tensorflow::Status::OK();
1221 }
1222 
1223 // Convert a simple operator which is not valid as a flex op.
1224 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1225 tensorflow::Status ConvertSimpleOperator(
1226     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1227     Model* model) {
1228   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1229       node, tf_import_flags, model);
1230 }
1231 
1232 // Convert a simple operator which is valid as a flex op.
1233 template <typename Op, int NumInputs, int NumOutputs>
ConvertSimpleOperatorFlexOk(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1234 tensorflow::Status ConvertSimpleOperatorFlexOk(
1235     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1236     Model* model) {
1237   return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1238       node, tf_import_flags, model);
1239 }
1240 
GetOutputNamesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)1241 void GetOutputNamesFromNodeDef(const NodeDef& node,
1242                                const tensorflow::OpDef& op_def,
1243                                TensorFlowUnsupportedOperator* op) {
1244   int next_output = 0;
1245   auto add_output = [&node, &next_output, op]() {
1246     if (next_output == 0) {
1247       op->outputs.push_back(node.name());  // Implicit :0.
1248     } else {
1249       op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
1250     }
1251     ++next_output;
1252   };
1253   for (int i = 0; i < op_def.output_arg_size(); ++i) {
1254     string multiples = op_def.output_arg(i).number_attr();
1255     if (!multiples.empty()) {
1256       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
1257       int num_outputs = GetIntAttr(node, multiples);
1258       for (int j = 0; j < num_outputs; ++j) {
1259         add_output();
1260       }
1261     } else {
1262       string list = op_def.output_arg(i).type_list_attr();
1263       if (!list.empty()) {
1264         CHECK(HasAttr(node, list)) << "No attr named " << list;
1265         const AttrValue::ListValue& list_value = GetListAttr(node, list);
1266         for (int j = 0; j < list_value.type_size(); ++j) {
1267           add_output();
1268         }
1269       } else {
1270         add_output();
1271       }
1272     }
1273   }
1274 }
1275 
GetOutputTypesFromNodeDef(const NodeDef & node,const tensorflow::OpDef & op_def,TensorFlowUnsupportedOperator * op)1276 void GetOutputTypesFromNodeDef(const NodeDef& node,
1277                                const tensorflow::OpDef& op_def,
1278                                TensorFlowUnsupportedOperator* op) {
1279   // The given type to the op, or clear the types if invalid.
1280   auto add_type = [&node, op](tensorflow::DataType type) {
1281     if (type == tensorflow::DT_INVALID) {
1282       LOG(WARNING) << "Op node missing output type attribute: " << node.name();
1283       op->output_data_types.clear();
1284     } else {
1285       op->output_data_types.push_back(ConvertDataType(type));
1286     }
1287   };
1288 
1289   // Retrieve the data type according to the OpDef definition: either the
1290   // "type" or "type_attr" field will be set.
1291   auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
1292     if (a.type() != tensorflow::DT_INVALID) {
1293       return a.type();
1294     } else if (HasAttr(node, a.type_attr())) {
1295       return GetDataTypeAttr(node, a.type_attr());
1296     } else {
1297       return tensorflow::DT_INVALID;
1298     }
1299   };
1300 
1301   for (int i = 0; i < op_def.output_arg_size(); ++i) {
1302     string multiples = op_def.output_arg(i).number_attr();
1303     if (!multiples.empty()) {
1304       CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
1305       int num_outputs = GetIntAttr(node, multiples);
1306       auto type = get_type(op_def.output_arg(i));
1307       for (int j = 0; j < num_outputs; ++j) {
1308         add_type(type);
1309       }
1310     } else {
1311       string list = op_def.output_arg(i).type_list_attr();
1312       if (!list.empty()) {
1313         CHECK(HasAttr(node, list)) << "No attr named " << list;
1314         const AttrValue::ListValue& list_value = GetListAttr(node, list);
1315         for (int j = 0; j < list_value.type_size(); ++j) {
1316           add_type(list_value.type(j));
1317         }
1318       } else {
1319         add_type(get_type(op_def.output_arg(i)));
1320       }
1321     }
1322   }
1323 }
1324 
ConvertUnsupportedOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1325 tensorflow::Status ConvertUnsupportedOperator(
1326     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1327     Model* model) {
1328   // Names of special attributes in TF graph that are used by Toco.
1329   static constexpr char kAttrOutputQuantized[] = "_output_quantized";
1330   static constexpr char kAttrOutputTypes[] = "_output_types";
1331   static constexpr char kAttrOutputShapes[] = "_output_shapes";
1332   static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
1333       "_support_output_type_float_in_quantized_op";
1334 
1335   LOG(INFO) << "Converting unsupported operation: " << node.op();
1336 
1337   auto* op = new TensorFlowUnsupportedOperator;
1338   op->tensorflow_op = node.op();
1339 
1340   // For Flex mode. Please read the comments of the function.
1341   RetainTensorFlowNodeDef(node, op);
1342 
1343   model->operators.emplace_back(op);
1344 
1345   // Parse inputs.
1346   const int num_inputs = GetInputsCount(node, tf_import_flags);
1347   for (int i = 0; i < num_inputs; ++i) {
1348     op->inputs.push_back(node.input(i));
1349   }
1350 
1351   // Parse outputs. Name them after the node's name, plus an ordinal suffix.
1352   // Note that some outputs are to be multiplied by a named attribute.
1353   const tensorflow::OpDef* op_def = nullptr;
1354   if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
1355     GetOutputNamesFromNodeDef(node, *op_def, op);
1356   } else {
1357     op->outputs.push_back(node.name());  // Implicit :0.
1358   }
1359 
1360   // Parse if the op supports quantization
1361   if (HasAttr(node, kAttrOutputQuantized)) {
1362     op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
1363   }
1364   // Parse if the quantized op allows output arrays of type float
1365   if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
1366     op->support_output_type_float_in_quantized_op =
1367         GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
1368   }
1369 
1370   // Parse output type(s).
1371   if (HasAttr(node, kAttrOutputTypes)) {
1372     const auto& output_types = GetListAttr(node, kAttrOutputTypes);
1373     for (int i = 0; i < output_types.type_size(); ++i) {
1374       op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
1375     }
1376   } else if (HasAttr(node, "Tout")) {
1377     const auto& output_type = GetDataTypeAttr(node, "Tout");
1378     op->output_data_types.push_back(ConvertDataType(output_type));
1379   } else if (op_def != nullptr) {
1380     GetOutputTypesFromNodeDef(node, *op_def, op);
1381   } else {
1382     // TODO(b/113613439): Figure out how to propagate types for custom ops
1383     // that have no OpDef.
1384     LOG(INFO) << "Unable to determine output type for op: " << node.op();
1385   }
1386 
1387   // Parse output shape(s).
1388   if (HasAttr(node, kAttrOutputShapes)) {
1389     const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
1390     Shape output_shape;
1391     for (int i = 0; i < output_shapes.shape_size(); ++i) {
1392       const auto& shape = output_shapes.shape(i);
1393       // TOCO doesn't yet properly handle shapes with wildcard dimensions.
1394       // TODO(b/113613439): Handle shape inference for unsupported ops that have
1395       // shapes with wildcard dimensions.
1396       if (HasWildcardDimension(shape)) {
1397         LOG(INFO) << "Skipping wildcard output shape(s) for node: "
1398                   << node.name();
1399         op->output_shapes.clear();
1400         break;
1401       }
1402       const auto status =
1403           ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
1404       if (!status.ok()) {
1405         return status;
1406       }
1407       op->output_shapes.push_back(output_shape);
1408     }
1409   }
1410   return tensorflow::Status::OK();
1411 }
1412 
1413 // Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1414 // the types are not supported. Converting Const operators here avoids
1415 // expensive copies of the protocol buffers downstream in the flex delegate.
ConditionallyConvertConstOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1416 tensorflow::Status ConditionallyConvertConstOperator(
1417     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1418     Model* model) {
1419   // We avoid incomplete and zero shapes because the resulting arrays
1420   // are not completely compatible with Eager/TensorFlow.
1421   const auto& tensor = GetTensorAttr(node, "value");
1422   const auto& shape = tensor.tensor_shape();
1423   for (const auto& dim : shape.dim()) {
1424     if (dim.size() <= 0) {
1425       return ConvertUnsupportedOperator(node, tf_import_flags, model);
1426     }
1427   }
1428 
1429   switch (GetDataTypeAttr(node, "dtype")) {
1430     case DT_FLOAT:
1431     case DT_INT32:
1432     case DT_QUINT8:
1433     case DT_INT64:
1434     case DT_STRING:
1435     case DT_BOOL:
1436     case DT_COMPLEX64:
1437       return ConvertConstOperator(node, tf_import_flags, model);
1438     default:
1439       return ConvertUnsupportedOperator(node, tf_import_flags, model);
1440   }
1441 }
1442 
ConvertStridedSliceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1443 tensorflow::Status ConvertStridedSliceOperator(
1444     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1445     Model* model) {
1446   CHECK_EQ(node.op(), "StridedSlice");
1447   // TODO(soroosh): The 4th input (strides) should be e optional, to be
1448   // consistent with TF.
1449   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1450 
1451   auto* op = new StridedSliceOperator;
1452   for (const auto& input : node.input()) {
1453     op->inputs.push_back(input);
1454   }
1455   op->outputs.push_back(node.name());
1456 
1457   op->begin_mask =
1458       HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1459   op->ellipsis_mask =
1460       HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1461   op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1462   op->new_axis_mask =
1463       HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1464   op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1465                              ? GetIntAttr(node, "shrink_axis_mask")
1466                              : 0;
1467 
1468   model->operators.emplace_back(op);
1469   return tensorflow::Status::OK();
1470 }
1471 
ConvertPlaceholderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1472 tensorflow::Status ConvertPlaceholderOperator(
1473     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1474     Model* model) {
1475   CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1476   if (node.op() == "Placeholder") {
1477     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1478   }
1479   auto& array = model->GetOrCreateArray(node.name());
1480   if (node.attr().count("dtype")) {
1481     array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1482   }
1483   if (node.attr().count("shape")) {
1484     const auto& shape = GetShapeAttr(node, "shape");
1485     auto num_dims = shape.dim_size();
1486     // TODO(b/62716978): This logic needs to be revisited.  During dims
1487     // refactoring it is an interim fix.
1488     if (num_dims > 0 && !HasWildcardDimension(shape)) {
1489       auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1490       dst_array_dims.resize(num_dims);
1491       for (std::size_t i = 0; i < num_dims; i++) {
1492         dst_array_dims[i] = shape.dim(i).size();
1493       }
1494     }
1495   }
1496   return tensorflow::Status::OK();
1497 }
1498 
ConvertNoOpOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1499 tensorflow::Status ConvertNoOpOperator(
1500     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1501     Model* model) {
1502   return tensorflow::Status::OK();
1503 }
1504 
ConvertCastOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1505 tensorflow::Status ConvertCastOperator(
1506     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1507     Model* model) {
1508   CHECK_EQ(node.op(), "Cast");
1509   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1510   const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1511   const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1512   auto* op = new CastOperator;
1513   op->src_data_type = ConvertDataType(tf_src_dtype);
1514   op->dst_data_type = ConvertDataType(tf_dst_dtype);
1515   op->inputs.push_back(node.input(0));
1516   op->outputs.push_back(node.name());
1517   model->operators.emplace_back(op);
1518   return tensorflow::Status::OK();
1519 }
1520 
ConvertFloorOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1521 tensorflow::Status ConvertFloorOperator(
1522     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1523     Model* model) {
1524   CHECK_EQ(node.op(), "Floor");
1525   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1526   const auto data_type = GetDataTypeAttr(node, "T");
1527   CHECK(data_type == DT_FLOAT);
1528   auto* op = new FloorOperator;
1529   op->inputs.push_back(node.input(0));
1530   op->outputs.push_back(node.name());
1531   model->operators.emplace_back(op);
1532   return tensorflow::Status::OK();
1533 }
1534 
ConvertCeilOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1535 tensorflow::Status ConvertCeilOperator(
1536     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1537     Model* model) {
1538   CHECK_EQ(node.op(), "Ceil");
1539   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1540   const auto data_type = GetDataTypeAttr(node, "T");
1541   CHECK(data_type == DT_FLOAT);
1542   auto* op = new CeilOperator;
1543   op->inputs.push_back(node.input(0));
1544   op->outputs.push_back(node.name());
1545   model->operators.emplace_back(op);
1546   return tensorflow::Status::OK();
1547 }
1548 
ConvertGatherOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1549 tensorflow::Status ConvertGatherOperator(
1550     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1551     Model* model) {
1552   CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1553   if (node.op() == "Gather")
1554     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1555   if (node.op() == "GatherV2")
1556     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1557   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1558   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1559   auto* op = new GatherOperator;
1560   op->inputs.push_back(node.input(0));
1561   op->inputs.push_back(node.input(1));
1562   if (node.input_size() >= 3) {
1563     // GatherV2 form where we are provided an axis. It may be either a constant
1564     // or runtime defined value, so we just wire up the array and let
1565     // ResolveGatherAttributes take care of it later on.
1566     const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1567     CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1568     op->inputs.push_back(node.input(2));
1569   } else {
1570     // Gather form that assumes axis=0.
1571     op->axis = {0};
1572   }
1573   op->outputs.push_back(node.name());
1574   model->operators.emplace_back(op);
1575   return tensorflow::Status::OK();
1576 }
1577 
ConvertGatherNdOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1578 tensorflow::Status ConvertGatherNdOperator(
1579     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1580     Model* model) {
1581   CHECK_EQ(node.op(), "GatherNd");
1582   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1583   const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1584   CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1585   auto* op = new GatherNdOperator;
1586   op->inputs.push_back(node.input(0));
1587   op->inputs.push_back(node.input(1));
1588   op->outputs.push_back(node.name());
1589   model->operators.emplace_back(op);
1590   return tensorflow::Status::OK();
1591 }
1592 
1593 template <typename Op>
ConvertArgMinMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1594 tensorflow::Status ConvertArgMinMaxOperator(
1595     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1596     Model* model) {
1597   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1598   const auto axis_data_type =
1599       HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1600   const auto output_type = HasAttr(node, "output_type")
1601                                ? GetDataTypeAttr(node, "output_type")
1602                                : DT_INT64;
1603   CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1604   CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1605   auto* op = new Op;
1606   op->output_data_type = ConvertDataType(output_type);
1607   op->inputs.push_back(node.input(0));
1608   op->inputs.push_back(node.input(1));
1609   op->outputs.push_back(node.name());
1610   model->operators.emplace_back(op);
1611   return tensorflow::Status::OK();
1612 }
1613 
ConvertArgMaxOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1614 tensorflow::Status ConvertArgMaxOperator(
1615     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1616     Model* model) {
1617   CHECK_EQ(node.op(), "ArgMax");
1618   return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags, model);
1619 }
1620 
ConvertArgMinOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1621 tensorflow::Status ConvertArgMinOperator(
1622     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1623     Model* model) {
1624   CHECK_EQ(node.op(), "ArgMin");
1625   return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags, model);
1626 }
1627 
ConvertResizeBilinearOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1628 tensorflow::Status ConvertResizeBilinearOperator(
1629     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1630     Model* model) {
1631   CHECK_EQ(node.op(), "ResizeBilinear");
1632   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1633   auto* op = new ResizeBilinearOperator;
1634 
1635   op->align_corners = false;
1636   if (HasAttr(node, "align_corners")) {
1637     op->align_corners = GetBoolAttr(node, "align_corners");
1638   }
1639 
1640   op->inputs.push_back(node.input(0));
1641   op->inputs.push_back(node.input(1));
1642   op->outputs.push_back(node.name());
1643   model->operators.emplace_back(op);
1644   return tensorflow::Status::OK();
1645 }
1646 
ConvertResizeNearestNeighborOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1647 tensorflow::Status ConvertResizeNearestNeighborOperator(
1648     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1649     Model* model) {
1650   CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1651   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1652   auto* op = new ResizeNearestNeighborOperator;
1653 
1654   op->align_corners = false;
1655   if (HasAttr(node, "align_corners")) {
1656     op->align_corners = GetBoolAttr(node, "align_corners");
1657   }
1658 
1659   op->inputs.push_back(node.input(0));
1660   op->inputs.push_back(node.input(1));
1661   op->outputs.push_back(node.name());
1662   model->operators.emplace_back(op);
1663   return tensorflow::Status::OK();
1664 }
1665 
ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1666 tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1667     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1668     Model* model) {
1669   CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1670   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1671 
1672   // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1673   // to the input, before feeding it into TensorFlowRsqrtOperator.
1674   // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1675 
1676   string multiplier = node.name() + "_mul";
1677   if (GetBoolAttr(node, "scale_after_normalization")) {
1678     // Create graph:
1679     //   v -> RSQRT ->
1680     //                 MUL  -> multiplier
1681     //   gamma  ----->
1682     string rsqrt = node.name() + "_rsqrt";
1683 
1684     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1685     rsqrt_op->inputs.push_back(node.input(2));
1686     rsqrt_op->outputs.push_back(rsqrt);
1687     model->operators.emplace_back(rsqrt_op);
1688 
1689     auto* mul_op = new MulOperator;
1690     mul_op->inputs.push_back(rsqrt);
1691     mul_op->inputs.push_back(node.input(4));
1692     mul_op->outputs.push_back(multiplier);
1693     model->operators.emplace_back(mul_op);
1694   } else {
1695     // Create graph:
1696     //   v -> RSQRT -> multiplier
1697     auto* rsqrt_op = new TensorFlowRsqrtOperator;
1698     rsqrt_op->inputs.push_back(node.input(2));
1699     rsqrt_op->outputs.push_back(multiplier);
1700     model->operators.emplace_back(rsqrt_op);
1701   }
1702 
1703   auto* op = new BatchNormalizationOperator;
1704   op->global_normalization = true;
1705 
1706   op->inputs.push_back(node.input(0));
1707   op->inputs.push_back(node.input(1));
1708   op->inputs.push_back(multiplier);
1709   op->inputs.push_back(node.input(3));
1710   op->outputs.push_back(node.name());
1711 
1712   model->operators.emplace_back(op);
1713   return tensorflow::Status::OK();
1714 }
1715 
ConvertFusedBatchNormOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1716 tensorflow::Status ConvertFusedBatchNormOperator(
1717     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1718     Model* model) {
1719   CHECK_EQ(node.op(), "FusedBatchNorm");
1720   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1721 
1722   // Declare shortcuts for the inputs.
1723   const string& gamma_input = node.input(1);
1724   const string& beta_input = node.input(2);
1725   const string& moving_mean_input = node.input(3);
1726   const string& moving_variance_input = node.input(4);
1727 
1728   // Create an array holding the epsilon value (typically, 0.001).
1729   const string epsilon_array_name = CreateConstArray<ArrayDataType::kFloat>(
1730       model, node.name() + "_epsilon_array", {GetFloatAttr(node, "epsilon")});
1731 
1732   // Add epsilon to the moving variance.
1733   const string epsilon_add_op_name = node.name() + "_epsilon";
1734   auto* epsilon_add_op = new AddOperator;
1735   epsilon_add_op->inputs.push_back(moving_variance_input);
1736   epsilon_add_op->inputs.push_back(epsilon_array_name);
1737   epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1738   model->operators.emplace_back(epsilon_add_op);
1739 
1740   // Take the inverse square root of the (variance + epsilon).
1741   const string rsqrt_op_name = node.name() + "_rsqrt";
1742   auto* rsqrt_op = new TensorFlowRsqrtOperator;
1743   rsqrt_op->inputs.push_back(epsilon_add_op_name);
1744   rsqrt_op->outputs.push_back(rsqrt_op_name);
1745   model->operators.emplace_back(rsqrt_op);
1746 
1747   // Multiply the result by gamma.
1748   const string multiplier = node.name() + "_mul";
1749   auto* mul_op = new MulOperator;
1750   mul_op->inputs.push_back(rsqrt_op_name);
1751   mul_op->inputs.push_back(gamma_input);
1752   mul_op->outputs.push_back(multiplier);
1753   model->operators.emplace_back(mul_op);
1754 
1755   // Now we have all required inputs for the BatchNormalizationOperator.
1756   auto* op = new BatchNormalizationOperator;
1757   op->global_normalization = true;
1758 
1759   op->inputs.push_back(node.input(0));
1760   op->inputs.push_back(moving_mean_input);
1761   op->inputs.push_back(multiplier);
1762   op->inputs.push_back(beta_input);
1763   op->outputs.push_back(node.name());
1764 
1765   model->operators.emplace_back(op);
1766   return tensorflow::Status::OK();
1767 }
1768 
ConvertSpaceToBatchNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1769 tensorflow::Status ConvertSpaceToBatchNDOperator(
1770     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1771     Model* model) {
1772   CHECK_EQ(node.op(), "SpaceToBatchND");
1773   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1774   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1775   CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1776   auto* op = new SpaceToBatchNDOperator;
1777   op->inputs.push_back(node.input(0));
1778   op->inputs.push_back(node.input(1));
1779   op->inputs.push_back(node.input(2));
1780   op->outputs.push_back(node.name());
1781   model->operators.emplace_back(op);
1782   return tensorflow::Status::OK();
1783 }
1784 
ConvertBatchToSpaceNDOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1785 tensorflow::Status ConvertBatchToSpaceNDOperator(
1786     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1787     Model* model) {
1788   CHECK_EQ(node.op(), "BatchToSpaceND");
1789   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1790   CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1791   CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1792   auto* op = new BatchToSpaceNDOperator;
1793   op->inputs.push_back(node.input(0));
1794   op->inputs.push_back(node.input(1));
1795   op->inputs.push_back(node.input(2));
1796   op->outputs.push_back(node.name());
1797   model->operators.emplace_back(op);
1798   return tensorflow::Status::OK();
1799 }
1800 
1801 template <typename T>
ConvertReduceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1802 tensorflow::Status ConvertReduceOperator(
1803     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1804     Model* model) {
1805   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1806   auto* op = new T;
1807   op->inputs.push_back(node.input(0));
1808   op->inputs.push_back(node.input(1));
1809   op->outputs.push_back(node.name());
1810   model->operators.emplace_back(op);
1811   if (HasAttr(node, "keepdims")) {
1812     op->keep_dims = GetBoolAttr(node, "keepdims");
1813   } else if (HasAttr(node, "keep_dims")) {
1814     op->keep_dims = GetBoolAttr(node, "keep_dims");
1815   }
1816   return tensorflow::Status::OK();
1817 }
1818 
ConvertSvdfOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1819 tensorflow::Status ConvertSvdfOperator(
1820     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1821     Model* model) {
1822   CHECK_EQ(node.op(), "Svdf");
1823   const int input_size = GetInputsCount(node, tf_import_flags);
1824   QCHECK(input_size == 3 || input_size == 4)
1825       << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1826       << node.DebugString();
1827   bool has_bias = (input_size == 4);
1828   auto* op = new SvdfOperator;
1829   op->inputs.push_back(node.input(0));
1830   op->inputs.push_back(node.input(1));
1831   op->inputs.push_back(node.input(2));
1832   if (has_bias) {
1833     op->inputs.push_back(node.input(3));
1834   }
1835   op->outputs.push_back(node.name() + "_state");
1836   op->outputs.push_back(node.name());
1837   if (node.attr().at("ActivationFunction").s() == "Relu") {
1838     op->fused_activation_function = FusedActivationFunctionType::kRelu;
1839   } else {
1840     op->fused_activation_function = FusedActivationFunctionType::kNone;
1841   }
1842   op->rank = node.attr().at("Rank").i();
1843   model->operators.emplace_back(op);
1844   return tensorflow::Status::OK();
1845 }
1846 
1847 // This is just bare bones support to get the shapes to propagate.
ConvertTransposeConvOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1848 tensorflow::Status ConvertTransposeConvOperator(
1849     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1850     Model* model) {
1851   CHECK_EQ(node.op(), "Conv2DBackpropInput");
1852   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1853   auto* op = new TransposeConvOperator;
1854   op->inputs.push_back(node.input(0));
1855   op->inputs.push_back(node.input(1));
1856   op->inputs.push_back(node.input(2));
1857   op->outputs.push_back(node.name());
1858   const auto& strides = GetListAttr(node, "strides");
1859   op->stride_height = strides.i(1);
1860   op->stride_width = strides.i(2);
1861   CHECK_EQ(strides.i_size(), 4)
1862       << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1863       << node.name() << "\" has " << strides.i_size() << "D strides.";
1864   CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1865       << "Can only import TransposeConv ops with striding along the height "
1866          "(1st) or width (2nd) axis. TensorFlow op \""
1867       << node.name() << "\" had strides:[ " << strides.i(0) << ", "
1868       << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
1869   op->stride_height = strides.i(1);
1870   op->stride_width = strides.i(2);
1871   if (HasAttr(node, "dilations")) {
1872     const auto& dilations = GetListAttr(node, "dilations");
1873     CHECK_EQ(dilations.i_size(), 4)
1874         << "Dilation unsupported in TransposeConv. TensorFlow op \""
1875         << node.name() << "\" had dilations";
1876     CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
1877           (dilations.i(1) == 1) && (dilations.i(3) == 1))
1878         << "Dilation unsupported in TransposeConv. TensorFlow op \""
1879         << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
1880         << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
1881         << "].";
1882   }
1883 
1884   const string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
1885   const string& transposed_weights_name = weights_name + "_transposed";
1886   // Check if a TransposeOperator was already created for these weights
1887   // (can happen when multiple layers share the same weights).
1888   const Operator* existing_transpose =
1889       GetOpWithOutput(*model, transposed_weights_name);
1890   if (existing_transpose) {
1891     CHECK(existing_transpose->type == OperatorType::kTranspose);
1892   } else {
1893     // Transpose weights from HWOI order to OHWI order, which is more efficient
1894     // for computation. (Note that TensorFlow considers the order as HWIO
1895     // because they consider this a backward conv, inverting the sense of
1896     // input/output.)
1897     TransposeOperator* transpose = new TransposeOperator;
1898     string perm_array = CreateConstArray<ArrayDataType::kInt32>(
1899         model, node.name() + "_transpose_perm", {2, 0, 1, 3});
1900     transpose->inputs = {weights_name, perm_array};
1901     transpose->outputs = {transposed_weights_name};
1902     model->operators.emplace_back(transpose);
1903   }
1904   op->inputs[1] = transposed_weights_name;
1905 
1906   auto const& padding = GetStringAttr(node, "padding");
1907   if (padding == "SAME") {
1908     op->padding.type = PaddingType::kSame;
1909   } else if (padding == "VALID") {
1910     op->padding.type = PaddingType::kValid;
1911   } else {
1912     LOG(FATAL) << "Only SAME and VALID padding supported on "
1913                   "Conv2DBackpropInput nodes.";
1914   }
1915   model->operators.emplace_back(op);
1916   return tensorflow::Status::OK();
1917 }
1918 
ConvertRangeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1919 tensorflow::Status ConvertRangeOperator(
1920     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1921     Model* model) {
1922   CHECK_EQ(node.op(), "Range");
1923   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1924   auto* op = new RangeOperator;
1925   if (HasAttr(node, "Tidx")) {
1926     const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
1927     CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
1928           dtype == DT_FLOAT);
1929     op->dtype = ConvertDataType(dtype);
1930   }
1931   op->inputs.push_back(node.input(0));
1932   op->inputs.push_back(node.input(1));
1933   op->inputs.push_back(node.input(2));
1934   op->outputs.push_back(node.name());
1935 
1936   model->operators.emplace_back(op);
1937   return tensorflow::Status::OK();
1938 }
1939 
1940 // Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
1941 // they aren't the same thing.  tf.stack results in a "Pack" operator.  "Stack"
1942 // operators also exist, but involve manipulating the TF runtime stack, and are
1943 // not directly related to tf.stack() usage.
ConvertPackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1944 tensorflow::Status ConvertPackOperator(
1945     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1946     Model* model) {
1947   CHECK_EQ(node.op(), "Pack");
1948   auto op = absl::make_unique<PackOperator>();
1949   const int num_inputs = GetInputsCount(node, tf_import_flags);
1950   QCHECK_GE(num_inputs, 1)
1951       << node.op()
1952       << " node expects at least 1 input other than control dependencies: "
1953       << node.DebugString();
1954   CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
1955   for (int i = 0; i < num_inputs; ++i) {
1956     op->inputs.push_back(node.input(i));
1957   }
1958   op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
1959   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
1960   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
1961   op->outputs.push_back(node.name());
1962   model->operators.emplace_back(std::move(op));
1963   return tensorflow::Status::OK();
1964 }
1965 
ConvertUnpackOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1966 tensorflow::Status ConvertUnpackOperator(
1967     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1968     Model* model) {
1969   CHECK_EQ(node.op(), "Unpack");
1970   auto op = absl::make_unique<UnpackOperator>();
1971   const int num_inputs = GetInputsCount(node, tf_import_flags);
1972   QCHECK_EQ(num_inputs, 1);
1973   op->inputs.push_back(node.input(0));
1974   op->num = GetIntAttr(node, "num");
1975   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
1976   op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
1977 
1978   op->outputs.push_back(node.name());  // Implicit :0.
1979   for (int i = 1; i < op->num; ++i) {
1980     op->outputs.push_back(node.name() + ":" + std::to_string(i));
1981   }
1982   model->operators.emplace_back(std::move(op));
1983   return tensorflow::Status::OK();
1984 }
1985 
1986 // Some TensorFlow ops only occur in graph cycles, representing
1987 // control flow. We do not currently support control flow, so we wouldn't
1988 // be able to fully support such graphs, including performing inference,
1989 // anyway. However, rather than erroring out early on graphs being cyclic,
1990 // it helps to at least support these just enough to allow getting a
1991 // graph visualization. This is not trivial, as we require graphs to be
1992 // acyclic aside from RNN back-edges. The solution is to special-case
1993 // such ops as RNN back-edges, which is technically incorrect (does not
1994 // allow representing the op's semantics) but good enough to get a
1995 // graph visualization.
ConvertOperatorSpecialCasedAsRNNBackEdge(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)1996 tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
1997     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1998     Model* model) {
1999   // At the moment, the only type of operator special-cased in this way is
2000   // NextIteration, occurring only in control-flow cycles.
2001   CHECK_EQ(node.op(), "NextIteration");
2002   CHECK_EQ(node.input_size(), 1);
2003   auto* rnn_state = model->flags.add_rnn_states();
2004   // This RNN state is not explicitly created by the user, so it's
2005   // OK for some later graph transformation to discard it.
2006   rnn_state->set_discardable(true);
2007   rnn_state->set_state_array(node.name());
2008   rnn_state->set_back_edge_source_array(node.input(0));
2009   return tensorflow::Status::OK();
2010 }
2011 
ConvertShapeOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2012 tensorflow::Status ConvertShapeOperator(
2013     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2014     Model* model) {
2015   CHECK_EQ(node.op(), "Shape");
2016   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2017   const auto out_type =
2018       HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2019   CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2020   auto op = absl::make_unique<TensorFlowShapeOperator>();
2021   op->output_data_type = ConvertDataType(out_type);
2022   op->inputs.push_back(node.input(0));
2023   op->outputs.push_back(node.name());
2024   model->operators.push_back(std::move(op));
2025   return tensorflow::Status::OK();
2026 }
2027 
ConvertReverseSequenceOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2028 tensorflow::Status ConvertReverseSequenceOperator(
2029     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2030     Model* model) {
2031   CHECK_EQ(node.op(), "ReverseSequence");
2032   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2033   auto op = absl::make_unique<ReverseSequenceOperator>();
2034   if (HasAttr(node, "seq_dim")) {
2035     op->seq_dim = GetIntAttr(node, "seq_dim");
2036   }
2037   // In tf.reverse_sequence, batch_dim defaults to 0.
2038   op->batch_dim =
2039       HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2040   const int num_inputs = GetInputsCount(node, tf_import_flags);
2041   for (int i = 0; i < num_inputs; ++i) {
2042     op->inputs.push_back(node.input(i));
2043   }
2044   op->outputs.push_back(node.name());
2045   model->operators.push_back(std::move(op));
2046   return tensorflow::Status::OK();
2047 }
2048 
StripCaretFromArrayNames(Model * model)2049 void StripCaretFromArrayNames(Model* model) {
2050   for (auto& op : model->operators) {
2051     for (auto& input : op->inputs) {
2052       input = string(absl::StripPrefix(input, "^"));
2053     }
2054     for (auto& output : op->outputs) {
2055       output = string(absl::StripPrefix(output, "^"));
2056     }
2057   }
2058   for (auto& array : model->GetArrayMap()) {
2059     if (absl::StartsWith(array.first, "^")) {
2060       LOG(FATAL) << "What?";
2061     }
2062   }
2063 }
2064 
StripZeroOutputIndexFromInputs(NodeDef * node)2065 void StripZeroOutputIndexFromInputs(NodeDef* node) {
2066   for (auto& input : *node->mutable_input()) {
2067     input = string(absl::StripSuffix(input, ":0"));
2068   }
2069 }
2070 
2071 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
2072 // name:0, name:1, ...
2073 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2074 // form for name:0.
2075 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
2076 // (unlike inputs), it being implied by the node's name and operator type
2077 // (the latter implies the number of outputs).
2078 // This makes it non-trivial for us to reconstruct the list of all arrays
2079 // present in the graph and, for each operator, the list of its outputs.
2080 // We do that by taking advantage of the fact that
2081 // at least each node lists explicitly its inputs, so after we've loaded
2082 // all nodes, we can use that information.
AddExtraOutputs(Model * model)2083 void AddExtraOutputs(Model* model) {
2084   // Construct the list of all arrays consumed by anything in the graph.
2085   std::vector<string> consumed_arrays;
2086   // Add arrays consumed by an op.
2087   for (const auto& consumer_op : model->operators) {
2088     for (const string& input : consumer_op->inputs) {
2089       consumed_arrays.push_back(input);
2090     }
2091   }
2092   // Add global outputs of the model.
2093   for (const string& output_array : model->flags.output_arrays()) {
2094     consumed_arrays.push_back(output_array);
2095   }
2096   // Add arrays consumed by a RNN back-edge.
2097   for (const auto& rnn_state : model->flags.rnn_states()) {
2098     consumed_arrays.push_back(rnn_state.back_edge_source_array());
2099   }
2100   // Now add operator outputs so that all arrays that are consumed,
2101   // are produced.
2102   for (const string& consumed_array : consumed_arrays) {
2103     // Split the consumed array name into the form name:output_index.
2104     const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
2105     // If not of the form name:output_index, then this is not an additional
2106     // output of a node with multiple outputs, so nothing to do here.
2107     if (split.size() != 2) {
2108       continue;
2109     }
2110     int output_index = 0;
2111     if (!absl::SimpleAtoi(split[1], &output_index)) {
2112       continue;
2113     }
2114     // Each op is initially recorded as producing at least the array that
2115     // has its name. We use that to identify the producer node.
2116     auto* producer_op = GetOpWithOutput(*model, split[0]);
2117     if (!producer_op) {
2118       continue;
2119     }
2120     // Add extra outputs to that producer node, all the way to the
2121     // output_index.
2122     while (producer_op->outputs.size() <= output_index) {
2123       using toco::port::StringF;
2124       producer_op->outputs.push_back(
2125           StringF("%s:%d", split[0], producer_op->outputs.size()));
2126     }
2127   }
2128 }
2129 
InlineAllFunctions(GraphDef * graphdef)2130 bool InlineAllFunctions(GraphDef* graphdef) {
2131   if (graphdef->library().function().empty()) {
2132     VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2133     return false;
2134   }
2135 
2136   // Override "_noinline" attribute on all functions
2137   GraphDef graphdef_copy(*graphdef);
2138   for (auto& function :
2139        (*graphdef_copy.mutable_library()->mutable_function())) {
2140     auto* attributes = function.mutable_attr();
2141     if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2142       (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2143     }
2144   }
2145 
2146   // Construct minimum resources needed to use ExpandInlineFunctions().
2147   tensorflow::SessionOptions options;
2148   auto* device_count = options.config.mutable_device_count();
2149   device_count->insert({"CPU", 1});
2150   std::vector<std::unique_ptr<tensorflow::Device>> devices;
2151   TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2152       options, "/job:localhost/replica:0/task:0", &devices));
2153 
2154   tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2155                                             graphdef_copy.library());
2156   tensorflow::DeviceMgr device_mgr(std::move(devices));
2157   tensorflow::OptimizerOptions o_opts;
2158   tensorflow::ProcessFunctionLibraryRuntime pflr(
2159       &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
2160       o_opts, nullptr);
2161   tensorflow::FunctionLibraryRuntime* flr;
2162   flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2163 
2164   tensorflow::Graph graph(fld);
2165   tensorflow::ImportGraphDefOptions gc_opts;
2166   gc_opts.validate_shape = false;
2167   const auto& tf_convert_status = tensorflow::ImportGraphDef(
2168       gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2169   if (!tf_convert_status.ok()) {
2170     LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2171                << tf_convert_status.ToString();
2172     return false;
2173   }
2174 
2175   // Iterate over the graph until there are no more nodes to be inlined.
2176   bool graph_modified = false;
2177   while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2178     graph_modified = true;
2179   }
2180 
2181   // Output inlined graph
2182   if (graph_modified) {
2183     LOG(INFO) << "Found and inlined TensorFlow functions.";
2184     graph.ToGraphDef(graphdef);
2185   }
2186   return graph_modified;
2187 }
2188 
ConvertTopKV2Operator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2189 tensorflow::Status ConvertTopKV2Operator(
2190     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2191     Model* model) {
2192   CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2193   auto op = absl::make_unique<TopKV2Operator>();
2194   op->inputs.push_back(node.input(0));
2195   // K can be encoded as attr (TopK) convert it to a const.
2196   if (HasAttr(node, "k")) {
2197     string k_array = CreateConstArray<ArrayDataType::kInt32>(
2198         model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2199     op->inputs.push_back(k_array);
2200   } else {
2201     TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2202     op->inputs.push_back(node.input(1));
2203   }
2204   // The op has two outputs.
2205   op->outputs.push_back(node.name());
2206   op->outputs.push_back(node.name() + ":1");
2207   model->operators.emplace_back(op.release());
2208   return tensorflow::Status::OK();
2209 }
2210 
ConvertDynamicPartitionOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2211 tensorflow::Status ConvertDynamicPartitionOperator(
2212     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2213     Model* model) {
2214   auto op = absl::make_unique<DynamicPartitionOperator>();
2215   CHECK(HasAttr(node, "num_partitions"));
2216   op->num_partitions = GetIntAttr(node, "num_partitions");
2217   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2218   op->inputs.push_back(node.input(0));
2219   op->inputs.push_back(node.input(1));
2220   CHECK_GT(op->num_partitions, 1);
2221   op->outputs.push_back(node.name());  // Implicit :0.
2222   for (int i = 1; i < op->num_partitions; ++i) {
2223     op->outputs.push_back(node.name() + ":" + std::to_string(i));
2224   }
2225   model->operators.emplace_back(op.release());
2226   return tensorflow::Status::OK();
2227 }
2228 
ConvertDynamicStitchOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2229 tensorflow::Status ConvertDynamicStitchOperator(
2230     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2231     Model* model) {
2232   // The parallel and non-parallel variants are the same besides whether they
2233   // have a parallel loop; there are no behavioral differences.
2234   CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2235   auto op = absl::make_unique<DynamicStitchOperator>();
2236   CHECK(HasAttr(node, "N"));
2237   op->num_partitions = GetIntAttr(node, "N");
2238   // Expect all ID partitions + all value partitions.
2239   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2240   for (int i = 0; i < op->num_partitions * 2; ++i) {
2241     op->inputs.push_back(node.input(i));
2242   }
2243   op->outputs.push_back(node.name());
2244   model->operators.emplace_back(op.release());
2245   return tensorflow::Status::OK();
2246 }
2247 
ConvertSparseToDenseOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2248 tensorflow::Status ConvertSparseToDenseOperator(
2249     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2250     Model* model) {
2251   CHECK_EQ(node.op(), "SparseToDense");
2252   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2253 
2254   auto* op = new SparseToDenseOperator;
2255   for (const string& input : node.input()) {
2256     op->inputs.push_back(input);
2257   }
2258   op->outputs.push_back(node.name());
2259 
2260   op->validate_indices = HasAttr(node, "validate_indices")
2261                              ? GetBoolAttr(node, "validate_indices")
2262                              : true;
2263   model->operators.emplace_back(op);
2264   return tensorflow::Status::OK();
2265 }
2266 
ConvertOneHotOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2267 tensorflow::Status ConvertOneHotOperator(
2268     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2269     Model* model) {
2270   CHECK_EQ(node.op(), "OneHot");
2271   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2272 
2273   const auto dtype = GetDataTypeAttr(node, "T");
2274   // TODO(b/111744875): Support DT_UINT8 and quantization.
2275   CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2276         dtype == DT_BOOL);
2277 
2278   auto op = absl::make_unique<OneHotOperator>();
2279   op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2280   for (const string& input : node.input()) {
2281     op->inputs.push_back(input);
2282   }
2283   op->outputs.push_back(node.name());
2284   model->operators.emplace_back(op.release());
2285   return tensorflow::Status::OK();
2286 }
2287 
ConvertCTCBeamSearchDecoderOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2288 tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2289     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2290     Model* model) {
2291   CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2292   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2293 
2294   auto* op = new CTCBeamSearchDecoderOperator;
2295   for (const string& input : node.input()) {
2296     op->inputs.push_back(input);
2297   }
2298 
2299   op->beam_width =
2300       HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2301   op->top_paths =
2302       HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2303   op->merge_repeated = HasAttr(node, "merge_repeated")
2304                            ? GetBoolAttr(node, "merge_repeated")
2305                            : true;
2306 
2307   // There are top_paths + 1 outputs.
2308   op->outputs.push_back(node.name());  // Implicit :0.
2309   for (int i = 0; i < op->top_paths; ++i) {
2310     op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2311   }
2312   model->operators.emplace_back(op);
2313   return tensorflow::Status::OK();
2314 }
2315 
2316 // This isn't a TensorFlow builtin op. Currently this node can only be generated
2317 // with TfLite OpHint API.
ConvertUnidirectionalSequenceLstm(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2318 tensorflow::Status ConvertUnidirectionalSequenceLstm(
2319     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2320     Model* model) {
2321   DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2322 
2323   auto* op = new UnidirectionalSequenceLstmOperator();
2324   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2325   if (indices.i_size() != node.input().size()) {
2326     return tensorflow::errors::InvalidArgument("Input size does not match.");
2327   }
2328 
2329   // The input size needs to be the same as the TfLite UniDirectionalSequence
2330   // Lstm implementation.
2331   const int kInputsSize = 20;
2332 
2333   op->inputs.resize(kInputsSize);
2334   std::vector<bool> done(kInputsSize);
2335   int idx = 0;
2336   for (const string& input : node.input()) {
2337     int real_index = indices.i(idx);
2338     op->inputs[real_index] = (input);
2339     done[real_index] = true;
2340     idx++;
2341   }
2342 
2343   for (int idx = 0; idx < done.size(); idx++) {
2344     if (!done[idx]) {
2345       string optional_name = node.name() + "_" + std::to_string(idx);
2346       model->CreateOptionalArray(optional_name);
2347       op->inputs[idx] = optional_name;
2348     }
2349   }
2350 
2351   // There're three outputs, only the last one is required.
2352   op->outputs.push_back(node.name() + ":2");
2353   model->operators.emplace_back(op);
2354 
2355   return tensorflow::Status::OK();
2356 }
2357 
ConvertLeakyReluOperator(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2358 tensorflow::Status ConvertLeakyReluOperator(
2359     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2360     Model* model) {
2361   CHECK_EQ(node.op(), "LeakyRelu");
2362   TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2363   CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2364   const auto& input_name = node.input(0);
2365   auto* op = new LeakyReluOperator;
2366   op->inputs.push_back(input_name);
2367   op->outputs.push_back(node.name());
2368   op->alpha = GetFloatAttr(node, "alpha");
2369   model->operators.emplace_back(op);
2370   return tensorflow::Status::OK();
2371 }
2372 
ConvertUnidirectionalSequenceRnn(const NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model)2373 tensorflow::Status ConvertUnidirectionalSequenceRnn(
2374     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2375     Model* model) {
2376   DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2377 
2378   auto* op = new UnidirectionalSequenceRnnOperator();
2379   const auto& indices = GetListAttr(node, "_tflite_input_indices");
2380   if (indices.i_size() != node.input().size()) {
2381     return tensorflow::errors::InvalidArgument("Input size does not match.");
2382   }
2383 
2384   for (const string& input : node.input()) {
2385     op->inputs.push_back(input);
2386   }
2387   // Only use the last one as input.
2388   op->outputs.push_back(node.name() + ":1");
2389   model->operators.emplace_back(op);
2390 
2391   return tensorflow::Status::OK();
2392 }
2393 
2394 }  // namespace
2395 
2396 namespace internal {
2397 
2398 using ConverterType = tensorflow::Status (*)(
2399     const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2400     Model* model);
2401 using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2402 
GetTensorFlowNodeConverterMapForFlex()2403 ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2404   return std::unordered_map<std::string, ConverterType>({
2405       // We need to let TOCO convert Placeholder information into
2406       // array data, so that the data types are correct.
2407       {"LegacyFedInput", ConvertPlaceholderOperator},
2408       {"Placeholder", ConvertPlaceholderOperator},
2409       {"Const", ConditionallyConvertConstOperator},
2410   });
2411 }
2412 
GetTensorFlowNodeConverterMap()2413 ConverterMapType GetTensorFlowNodeConverterMap() {
2414   return std::unordered_map<std::string, ConverterType>({
2415       {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2416       {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2417       {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2418       {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2419       {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2420       {"ArgMax", ConvertArgMaxOperator},
2421       {"ArgMin", ConvertArgMinOperator},
2422       {"Assert",
2423        ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2424       {"AvgPool", ConvertAvgPoolOperator},
2425       {"BatchMatMul", ConvertBatchMatMulOperator},
2426       {"BatchNormWithGlobalNormalization",
2427        ConvertBatchNormWithGlobalNormalizationOperator},
2428       {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2429       {"BiasAdd", ConvertBiasAddOperator},
2430       {"Cast", ConvertCastOperator},
2431       {"Ceil", ConvertCeilOperator},
2432       {"CheckNumerics", ConvertIdentityOperator},
2433       {"Concat", ConvertConcatOperator},
2434       {"ConcatV2", ConvertConcatOperator},
2435       {"Const", ConvertConstOperator},
2436       {"Conv2D", ConvertConvOperator},
2437       {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2438       {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2439       {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2440       {"DepthToSpace", ConvertDepthToSpaceOperator},
2441       {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2442       {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2443       {"DynamicPartition", ConvertDynamicPartitionOperator},
2444       {"DynamicStitch", ConvertDynamicStitchOperator},
2445       {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2446       {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2447       {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2448       {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2449       {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2450       {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2451       {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2452       {"Floor", ConvertFloorOperator},
2453       {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2454       {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2455       {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2456       {"Gather", ConvertGatherOperator},
2457       {"GatherV2", ConvertGatherOperator},
2458       {"GatherNd", ConvertGatherNdOperator},
2459       {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2460       {"GreaterEqual",
2461        ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2462       {"Identity", ConvertIdentityOperator},
2463       {"LRN", ConvertLRNOperator},
2464       {"LeakyRelu", ConvertLeakyReluOperator},
2465       {"LegacyFedInput", ConvertPlaceholderOperator},
2466       {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2467       {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2468       {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2469       {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2470       {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2471       {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2472       {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2473       {"MatMul", ConvertMatMulOperator},
2474       {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2475       {"MaxPool", ConvertMaxPoolOperator},
2476       {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2477       {"Mean", ConvertReduceOperator<MeanOperator>},
2478       {"Merge",
2479        ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2480       {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2481       {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2482       {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2483       {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2484       {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2485       {"NoOp", ConvertNoOpOperator},
2486       {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2487       {"OneHot", ConvertOneHotOperator},
2488       {"Pack", ConvertPackOperator},
2489       {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2490       {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2491       {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2492       {"Placeholder", ConvertPlaceholderOperator},
2493       {"PlaceholderWithDefault", ConvertIdentityOperator},
2494       {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2495       {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2496       {"RandomUniform", ConvertRandomUniform},
2497       {"Range", ConvertRangeOperator},
2498       {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2499       {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2500       {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2501       {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2502       {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2503       {"ResizeBilinear", ConvertResizeBilinearOperator},
2504       {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2505       {"ReverseSequence", ConvertReverseSequenceOperator},
2506       {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2507       {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2508       {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2509       {"Shape", ConvertShapeOperator},
2510       {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2511       {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2512       {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2513       {"Softmax", ConvertSoftmaxOperator},
2514       {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2515       {"SpaceToDepth", ConvertSpaceToDepthOperator},
2516       {"SparseToDense", ConvertSparseToDenseOperator},
2517       {"Split", ConvertSplitOperator},
2518       {"SplitV", ConvertSplitVOperator},
2519       {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2520       {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2521       {"SquaredDifference",
2522        ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2523       {"Squeeze", ConvertSqueezeOperator},
2524       {"StopGradient", ConvertIdentityOperator},
2525       {"StridedSlice", ConvertStridedSliceOperator},
2526       {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2527       {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2528       {"Svdf", ConvertSvdfOperator},
2529       {"Switch", ConvertSwitchOperator},
2530       {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2531       {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2532       {"TopK", ConvertTopKV2Operator},
2533       {"TopKV2", ConvertTopKV2Operator},
2534       {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2535       {"Unpack", ConvertUnpackOperator},
2536       {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2537       {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2538       {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2539       {"MirrorPad", ConvertMirrorPadOperator},
2540       {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2541       {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2542   });
2543 }
2544 
ImportTensorFlowNode(const tensorflow::NodeDef & node,const TensorFlowImportFlags & tf_import_flags,Model * model,const ConverterMapType & converter_map)2545 tensorflow::Status ImportTensorFlowNode(
2546     const tensorflow::NodeDef& node,
2547     const TensorFlowImportFlags& tf_import_flags, Model* model,
2548     const ConverterMapType& converter_map) {
2549   auto converter = converter_map.find(node.op());
2550   if (converter == converter_map.end()) {
2551     return ConvertUnsupportedOperator(node, tf_import_flags, model);
2552   } else {
2553     return converter->second(node, tf_import_flags, model);
2554   }
2555 }
2556 }  // namespace internal
2557 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const GraphDef & tf_graph)2558 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2559     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2560     const GraphDef& tf_graph) {
2561   LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2562 
2563   GraphDef inlined_graph(tf_graph);
2564   if (InlineAllFunctions(&inlined_graph)) {
2565     LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2566   }
2567 
2568   // Check input and output specification.
2569   for (const auto& specified_input_array : model_flags.input_arrays()) {
2570     CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2571         << "Unsupported explicit zero output index: "
2572         << specified_input_array.name();
2573   }
2574   for (const string& specified_output_array : model_flags.output_arrays()) {
2575     CHECK(!absl::EndsWith(specified_output_array, ":0"))
2576         << "Unsupported explicit zero output index: " << specified_output_array;
2577   }
2578 
2579   Model* model = new Model;
2580   internal::ConverterMapType converter_map;
2581 
2582   // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2583   // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2584   // converted to TFLite Flex ops.
2585   if (!tf_import_flags.import_all_ops_as_unsupported) {
2586     converter_map = internal::GetTensorFlowNodeConverterMap();
2587   } else {
2588     converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2589   }
2590 
2591   for (auto node : inlined_graph.node()) {
2592     StripZeroOutputIndexFromInputs(&node);
2593     auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
2594                                                  converter_map);
2595     CHECK(status.ok()) << status.error_message();
2596   }
2597 
2598   ResolveModelFlags(model_flags, model);
2599 
2600   StripCaretFromArrayNames(model);
2601   AddExtraOutputs(model);
2602   FixNoMissingArray(model);
2603   FixNoOrphanedArray(model);
2604   FixOperatorOrdering(model);
2605   CheckInvariants(*model);
2606 
2607   // if rnn state arrays are constant, make them transient
2608   for (const auto& rnn_state : model->flags.rnn_states()) {
2609     model->GetArray(rnn_state.state_array()).buffer = nullptr;
2610   }
2611 
2612   return std::unique_ptr<Model>(model);
2613 }
2614 
ImportTensorFlowGraphDef(const ModelFlags & model_flags,const TensorFlowImportFlags & tf_import_flags,const string & input_file_contents)2615 std::unique_ptr<Model> ImportTensorFlowGraphDef(
2616     const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2617     const string& input_file_contents) {
2618   std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2619   CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2620 
2621   std::unique_ptr<GraphDef> pruned_graph =
2622       MaybeReplaceCompositeSubgraph(*tf_graph);
2623   if (pruned_graph) {
2624     tf_graph = std::move(pruned_graph);
2625   }
2626   return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2627 }
2628 }  // namespace toco
2629