• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
19 
20 #define DRAW_GE_GRAPH
21 
22 #include <memory>
23 #include <map>
24 #include <set>
25 #include <vector>
26 #include <unordered_map>
27 #include <string>
28 #include <utility>
29 #include <stack>
30 #include <fstream>
31 #include <sstream>
32 
33 #include "ir/anf.h"
34 #include "ir/func_graph.h"
35 #include "transform/graph_ir/util.h"
36 #include "ir/tensor.h"
37 #include "transform/graph_ir/df_graph_manager.h"
38 #include "utils/config_manager.h"
39 #include "transform/graph_ir/op_adapter.h"
40 #include "graph/operator_reg.h"
41 #ifdef OPEN_SOURCE
42 #include "ge/client/ge_api.h"
43 #else
44 #include "external/ge/ge_api.h"
45 #endif
46 #include "graph/tensor.h"
47 #include "ops/hcom_ops.h"
48 
49 namespace mindspore {
50 namespace transform {
51 using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>;
52 using HcomBroadcast = ge::op::HcomBroadcast;
53 class DfGraphConvertor {
54  public:
DfGraphConvertor(const AnfGraphPtr & anf_graph)55   explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
56     MS_EXCEPTION_IF_NULL(anf_graph);
57     df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
58 #if (!defined ENABLE_GE) || (defined ENABLE_INFER)
59     training_ = anf_graph->has_flag("training");
60 #else
61     training_ = ENABLE_TRAIN;
62 #endif
63     distribute_ = anf_graph->has_flag("broadcast_flag");
64     if (anf_graph->has_flag("broadcast_flag")) {
65       ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION);
66     } else {
67       ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE);
68     }
69 
70     MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_;
71   }
72 
~DfGraphConvertor()73   ~DfGraphConvertor() {}
74 
75   static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt);
76   static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt);
77 
DrawComputeGraph(const std::string & name)78   void DrawComputeGraph(const std::string &name) {
79 #ifndef ENABLE_SECURITY
80     std::ofstream fout(name);
81     if (!fout.is_open()) {
82       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
83       return;
84     }
85     fout << compute_sout_.str();
86     fout.close();
87 #endif
88   }
89 
DrawInitGraph(const std::string & name)90   void DrawInitGraph(const std::string &name) {
91 #ifndef ENABLE_SECURITY
92     std::ofstream fout(name);
93     if (!fout.is_open()) {
94       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
95       return;
96     }
97     fout << init_sout_.str();
98     fout.close();
99 #endif
100   }
DrawSaveCheckpointGraph(const std::string & name)101   void DrawSaveCheckpointGraph(const std::string &name) {
102     std::ofstream fout(name);
103     if (!fout.is_open()) {
104       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
105       return;
106     }
107     fout << checkpoint_sout_.str();
108     fout.close();
109   }
110 
111   DfGraphConvertor &ConvertAllNode();
112   DfGraphConvertor &BuildGraph();
113   DfGraphConvertor &InitParam(const TensorOrderMap &tensors);
114   DfGraphConvertor &GenerateCheckpointGraph();
115   DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors);
116   void InitParamWithData(const TensorOrderMap &tensors);
117   void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node);
118   void SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
119                       const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input);
120   void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it);
121   void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input);
122   void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it);
123 
124   DfGraphPtr GetComputeGraph();
125   DfGraphPtr GetInitGraph();
126   DfGraphPtr GetSaveCheckpointGraph();
127   DfGraphPtr GetBroadcastGraph();
128   static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false);
129   static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false);
ErrCode()130   int ErrCode() const { return static_cast<int>(error_); }
131 
is_training()132   bool is_training() const { return training_; }
set_training(bool is_training)133   void set_training(bool is_training) { training_ = is_training; }
134 
135  protected:
136   void InitLoopVar(std::vector<ge::Operator> *init_input);
137 
138  private:
139   std::ostringstream compute_sout_;
140   std::ostringstream init_sout_;
141   std::ostringstream checkpoint_sout_;
142   std::ostringstream restore_checkpoint_sout_;
143   std::unordered_map<AnfNode *, std::string> op_draw_name_;
144   std::map<std::string, std::string> param_format_;
145 
146   AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index);
147   AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
148   AnfNodePtr TraceDepend(const CNodePtr &node);
149   OutHandler TraceRealOp(AnfNodePtr node);
150   OutHandler GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack, AnfNode *const draw_index);
151   OperatorPtr Convert(AnfNodePtr node);
152   OperatorPtr ConvertCNode(CNodePtr node);
153   std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node);
154   AnfNodePtr GetRealOpNode(AnfNodePtr node);
155   OperatorPtr ConvertParameter(AnfNodePtr node);
156   Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
157   OperatorPtr ConvertValueNode(ValueNodePtr node);
158   void SaveParamFormat(CNodePtr node);
159   void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
160   void ConvertTupleGetItem(const CNodePtr node);
161   void ConvertMakeTuple(const CNodePtr node);
162   void ConvertTopK(const CNodePtr node);
163   void ConvertReshape(const CNodePtr node);
164   std::vector<int64_t> CastToInt(const ValuePtr &value);
165   bool CheckCNode(const std::string &name, const CNodePtr node);
166   void TraceOutput(AnfNodePtr node);
167   void TraceOutputFromParameter(const AnfNodePtr &anf_out);
168   void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out);
169   void SetNodeInput(AnfNodePtr node);
170   void SetOpControlInput(const AnfNodePtr &node);
171   void UpdateOpDesc(AnfNodePtr node);
172   void SetSubgraph(AnfNodePtr node);
173   void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs);
174   void BuildSaveCheckpointGraph();
175   void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
176   void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
177   void AddGraphConstInput(const OperatorPtr &op);
178   OperatorPtr ToOperatorPtr(const AnfNodePtr &node);
179   bool IsSourceEdgeNode(const AnfNodePtr &node);
180   bool IsControlEdgeNode(const AnfNodePtr &node);
181   void AddEdgeForLoad(const AnfNodePtr &node);
182   void AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest);
183   void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top);
184   AnfNodePtr ParseLoadInput(const CNodePtr &cnode);
185   void AutoMonadSetControlInput(const AnfNodePtr &node);
186   void AutoMonadCollectInput(const AnfNodePtr &node);
187   void AutoMonadSetInput(const AnfNodePtr &node);
188   void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src,
189                        int index);
190   void UpdateTupleOutCache(void);
191   AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input);
192 
193   std::shared_ptr<AnfGraph> anf_graph_{nullptr};
194   std::shared_ptr<DfGraph> df_graph_{nullptr};
195   std::shared_ptr<DfGraph> init_graph_{nullptr};
196   std::shared_ptr<DfGraph> save_ckp_graph_{nullptr};
197   std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr};
198   std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
199   std::unordered_map<AnfNode *, DfGraph> branches_map_;
200   std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
201   std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_;
202   std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_;
203   /* record "tuple_getitem"<->"out_handler" mapping */
204   std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;
205   /* record "make_tuple"<->"out_handler vector" mapping */
206   std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_;
207   std::unordered_map<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> case_input_handle_cache_;
208   std::unordered_map<std::string, AnfNodePtr> params_;
209   std::unordered_map<std::string, OperatorPtr> vars_;
210   std::vector<std::pair<ge::Operator, std::string>> graph_outputs_;
211   std::vector<OperatorPtr> graph_const_inputs_;
212   std::vector<OperatorPtr> init_ops_;
213   std::vector<OperatorPtr> broadcast_ops_;
214   std::vector<AnfNodePtr> inputs_;
215   OperatorPtr dataset_iter_getnext_;
216   Status error_ = SUCCESS;
217   bool training_ = false;
218   bool distribute_ = false;
219   bool use_inputs_ = false;
220 };
221 }  // namespace transform
222 }  // namespace mindspore
223 
224 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
225