1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 #include <set> 26 #include <utility> 27 #include "proto/graph.pb.h" 28 #include "proto/node_def.pb.h" 29 #include "schema/inner/model_generated.h" 30 #include "securec/include/securec.h" 31 #include "tools/common/tensor_util.h" 32 #include "include/registry/model_parser.h" 33 #include "include/registry/model_parser_registry.h" 34 #include "ops/primitive_c.h" 35 36 namespace mindspore { 37 namespace lite { 38 class TFModelParser : public converter::ModelParser { 39 public: 40 TFModelParser() = default; 41 ~TFModelParser() override = default; 42 43 api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; 44 45 static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs, 46 std::map<AnfNodePtr, int> *ineffective_if_op_map = nullptr); 47 48 private: 49 static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); 50 STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value, 51 const TypeId &type, const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector); 52 STATUS SetInt64TensorInfo(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info, 53 const std::string &node_name); 54 STATUS SetInt64TensorToInt64Tensor(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); 55 STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info, 56 const std::string &node_name); 57 STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, 58 std::unordered_map<std::string, AnfNodePtr> *anf_node_map, bool root_graph = false); 59 STATUS ConvertGraphInputsAndConsts(const std::vector<const tensorflow::NodeDef *> &tf_graph_nodes, 60 const FuncGraphPtr &anf_graph, 61 std::unordered_map<std::string, AnfNodePtr> *anf_node_map, 62 bool root_graph = false); 63 static STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, 64 const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, 65 const std::unordered_map<std::string, AnfNodePtr> &anf_node_map, 66 std::vector<AnfNodePtr> *inputs, std::vector<std::string> *input_name_not_found); 67 static STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, 68 std::unordered_map<std::string, AnfNodePtr> *anf_node_map, 69 const FuncGraphPtr &anf_graph, int output_size); 70 STATUS ConvertOps(const tensorflow::NodeDef &node_def, 71 const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, 72 const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map); 73 STATUS ResetAbstractTensorToInt64(const std::string &op_type, const std::vector<std::string> &input_names, 74 const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map, 75 const std::unordered_map<std::string, AnfNodePtr> &anf_node_map); 76 STATUS ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def); 77 78 bool IsIneffectiveIfOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def); 79 80 bool IsEmptyTfFunction(const CNodePtr &anf_node, std::string branch_name); 81 82 std::set<std::string> GetAllNodeInputs(); 83 84 STATUS GetGraphOutputNames(std::vector<AnfNodePtr> *output_nodes); 85 86 STATUS ConvertRootGraphOutputs(); 87 88 void UpdateMap(const CNodePtr &cnode, const FuncGraphPtr &sub_func_graph, const std::string &sub_graph_name); 89 90 STATUS ConvertSubgraph(); 91 92 STATUS ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map, 93 std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map, 94 const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode, 95 const FuncGraphPtr &sub_func_graph); 96 97 static STATUS ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map, 98 const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map, 99 const tensorflow::FunctionDef &tf_sub_fuction, 100 const FuncGraphPtr &sub_func_graph); 101 102 STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map, 103 const std::map<CNodePtr, FuncGraphPtr> &second_func_map); 104 105 static STATUS MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_nodes, const FuncGraphPtr &anf_graph); 106 107 STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found); 108 109 STATUS ConnectNullInput(); 110 111 std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def 112 std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map 113 std::vector<const tensorflow::NodeDef *> tf_root_graph_nodes_vec_; 114 std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; 115 std::vector<std::string> graph_input_names_; 116 std::vector<std::string> graph_output_names_; 117 std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name 118 std::map<std::string, AnfNodePtr> function_if_map_; // tf function name->if_node 119 std::map<AnfNodePtr, int> ineffective_if_op_map_; 120 std::vector<std::pair<CNodePtr, std::vector<std::string>>> nodes_with_null_input_{}; 121 std::vector<std::string> while_cond_branch_name_; 122 std::vector<std::string> if_then_branch_name_; 123 std::unordered_map<std::string, int> node_output_num_; 124 std::map<CNodePtr, FuncGraphPtr> while_cond_map_, while_body_map_, if_then_map_, if_else_map_; 125 }; 126 } // namespace lite 127 } // namespace mindspore 128 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MODEL_PARSER_H_ 129