• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * distributed under the License is distributed on an AS
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the Lictf_logical_ense.
16  */
17 
18 #include "tools/converter/parser/tf/tf_model_parser.h"
19 #include <algorithm>
20 #include <functional>
21 #include <queue>
22 #include <set>
23 #include "abstract/utils.h"
24 #include "include/registry/node_parser_registry.h"
25 #include "ir/anf.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "mindspore/core/ops/lite_ops.h"
28 #include "mindspore/core/ops/structure_ops.h"
29 #include "ops/make_tuple.h"
30 #include "ops/return.h"
31 #include "ops/tuple_get_item.h"
32 #include "src/common/log_adapter.h"
33 #include "src/common/log_util.h"
34 #include "src/common/utils.h"
35 #include "tools/common/graph_util.h"
36 #include "tools/common/protobuf_utils.h"
37 #include "tools/common/tensor_util.h"
38 #include "tools/converter/converter_context.h"
39 #include "tools/converter/parser/lite_model_parser_creator.h"
40 #include "tools/converter/parser/parser_utils.h"
41 #include "tools/converter/parser/tf/functionalize_control_op_pass.h"
42 #include "tools/converter/parser/tf/remove_ineffective_control_flow.h"
43 #include "tools/converter/parser/tf/tf_fake_quant_adjust.h"
44 #include "tools/converter/parser/tf/tf_input_adjust.h"
45 #include "tools/converter/parser/tf/tf_node_parser_registry.h"
46 #include "tools/converter/parser/tf/tf_util.h"
47 #include "tools/converter/parser/unify_format.h"
48 #include "tools/converter/quantizer/quant_param_holder.h"
49 #include "tools/optimizer/common/gllo_utils.h"
50 
51 using mindspore::converter::kFmkTypeTf;
52 namespace mindspore {
53 namespace lite {
54 namespace {
IsTensorListOp(const AnfNodePtr & anf_node)55 bool IsTensorListOp(const AnfNodePtr &anf_node) {
56   return opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListFromTensor) ||
57          opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListSetItem) ||
58          opt::CheckPrimitiveType(anf_node, prim::kPrimTensorListReserve);
59 }
60 
61 constexpr size_t kConvWeightIndex = 2;
62 
GetAnfNode(const std::string & name,const std::unordered_map<std::string,AnfNodePtr> & anf_node_map,int index=0)63 AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
64                       int index = 0) {
65   AnfNodePtr ret = nullptr;
66   auto flat_anf_name = TensorFlowUtils::GetFlattenNodeName(name);
67   if (anf_node_map.find(flat_anf_name) != anf_node_map.end()) {
68     ret = anf_node_map.at(flat_anf_name);
69   } else if (anf_node_map.find(name + ":" + std::to_string(index)) != anf_node_map.end()) {
70     ret = anf_node_map.at(flat_anf_name + ":" + std::to_string(index));
71   }
72   return ret;
73 }
74 
GetOriginInputName(const tensorflow::NodeDef & node,const std::map<std::string,const tensorflow::NodeDef * > & tf_graph_nodes)75 std::string GetOriginInputName(const tensorflow::NodeDef &node,
76                                const std::map<std::string, const tensorflow::NodeDef *> &tf_graph_nodes) {
77   if (!TensorFlowUtils::OutputIsInputOp(node.op())) {
78     return node.name();
79   }
80   auto tmp_node = &node;
81   while (TensorFlowUtils::OutputIsInputOp(tmp_node->op())) {
82     auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0));
83     if (tf_graph_nodes.find(flatten_input_name) == tf_graph_nodes.end()) {
84       return flatten_input_name;
85     }
86     tmp_node = tf_graph_nodes.at(flatten_input_name);
87   }
88   return tmp_node->name();
89 }
90 
CheckStrView(std::string_view str_view,uint64_t * scratch)91 STATUS CheckStrView(std::string_view str_view, uint64_t *scratch) {
92   if (!TensorFlowUtils::DecodeInt64(&str_view, scratch)) {
93     return RET_ERROR;
94   }
95   for (size_t i = 0; i < static_cast<size_t>(*scratch); ++i) {
96     if (!TensorFlowUtils::DecodeInt64(&str_view, scratch)) {
97       return RET_ERROR;
98     }
99   }
100   if (!TensorFlowUtils::DecodeInt64(&str_view, scratch)) {
101     return RET_ERROR;
102   }
103   if (!TensorFlowUtils::DecodeInt64(&str_view, scratch)) {
104     return RET_ERROR;
105   }
106   return RET_OK;
107 }
108 
GetShapeSize(const tensorflow::TensorProto & tensor_proto)109 int GetShapeSize(const tensorflow::TensorProto &tensor_proto) {
110   auto &tensor_shape = tensor_proto.tensor_shape();
111   int shape_size = 1;
112   for (int i = 0; i < tensor_shape.dim_size(); i++) {
113     MS_CHECK_INT_MUL_NOT_OVERFLOW(shape_size, tensor_shape.dim(i).size(), 0);
114     shape_size *= tensor_shape.dim(i).size();
115   }
116   return shape_size;
117 }
118 
SetFloatTensorInfo(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)119 STATUS SetFloatTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) {
120   auto shape_size = GetShapeSize(tensor_proto);
121   auto &tensor_shape = tensor_proto.tensor_shape();
122   ShapeVector shape_vector{};
123   for (int i = 0; i < tensor_shape.dim_size(); i++) {
124     shape_vector.push_back(tensor_shape.dim(i).size());
125   }
126   *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32);
127   if (*tensor_info == nullptr) {
128     MS_LOG(ERROR) << "create tensor data failed.";
129     return RET_ERROR;
130   }
131   auto tensor_data = reinterpret_cast<float *>((*tensor_info)->data_c());
132   if (tensor_data == nullptr) {
133     MS_LOG(ERROR) << "new data failed";
134     return RET_ERROR;
135   }
136 
137   if (tensor_proto.float_val_size() == 1) {
138     for (int i = 0; i < shape_size; i++) {
139       tensor_data[i] = tensor_proto.float_val(0);
140     }
141   }
142   if (INT_MUL_OVERFLOW_THRESHOLD(shape_size, sizeof(float), SIZE_MAX)) {
143     MS_LOG(ERROR) << "data_size overflow.";
144     return RET_ERROR;
145   }
146   if (tensor_proto.tensor_content().size() == shape_size * sizeof(float)) {
147     const auto addr = reinterpret_cast<const float *>(tensor_proto.tensor_content().data());
148     if (::memcpy_s(tensor_data, (*tensor_info)->Size(), addr, shape_size * sizeof(float)) != EOK) {
149       MS_LOG(ERROR) << "memcpy_s failed";
150       return RET_ERROR;
151     }
152   }
153 
154   return RET_OK;
155 }
156 
SetInt32TensorInfo(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)157 STATUS SetInt32TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) {
158   auto shape_size = GetShapeSize(tensor_proto);
159   auto &tensor_shape = tensor_proto.tensor_shape();
160   ShapeVector shape_vector{};
161   for (int i = 0; i < tensor_shape.dim_size(); i++) {
162     shape_vector.push_back(tensor_shape.dim(i).size());
163   }
164   *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeInt32);
165   if (*tensor_info == nullptr) {
166     MS_LOG(ERROR) << "create tensor data failed.";
167     return RET_ERROR;
168   }
169   auto tensor_data = reinterpret_cast<int *>((*tensor_info)->data_c());
170   if (tensor_data == nullptr) {
171     MS_LOG(ERROR) << "new data failed";
172     return RET_ERROR;
173   }
174   if (shape_size == 0) {
175     return RET_OK;
176   }
177   if (tensor_proto.tensor_content().empty()) {
178     const auto &origin_data = tensor_proto.int_val();
179     if (tensor_proto.int_val_size() == 1) {
180       for (int i = 0; i < shape_size; ++i) {
181         tensor_data[i] = origin_data[0];
182       }
183     } else {
184       MS_CHECK_GE(tensor_proto.int_val_size(), shape_size, RET_ERROR);
185       for (int i = 0; i < shape_size; ++i) {
186         tensor_data[i] = origin_data[i];
187       }
188     }
189   } else {
190     if (INT_MUL_OVERFLOW_THRESHOLD(shape_size, sizeof(int32_t), SIZE_MAX)) {
191       MS_LOG(ERROR) << "data_size overflow.";
192       return RET_ERROR;
193     }
194     MS_CHECK_GE(tensor_proto.tensor_content().size(), shape_size * sizeof(int32_t), RET_ERROR);
195     const auto addr = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
196     if (::memcpy_s(tensor_data, (*tensor_info)->Size(), addr, shape_size * sizeof(int32_t)) != EOK) {
197       MS_LOG(ERROR) << "memcpy_s failed";
198       return RET_ERROR;
199     }
200   }
201   return RET_OK;
202 }
203 
SetBoolTensorInfo(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)204 STATUS SetBoolTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) {
205   auto shape_size = GetShapeSize(tensor_proto);
206   auto &tensor_shape = tensor_proto.tensor_shape();
207   ShapeVector shape_vector{};
208   for (int i = 0; i < tensor_shape.dim_size(); i++) {
209     shape_vector.push_back(tensor_shape.dim(i).size());
210   }
211   *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeBool);
212   if (*tensor_info == nullptr) {
213     MS_LOG(ERROR) << "create tensor data failed.";
214     return RET_ERROR;
215   }
216   auto tensor_data = reinterpret_cast<bool *>((*tensor_info)->data_c());
217   if (tensor_data == nullptr) {
218     MS_LOG(ERROR) << "new data failed";
219     return RET_ERROR;
220   }
221   if (tensor_proto.bool_val_size() != shape_size) {
222     MS_LOG(ERROR) << "shape size:[" << shape_size << "] not equal bool val size:[" << tensor_proto.bool_val_size()
223                   << "]";
224     return RET_ERROR;
225   }
226   for (int i = 0; i < shape_size; i++) {
227     int value = tensor_proto.bool_val(i);
228     tensor_data[i] = value;
229   }
230   return RET_OK;
231 }
232 
SetStringTensorInfo(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)233 STATUS SetStringTensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) {
234   auto &tensor_shape = tensor_proto.tensor_shape();
235   ShapeVector shape_vector{};
236   for (int i = 0; i < tensor_shape.dim_size(); i++) {
237     shape_vector.push_back(tensor_shape.dim(i).size());
238   }
239 
240   if (shape_vector.empty()) {
241     *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kObjectTypeString);
242     if (*tensor_info == nullptr) {
243       MS_LOG(ERROR) << "create tensor info failed.";
244       return RET_ERROR;
245     }
246     return RET_OK;
247   }
248 
249   std::string shape_str;
250   shape_str += std::to_string(shape_vector.size()) + ",";
251   for (auto &dim : shape_vector) {
252     shape_str += std::to_string(dim) + ",";
253   }
254 
255   auto tensor_data = new (std::nothrow) string;
256   CHECK_NULL_RETURN(tensor_data);
257   if (tensor_proto.string_val_size() == 1) {
258     *tensor_data = tensor_proto.string_val(0);
259   } else {
260     MS_LOG(ERROR) << "string size bigger than one, not support.";
261     delete tensor_data;
262     return RET_ERROR;
263   }
264   if (INT_ADD_OVERFLOW(shape_str.size(), (*tensor_data).size())) {
265     MS_LOG(ERROR) << "data_size overflow.";
266     delete tensor_data;
267     return RET_ERROR;
268   }
269   shape_vector = {static_cast<int64_t>(shape_str.size() + (*tensor_data).size())};
270   *tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kObjectTypeString);
271   if (*tensor_info == nullptr) {
272     MS_LOG(ERROR) << "create tensor info failed.";
273     delete tensor_data;
274     return RET_ERROR;
275   }
276   auto tensor_info_data = reinterpret_cast<uint8_t *>((*tensor_info)->data_c());
277   if (memcpy_s(tensor_info_data, (*tensor_info)->Size(), shape_str.data(), shape_str.size()) != EOK) {
278     MS_LOG(ERROR) << "memcpy failed.";
279     delete tensor_data;
280     return RET_ERROR;
281   }
282   MS_CHECK_TRUE_RET((*tensor_info)->Size() >= (*tensor_data).size(), RET_ERROR);
283   if (memcpy_s(tensor_info_data + shape_str.size(), (*tensor_info)->Size() - (*tensor_data).size(),
284                (*tensor_data).data(), (*tensor_data).size()) != EOK) {
285     MS_LOG(ERROR) << "memcpy failed.";
286     delete tensor_data;
287     return RET_ERROR;
288   }
289 
290   delete tensor_data;
291   return RET_OK;
292 }
293 
ConvertGraph(api::FuncGraphPtr func_graph)294 FuncGraphPtr ConvertGraph(api::FuncGraphPtr func_graph) {
295   auto impl = func_graph->impl();
296   return std::dynamic_pointer_cast<FuncGraph>(impl);
297 }
298 }  // namespace
299 
SetInt64TensorToInt64Tensor(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)300 STATUS TFModelParser::SetInt64TensorToInt64Tensor(const tensorflow::TensorProto &tensor_proto,
301                                                   tensor::TensorPtr *tensor_info) {
302   auto &tensor_shape = tensor_proto.tensor_shape();
303   ShapeVector shape_vector{};
304   for (int i = 0; i < tensor_shape.dim_size(); i++) {
305     shape_vector.push_back(tensor_shape.dim(i).size());
306   }
307   tensor::TensorPtr tensor_info_int64;
308   if (tensor_proto.tensor_content().empty()) {
309     tensor_info_int64 = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeInt64);
310     if (tensor_info_int64 == nullptr) {
311       MS_LOG(ERROR) << "CreateTensorInfo failed.";
312       return RET_ERROR;
313     }
314     auto tensor_int64_data = reinterpret_cast<int64_t *>(tensor_info_int64->data_c());
315     if (tensor_int64_data == nullptr) {
316       MS_LOG(ERROR) << "new data failed";
317       return RET_ERROR;
318     }
319     const auto &origin_data = tensor_proto.int64_val();
320     for (int i = 0; i < tensor_proto.int64_val_size(); ++i) {
321       tensor_int64_data[i] = origin_data[i];
322     }
323   } else {
324     const auto origin_data = reinterpret_cast<const int64_t *>(tensor_proto.tensor_content().data());
325     tensor_info_int64 =
326       CreateTensorInfo(origin_data, tensor_proto.tensor_content().size(), shape_vector, kNumberTypeInt64);
327     if (tensor_info_int64 == nullptr) {
328       MS_LOG(ERROR) << "CreateTensorInfo failed.";
329       return RET_ERROR;
330     }
331   }
332   *tensor_info = tensor_info_int64;
333   return RET_OK;
334 }
335 
SetInt64TensorInfo(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info,const std::string & node_name)336 STATUS TFModelParser::SetInt64TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info,
337                                          const std::string &node_name) {
338   if (SetInt64TensorToInt64Tensor(tensor_proto, tensor_info) != RET_OK) {
339     MS_LOG(ERROR) << "SetInt64TensorInfoMap failed.";
340     return RET_ERROR;
341   }
342 
343   return RET_OK;
344 }
345 
ConvertConstVariant(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info)346 STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info) {
347   if (tensor_proto.variant_val_size() != 1) {
348     MS_LOG(ERROR) << "only support variant_val_size == 1 now";
349     return RET_ERROR;
350   }
351   auto &variant = tensor_proto.variant_val(0);
352   if (variant.type_name() != "tensorflow::TensorList" || variant.tensors_size() <= 0) {
353     MS_LOG(DEBUG) << "Only nonempty TensorList type is supported now";
354   }
355   auto descriptor = variant.GetMetadata().descriptor;
356   auto reflection = variant.GetMetadata().reflection;
357   if (descriptor == nullptr || reflection == nullptr) {
358     MS_LOG(ERROR) << "descriptor or reflection is nullptr";
359     return RET_ERROR;
360   }
361   auto field_descriptor = descriptor->field(1);
362   if (field_descriptor == nullptr) {
363     MS_LOG(ERROR) << "field_descriptor is nullptr";
364     return RET_ERROR;
365   }
366   if (field_descriptor->type() != google::protobuf::FieldDescriptor::TYPE_BYTES) {
367     MS_LOG(ERROR) << "metadata type is not TYPE_BYTES";
368     return RET_ERROR;
369   }
370   auto origin_str = reflection->GetString(variant, field_descriptor);
371   std::string_view str_view(origin_str);
372   uint64_t scratch;
373   if (CheckStrView(str_view, &scratch) != RET_OK) {
374     return RET_ERROR;
375   }
376   auto element_dtype = static_cast<size_t>(scratch);
377 
378   tensorflow::TensorShapeProto element_shape_proto;
379   element_shape_proto.ParseFromString(origin_str);
380   auto dim_size = element_shape_proto.dim_size();
381   std::vector<int> tensor_list_data(dim_size + 2);
382   tensor_list_data[0] = TensorFlowUtils::GetTFDataType(tensorflow::DataType(element_dtype));
383   if (tensor_list_data[0] == kNumberTypeFloat64) {
384     tensor_list_data[0] = kNumberTypeFloat32;
385   }
386   tensor_list_data[1] = element_shape_proto.dim_size();
387   for (int i = 0; i < dim_size; i++) {
388     auto dim = element_shape_proto.dim(i).size();
389     if (dim > static_cast<int64_t>(INT32_MAX) || dim < static_cast<int64_t>(INT32_MIN)) {
390       MS_LOG(ERROR) << "int64 data " << dim << " too big to fit into int32";
391       return RET_ERROR;
392     } else {
393       tensor_list_data[i + 2] = static_cast<int>(dim);
394     }
395   }
396   tensor_list_data.emplace_back(variant.tensors_size());
397   for (const auto &tensor : variant.tensors()) {
398     std::vector<int> single_tensor_data;
399     single_tensor_data.emplace_back(tensor.tensor_shape().dim_size());
400     for (int i = 0; i < tensor.tensor_shape().dim_size(); i++) {
401       single_tensor_data.emplace_back(tensor.tensor_shape().dim(i).size());
402     }
403     tensor_list_data.insert(tensor_list_data.end(), single_tensor_data.begin(), single_tensor_data.end());
404   }
405   if (INT_MUL_OVERFLOW_THRESHOLD(tensor_list_data.size(), sizeof(int), INT_MAX)) {
406     MS_LOG(ERROR) << "tensor_list_data's size overflow.";
407     return RET_ERROR;
408   }
409   *tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int),
410                                   {static_cast<int64_t>(tensor_list_data.size())}, kObjectTypeTensorType);
411   if (*tensor_info == nullptr) {
412     MS_LOG(ERROR) << "create tensor data failed.";
413     return RET_ERROR;
414   }
415   return RET_OK;
416 }
417 
SetTensorInfoFromType(const tensorflow::TensorProto & tensor_proto,tensor::TensorPtr * tensor_info,const std::string & node_name)418 STATUS TFModelParser::SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info,
419                                             const std::string &node_name) {
420   auto type = (*tensor_info)->data_type();
421   if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) {
422     return SetFloatTensorInfo(tensor_proto, tensor_info);
423   } else if (type == kNumberTypeInt32 || type == kNumberTypeInt) {
424     return SetInt32TensorInfo(tensor_proto, tensor_info);
425   } else if (type == kNumberTypeInt64) {
426     return SetInt64TensorInfo(tensor_proto, tensor_info, node_name);
427   } else if (type == kNumberTypeBool) {
428     return SetBoolTensorInfo(tensor_proto, tensor_info);
429   } else if (type == kObjectTypeTensorType) {
430     return ConvertConstVariant(tensor_proto, tensor_info);
431   } else if (type == kObjectTypeString) {
432     return SetStringTensorInfo(tensor_proto, tensor_info);
433   } else {
434     MS_LOG(ERROR) << "Unsupported dataType: " << type;
435     return RET_ERROR;
436   }
437   return RET_OK;
438 }
439 
ConvertConstTensor(const tensorflow::NodeDef & node_def,const tensorflow::AttrValue & attr_value,const TypeId & type,const ParameterPtr & parameter,std::vector<int64_t> * shape_vector)440 STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
441                                          const TypeId &type, const ParameterPtr &parameter,
442                                          std::vector<int64_t> *shape_vector) {
443   MSLITE_CHECK_PTR(parameter);
444   MSLITE_CHECK_PTR(shape_vector);
445   const tensorflow::TensorProto &tensor_proto = attr_value.tensor();
446   const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape();
447   shape_vector->clear();
448   for (int i = 0; i < tensor_shape.dim_size(); i++) {
449     shape_vector->push_back(tensor_shape.dim(i).size());
450   }
451   auto tensor_info = std::make_shared<tensor::Tensor>(type, *shape_vector);
452   if (tensor_info == nullptr) {
453     MS_LOG(ERROR) << "tensor info is nullptr";
454     return RET_ERROR;
455   }
456   auto status = SetTensorInfoFromType(tensor_proto, &tensor_info, node_def.name());
457   if (status != RET_OK) {
458     MS_LOG(ERROR) << "set tensor data from type failed.";
459     return RET_ERROR;
460   }
461   status = InitParameterFromTensorInfo(parameter, tensor_info);
462   if (status != RET_OK) {
463     MS_LOG(ERROR) << "init parameter from tensor info failed.";
464     return RET_ERROR;
465   }
466   return RET_OK;
467 }
468 
ConvertParameter(const tensorflow::NodeDef & node,const ParameterPtr & parameter,std::unordered_map<std::string,AnfNodePtr> * anf_node_map,bool root_graph)469 STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,
470                                        std::unordered_map<std::string, AnfNodePtr> *anf_node_map, bool root_graph) {
471   MSLITE_CHECK_PTR(parameter);
472   MSLITE_CHECK_PTR(anf_node_map);
473 
474   tensorflow::AttrValue attr_value;
475   TypeId type = kNumberTypeFloat32;
476   if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) {
477     type = TensorFlowUtils::GetTFDataType(attr_value.type());
478   }
479 
480   std::vector<int64_t> shape;
481   if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
482     shape = ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(node.name());
483     if (ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() > 0 && shape.empty()) {
484       MS_LOG(WARNING) << "Can not find name in map. name is " << node.name();
485     }
486     if (shape.empty()) {
487       auto &shape_attr = attr_value.shape();
488       for (int i = 0; i < shape_attr.dim_size(); ++i) {
489         shape.push_back(shape_attr.dim(i).size());
490       }
491     }
492   }
493 
494   if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
495     MS_LOG(INFO) << "Found value attr, means it has default value";
496     auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape);
497     if (status != RET_OK) {
498       MS_LOG(ERROR) << "convert const tensor failed.";
499       return status;
500     }
501   } else {
502     if (root_graph) {
503       graph_input_names_.emplace_back(node.name());  // only root graph need set graph input names
504     }
505   }
506 
507   auto abstract_tensor = CreateTensorAbstract(shape, type);
508   if (abstract_tensor == nullptr) {
509     MS_LOG(ERROR) << "Create tensor abstarct failed";
510     return RET_ERROR;
511   }
512   parameter->set_name(node.name());
513   parameter->set_abstract(abstract_tensor);
514 
515   (*anf_node_map)[node.name()] = parameter;
516   (*anf_node_map)[node.name() + ":0"] = parameter;
517   return RET_OK;
518 }
519 
ConvertGraphInputsAndConsts(const std::vector<const tensorflow::NodeDef * > & tf_graph_nodes,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_node_map,bool root_graph)520 STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensorflow::NodeDef *> &tf_graph_nodes,
521                                                   const FuncGraphPtr &anf_graph,
522                                                   std::unordered_map<std::string, AnfNodePtr> *anf_node_map,
523                                                   bool root_graph) {
524   MSLITE_CHECK_PTR(anf_graph);
525   MSLITE_CHECK_PTR(anf_node_map);
526   for (auto &node : tf_graph_nodes) {
527     bool have_data_depend = false;
528     for (int i = 0; i < node->input_size(); ++i) {
529       auto name = node->input(i);
530       if (!name.empty() && name[0] != '^') {  // control_depend input start with "^"
531         have_data_depend = true;
532         break;
533       }
534     }
535     if (!have_data_depend && node->op() != "NoOp") {
536       auto parameter = anf_graph->add_parameter();
537       CHECK_NULL_RETURN(parameter);
538       if (ConvertParameter(*node, parameter, anf_node_map, root_graph) != RET_OK) {
539         MS_LOG(ERROR) << "convert Parameter Node failed";
540         return RET_ERROR;
541       }
542     }
543   }
544   return RET_OK;
545 }
546 
Parse(const converter::ConverterParameters & flag)547 api::FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
548   auto modelFile = flag.model_file;
549   NotSupportOp::GetInstance()->set_fmk_type("TF");
550   auto status = ValidateFileStr(modelFile, ".pb");
551   if (status != RET_OK) {
552     MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
553     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
554     return nullptr;
555   }
556   tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
557   if (tf_root_graph_ == nullptr) {
558     MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
559     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
560     return nullptr;
561   }
562   status = ReadProtoFromBinaryFile(modelFile, tf_root_graph_.get());
563   if (status != RET_OK) {
564     MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
565     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
566     return nullptr;
567   }
568   auto graph = std::make_shared<FuncGraph>();
569   MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "create FuncGraph failed");
570   res_graph_ = api::MakeShared<api::FuncGraph>(graph);
571   if (res_graph_ == nullptr) {
572     MS_LOG(ERROR) << "funGraphPtr is nullptr";
573     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
574     return nullptr;
575   }
576   graph->set_attr("graph_name", MakeValue("main_graph"));
577   graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
578 
579   for (int i = 0; i < tf_root_graph_->node_size(); i++) {
580     auto &node_def = tf_root_graph_->node(i);
581     tf_root_graph_nodes_[node_def.name()] = &node_def;
582     tf_root_graph_nodes_vec_.emplace_back(&node_def);
583   }
584 
585   status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_vec_, graph, &anf_root_node_map_, true);
586   if (status != RET_OK) {
587     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
588     return nullptr;
589   }
590   bool success_flag = true;
591   ineffective_if_op_map_.clear();
592   for (int i = 0; i < tf_root_graph_->node_size(); i++) {
593     auto &node_def = tf_root_graph_->node(i);
594     status = ConvertOps(node_def, tf_root_graph_nodes_, graph, &anf_root_node_map_);
595     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
596     if (status != RET_OK) {
597       success_flag = false;
598     }
599   }
600   if (!success_flag) {
601     MS_LOG(ERROR) << "Convert ops failed.";
602     return nullptr;
603   }
604 
605   if (!nodes_with_null_input_.empty()) {
606     status = ConnectNullInput();
607     if (status != RET_OK) {
608       MS_LOG(ERROR) << "Connect null inputs failed.";
609       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
610       return nullptr;
611     }
612   }
613 
614   if ((status = ConvertRootGraphOutputs()) != RET_OK) {
615     MS_LOG(ERROR) << "Convert graph outputs failed.";
616     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
617     return nullptr;
618   }
619 
620   status = ConvertSubgraph();
621   if (status != RET_OK) {
622     MS_LOG(ERROR) << "Convert subgraph failed.";
623     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
624     return nullptr;
625   }
626 
627   if ((status = CommonAnfAdjust(graph)) != RET_OK) {
628     MS_LOG(ERROR) << "AdjustForAnf failed.";
629     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
630     return nullptr;
631   }
632   std::set<FuncGraphPtr> all_func_graphs = {};
633   GetAllFuncGraph(graph, &all_func_graphs);
634   if ((status = TF2AnfAdjust(all_func_graphs, &ineffective_if_op_map_)) != RET_OK) {
635     MS_LOG(ERROR) << "TF2AnfAdjust failed.";
636     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
637     return nullptr;
638   }
639   auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false, flag.save_type);
640   MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
641   if (!unify_format->Run(graph)) {
642     MS_LOG(ERROR) << "Run insert transpose failed.";
643     return nullptr;
644   }
645   graph->set_manager(nullptr);
646   static auto root_func_manager = Manage(graph);
647   if (root_func_manager == nullptr) {
648     MS_LOG(ERROR) << "root_func_manager is nullptr.";
649     return nullptr;
650   }
651   return res_graph_;
652 }
653 
ConvertSubgraphInputs(std::map<std::string,const tensorflow::NodeDef * > * tf_sub_node_map,std::unordered_map<std::string,AnfNodePtr> * anf_sub_node_map,const tensorflow::FunctionDef & tf_sub_fuction,const CNodePtr & cnode,const FuncGraphPtr & sub_func_graph)654 STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
655                                             std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
656                                             const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
657                                             const FuncGraphPtr &sub_func_graph) {
658   MSLITE_CHECK_PTR(anf_sub_node_map);
659   MSLITE_CHECK_PTR(cnode);
660   MSLITE_CHECK_PTR(sub_func_graph);
661   MSLITE_CHECK_PTR(tf_sub_node_map);
662   std::vector<ParameterPtr> sub_graph_inputs;
663   auto &tf_sub_signature = tf_sub_fuction.signature();
664   auto &sub_graph_name = tf_sub_signature.name();
665   auto input_arg_size = tf_sub_signature.input_arg_size();
666   for (int j = 0; j < input_arg_size; j++) {
667     auto &input_arg = tf_sub_signature.input_arg(j);
668     auto parameter = sub_func_graph->add_parameter();
669     CHECK_NULL_RETURN(parameter);
670     parameter->set_name(input_arg.name());
671     (*anf_sub_node_map)[input_arg.name()] = parameter;
672     auto root_inputs = cnode->inputs();
673     if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) {
674       parameter->set_abstract(root_inputs[j + 1]->abstract());
675     } else {
676       parameter->set_abstract(root_inputs[j + 2]->abstract());
677     }
678     sub_graph_inputs.emplace_back(parameter);
679   }
680   std::vector<const tensorflow::NodeDef *> subgraph_tf_node_vec;
681   for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
682     auto &node_def = tf_sub_fuction.node_def(j);
683     (*tf_sub_node_map)[node_def.name()] = &node_def;
684     subgraph_tf_node_vec.emplace_back(&node_def);
685   }
686   if (ConvertGraphInputsAndConsts(subgraph_tf_node_vec, sub_func_graph, anf_sub_node_map, false) != RET_OK) {
687     MS_LOG(ERROR) << "Convert subgraph consts failed";
688     return RET_ERROR;
689   }
690 
691   // hardcode subgraph inputs name
692   for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
693     sub_graph_inputs[j]->set_name(sub_graph_name + "_input_" + std::to_string(j) + "_parameter");
694   }
695 
696   return RET_OK;
697 }
698 
ConvertSubgraphOutputs(std::map<std::string,const tensorflow::NodeDef * > * tf_sub_node_map,const std::unordered_map<std::string,AnfNodePtr> & anf_sub_node_map,const tensorflow::FunctionDef & tf_sub_fuction,const FuncGraphPtr & sub_func_graph)699 STATUS TFModelParser::ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
700                                              const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
701                                              const tensorflow::FunctionDef &tf_sub_fuction,
702                                              const FuncGraphPtr &sub_func_graph) {
703   MSLITE_CHECK_PTR(sub_func_graph);
704   MSLITE_CHECK_PTR(tf_sub_node_map);
705   auto &tf_sub_signature = tf_sub_fuction.signature();
706   auto &sub_graph_name = tf_sub_signature.name();
707 
708   std::vector<AnfNodePtr> sub_output_nodes;
709   auto &subgraph_ret = tf_sub_fuction.ret();
710   for (auto &output_arg : tf_sub_signature.output_arg()) {
711     auto &signature_name = output_arg.name();
712     if (subgraph_ret.find(signature_name) == subgraph_ret.end()) {
713       MS_LOG(ERROR) << "can't find signature_name: " << signature_name;
714       return RET_ERROR;
715     }
716     auto t = subgraph_ret.find(signature_name);
717     MS_LOG(INFO) << "subret " << t->first << " " << t->second;
718     auto tf_output_name = TensorFlowUtils::GetFlattenNodeName(t->second);
719     AnfNodePtr anf_node = nullptr;
720     if (tf_sub_node_map->find(tf_output_name) == tf_sub_node_map->end()) {
721       anf_node = GetAnfNode(tf_output_name, anf_sub_node_map);
722     } else {
723       auto tf_real_name = GetOriginInputName(*tf_sub_node_map->at(tf_output_name), *tf_sub_node_map);
724       anf_node = GetAnfNode(tf_real_name, anf_sub_node_map);
725     }
726     if (anf_node == nullptr) {
727       MS_LOG(ERROR) << "can't find anf node,tf node flatten name" << tf_output_name;
728       return RET_ERROR;
729     }
730     sub_output_nodes.push_back(anf_node);
731   }
732   if (MakeAnfGraphOutputs(sub_output_nodes, sub_func_graph) != RET_OK) {
733     MS_LOG(ERROR) << "cmake anf graph outputs node error";
734     return RET_ERROR;
735   }
736 
737   // hardcode subgraph outputs name
738   for (size_t j = 0; j < sub_output_nodes.size(); j++) {
739     if (utils::isa<CNodePtr>(sub_output_nodes[j])) {
740       sub_output_nodes[j]->cast<CNodePtr>()->set_fullname_with_scope(sub_graph_name + "_output_" + std::to_string(j) +
741                                                                      "_cnode");
742     } else if (utils::isa<ParameterPtr>(sub_output_nodes[j])) {
743       sub_output_nodes[j]->cast<ParameterPtr>()->set_name(sub_graph_name + "_output_" + std::to_string(j) +
744                                                           "_parameter");
745     }
746   }
747   return RET_OK;
748 }
749 
UpdateMap(const CNodePtr & cnode,const FuncGraphPtr & sub_func_graph,const std::string & sub_graph_name)750 void TFModelParser::UpdateMap(const CNodePtr &cnode, const FuncGraphPtr &sub_func_graph,
751                               const std::string &sub_graph_name) {
752   CHECK_NULL_RETURN_VOID(cnode);
753   CHECK_NULL_RETURN_VOID(sub_func_graph);
754   if (opt::CheckPrimitiveType(cnode, prim::kPrimWhile)) {
755     if (find(while_cond_branch_name_.begin(), while_cond_branch_name_.end(), sub_graph_name) !=
756         while_cond_branch_name_.end()) {
757       while_cond_map_[cnode] = sub_func_graph;
758     } else {
759       while_body_map_[cnode] = sub_func_graph;
760     }
761   }
762   if (opt::CheckPrimitiveType(cnode, prim::kPrimIf)) {
763     if (find(if_then_branch_name_.begin(), if_then_branch_name_.end(), sub_graph_name) != if_then_branch_name_.end()) {
764       if_then_map_[cnode] = sub_func_graph;
765     } else {
766       if_else_map_[cnode] = sub_func_graph;
767     }
768   }
769 }
770 
ConvertSubgraph()771 STATUS TFModelParser::ConvertSubgraph() {
772   bool success_flag = true;
773   std::queue<int> tf_graph_index_q{};
774   for (int i = 0; i < tf_root_graph_->library().function_size(); i++) {
775     tf_graph_index_q.push(i);
776   }
777   int max_move_times = tf_root_graph_->library().function_size();
778   // key is graph index, value is the time move to the queue back.
779   std::unordered_map<int, int> move_times_map{};
780   while (!tf_graph_index_q.empty()) {
781     auto cur_index = tf_graph_index_q.front();
782     tf_graph_index_q.pop();
783     auto &tf_sub_fuction = tf_root_graph_->library().function(cur_index);
784     auto &tf_sub_signature = tf_sub_fuction.signature();
785     auto input_arg_size = tf_sub_signature.input_arg_size();
786     auto &sub_graph_name = tf_sub_signature.name();
787     CNodePtr cnode = nullptr;
788     if (function_while_map_.count(sub_graph_name)) {
789       cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>();
790       if (cnode == nullptr || static_cast<int>(cnode->size()) != input_arg_size + 1) {
791         MS_LOG(ERROR) << "while cnode  not equal input arg size";
792         return RET_ERROR;
793       }
794     } else if (function_if_map_.count(sub_graph_name)) {
795       cnode = function_if_map_[sub_graph_name]->cast<CNodePtr>();
796       if (cnode == nullptr || static_cast<int>(cnode->size()) != input_arg_size + 2) {
797         MS_LOG(ERROR) << "if cnode  not equal input arg size";
798         return RET_ERROR;
799       }
800     } else {
801       if (move_times_map.find(cur_index) == move_times_map.end()) {
802         move_times_map[cur_index] = 1;
803         tf_graph_index_q.push(cur_index);
804       } else {
805         move_times_map[cur_index]++;
806         if (move_times_map[cur_index] >= max_move_times) {
807           MS_LOG(WARNING) << "This function is not belong to any while op or if op, graph name: " << sub_graph_name;
808         } else {
809           tf_graph_index_q.push(cur_index);
810         }
811       }
812       continue;
813     }
814     FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
815     MS_CHECK_TRUE_RET(sub_func_graph != nullptr, RET_ERROR);
816     sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
817     sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
818     std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
819     std::map<std::string, const tensorflow::NodeDef *> tf_sub_node_map;
820 
821     if (ConvertSubgraphInputs(&tf_sub_node_map, &anf_sub_node_map, tf_sub_fuction, cnode, sub_func_graph) != RET_OK) {
822       MS_LOG(ERROR) << "Convert subgraph inputs failed.";
823       return RET_ERROR;
824     }
825 
826     // convert sub graph ops
827     STATUS status = RET_OK;
828     for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) {
829       auto &node_def = tf_sub_fuction.node_def(j);
830       status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map);
831       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
832       if (status != RET_OK) {
833         MS_LOG(ERROR) << "Convert subgraph ops failed.";
834         success_flag = false;
835       }
836     }
837     if (!success_flag) {
838       MS_LOG(ERROR) << "Convert subgraph is failed.";
839       return RET_ERROR;
840     }
841 
842     if (ConvertSubgraphOutputs(&tf_sub_node_map, anf_sub_node_map, tf_sub_fuction, sub_func_graph) != RET_OK) {
843       MS_LOG(ERROR) << "Convert subgraph outputs failed.";
844       return RET_ERROR;
845     }
846 
847     // add while cond body function to while node input
848     UpdateMap(cnode, sub_func_graph, sub_graph_name);
849   }
850 
851   if (ControlFlowNodePostProcess(while_cond_map_, while_body_map_) != RET_OK ||
852       (ControlFlowNodePostProcess(if_then_map_, if_else_map_) != RET_OK)) {
853     MS_LOG(ERROR) << "while/if node post process failed";
854     return RET_ERROR;
855   }
856   return RET_OK;
857 }
858 
ControlFlowNodePostProcess(const std::map<CNodePtr,FuncGraphPtr> & first_func_map,const std::map<CNodePtr,FuncGraphPtr> & second_func_map)859 STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
860                                                  const std::map<CNodePtr, FuncGraphPtr> &second_func_map) {
861   if (first_func_map.size() != second_func_map.size()) {
862     MS_LOG(ERROR) << "first_func_map.size(): " << first_func_map.size()
863                   << " second_func_map.size(): " << second_func_map.size();
864     return RET_ERROR;
865   }
866   auto main_graph = ConvertGraph(res_graph_);
867   MS_CHECK_TRUE_RET(main_graph != nullptr, RET_ERROR);
868   static auto root_func_manager = Manage(main_graph);
869   MS_CHECK_TRUE_RET(root_func_manager != nullptr, RET_ERROR);
870 
871   for (auto &kv : first_func_map) {
872     auto control_flow_node = kv.first;
873     MS_CHECK_TRUE_RET(control_flow_node != nullptr, RET_ERROR);
874     auto func_graph = control_flow_node->func_graph();
875     MS_CHECK_TRUE_RET(func_graph != nullptr, RET_ERROR);
876 
877     auto &first_sub_graph = kv.second;
878     auto &second_sub_graph = second_func_map.at(control_flow_node);
879     CHECK_NULL_RETURN(control_flow_node);
880     CHECK_NULL_RETURN(first_sub_graph);
881     CHECK_NULL_RETURN(second_sub_graph);
882     first_sub_graph->set_manager(root_func_manager);
883     second_sub_graph->set_manager(root_func_manager);
884     auto first_value_node = NewValueNode(first_sub_graph);
885     CHECK_NULL_RETURN(first_value_node);
886     auto second_value_node = NewValueNode(second_sub_graph);
887     CHECK_NULL_RETURN(second_value_node);
888     auto inputs = control_flow_node->inputs();
889     inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node});
890     auto new_node = func_graph->NewCNode(inputs);  // must create new node, otherwise node_users won't update
891     if (new_node == nullptr) {
892       MS_LOG(ERROR) << "new node failed";
893       return RET_ERROR;
894     }
895     new_node->set_abstract(control_flow_node->abstract()->Clone());
896     new_node->set_fullname_with_scope(control_flow_node->fullname_with_scope());
897     if (!root_func_manager->Replace(control_flow_node, new_node)) {
898       MS_LOG(ERROR) << "replace new node failed";
899       return RET_ERROR;
900     }
901   }
902   return RET_OK;
903 }
904 
ConvertInputNodes(const tensorflow::NodeDef & node_def,const std::vector<std::string> & input_names,const std::map<std::string,const tensorflow::NodeDef * > & tf_node_map,const std::unordered_map<std::string,AnfNodePtr> & anf_node_map,std::vector<AnfNodePtr> * inputs,std::vector<std::string> * input_name_not_found)905 STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
906                                         const std::vector<std::string> &input_names,
907                                         const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
908                                         const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
909                                         std::vector<AnfNodePtr> *inputs,
910                                         std::vector<std::string> *input_name_not_found) {
911   CHECK_NULL_RETURN(inputs);
912   CHECK_NULL_RETURN(input_name_not_found);
913   // parse inputs
914   for (size_t j = 0; j < input_names.size(); j++) {
915     std::string input_name = input_names[j];  // input may be produced by multi-outputs node
916     // subgraph input name x:output:index,need flatten
917     auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name);
918     if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) {
919       auto input_node = tf_node_map.at(flatten_input_name);
920       flatten_input_name = GetOriginInputName(*input_node, tf_node_map);
921     }
922     auto input = GetAnfNode(flatten_input_name, anf_node_map);
923     if (input == nullptr) {
924       MS_LOG(WARNING) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
925       (*input_name_not_found).push_back(flatten_input_name);
926     }
927     inputs->emplace_back(input);
928   }
929   return RET_OK;
930 }
931 
ConvertOutputTensor(const tensorflow::NodeDef & op,const CNodePtr & anf_node,std::unordered_map<std::string,AnfNodePtr> * anf_node_map,const FuncGraphPtr & anf_graph,int output_size)932 STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
933                                           std::unordered_map<std::string, AnfNodePtr> *anf_node_map,
934                                           const FuncGraphPtr &anf_graph, int output_size) {
935   MSLITE_CHECK_PTR(anf_node);
936   MSLITE_CHECK_PTR(anf_node_map);
937   MSLITE_CHECK_PTR(anf_graph);
938   if (IsTensorListOp(anf_node) && output_size != 1) {
939     MS_LOG(ERROR) << "tensorlist output op output_size !=1";
940     return RET_ERROR;
941   }
942   if (output_size == 0) {
943     return RET_OK;
944   } else if (output_size == 1) {
945     auto type = kNumberTypeFloat32;
946     if (IsTensorListOp(anf_node)) {
947       type = kObjectTypeTensorType;
948     }
949     auto abstract_tensor = CreateTensorAbstract({}, type);
950     if (abstract_tensor == nullptr) {
951       MS_LOG(ERROR) << "Create tensor abstarct failed";
952       return RET_ERROR;
953     }
954     anf_node->set_abstract(abstract_tensor);
955     anf_node_map->insert(std::pair(op.name(), anf_node));
956   } else {
957     AbstractBasePtrList abstract_list;
958     for (int output_idx = 0; output_idx < output_size; output_idx++) {
959       auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
960       if (abstract_tensor == nullptr) {
961         MS_LOG(ERROR) << "Create tensor abstarct failed";
962         return RET_ERROR;
963       }
964       abstract_list.emplace_back(abstract_tensor);
965       auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
966       if (tuple_get_item_prim_ptr == nullptr) {
967         MS_LOG(ERROR) << "new TupleGetItem failed";
968         return RET_NULL_PTR;
969       }
970       auto prim_c = tuple_get_item_prim_ptr->GetPrim();
971       CHECK_NULL_RETURN(prim_c);
972       auto tuple_get_item_prim = NewValueNode(prim_c);
973       CHECK_NULL_RETURN(tuple_get_item_prim);
974       auto get_item_value = NewValueNode(MakeValue<int64_t>(output_idx));
975       CHECK_NULL_RETURN(get_item_value);
976       std::vector<AnfNodePtr> inputs{tuple_get_item_prim, anf_node, get_item_value};
977       CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
978       CHECK_NULL_RETURN(get_item_cnode);
979       std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
980       auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
981       if (get_item_abstract == nullptr) {
982         MS_LOG(ERROR) << "Create tensor abstarct failed";
983         return RET_ERROR;
984       }
985       get_item_cnode->set_abstract(get_item_abstract);
986       get_item_cnode->set_fullname_with_scope(output_item_name);
987       anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode));
988     }
989     anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
990   }
991   return RET_OK;
992 }
993 
RecordNullInput(const CNodePtr & node,const std::vector<std::string> & input_name_not_found)994 STATUS TFModelParser::RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found) {
995   CHECK_NULL_RETURN(node);
996   nodes_with_null_input_.emplace_back(node, input_name_not_found);
997   return RET_OK;
998 }
999 
ConnectNullInput()1000 STATUS TFModelParser::ConnectNullInput() {
1001   for (auto &it : nodes_with_null_input_) {
1002     auto &cnode = it.first;
1003     auto &input_name_not_found = it.second;
1004     auto &inputs = cnode->inputs();
1005     int i = 0;
1006     for (size_t j = 0; j < inputs.size(); ++j) {
1007       if (inputs[j] == nullptr) {
1008         cnode->set_input(j, GetAnfNode(input_name_not_found[i], anf_root_node_map_));
1009         ++i;
1010       }
1011     }
1012   }
1013   return RET_OK;
1014 }
1015 
ConvertOps(const tensorflow::NodeDef & node_def,const std::map<std::string,const tensorflow::NodeDef * > & tf_node_map,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_node_map)1016 STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
1017                                  const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
1018                                  const FuncGraphPtr &func_graph_ptr,
1019                                  std::unordered_map<std::string, AnfNodePtr> *anf_node_map) {
1020   MS_ASSERT(node_def != nullptr);
1021   MSLITE_CHECK_PTR(func_graph_ptr);
1022   MSLITE_CHECK_PTR(anf_node_map);
1023   STATUS status = RET_OK;
1024   const auto &op_type = node_def.op();
1025   if (TensorFlowUtils::OutputIsInputOp(op_type)) {
1026     return RET_OK;
1027   } else if (op_type == "Placeholder" || op_type == "Const") {
1028     node_output_num_[node_def.name()] = 1;
1029     return RET_OK;
1030   }
1031   MS_LOG(INFO) << "parse op : " << op_type;
1032   ops::PrimitiveCPtr primitive_c;
1033   auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeTf, op_type);
1034   int output_size;
1035   std::vector<std::string> input_names;
1036   if (node_parser != nullptr) {
1037     auto parser_result = node_parser->Parse(node_def, tf_node_map, &input_names, &output_size);
1038     if (parser_result == nullptr) {
1039       MS_LOG(ERROR) << "Node parse result nullptr!Please check system memory.";
1040       return RET_ERROR;
1041     }
1042     primitive_c = parser_result->GetPrim();
1043   } else {
1044     auto node_parser_builtin = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
1045     if (node_parser_builtin == nullptr) {
1046       NotSupportOp::GetInstance()->InsertOp(op_type);
1047       MS_LOG(ERROR) << "cannot find node parser: " << node_def.name() << " in "
1048                     << func_graph_ptr->get_attr("graph_name")->ToString();
1049       return RET_NOT_FIND_OP;
1050     }
1051     primitive_c = node_parser_builtin->Parse(node_def, tf_node_map, &input_names, &output_size);
1052   }
1053   if (primitive_c == nullptr) {
1054     MS_LOG(ERROR) << "node " << op_type << " parser failed!";
1055     return RET_ERROR;
1056   }
1057   node_output_num_[node_def.name()] = output_size;
1058   for (int i = 0; i < output_size; i++) {
1059     node_output_num_[node_def.name() + ":" + std::to_string(i)] = 1;
1060   }
1061   auto value_node = NewValueNode(primitive_c);
1062   if (value_node == nullptr) {
1063     MS_LOG(ERROR) << "value_node is nullptr";
1064     return RET_ERROR;
1065   }
1066 
1067   std::vector<AnfNodePtr> inputs = {value_node};
1068   std::vector<std::string> input_name_not_found{};
1069   status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs, &input_name_not_found);
1070   if (status != RET_OK) {
1071     return status;
1072   }
1073   // control_depends are not processed currently
1074   auto anf_node = func_graph_ptr->NewCNode(inputs);
1075   CHECK_NULL_RETURN(anf_node);
1076   anf_node->set_fullname_with_scope(node_def.name());
1077   status = ProcessControlFlowOp(anf_node, op_type, node_def);
1078   if (status != RET_OK) {
1079     MS_LOG(ERROR) << "ProcessControlFlowOp failed.";
1080     return RET_ERROR;
1081   }
1082 
1083   if (!input_name_not_found.empty()) {
1084     status = RecordNullInput(anf_node, input_name_not_found);
1085     if (status != RET_OK) {
1086       MS_LOG(ERROR) << "RecordNullInput for " << anf_node->fullname_with_scope() << " failed.";
1087       return status;
1088     }
1089   }
1090 
1091   status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size);
1092   if (status != RET_OK) {
1093     MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
1094     return status;
1095   }
1096 
1097   return status;
1098 }
1099 
IsEmptyTfFunction(const CNodePtr & anf_node,std::string branch_name)1100 bool TFModelParser::IsEmptyTfFunction(const CNodePtr &anf_node, std::string branch_name) {
1101   for (int i = 0; i < tf_root_graph_->library().function_size(); i++) {
1102     auto &tf_sub_fuction = tf_root_graph_->library().function(i);
1103     auto &tf_sub_signature = tf_sub_fuction.signature();
1104     auto &sub_graph_name = tf_sub_signature.name();
1105 
1106     if (branch_name != sub_graph_name) {
1107       continue;
1108     }
1109     auto &tf_sub_signature_output_arg = tf_sub_signature.output_arg();
1110     if (tf_sub_signature_output_arg.size() != 1) {
1111       return false;
1112     }
1113     auto &tf_sub_signature_output_name = tf_sub_signature_output_arg.Get(0).name();
1114     auto input_arg_size = tf_sub_signature.input_arg_size();
1115     if (tf_sub_fuction.node_def_size() == 0) {
1116       for (int index = 0; index < input_arg_size; index++) {
1117         auto &input_arg = tf_sub_signature.input_arg(index);
1118         if (input_arg.name() == tf_sub_signature_output_name &&
1119             ineffective_if_op_map_.find(anf_node) == ineffective_if_op_map_.end()) {
1120           ineffective_if_op_map_[anf_node] = index + C2NUM;
1121           return true;
1122         }
1123       }
1124     } else if (tf_sub_fuction.node_def_size() == 1) {
1125       auto &node_def = tf_sub_fuction.node_def(0);
1126       if (!TensorFlowUtils::OutputIsInputOp(node_def.name())) {
1127         return false;
1128       }
1129       for (int index = 0; index < input_arg_size; index++) {
1130         auto &input_arg = tf_sub_signature.input_arg(index);
1131         if (input_arg.name() == node_def.input(0)) {
1132           auto output_name = node_def.name();
1133           transform(output_name.begin(), output_name.end(), output_name.begin(), ::tolower);
1134           if (output_name == tf_sub_signature_output_name &&
1135               ineffective_if_op_map_.find(anf_node) == ineffective_if_op_map_.end()) {
1136             ineffective_if_op_map_[anf_node] = index + C2NUM;
1137             return true;
1138           }
1139         }
1140       }
1141     }
1142   }
1143   return false;
1144 }  // namespace lite
1145 
IsIneffectiveIfOp(const CNodePtr & anf_node,const string & op_type,const tensorflow::NodeDef & node_def)1146 bool TFModelParser::IsIneffectiveIfOp(const CNodePtr &anf_node, const string &op_type,
1147                                       const tensorflow::NodeDef &node_def) {
1148   if (op_type != "If") {
1149     return false;
1150   }
1151   lite::DataInfo if_cond_info;
1152   auto if_cond = anf_node->input(1);
1153   if (if_cond == nullptr) {
1154     return false;
1155   }
1156   int status = lite::RET_ERROR;
1157   if (if_cond->isa<Parameter>()) {
1158     status = lite::FetchDataFromParameterNode(anf_node, 1, converter::kFmkTypeMs, &if_cond_info, true);
1159   } else if (utils::isa<CNodePtr>(if_cond)) {
1160     auto input_cnode = if_cond->cast<CNodePtr>();
1161     if (input_cnode == nullptr) {
1162       return false;
1163     }
1164     if (!opt::CheckPrimitiveType(input_cnode, prim::kPrimConstant)) {
1165       return false;
1166     }
1167 
1168     auto input_cnode_in1 = input_cnode->input(1);
1169     if (input_cnode_in1 == nullptr) {
1170       return false;
1171     }
1172     if (input_cnode_in1->isa<Parameter>()) {
1173       status = lite::FetchDataFromParameterNode(input_cnode, 1, converter::kFmkTypeMs, &if_cond_info, true);
1174     } else if (input_cnode_in1->isa<ValueNode>()) {
1175       status = lite::FetchDataFromValueNode(input_cnode, 1, converter::kFmkTypeMs, false, &if_cond_info, true);
1176     }
1177   }
1178 
1179   if (status != lite::RET_OK) {
1180     return false;
1181   }
1182   if (static_cast<TypeId>(if_cond_info.data_type_) == kNumberTypeBool && if_cond_info.data_.size() == 1) {
1183     tensorflow::AttrValue attr_value;
1184     if (static_cast<bool>(if_cond_info.data_[0])) {
1185       if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) {
1186         auto then_name = attr_value.func().name();
1187         if (IsEmptyTfFunction(anf_node, then_name)) {
1188           return true;
1189         }
1190       }
1191     } else {
1192       if (TensorFlowUtils::FindAttrValue(node_def, "else_branch", &attr_value)) {
1193         auto else_name = attr_value.func().name();
1194         if (IsEmptyTfFunction(anf_node, else_name)) {
1195           return true;
1196         }
1197       }
1198     }
1199   }
1200 
1201   return false;
1202 }
1203 
ProcessControlFlowOp(const CNodePtr & anf_node,const string & op_type,const tensorflow::NodeDef & node_def)1204 STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type,
1205                                            const tensorflow::NodeDef &node_def) {
1206   MSLITE_CHECK_PTR(anf_node);
1207   if (IsIneffectiveIfOp(anf_node, op_type, node_def)) {
1208     return RET_OK;
1209   }
1210   if (op_type == "StatelessWhile" || op_type == "While") {
1211     MS_LOG(INFO) << "find while node:" << node_def.name();
1212     tensorflow::AttrValue attr_value;
1213     if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) {
1214       auto body_name = attr_value.func().name();
1215       function_while_map_[body_name] = anf_node;
1216       MS_LOG(DEBUG) << "parse body name:" << body_name;
1217     }
1218     if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
1219       auto cond_name = attr_value.func().name();
1220       function_while_map_[cond_name] = anf_node;
1221       while_cond_branch_name_.push_back(cond_name);
1222       MS_LOG(DEBUG) << "parse cond name:" << cond_name;
1223     }
1224   } else if (op_type == "StatelessIf" || op_type == "If") {
1225     MS_LOG(INFO) << "find if node:" << node_def.name();
1226     tensorflow::AttrValue attr_value;
1227     if (TensorFlowUtils::FindAttrValue(node_def, "then_branch", &attr_value)) {
1228       auto then_name = attr_value.func().name();
1229       if_then_branch_name_.push_back(then_name);
1230       function_if_map_[then_name] = anf_node;
1231       MS_LOG(DEBUG) << "parse then name:" << then_name;
1232     }
1233     if (TensorFlowUtils::FindAttrValue(node_def, "else_branch", &attr_value)) {
1234       auto else_name = attr_value.func().name();
1235       function_if_map_[else_name] = anf_node;
1236       MS_LOG(DEBUG) << "parse else name:" << else_name;
1237     }
1238   }
1239   return RET_OK;
1240 }
1241 
GetAllNodeInputs()1242 std::set<std::string> TFModelParser::GetAllNodeInputs() {
1243   std::set<std::string> all_node_inputs;
1244   for (auto &node : tf_root_graph_nodes_vec_) {
1245     for (int i = 0; i < node->input_size(); ++i) {
1246       all_node_inputs.insert(TensorFlowUtils::GetNodeName(node->input(i)));
1247       auto input_name = node->input(i);
1248       if (input_name[0] == '^') {
1249         input_name.erase(0, 1);
1250       }
1251       all_node_inputs.insert(input_name);
1252     }
1253   }
1254   return all_node_inputs;
1255 }
1256 
GetGraphOutputNames(std::vector<AnfNodePtr> * output_nodes)1257 STATUS TFModelParser::GetGraphOutputNames(std::vector<AnfNodePtr> *output_nodes) {
1258   MS_CHECK_TRUE_RET(output_nodes->empty(), RET_ERROR);
1259   std::set<std::string> all_node_inputs = GetAllNodeInputs();
1260   for (auto &node : tf_root_graph_nodes_vec_) {
1261     if (node->op() == "Assert") {
1262       continue;
1263     }
1264     auto it = all_node_inputs.find(node->name());
1265     if (it != all_node_inputs.end() || node->input_size() <= 0) {  // output node not constraint to Identity
1266       continue;
1267     }
1268     auto origin_name = GetOriginInputName(*(node), tf_root_graph_nodes_);
1269     // node with multiple outputs has been changed to tupleGetItem, and the original name changes to be name:idx.
1270     for (int i = 0; i < node_output_num_[origin_name]; i++) {
1271       auto anf_node = GetAnfNode(origin_name, anf_root_node_map_, i);
1272       if (anf_node == nullptr) {
1273         MS_LOG(ERROR) << "can't find anf node: " << origin_name;
1274         return RET_ERROR;
1275       }
1276       output_nodes->push_back(anf_node);
1277       if (TensorFlowUtils::OutputIsInputOp(node->op())) {
1278         auto tmp_node = node;
1279         bool found_input = true;
1280         while (tmp_node->name().empty() && TensorFlowUtils::OutputIsInputOp(tmp_node->op())) {
1281           auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0));
1282           if (tf_root_graph_nodes_.find(flatten_input_name) != tf_root_graph_nodes_.end()) {
1283             tmp_node = tf_root_graph_nodes_.at(flatten_input_name);
1284           } else {
1285             found_input = false;
1286             break;
1287           }
1288         }
1289         origin_name = found_input ? tmp_node->name() : origin_name;
1290       }
1291       graph_output_names_.push_back(origin_name);
1292     }
1293   }
1294   return RET_OK;
1295 }
1296 
ConvertRootGraphOutputs()1297 STATUS TFModelParser::ConvertRootGraphOutputs() {
1298   // because output of intermediate node in anf graph may also be output tensors, we search output tensors in
1299   // tf_root_graph_nodes_ but not anf_root_node_map_
1300   std::vector<AnfNodePtr> output_nodes;
1301   auto status = GetGraphOutputNames(&output_nodes);
1302   if (status != RET_OK) {
1303     MS_LOG(ERROR) << "get graph outputs node error";
1304     return status;
1305   }
1306   auto func_graph = ConvertGraph(res_graph_);
1307   if (func_graph == nullptr) {
1308     MS_LOG(ERROR) << "unc graph is invalid.";
1309     return RET_ERROR;
1310   }
1311   status = MakeAnfGraphOutputs(output_nodes, func_graph);
1312   if (status != RET_OK) {
1313     MS_LOG(ERROR) << "make anf graph outputs node error";
1314     return status;
1315   }
1316   // save original output tensor names.
1317   ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(graph_output_names_);
1318   return RET_OK;
1319 }
MakeAnfGraphOutputs(const std::vector<AnfNodePtr> & output_nodes,const FuncGraphPtr & anf_graph)1320 STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_nodes, const FuncGraphPtr &anf_graph) {
1321   if (output_nodes.empty() || anf_graph == nullptr) {
1322     MS_LOG(ERROR) << "anf output nodes empty or  null anf graph";
1323     return RET_ERROR;
1324   }
1325   if (output_nodes.size() > 1) {
1326     std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
1327     auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
1328     if (make_tuple_prim_ptr == nullptr) {
1329       MS_LOG(ERROR) << "new MakeTuple failed";
1330       return RET_NULL_PTR;
1331     }
1332     auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
1333     CHECK_NULL_RETURN(make_tuple_prim_c);
1334     auto make_tuple_prim = NewValueNode(make_tuple_prim_c);
1335     CHECK_NULL_RETURN(make_tuple_prim);
1336     make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
1337     auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_inputs);
1338     CHECK_NULL_RETURN(make_tuple_cnode);
1339     make_tuple_cnode->set_fullname_with_scope("return_tuple");
1340 
1341     auto return_prim_ptr = std::make_shared<ops::Return>();
1342     if (return_prim_ptr == nullptr) {
1343       MS_LOG(ERROR) << "new Return failed";
1344       return RET_NULL_PTR;
1345     }
1346     auto return_prim_c = return_prim_ptr->GetPrim();
1347     CHECK_NULL_RETURN(return_prim_c);
1348     auto value_node = NewValueNode(return_prim_c);
1349     CHECK_NULL_RETURN(value_node);
1350     std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode};
1351     auto cnode = anf_graph->NewCNode(op_inputs);
1352     CHECK_NULL_RETURN(cnode);
1353     cnode->set_fullname_with_scope("Return");
1354     anf_graph->set_return(cnode);
1355   } else {
1356     auto return_prim_ptr = std::make_shared<ops::Return>();
1357     if (return_prim_ptr == nullptr) {
1358       MS_LOG(ERROR) << "new Return failed";
1359       return RET_NULL_PTR;
1360     }
1361     auto return_prim_c = return_prim_ptr->GetPrim();
1362     CHECK_NULL_RETURN(return_prim_c);
1363     auto value_node = NewValueNode(return_prim_c);
1364     CHECK_NULL_RETURN(value_node);
1365     std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()};
1366     auto return_cnode = anf_graph->NewCNode(op_inputs);
1367     CHECK_NULL_RETURN(return_cnode);
1368     return_cnode->set_fullname_with_scope("Return");
1369     anf_graph->set_return(return_cnode);
1370   }
1371   return RET_OK;
1372 }
1373 
TF2AnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs,std::map<AnfNodePtr,int> * ineffective_if_op_map)1374 int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs,
1375                                 std::map<AnfNodePtr, int> *ineffective_if_op_map) {
1376   MSLITE_CHECK_PTR(ineffective_if_op_map);
1377   for (const auto &func_graph : all_func_graphs) {
1378     if (!TfInputAdjust::Adjust(func_graph)) {
1379       MS_LOG(ERROR) << "Do TfInputAdjust failed.";
1380       return RET_ERROR;
1381     }
1382     auto remove_ineffective_control_flow = std::make_shared<RemoveIneffectiveControlFlow>();
1383     MS_CHECK_TRUE_RET(remove_ineffective_control_flow != nullptr, RET_ERROR);
1384     if (!remove_ineffective_control_flow->Run(func_graph, ineffective_if_op_map)) {
1385       MS_LOG(ERROR) << "Do RemoveIneffectiveControlFlow failed.";
1386       return RET_ERROR;
1387     }
1388     auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
1389     MS_CHECK_TRUE_RET(functionalize_control_op_pass != nullptr, RET_ERROR);
1390     if (!functionalize_control_op_pass->Run(func_graph)) {
1391       MS_LOG(ERROR) << "functionalize control op pass failed.";
1392       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
1393       return RET_ERROR;
1394     }
1395     auto fake_quant_adjust = std::make_shared<TFFakeQuantAdjust>();
1396     if (!fake_quant_adjust->Adjust(func_graph)) {
1397       MS_LOG(ERROR) << "tf fake quant adjust failed.";
1398       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
1399       return RET_ERROR;
1400     }
1401   }
1402   return RET_OK;
1403 }
1404 
1405 REG_MODEL_PARSER(kFmkTypeTf, LiteModelParserCreator<TFModelParser>)
1406 }  // namespace lite
1407 }  // namespace mindspore
1408