• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
19 
20 #define DRAW_GE_GRAPH
21 
22 #include <cstdlib>
23 #include <memory>
24 #include <map>
25 #include <set>
26 #include <unordered_set>
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <stack>
31 #include <fstream>
32 #include <sstream>
33 #include "include/common/utils/config_manager.h"
34 #include "mindspore/core/ops/structure_ops.h"
35 #include "utils/hash_map.h"
36 #include "utils/ms_context.h"
37 #include "ir/anf.h"
38 #include "ir/func_graph.h"
39 #include "ir/tensor.h"
40 #include "transform/graph_ir/df_graph_manager.h"
41 #include "transform/graph_ir/op_adapter.h"
42 #include "graph/operator_reg.h"
43 #include "ge/ge_api.h"
44 
45 namespace mindspore {
46 namespace transform {
47 class BaseOpAdapter;
48 
49 using ParamIndexMap = std::map<std::size_t, std::size_t>;
50 enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch };
51 enum class DfsVisitFlag { kUnVisited, kVisiting, kVisited };
52 enum class RefModeFlag {
53   kRefModeNone,
54   kRefModeVariable,  // Only Variables will be treated as RefData
55   kRefModeAll,       // All Parameter including Variables and Constants will be treated as RefData
56   kRefModeEnv        // depend on REF_MODE, default value is on, ref mode type will be kRefModeAll
57 };
58 constexpr char kGraphFlagHasGetNext[] = "graph_has_getnext";
59 constexpr char kGraphNeedIteration[] = "graph_need_iteration";
60 
61 struct GEInputList {
62   std::vector<AnfNodeWeakPtr> ge_inputs;
63   constexpr static char key[] = "GEInputs";
64 };
65 
66 class GeOpConvertor {
67  public:
68   static std::map<std::string, ValuePtr> GetAttrAndValue(const AnfNodePtr &node, const bool training);
69 
70   static std::string GetOpType(const AnfNodePtr &node, const bool training);
71 
72   static std::shared_ptr<GeTensorDesc> GetTensorDesc(const ShapeVector &dev_shape, const TypeId &dev_type,
73                                                      const std::string &dev_format, const ShapeVector &ori_shape,
74                                                      const std::string &ori_format);
75 
76   static mindspore::HashMap<std::string, std::string> GetNeedAddInput(const AnfNodePtr &node, const bool training);
77 
78   static bool IsDynamicInput(const AnfNodePtr &node, const size_t idx);
79 
80   static std::map<int, std::string> GetAclInputNames(const AnfNodePtr &node);
81 
82   static std::map<int, std::string> GetAclOutputNames(const AnfNodePtr &node);
83 
84   static std::map<int, std::string> GetAclDynamicInputNames(const AnfNodePtr &node);
85 
86   static std::map<int, std::string> GetAclDynamicOutputNames(const AnfNodePtr &node);
87 };
88 
89 DfGraphPtr GenExampleGraph(const std::string &name);
90 
91 using SetDynRefDataFunc = std::function<ShapeVector(const AnfNodePtr &, const ShapeVector &)>;
92 
93 class DfGraphConvertor {
94  public:
95   explicit DfGraphConvertor(const AnfGraphPtr &anf_graph, const std::string &phase_prefix,
96                             RefModeFlag ref_mode_type = RefModeFlag::kRefModeEnv,
97                             const std::vector<std::string> &extra_variables_names = {},
98                             SetDynRefDataFunc dyn_ref_data_func = nullptr, bool offline_convert = false)
anf_graph_(anf_graph)99       : anf_graph_(anf_graph),
100         extra_variables_names_(extra_variables_names),
101         phase_prefix_(phase_prefix),
102         offline_convert_(offline_convert) {
103     MS_EXCEPTION_IF_NULL(anf_graph);
104     if (ref_mode_type == RefModeFlag::kRefModeEnv) {
105       ref_mode_ = IsEnableRefMode();
106       ref_mode_type_ = RefModeFlag::kRefModeAll;
107     } else {
108       ref_mode_ = (ref_mode_type != RefModeFlag::kRefModeNone);
109       ref_mode_type_ = ref_mode_type;
110     }
111     dyn_ref_data_func_ = dyn_ref_data_func;
112     auto context = MsContext::GetInstance();
113     MS_EXCEPTION_IF_NULL(context);
114     bool enable_ge = context->backend_policy() == "ge";
115     bool enable_training = phase_prefix_ == "train";
116     static bool is_training = false;
117     if (enable_ge && enable_training) {
118       is_training = true;
119     }
120     if (is_training) {
121       training_ = true;
122     } else {
123       training_ = anf_graph->has_flag("training");
124     }
125     distribute_ = anf_graph->has_flag("broadcast_flag");
126     if (anf_graph->has_flag("broadcast_flag")) {
127       ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION);
128     } else {
129       ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE);
130     }
131     is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphTypeName;
132     df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString());
133 
134     std::string graph_type = is_kernel_graph_ ? "kernel_graph" : "func_graph";
135     std::string graph_name = anf_graph_->ToString();
136     graph_manager_ = Manage(anf_graph_, true);
137     MS_EXCEPTION_IF_NULL(graph_manager_);
138     MS_LOG(INFO) << "Create DfGraphConvertor with graph: " << graph_name << "(type: " << graph_type << ")"
139                  << ", training: " << training_ << ", dynamic input: " << dynamic_shape_inputs_
140                  << ", distribute: " << distribute_;
141   }
142 
~DfGraphConvertor()143   ~DfGraphConvertor() {}
144 
145   static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt);
146   static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt);
147 
DrawComputeGraph(const std::string & name)148   void DrawComputeGraph(const std::string &name) {
149 #ifndef ENABLE_SECURITY
150     std::ofstream fout(name);
151     if (!fout.is_open()) {
152       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
153       return;
154     }
155     fout << compute_sout_.str();
156     fout.close();
157 #endif
158   }
159 
DrawInitGraph(const std::string & name)160   void DrawInitGraph(const std::string &name) {
161 #ifndef ENABLE_SECURITY
162     std::ofstream fout(name);
163     if (!fout.is_open()) {
164       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
165       return;
166     }
167     fout << init_sout_.str();
168     fout.close();
169 #endif
170   }
171 
DrawSaveCheckpointGraph(const std::string & name)172   void DrawSaveCheckpointGraph(const std::string &name) {
173     std::ofstream fout(name);
174     if (!fout.is_open()) {
175       MS_LOG(ERROR) << "Open file '" << name << "' failed!";
176       return;
177     }
178     fout << checkpoint_sout_.str();
179     fout.close();
180   }
181 
182   DfGraphConvertor &ConvertAllNode();
183   void GenFakeGraph(const std::string &name);
184   DfGraphConvertor &BuildGraph(const std::string &name);
185   DfGraphConvertor &InitParam(const TensorOrderMap &tensors);
186   DfGraphConvertor &GenerateCheckpointGraph();
187   DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors);
188   void InitParamWithData(const TensorOrderMap &tensors);
189   bool NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
190   OutHandler GetNormalOpInput(const AnfNodePtr &node, const AnfNodePtr &pred);
191   void DrawOpInput(const AnfNodePtr &node, const AnfNodePtr &pred, size_t i);
192   void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node);
193   void SetOpAttrToInput(const OpAdapterPtr &adpt, const CNodePtr &node);
194   void SetupBroadcast(const OperatorPtr &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
195                       const DfGraphPtr &broadcast_graph, std::vector<::ge::Operator> broadcast_input);
196   void SetupParamInitSubGraph(const TensorOrderMap &tensors, const std::vector<::ge::Operator> *init_input,
197                               bool is_sink_size_repeat);
198   void SetupParamInitSubGraph();
199   void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it);
200 
201   DfGraphPtr GetComputeGraph();
202   DfGraphPtr GetInitGraph();
GetInitDataNames()203   std::vector<std::string> GetInitDataNames() const { return init_data_names_; }
GetRefDataNames()204   std::vector<std::string> GetRefDataNames() const { return ref_data_names_; }
205   DfGraphPtr GetSaveCheckpointGraph();
206   DfGraphPtr GetBroadcastGraph();
ErrCode()207   int ErrCode() const { return static_cast<int>(error_); }
208 
is_training()209   bool is_training() const { return training_; }
set_training(bool is_training)210   void set_training(bool is_training) { training_ = is_training; }
211 
export_air()212   bool export_air() const { return export_air_; }
set_export_air(bool export_air)213   void set_export_air(bool export_air) { export_air_ = export_air; }
dynamic_shape_inputs()214   bool dynamic_shape_inputs() const { return dynamic_shape_inputs_; }
input_shapes()215   std::vector<ShapeVector> input_shapes() { return input_shapes_; }
216 
217   void SetupInputFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
218 
219  protected:
220   bool InitLoopVar(std::vector<::ge::Operator> *init_input);
221 
222  private:
223   std::ostringstream compute_sout_;
224   std::ostringstream init_sout_;
225   std::ostringstream checkpoint_sout_;
226   std::ostringstream restore_checkpoint_sout_;
227   mindspore::HashMap<AnfNode *, std::string> op_draw_name_;
228   std::map<std::string, std::string> param_format_;
229 
230   OutHandler GetHandler(const AnfNodePtr &node);
231   OperatorPtr Convert(AnfNodePtr node);
232   OperatorPtr ConvertCNode(CNodePtr node);
233   OperatorPtr ConvertParameter(AnfNodePtr node);
234   void SetNodeAbstract(const CNodePtr &node) const;
235   Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
236   OperatorPtr ConvertValueNode(ValueNodePtr node);
237   void SaveParamFormat(CNodePtr node);
238   void GetBranchNodeInput(const CNodePtr node);
239   void ConvertTopK(const CNodePtr &node);
240   void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) const;
241   AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const;
242   void ConvertReshape(const CNodePtr &node);
243   void ConvertHcomFusionId(const CNodePtr &node);
244   void ConvertHcclNode(const CNodePtr &node);
245   void ConvertAllToAllv(const CNodePtr &node);
246   void ConvertUniformReal(const CNodePtr &node);
247   void ConvertUpdateState(const CNodePtr &node);
248   void AddCommAttrForHcclNode(const CNodePtr &node, const OperatorPtr &converted_op) const;
249   void ConvertOCRRecPreHandle(const CNodePtr &node);
250   void ConvertConv2D(const CNodePtr &node);
251   void ConvertDynamicStitch(const CNodePtr &node);
252   void ConvertParallelGroupToHcom(const CNodePtr &node);
253   void ConvertParallelGroupIdToHcom(const CNodePtr &node);
254   std::vector<int64_t> CastToInt(const ValuePtr &value) const;
255   void TransDataType(const FuncGraphPtr &anf_graph) const;
256   void TransInputDataType(const CNodePtr &node, const std::string &node_name) const;
257   void TransAttrDataType(const CNodePtr &node, const std::string &node_name) const;
258   bool CheckCNode(const std::string &name, const CNodePtr node);
259   void SetNodeInput(AnfNodePtr node);
260   void UpdateOpDesc(AnfNodePtr node);
261   void SetSubgraph(const AnfNodePtr &node);
262   void ProcessSubgraph(const AnfNodePtr &node, const AnfNodePtr &branch_node, ParamIndexMap &branch_to_parent_node_map);
263   void BuildSaveCheckpointGraph();
264   void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
265   void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
266   void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
267   void AddGraphConstInput(const OperatorPtr &op);
268   AnfNodePtr ParseLoadInput(const CNodePtr &cnode) const;
269   void SetGraphInputs(std::vector<Operator> *inputs);
270   void SetGraphInputs(std::vector<Operator> *inputs, AnfNodeWeakPtrList *ge_inputs);
271   void TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred);
272   void ProcessInputData(std::vector<Operator> *init_input,
273                         std::unordered_set<std::string> *infer_need_update_parameter_names, const OperatorPtr &param_op,
274                         const string &name, const std::shared_ptr<GeTensorDesc> &desc);
275   AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input);
276 
277   void ConvertWhileNode(const CNodePtr &node);
278   void CacheWhileGraph(const CNodePtr &cnode);
279   void ConvertWhileBody(const AnfNodePtr &node);
280   std::shared_ptr<std::vector<Operator>> GetWhileSubGraphInput();
281   void BuildWhileSubGraph();
282   void ConvertWhileCond(const AnfNodePtr &node);
283   void ConvertWhileAfter(const AnfNodePtr &node);
284   void BuildWhileAfterSubGraph();
285   void BuildCallSubGraphs(const AnfNodePtr &node);
286   void GetCallNodeInputs(const CNodePtr &node);
287   std::vector<Operator> GetWhileBodyOutputs();
IsSubGraph()288   bool IsSubGraph() const { return graph_type_ == GraphType::kCond || graph_type_ == GraphType::kBody; }
IsCondGraph()289   bool IsCondGraph() const { return graph_type_ == GraphType::kCond; }
IsBodyGraph()290   bool IsBodyGraph() const { return graph_type_ == GraphType::kBody; }
IsBranchGraph()291   bool IsBranchGraph() const { return graph_type_ == GraphType::kBranch; }
IsAfterGraph()292   bool IsAfterGraph() const { return graph_type_ == GraphType::kAfter; }
IsNormalGraph()293   bool IsNormalGraph() const { return graph_type_ == GraphType::kNormal; }
294   void SetParamIndexMap(const std::vector<AnfNodePtr> &graphs);
295   void SetWhileOutputHandle(const OperatorPtr &prev_while_op);
296   void GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs);
297 
298   bool IsDataInput(const AnfNodePtr &node, const AnfNodePtr &input, size_t input_index);
299   void SetMakeTupleInput(const OpAdapterPtr &adpt, const CNodePtr &make_tuple_node);
300   void SetMergeInput(const OpAdapterPtr &adpt, const CNodePtr &merge_node);
301   bool IsMergeOrSwitchLayerInput(const CNodePtr &node) const;
302   void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
303                                          const CNodePtr &from_node_input);
304   void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input);
305   void SetGraphOutputs(bool is_main_graph = false);
306   std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input);
307   void FillEmptyInputsWithNoInputOp(std::vector<Operator> *);
308   bool IsDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, int *ge_input_size,
309                                        mindspore::HashMap<int, int> *ge_input_to_ms_input);
310   void SetDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, const CNodePtr &node,
311                                         const std::vector<AnfNodePtr> &inputs, const int &ge_input_size,
312                                         const mindspore::HashMap<int, int> &ge_input_to_ms_input,
313                                         std::vector<int64_t> *dyn_input_sizes);
314 
315   // Identity Optimization
316   void IdentityOptimization();
317   std::string GetGNodeName(const ::ge::GNode &node) const;
318   std::string GetGNodeType(const ::ge::GNode &node) const;
319   bool IsIdentityRedundant(const ::ge::GNode &node) const;
320   void RemoveIdentity(::ge::GNode identity_node);
321   void NoOpOptimization();
322   bool IsNoOpRedundant(const ::ge::GNode &node) const;
323   void RemoveNoOp(::ge::GNode noop);
324   std::shared_ptr<std::vector<DfGraph>> BuildBranchGraphs(const CNodePtr &cnode);
325   void BuildInitDataGraph(const std::string &name);
326   bool IsConstantOp(const OperatorPtr &op) const;
327   void JudgeParamTransType(const bool &node_will_update, bool *as_ref_data, bool *as_constant) const;
328   OperatorPtr SetGraphInputsForNotVar(const AnfNodePtr &it, int64_t *index, std::vector<Operator> *inputs);
329   void GenFakeGraphInRefMode();
330   void AddInputAttrsForESNode(const CNodePtr &node, const AnfNodePtr &input);
331   void RemoveIdentityForES(::ge::GNode node);
332   void ESOptimization();
333   void ReplaceAllParameterToRefData();
334 
335   std::shared_ptr<AnfGraph> anf_graph_{nullptr};
336   FuncGraphManagerPtr graph_manager_{nullptr};
337   RefModeFlag ref_mode_type_ = RefModeFlag::kRefModeNone;
338   bool ref_mode_ = false;
339   std::vector<std::string> extra_variables_names_;
340   std::vector<std::string> ref_data_names_;
341   std::set<std::string> unsupported_ops_names_;
342   SetDynRefDataFunc dyn_ref_data_func_ = nullptr;
343 
344   std::shared_ptr<DfGraph> df_graph_{nullptr};
345   std::shared_ptr<DfGraph> init_graph_{nullptr};
346   std::shared_ptr<DfGraph> save_ckp_graph_{nullptr};
347   std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr};
348   std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
349   mindspore::HashMap<AnfNode *, DfGraph> branches_map_;
350   mindspore::HashMap<AnfNode *, OperatorPtr> op_cache_;
351   /* record "getnext"<->"out_handler" mapping */
352   mindspore::HashMap<AnfNode *, OutHandler> out_handle_cache_;
353   /* record "value tuple"<->"out_handler vector" mapping */
354   mindspore::HashMap<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_;
355   mindspore::HashMap<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> branch_input_handle_cache_;
356   mindspore::HashMap<std::string, AnfNodePtr> params_;
357   mindspore::HashMap<std::string, OperatorPtr> vars_;
358   std::vector<OperatorPtr> ref_datas_;
359   std::vector<std::pair<::ge::Operator, std::string>> graph_outputs_;
360   std::vector<AnfNodePtr> graph_anf_outputs_;
361   std::vector<OperatorPtr> graph_const_inputs_;
362   std::vector<OperatorPtr> init_ops_;
363   std::vector<std::string> init_data_names_;
364   std::vector<OperatorPtr> broadcast_ops_;
365   std::vector<AnfNodePtr> inputs_;
366   ShapeArray input_shapes_;
367   Status error_ = SUCCESS;
368   bool training_ = false;
369   bool export_air_ = false;
370   bool distribute_ = false;
371   bool use_inputs_ = false;
372   bool dynamic_shape_inputs_ = false;
373   bool has_es_node_ = false;
374 
375   AnfNodePtr while_cond_node_ = nullptr;
376   mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_;
377   mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> call_dfgraph_cache_;
378   CNodePtr cur_while_node_ = nullptr;
379   size_t cur_while_node_out_size_ = 0;
380   mindspore::HashMap<size_t, OutHandler> while_const_input_index_;
381   mindspore::HashMap<size_t, OutHandler> prev_while_const_input_index_;
382   mindspore::HashMap<size_t, size_t> prev_cond_to_while_out_index_;
383   mindspore::HashMap<OperatorPtr, std::shared_ptr<tensor::Tensor>> const_op_to_value_;
384   AnfNodePtr prev_while_node_ = nullptr;
385   size_t prev_while_node_out_size_ = 0;
386 
387   mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>> while_graph_cache_;
388   mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> call_input_handle_cache_;
389   mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> while_output_handle_cache_;
390   AnfNodePtr call_node_in_while_body_ = nullptr;
391   GraphType graph_type_ = GraphType::kNormal;
392 
393   ParamIndexMap body_cond_map_;
394   ParamIndexMap after_cond_map_;
395   ParamIndexMap prev_after_cond_map_;
396   mindspore::HashMap<size_t, OperatorPtr> subgraph_input_cache_;
397 
398   std::set<size_t> while_used_input_index_;
399   std::set<size_t> prev_while_used_input_index_;
400 
401   mindspore::HashMap<size_t, OutHandler> bypass_node_prev_handle_cache_;
402   mindspore::HashMap<size_t, OutHandler> bypass_node_handle_cache_;
403   size_t case_call_input_size_ = 0;
404   bool is_kernel_graph_ = false;
405 
406   std::string phase_prefix_;
407   bool offline_convert_ = false;
408   void AddInputInDataSink(std::vector<Operator> *inputs);
409 };
410 }  // namespace transform
411 }  // namespace mindspore
412 
413 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_
414