• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 #include <set>
26 #include <utility>
27 #include "proto/graph.pb.h"
28 #include "proto/node_def.pb.h"
29 #include "schema/inner/model_generated.h"
30 #include "securec/include/securec.h"
31 #include "tools/common/tensor_util.h"
32 #include "include/registry/model_parser.h"
33 #include "include/registry/model_parser_registry.h"
34 #include "ops/primitive_c.h"
35 
36 namespace mindspore {
37 namespace lite {
38 class TFModelParser : public converter::ModelParser {
39  public:
40   TFModelParser() = default;
41   ~TFModelParser() override = default;
42 
43   api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
44 
45   static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs,
46                           std::map<AnfNodePtr, int> *ineffective_if_op_map = nullptr);
47 
48  private:
49   static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
50   STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
51                             const TypeId &type, const ParameterPtr &parameter, std::vector<int64_t> *shape_vector);
52   STATUS SetInt64TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info,
53                             const std::string &node_name);
54   STATUS SetInt64TensorToInt64Tensor(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
55   STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info,
56                                const std::string &node_name);
57   STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,
58                           std::unordered_map<std::string, AnfNodePtr> *anf_node_map, bool root_graph = false);
59   STATUS ConvertGraphInputsAndConsts(const std::vector<const tensorflow::NodeDef *> &tf_graph_nodes,
60                                      const FuncGraphPtr &anf_graph,
61                                      std::unordered_map<std::string, AnfNodePtr> *anf_node_map,
62                                      bool root_graph = false);
63   static STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names,
64                                   const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
65                                   const std::unordered_map<std::string, AnfNodePtr> &anf_node_map,
66                                   std::vector<AnfNodePtr> *inputs, std::vector<std::string> *input_name_not_found);
67   static STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node,
68                                     std::unordered_map<std::string, AnfNodePtr> *anf_node_map,
69                                     const FuncGraphPtr &anf_graph, int output_size);
70   STATUS ConvertOps(const tensorflow::NodeDef &node_def,
71                     const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
72                     const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
73   STATUS ResetAbstractTensorToInt64(const std::string &op_type, const std::vector<std::string> &input_names,
74                                     const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
75                                     const std::unordered_map<std::string, AnfNodePtr> &anf_node_map);
76   STATUS ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def);
77 
78   bool IsIneffectiveIfOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def);
79 
80   bool IsEmptyTfFunction(const CNodePtr &anf_node, std::string branch_name);
81 
82   std::set<std::string> GetAllNodeInputs();
83 
84   STATUS GetGraphOutputNames(std::vector<AnfNodePtr> *output_nodes);
85 
86   STATUS ConvertRootGraphOutputs();
87 
88   void UpdateMap(const CNodePtr &cnode, const FuncGraphPtr &sub_func_graph, const std::string &sub_graph_name);
89 
90   STATUS ConvertSubgraph();
91 
92   STATUS ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
93                                std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
94                                const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
95                                const FuncGraphPtr &sub_func_graph);
96 
97   static STATUS ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
98                                        const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
99                                        const tensorflow::FunctionDef &tf_sub_fuction,
100                                        const FuncGraphPtr &sub_func_graph);
101 
102   STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
103                                     const std::map<CNodePtr, FuncGraphPtr> &second_func_map);
104 
105   static STATUS MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_nodes, const FuncGraphPtr &anf_graph);
106 
107   STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);
108 
109   STATUS ConnectNullInput();
110 
111   std::unique_ptr<tensorflow::GraphDef> tf_root_graph_;                     // tf root graph def
112   std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_;  // tf root graph node map
113   std::vector<const tensorflow::NodeDef *> tf_root_graph_nodes_vec_;
114   std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
115   std::vector<std::string> graph_input_names_;
116   std::vector<std::string> graph_output_names_;
117   std::map<std::string, AnfNodePtr> function_while_map_;  // tf function name->while_node_name
118   std::map<std::string, AnfNodePtr> function_if_map_;     // tf function name->if_node
119   std::map<AnfNodePtr, int> ineffective_if_op_map_;
120   std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{};
121   std::vector<std::string> while_cond_branch_name_;
122   std::vector<std::string> if_then_branch_name_;
123   std::unordered_map<std::string, int> node_output_num_;
124   std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_;
125 };
126 }  // namespace lite
127 }  // namespace mindspore
128 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_
129