• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_
18 #define MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_
19 
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include <utility>
25 #include <set>
26 #include <list>
27 #include <mutex>
28 #include "schema/inner/model_generated.h"
29 #include "ops/primitive_c.h"
30 #include "ir/func_graph.h"
31 #include "tools/lite_exporter/fetch_content.h"
32 #include "tools/converter/converter_context.h"
33 #include "tools/optimizer/common/gllo_utils.h"
34 #include "tools/common/node_util.h"
35 #include "tools/common/persist_future.h"
36 
37 using mindspore::ops::PrimitiveC;
38 
39 namespace mindspore::lite {
40 class AnfExporter {
41  public:
42   AnfExporter() = default;
43   virtual ~AnfExporter() = default;
44   schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
45                              bool train_flag = false);
46   int SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
47                       schema::CNodeT *fb_node);
48   int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
49                      schema::CNodeT *fb_node);
50 
51  protected:
52   int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode);
53   int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode);
54   int ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
55                             const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node,
56                             size_t *tensor_index_ptr);
57   int ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
58                             const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node);
59   int SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
60   int SetSubGraphOutputIndex(const CNodePtr &cnode, size_t subgraph_index,
61                              const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node);
62   int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
63                         const std::shared_ptr<mindspore::Primitive> &primitive,
64                         const std::unique_ptr<schema::CNodeT> &dst_node);
65   int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, const CNodePtr &cnode,
66                         const std::shared_ptr<mindspore::Primitive> &primitive,
67                         const std::unique_ptr<schema::CNodeT> &dst_node);
68 
69   int SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> &primitive, const AnfNodePtr &input_node,
70                                   mindspore::schema::TensorT *tensor_input);
71   int Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
72              const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive);
73   int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
74                      bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode = nullptr);
75   static CNodePtr CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &cnode);
76   static CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node);
77   bool HasExported(const FuncGraphPtr &func_graph);
78   int ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const bool &keep_graph,
79                         const bool &copy_primitive, const CNodePtr &partial_cnode,
80                         const std::unique_ptr<schema::CNodeT> &schema_cnode);
81   std::list<CNodePtr> InsertCallNode(const FuncGraphPtr &func_graph);
82   int SetMetaGraphInput(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
83   int SetMetaGraphOutput(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
84   int CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const AnfNodePtr &input,
85                                   size_t *tensor_index_ptr);
86   bool CaseToContinue(const string &prim_name);
87 
88  private:
89   void SetNonTailCall(const CNodePtr &cnode, schema::CNodeT *node);
90   int SetTailCallForReturn(const CNodePtr &return_cnode);
91   // To deal witch case which call node has not output.
92   int SetTailCallForNonOutput();
93   size_t GetNodeId(const std::pair<AnfNodePtr, size_t> &key);
94   void SetNodeId(const std::pair<AnfNodePtr, size_t> &key, size_t value);
95   bool HasNodeIdKey(const std::pair<AnfNodePtr, size_t> &key);
96 
97   // meta graph all tensor op functions
98   // insert tensor to allTensor and return the index of the tensor
99   size_t NewFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, mindspore::schema::TensorT *tensor);
100   // insert tensor to allTensor
101   void InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, mindspore::schema::TensorT *tensor);
102   // get the allTensor size
103   size_t GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
104   // get the tensor in allTensor
105   mindspore::schema::TensorT *GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
106                                                      size_t index);
107 
108   // Key is a pair of node and its output id. Value is the mapped tensor id of meta_graph.
109   std::map<std::pair<AnfNodePtr, size_t>, size_t> node_id_map_;
110   // The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph
111   std::map<FuncGraphPtr, size_t> fg_subgraph_map_;
112   std::vector<AnfNodePtr> graph_inputs_;
113   std::map<AnfNodePtr, size_t> graph_inputs_map_;
114   std::map<AnfNodePtr, schema::CNodeT *> call_node_map_;
115   std::mutex fb_graph_node_mutex_;
116   std::mutex fb_graph_all_tensors_mutex_;
117   std::mutex node_id_map_mutex_;
118   std::map<AnfNodePtr, PersistFuture<bool>> batch_cnode_map_;
119   uint32_t node_idx_ = 0;
120   bool train_flag_ = false;
121 };
122 // by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
123 // but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify
124 // the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple
125 // and clear.
126 schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
127                            bool train_flag = false);
128 }  // namespace mindspore::lite
129 #endif  // MINDSPORE_LITE_TOOLS_LITE_EXPORTER_ANF_EXPORTER_H_
130