1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_ 18 #define MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_ 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include <utility> 25 #include <set> 26 #include <list> 27 #include <mutex> 28 #include "schema/inner/model_generated.h" 29 #include "ops/primitive_c.h" 30 #include "ir/func_graph.h" 31 #include "tools/lite_exporter/fetch_content.h" 32 #include "tools/converter/converter_context.h" 33 #include "tools/optimizer/common/gllo_utils.h" 34 #include "tools/common/node_util.h" 35 #include "tools/common/persist_future.h" 36 37 using mindspore::ops::PrimitiveC; 38 39 namespace mindspore::lite { 40 class AnfExporter { 41 public: 42 AnfExporter() = default; 43 virtual ~AnfExporter() = default; 44 schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, 45 bool train_flag = false); 46 int SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 47 schema::CNodeT *fb_node); 48 int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 49 schema::CNodeT *fb_node); 50 51 protected: 52 int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); 53 int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode); 54 int ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, 55 const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node, 56 size_t *tensor_index_ptr); 57 int ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, 58 const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node); 59 int SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index); 60 int SetSubGraphOutputIndex(const CNodePtr &cnode, size_t subgraph_index, 61 const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node); 62 int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, 63 const std::shared_ptr<mindspore::Primitive> &primitive, 64 const std::unique_ptr<schema::CNodeT> &dst_node); 65 int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, const CNodePtr &cnode, 66 const std::shared_ptr<mindspore::Primitive> &primitive, 67 const std::unique_ptr<schema::CNodeT> &dst_node); 68 69 int SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> &primitive, const AnfNodePtr &input_node, 70 mindspore::schema::TensorT *tensor_input); 71 int Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 72 const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive); 73 int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 74 bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode = nullptr); 75 static CNodePtr CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &cnode); 76 static CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node); 77 bool HasExported(const FuncGraphPtr &func_graph); 78 int ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const bool &keep_graph, 79 const bool ©_primitive, const CNodePtr &partial_cnode, 80 const std::unique_ptr<schema::CNodeT> &schema_cnode); 81 std::list<CNodePtr> InsertCallNode(const FuncGraphPtr &func_graph); 82 int SetMetaGraphInput(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT); 83 int SetMetaGraphOutput(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT); 84 int CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const AnfNodePtr &input, 85 size_t *tensor_index_ptr); 86 bool CaseToContinue(const string &prim_name); 87 88 private: 89 void SetNonTailCall(const CNodePtr &cnode, schema::CNodeT *node); 90 int SetTailCallForReturn(const CNodePtr &return_cnode); 91 // To deal witch case which call node has not output. 92 int SetTailCallForNonOutput(); 93 size_t GetNodeId(const std::pair<AnfNodePtr, size_t> &key); 94 void SetNodeId(const std::pair<AnfNodePtr, size_t> &key, size_t value); 95 bool HasNodeIdKey(const std::pair<AnfNodePtr, size_t> &key); 96 97 // meta graph all tensor op functions 98 // insert tensor to allTensor and return the index of the tensor 99 size_t NewFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, mindspore::schema::TensorT *tensor); 100 // insert tensor to allTensor 101 void InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, mindspore::schema::TensorT *tensor); 102 // get the allTensor size 103 size_t GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); 104 // get the tensor in allTensor 105 mindspore::schema::TensorT *GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, 106 size_t index); 107 108 // Key is a pair of node and its output id. Value is the mapped tensor id of meta_graph. 109 std::map<std::pair<AnfNodePtr, size_t>, size_t> node_id_map_; 110 // The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph 111 std::map<FuncGraphPtr, size_t> fg_subgraph_map_; 112 std::vector<AnfNodePtr> graph_inputs_; 113 std::map<AnfNodePtr, size_t> graph_inputs_map_; 114 std::map<AnfNodePtr, schema::CNodeT *> call_node_map_; 115 std::mutex fb_graph_node_mutex_; 116 std::mutex fb_graph_all_tensors_mutex_; 117 std::mutex node_id_map_mutex_; 118 std::map<AnfNodePtr, PersistFuture<bool>> batch_cnode_map_; 119 uint32_t node_idx_ = 0; 120 bool train_flag_ = false; 121 }; 122 // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. 123 // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify 124 // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple 125 // and clear. 126 schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false, 127 bool train_flag = false); 128 } // namespace mindspore::lite 129 #endif // MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_ 130