• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
18 
19 #include <vector>
20 #include <memory>
21 #include <utility>
22 #include <string>
23 #include <queue>
24 #include <map>
25 #include <set>
26 #include <stack>
27 #include <atomic>
28 #include "utils/hash_map.h"
29 #include "utils/hash_set.h"
30 #include "ir/func_graph.h"
31 #include "ir/anf.h"
32 #include "ir/graph_utils.h"
33 #include "include/common/utils/contract.h"
34 #include "include/backend/device_type.h"
35 #include "include/backend/kernel_info.h"
36 #include "include/backend/device_address.h"
37 #include "include/backend/visible.h"
38 
39 namespace mindspore {
40 namespace session {
41 using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
42 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
43 struct KernelWithIndexCmp {
operatorKernelWithIndexCmp44   bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const {
45     if (key1.first != key2.first) {
46       return key1.first < key2.first;
47     }
48     if (key1.second != key2.second) {
49       return key1.second < key2.second;
50     }
51     return false;
52   }
53 };
54 
55 struct SomasInfo {
56   // whole_block_size_ is 0 indicating that somas did not allocate memory for this graph.
57   size_t whole_block_size_{0};
58   // offset -> aligned_size_
59   std::map<size_t, size_t> merged_blocks_map_;
60 
61   // Alloc the base address of graph during execution, which is variable.
62   void *base_address_{nullptr};
63   // Block offset -> address.
64   std::map<size_t, void *> merged_base_addresses_;
65 
66   // Used to keep the graph output address when somas block memory free.
InsertGraphOutputInfoSomasInfo67   void InsertGraphOutputInfo(device::DeviceAddress *graph_output_device_address, size_t graph_output_address_offset,
68                              size_t graph_output_address_size) {
69     // Not insert the repeat size.
70     if (graph_output_address_offsets_set_.count(graph_output_address_offset) > 0) {
71       MS_LOG(INFO) << "The graph:" << graph_id_
72                    << " output somas device is same for offset: " << graph_output_address_offset;
73       return;
74     }
75     (void)graph_output_device_addresses_.emplace_back(graph_output_device_address);
76     (void)graph_output_address_sizes_.emplace_back(graph_output_address_size);
77     (void)graph_output_address_offsets_set_.insert(graph_output_address_offset);
78   }
79   std::vector<device::DeviceAddress *> graph_output_device_addresses_;
80   std::vector<size_t> graph_output_address_sizes_;
81   std::set<size_t> graph_output_address_offsets_set_;
82 
83   // The owner graph id.
84   uint32_t graph_id_{0};
85 };
86 
87 using DeviceType = device::DeviceType;
88 using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>;
89 
90 class BACKEND_EXPORT KernelGraph : public FuncGraph {
91  public:
KernelGraph()92   KernelGraph()
93       : inputs_(std::make_shared<AnfNodePtrList>()),
94         somas_info_(std::make_shared<SomasInfo>()),
95         graph_id_(0),
96         stream_distinction_label_(kInvalidDistincLabel),
97         device_target_(DeviceType::kUnknown),
98         executable_(true),
99         summary_node_exist_(false),
100         need_inline_(false),
101         start_label_(nullptr),
102         end_goto_(nullptr),
103         current_epoch_(0),
104         is_dynamic_shape_(false) {}
105 
KernelGraph(const KernelGraph & graph)106   KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
107     inputs_ = graph.inputs_;
108     somas_info_ = graph.somas_info_;
109     child_graph_result_ = graph.child_graph_result_;
110     execution_order_ = graph.execution_order_;
111     mem_reuse_exec_order_ = graph.mem_reuse_exec_order_;
112     graph_id_ = graph.graph_id_;
113     device_target_ = graph.device_target_;
114     stream_distinction_label_ = graph.stream_distinction_label_;
115     front_backend_anf_map_ = graph.front_backend_anf_map_;
116     backend_front_anf_map_ = graph.backend_front_anf_map_;
117     tensor_to_value_node_map_ = graph.tensor_to_value_node_map_;
118     graph_value_nodes_ = graph.graph_value_nodes_;
119     ref_out_in_map_ = graph.ref_out_in_map_;
120     node_output_edges_ = graph.node_output_edges_;
121     summary_nodes_ = graph.summary_nodes_;
122     updated_parameters_ = graph.updated_parameters_;
123     executable_ = graph.executable_;
124     summary_node_exist_ = graph.summary_node_exist_;
125     need_inline_ = graph.need_inline_;
126     valid_inputs_ = graph.valid_inputs_;
127     child_graph_order_ = graph.child_graph_order_;
128     device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_;
129     device_loop_ctrl_params_ = graph.device_loop_ctrl_params_;
130     parent_graph_ = graph.parent_graph_;
131     start_label_ = graph.start_label_;
132     end_goto_ = graph.end_goto_;
133     internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_;
134     graph_output_to_front_node_map_ = graph.graph_output_to_front_node_map_;
135     front_node_to_graph_output_map_ = graph.front_node_to_graph_output_map_;
136     front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
137     internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
138     internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
139     current_epoch_ = graph.current_epoch_;
140     tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_;
141     input_nodes_ = graph.input_nodes_;
142     pre_graphs_ = graph.pre_graphs_;
143     post_graphs_ = graph.post_graphs_;
144     send_recv_pairs_for_parallel_op_inputs_ = graph.send_recv_pairs_for_parallel_op_inputs_;
145     send_recv_pairs_for_parallel_op_outputs_ = graph.send_recv_pairs_for_parallel_op_outputs_;
146     size_t pre_graph_finished_count = graph.pre_graph_finished_count_;
147     pre_graph_finished_count_ = pre_graph_finished_count;
148     size_t post_graph_finished_count = graph.post_graph_finished_count_;
149     post_graph_finished_count_ = post_graph_finished_count;
150     first_step_ = graph.first_step_;
151     has_optimizer_ = graph.has_optimizer_;
152     is_dynamic_shape_ = graph.is_dynamic_shape_;
153     front_outputs_ = graph.front_outputs_;
154     has_kernel_need_user_data_ = graph.has_kernel_need_user_data_;
155   }
156 
157   ~KernelGraph() override;
158 
159   MS_DECLARE_PARENT(KernelGraph, FuncGraph);
160 
161   const AnfNodePtrList &inputs() const;
MutableInputs()162   AnfNodePtrList *MutableInputs() const { return inputs_.get(); }
SetGraphInputs(const AnfNodePtrList & inputs)163   void SetGraphInputs(const AnfNodePtrList &inputs) { inputs_ = std::make_shared<AnfNodePtrList>(inputs); }
164   void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter);
165   AnfNodePtrList outputs() const;
166   CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) override;
167   CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs = AnfNodeWeakPtrList()) override;
168   // NewCNodeWeak is recommended.
169   CNodePtr NewCNode(AnfNodePtrList &&inputs) override;
170   // NewCNodeWeak is recommended.
171   CNodePtr NewCNode(const AnfNodePtrList &inputs = AnfNodePtrList()) override;
172   CNodePtr NewCNodeWithInfos(const AnfNodePtrList &inputs, const CNodePtr &ori_cnode = nullptr);
173   void CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const;
174   CNodePtr NewCNode(const CNodePtr &cnode);
175   void ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const;
176   ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
177   ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
178   ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr) const;
179   ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);
180   ValueNodePtr NewValueNode(const tensor::TensorPtr &input_tensor);
181   ValueNodePtr NewValueNode(const ValuePtr &input_value);
182   // trans tuple output to maketuple + no_tuple out
183   AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node);
set_execution_order(const std::vector<CNodePtr> & order)184   void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
set_execution_order(std::vector<CNodePtr> && order)185   void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); }
execution_order()186   const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
187   // Set new exec_order for mem_reuse
set_mem_reuse_exec_order(const std::vector<CNodePtr> & order)188   void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; }
mem_reuse_exec_order()189   const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; }
190   void SetExecOrderByDefault();
191   void SetNodeOutputEdges();
graph_id()192   uint32_t graph_id() const { return graph_id_; }
set_graph_id(uint32_t graph_id)193   void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
root_graph_id()194   uint32_t root_graph_id() const { return root_graph_id_; }
set_root_graph_id(uint32_t root_graph_id)195   void set_root_graph_id(uint32_t root_graph_id) { root_graph_id_ = root_graph_id; }
device_target()196   DeviceType device_target() const { return device_target_; }
set_device_target(DeviceType target)197   void set_device_target(DeviceType target) { device_target_ = target; }
198 
199   // and a new front to backend anf relation to maop
200   void FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf);
201   // replace old backend anf with new backend anf
202   void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf);
203   // get backend anf by front anf
204   AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf);
205   // get front anf by backend anf
206   AnfNodePtr GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) const;
backend_front_anf_map()207   const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &backend_front_anf_map() const { return backend_front_anf_map_; }
208   // check backend node whether exist in map
209   bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf);
210   // get value node by tensor
211   ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor);
212   // add value node tensor relation map
213   void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node);
214   // get all value nodes of graph
215   mindspore::HashSet<ValueNodePtr> graph_value_nodes() const;
216   // add value node to graph
217   void AddValueNodeToGraph(const ValueNodePtr &value_node);
218   // remove value node form graph
219   bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
ClearAllValueNode()220   void ClearAllValueNode() { graph_value_nodes_.clear(); }
221   // ref output is in map
222   bool IsInRefOutputMap(const AnfWithOutIndex &pair) const;
223   // Whether the value corresponds to ref output.
224   bool IsRefOutputMapValue(const AnfWithOutIndex &pair) const;
225   // get ref correspond pairs
226   AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const;
227   // Support the cascade ref node and get the first ref output recursive.
228   AnfWithOutIndex GetRefNodeRecursive(const AnfWithOutIndex &out_pair) const;
229   // add ref correspond pairs
230   void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
231   // Replace ref pair
232   void ReplaceRefPair(const AnfWithOutIndex &old_pair, const AnfWithOutIndex &new_pair);
233   // get map
GetRefMap()234   const std::map<AnfWithOutIndex, AnfWithOutIndex> &GetRefMap() const { return ref_out_in_map_; }
235   // update ref map
set_ref_out_in_map(const std::map<AnfWithOutIndex,AnfWithOutIndex> & ref_out_in_map)236   void set_ref_out_in_map(const std::map<AnfWithOutIndex, AnfWithOutIndex> &ref_out_in_map) {
237     ref_out_in_map_ = ref_out_in_map;
238   }
239   // check whether graph is executable
executable()240   bool executable() const { return executable_; }
241   // set executable of graph
set_executable(bool executable)242   void set_executable(bool executable) { executable_ = executable; }
243 #ifndef ENABLE_SECURITY
244   // set summary_node of graph
set_summary_node_exist(bool summary_node_exist)245   void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
246 #endif
247   // check whether exist summary node in graph
summary_node_exist()248   bool summary_node_exist() const { return summary_node_exist_; }
249   // set need inline
set_need_inline(bool need_inline)250   void set_need_inline(bool need_inline) { need_inline_ = need_inline; }
251   // check whether need inline
need_inline()252   bool need_inline() const { return need_inline_; }
253   // set invalid inputs for control sink
MutableValidInputs()254   std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
valid_inputs()255   std::vector<bool> valid_inputs() const { return valid_inputs_; }
256   // replace node in graph
257   void ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node);
258   // set stream label of graph
set_stream_distinction_label(uint32_t stream_label)259   void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
260   // get stream label of graph
stream_distinction_label()261   uint32_t stream_distinction_label() const { return stream_distinction_label_; }
262   // refresh execute kernel stream label
263   void UpdateExecuteKernelStreamLabel();
264   // calculate the leaf graph order of root graph
265   std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
266   // the child graph of current graph
child_graph_order()267   const std::vector<std::weak_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; }
set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> & order)268   void set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> &order) { child_graph_order_ = order; }
269   // checkout whether current graph is leaf graph
270   bool IsLeafGraph() const;
271 
set_device_loop_ctrl_tensors(const std::map<std::string,tensor::TensorPtr> & device_loop_ctrl_tensors)272   void set_device_loop_ctrl_tensors(const std::map<std::string, tensor::TensorPtr> &device_loop_ctrl_tensors) {
273     device_loop_ctrl_tensors_ = device_loop_ctrl_tensors;
274   }
device_loop_control_tensors()275   const std::map<std::string, tensor::TensorPtr> &device_loop_control_tensors() const {
276     return device_loop_ctrl_tensors_;
277   }
278 
set_device_loop_ctrl_params(const std::map<std::string,mindspore::ParameterPtr> & device_loop_ctrl_params)279   void set_device_loop_ctrl_params(const std::map<std::string, mindspore::ParameterPtr> &device_loop_ctrl_params) {
280     device_loop_ctrl_params_ = device_loop_ctrl_params;
281   }
device_loop_control_params()282   const std::map<std::string, mindspore::ParameterPtr> &device_loop_control_params() const {
283     return device_loop_ctrl_params_;
284   }
285 
286   // get parent kernel graph
parent_graph()287   std::weak_ptr<KernelGraph> parent_graph() const { return parent_graph_; }
288   // set parent kernel graph
set_parent_graph(const std::weak_ptr<KernelGraph> & parent_graph)289   void set_parent_graph(const std::weak_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; }
290   // find anf node in graph
291   std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
292   std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const;
293   // used to dump ir
294   std::string ToString() const override;
295 
296   bool FrontendNodeExistInFrontBackendMap(const AnfNodePtr &frontend_anf);
297 
set_start_label(const CNodePtr & start_label)298   void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
get_start_label()299   CNodePtr get_start_label() { return start_label_; }
set_end_goto(const CNodePtr & end_goto)300   void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
get_end_goto()301   CNodePtr get_end_goto() { return end_goto_; }
302   void PrintGraphExecuteOrder() const;
summary_nodes()303   const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
set_summary_nodes(const std::map<std::string,std::pair<AnfNodePtr,int>> & nodes)304   void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
305   void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx, bool unique_target);
306   void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
307                              size_t dst_output_idx);
308   void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
309   AnfWithOutIndex GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
310   bool IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const;
311   bool IsInternalOutput(const AnfNodePtr &node) const;
312   bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const;
313   void AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor);
314   tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx);
315   AnfWithOutIndex GetGraphOutputByFrontNode(const AnfWithOutIndex &front_node) const;
316 
317   // Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
318   void CacheInternalParameterToFrontNode(const AnfNodePtr &parameter, const AnfWithOutIndex &front_node_with_index);
319   void UpdateInternalParameter();
320   // This function gets the real node that skip the monad control node.
321   AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
322   // This function gets the origin node used to connect monad controls between subgraphs.
323   AnfWithOutIndex GetOriginFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
324 
325   // Get the funcgraph to which the kernel graph belongs.
326   FuncGraphPtr GetFuncGraph();
327   // Cache the backend graph output nodes and corresponding to front nodes with output index into
328   // graph_output_to_front_node_map_.
329   void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList &backend_outputs, const AnfNodePtrList &front_outputs);
330   AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const;
331 
332   void SetKernelObjectTypesForUnrealNodes() const;
333 
current_epoch()334   uint32_t current_epoch() const { return current_epoch_; }
set_current_epoch(uint32_t epoch)335   void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
336   void UpdateChildGraphOrder();
child_graph_result()337   const AnfNodePtrList &child_graph_result() const { return child_graph_result_; }
AddChildGraphResult(const AnfNodePtr & parameter)338   void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
339   bool IsChildGraphResult(const AnfNodePtr &node);
set_child_graph_result(const AnfNodePtrList & child_graph_result)340   void set_child_graph_result(const AnfNodePtrList &child_graph_result) { child_graph_result_ = child_graph_result; }
341 
InsertTupleParameterToMakeTupleMap(const AnfNodePtr & param,const AnfNodePtr & make_tuple)342   void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) {
343     if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
344       return;
345     }
346     tuple_parameter_to_make_tuple_map_[param] = make_tuple;
347   }
FindTupleParameterToMakeTupleMap(const AnfNodePtr & param)348   AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr &param) const {
349     if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
350       return tuple_parameter_to_make_tuple_map_.at(param);
351     } else {
352       return nullptr;
353     }
354   }
355   void RemoveNodeFromGraph(const AnfNodePtr &node);
356   void EnableRuntimeCache() const;
357   void DisableRuntimeCache() const;
358   void UpdateGraphDynamicAttr();
SetGraphDynamicAttr(bool is_dynamic_shape)359   void SetGraphDynamicAttr(bool is_dynamic_shape) { is_dynamic_shape_ = is_dynamic_shape; }
is_dynamic_shape()360   bool is_dynamic_shape() const { return is_dynamic_shape_; }
361   void UpdateGraphAquireGilAttr();
362   void SetOptimizerFlag();
363   void SetInputNodes();
input_nodes()364   const AnfNodePtrList &input_nodes() const { return input_nodes_; }
SetInputTensors(const std::vector<tensor::TensorPtr> & input_tensors)365   void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; }
input_tensors()366   const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }
367 
368   void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor);
369 
GetNodeOutputTensor(const session::KernelWithIndex & output_index)370   tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const {
371     auto iter = output_node_to_tensor_.find(output_index);
372     if (iter != output_node_to_tensor_.end()) {
373       return utils::cast<tensor::TensorPtr>(iter->second);
374     }
375     auto nop_node_output_iter = nop_node_output_map_.find(output_index);
376     if (nop_node_output_iter != nop_node_output_map_.end()) {
377       iter = output_node_to_tensor_.find(nop_node_output_iter->second);
378       if (iter != output_node_to_tensor_.end()) {
379         return utils::cast<tensor::TensorPtr>(iter->second);
380       }
381     }
382     return nullptr;
383   }
384 
has_optimizer()385   bool has_optimizer() const { return has_optimizer_; }
IsUpdatedParameter(const ParameterPtr & param)386   bool IsUpdatedParameter(const ParameterPtr &param) const {
387     return updated_parameters_.find(param) != updated_parameters_.end();
388   }
389   // handle graph dependency
AddPreGraph(const std::shared_ptr<session::KernelGraph> & graph)390   void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
391     if (graph != nullptr) {
392       pre_graphs_[graph->graph_id()] = graph;
393     }
394   }
395 
get_pre_graphs()396   const mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &get_pre_graphs() const {
397     return pre_graphs_;
398   }
AddPostGraph(const std::shared_ptr<session::KernelGraph> & graph)399   void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) {
400     if (graph != nullptr) {
401       post_graphs_[graph->graph_id()] = graph;
402     }
403   }
GetPostGraphs()404   const mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &GetPostGraphs() const {
405     return post_graphs_;
406   }
407 
IsPreGraphFinished()408   bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; }
IsPostGraphFinished()409   bool IsPostGraphFinished() const {
410     if (first_step_) {
411       return true;
412     }
413     return post_graphs_.size() == post_graph_finished_count_;
414   }
415 
HasPostGraph()416   bool HasPostGraph() const { return !post_graphs_.empty(); }
417 
IncPreGraphFinishedCount()418   void IncPreGraphFinishedCount() { ++pre_graph_finished_count_; }
IncPostGraphFinishedCount()419   void IncPostGraphFinishedCount() { ++post_graph_finished_count_; }
ResetGraphRunningStatus()420   void ResetGraphRunningStatus() {
421     first_step_ = false;
422     post_graph_finished_count_ = 0;
423     pre_graph_finished_count_ = 0;
424   }
OnRunGraphFinished()425   void OnRunGraphFinished() const {
426     for (const auto &post_graph : post_graphs_) {
427       auto post_graph_ptr = post_graph.second.lock();
428       if (post_graph_ptr != nullptr) {
429         post_graph_ptr->IncPreGraphFinishedCount();
430       }
431     }
432     for (const auto &pre_graph : pre_graphs_) {
433       auto pre_graph_ptr = pre_graph.second.lock();
434       if (pre_graph_ptr != nullptr) {
435         pre_graph_ptr->IncPostGraphFinishedCount();
436       }
437     }
438   }
439   // end of handle graph dependency
440 
441   // The interface of parallel op send/recv pairs map.
InsertSendRecvPairForParallelOpInputs(const CNodePtr & parallel_op,const std::pair<CNodePtr,CNodePtr> & send_recv_pair)442   void InsertSendRecvPairForParallelOpInputs(const CNodePtr &parallel_op,
443                                              const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
444     auto iter = send_recv_pairs_for_parallel_op_inputs_.find(parallel_op);
445     if (iter == send_recv_pairs_for_parallel_op_inputs_.end()) {
446       send_recv_pairs_for_parallel_op_inputs_[parallel_op] = {send_recv_pair};
447     } else {
448       iter->second.emplace_back(send_recv_pair);
449     }
450   }
451 
InsertSendRecvPairForParallelOpOutputs(const CNodePtr & parallel_op,const std::pair<CNodePtr,CNodePtr> & send_recv_pair)452   void InsertSendRecvPairForParallelOpOutputs(const CNodePtr &parallel_op,
453                                               const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
454     auto iter = send_recv_pairs_for_parallel_op_outputs_.find(parallel_op);
455     if (iter == send_recv_pairs_for_parallel_op_outputs_.end()) {
456       send_recv_pairs_for_parallel_op_outputs_[parallel_op] = {send_recv_pair};
457     } else {
458       iter->second.emplace_back(send_recv_pair);
459     }
460   }
461 
462   const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
send_recv_pairs_for_parallel_op_inputs()463     &send_recv_pairs_for_parallel_op_inputs() const {
464     return send_recv_pairs_for_parallel_op_inputs_;
465   }
466   const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>>
send_recv_pairs_for_parallel_op_outputs()467     &send_recv_pairs_for_parallel_op_outputs() const {
468     return send_recv_pairs_for_parallel_op_outputs_;
469   }
470 
label_num()471   uint32_t label_num() const { return label_num_; }
set_label_num(uint32_t num)472   void set_label_num(uint32_t num) { label_num_ = num; }
473   // The graphs has recursion.
recursive_call()474   bool recursive_call() const { return has_recursive_call_; }
475   // The graphs has subgraph multi-call.
subgraph_multi_call()476   bool subgraph_multi_call() const { return has_subgraph_multicall_; }
477   // set flag to indicate whether has recursion.
set_recursive_call(bool flag)478   void set_recursive_call(bool flag) { has_recursive_call_ = flag; }
479   // set flag to indicate whether has multi-call.
set_subgraph_multi_call(bool flag)480   void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; }
481 
graph_output_map()482   const std::map<AnfWithOutIndex, AnfWithOutIndex> &graph_output_map() const { return graph_output_to_front_node_map_; }
front_node_to_graph_output_map()483   const std::map<AnfWithOutIndex, AnfWithOutIndex> &front_node_to_graph_output_map() const {
484     return front_node_to_graph_output_map_;
485   }
486 
487   // The interface to set/get the graph GIL flag.
set_is_need_gil(bool flag)488   void set_is_need_gil(bool flag) { is_need_gil_ = flag; }
is_need_gil()489   bool is_need_gil() const { return is_need_gil_; }
490 
491   bool IsDatasetGraph() const;
492 
set_is_from_single_op(bool is_from_single_op)493   void set_is_from_single_op(bool is_from_single_op) { is_from_single_op_ = is_from_single_op; }
is_from_single_op()494   bool is_from_single_op() const { return is_from_single_op_; }
set_is_any_type_input(bool is_any_type_input)495   void set_is_any_type_input(bool is_any_type_input) { is_any_type_input_ = is_any_type_input; }
is_any_type_input()496   bool is_any_type_input() const { return is_any_type_input_; }
set_run_mode(device::RunMode run_mode)497   void set_run_mode(device::RunMode run_mode) { run_mode_ = run_mode; }
RunMode()498   device::RunMode RunMode() const { return run_mode_; }
is_graph_run_mode()499   bool is_graph_run_mode() const { return run_mode_ == device::RunMode::kGraphMode; }
is_loop_count_sink()500   bool is_loop_count_sink() const { return is_loop_count_sink_; }
set_memory_managed_by_ge(bool memory_managed_by_ge)501   void set_memory_managed_by_ge(bool memory_managed_by_ge) {
502     if (IsEnableRefMode()) {
503       memory_managed_by_ge_ = memory_managed_by_ge;
504     }
505   }
memory_managed_by_ge()506   bool memory_managed_by_ge() const { return memory_managed_by_ge_; }
set_is_loop_count_sink(bool is_loop_count_sink)507   void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; }
front_backend_anf_map()508   const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() const { return front_backend_anf_map_; }
509 
GetElementInTupleBackendFrontIndexMap(const AnfNodePtr & back_node)510   AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) const {
511     auto iter = tuple_backend_front_anf_index_map_.find(back_node);
512     if (iter == tuple_backend_front_anf_index_map_.end()) {
513       return AnfWithOutIndex{nullptr, 0};
514     }
515     return iter->second;
516   }
InternalParameterToFrontNodeMap()517   const HashMap<AnfNodePtr, AnfWithOutIndex> &InternalParameterToFrontNodeMap() const {
518     return internal_parameter_to_front_node_map_;
519   }
SetInternalParameterToFrontNodeMap(const HashMap<AnfNodePtr,AnfWithOutIndex> & ipf_map)520   void SetInternalParameterToFrontNodeMap(const HashMap<AnfNodePtr, AnfWithOutIndex> &ipf_map) {
521     internal_parameter_to_front_node_map_ = ipf_map;
522   }
523 
front_outputs()524   const AnfNodePtrList &front_outputs() const { return front_outputs_; }
set_front_outputs(const AnfNodePtrList & outputs)525   void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; }
IsCommSubGraph(uint32_t id)526   bool IsCommSubGraph(uint32_t id) const { return comm_sub_graph_ids_.find(id) != comm_sub_graph_ids_.end(); }
RecordNewCommSubGraphId(uint32_t id)527   void RecordNewCommSubGraphId(uint32_t id) { comm_sub_graph_ids_.insert(id); }
CommSubGraphIds()528   const std::set<uint32_t> &CommSubGraphIds() const { return comm_sub_graph_ids_; }
529 
530   // somas total memory size
MutableSomasInfo()531   SomasInfo *MutableSomasInfo() const { return somas_info_.get(); }
somas_whole_block_size()532   size_t somas_whole_block_size() const { return somas_info_->whole_block_size_; }
somas_merged_blocks_map()533   const std::map<size_t, size_t> &somas_merged_blocks_map() const { return somas_info_->merged_blocks_map_; }
534 
set_graph_info(const std::string & graph_info)535   void set_graph_info(const std::string &graph_info) { graph_info_ = graph_info; }
536   // Infer cnode abstract by parameter.
537   void InferType();
538   void PostNewCNode(const CNodePtr &cnode) const;
539   void SetKernelInfoForNode(const AnfNodePtr &node) const;
AddInlineSubgraphKernel(const AnfNodePtr & node,const std::string & graph_name)540   void AddInlineSubgraphKernel(const AnfNodePtr &node, const std::string &graph_name) {
541     inline_sub_graph_kernels_[node] = graph_name;
542   }
inline_sub_graph_kernels()543   const mindspore::HashMap<AnfNodePtr, std::string> &inline_sub_graph_kernels() const {
544     return inline_sub_graph_kernels_;
545   }
condition_gather_to_switch()546   mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch() const { return condition_gather_to_switch_; }
AddConditionGatherSwitchPair(const AnfNodePtr & condition_gather,const AnfNodePtr & condition_switch)547   void AddConditionGatherSwitchPair(const AnfNodePtr &condition_gather, const AnfNodePtr &condition_switch) {
548     condition_gather_to_switch_[condition_gather] = condition_switch;
549   }
RemoveConditionGatherSwitchPair(const AnfNodePtr & condition_gather)550   void RemoveConditionGatherSwitchPair(const AnfNodePtr &condition_gather) {
551     condition_gather_to_switch_.erase(condition_gather);
552   }
553 
set_is_from_pynative(const bool & from_pynative)554   void set_is_from_pynative(const bool &from_pynative) { from_pynative_ = from_pynative; }
is_from_pynative()555   bool is_from_pynative() const { return from_pynative_; }
556 
enable_multi_stream()557   bool enable_multi_stream() const { return enable_multi_stream_; }
set_enable_multi_stream(bool enable_multi_stream)558   void set_enable_multi_stream(bool enable_multi_stream) { enable_multi_stream_ = enable_multi_stream; }
559 
has_kernel_need_user_data()560   bool has_kernel_need_user_data() const { return has_kernel_need_user_data_; }
set_has_kernel_need_user_data(bool has_kernel_need_user_data)561   void set_has_kernel_need_user_data(bool has_kernel_need_user_data) {
562     has_kernel_need_user_data_ = has_kernel_need_user_data;
563   }
564   void CacheRootWeight(const std::vector<AnfNodePtr> &weights);
GetRootWeights()565   const std::vector<AnfNodePtr> &GetRootWeights() const { return root_weights_; }
566 
567  private:
568   AnfNodePtr MakeValueNode(const AnfNodePtr &node) const;
569 
570   AnfNodePtr TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value);
571   AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
572   AnfNodePtr TransCNodeTuple(const CNodePtr &node);
573   AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
574   std::vector<CNodePtr> SortStartLabelAndEndGoto();
575 
576   // members
577   std::shared_ptr<AnfNodePtrList> inputs_;
578   std::shared_ptr<SomasInfo> somas_info_;
579   AnfNodePtrList child_graph_result_;
580   std::vector<CNodePtr> execution_order_;
581   std::vector<CNodePtr> mem_reuse_exec_order_;
582   uint32_t graph_id_;
583   uint32_t stream_distinction_label_;
584   DeviceType device_target_;
585   uint32_t root_graph_id_{0};
586 
587   // record map between front anf and backend anf,use two map implement bidirectional map
588   mindspore::HashMap<AnfNodePtr, AnfNodePtr> front_backend_anf_map_;
589   mindspore::HashMap<AnfNodePtr, AnfNodePtr> backend_front_anf_map_;
590   mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> tuple_backend_front_anf_index_map_;
591   // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record
592   mindspore::HashMap<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_;
593   // include all value nodes, this second size_t represents the number of times value_node is used in the graph.
594   mindspore::HashMap<ValueNodePtr, size_t> graph_value_nodes_;
595   // record map between ref final output anf with index and ref origin input with index
596   std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
597   mindspore::HashMap<AnfNodePtr, AnfNodePtrList> node_output_edges_;
598   std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
599   // parameters that will be updated when graph is executed
600   mindspore::HashSet<ParameterPtr> updated_parameters_;
601   // Kernel in inline subgraph for switch node.
602   mindspore::HashMap<AnfNodePtr, std::string> inline_sub_graph_kernels_;
603   // Record the relationship between condition gather and condition switch.
604   mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch_;
605 
606   // graph needn't execute
607   bool executable_{false};
608   // exist summary node in graph
609   bool summary_node_exist_{false};
610   // valid inputs
611   std::vector<bool> valid_inputs_;
612   // need inline
613   bool need_inline_;
614 
615   // child graph execute order in parent graph
616   std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
617 
618   // device loop control frontend tensors
619   std::map<std::string, tensor::TensorPtr> device_loop_ctrl_tensors_;
620   // device loop control backend nodes
621   std::map<std::string, mindspore::ParameterPtr> device_loop_ctrl_params_;
622 
623   // parameter graph
624   std::weak_ptr<KernelGraph> parent_graph_;
625 
626   CNodePtr start_label_;
627   CNodePtr end_goto_;
628 
629   AnfNodePtrList front_outputs_;
630   // Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
631   // related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
632   // of unordered map is front node corresponding to the output of previous kernel graph.
633   mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_;
634   // The first of map is the backend graph output of this kernel graph, the second of map is front node corresponding to
635   // the backend node with index.
636   std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_to_front_node_map_;
637   std::map<AnfWithOutIndex, AnfWithOutIndex> front_node_to_graph_output_map_;
638 
639   mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> front_to_internal_outputs_map_;
640   mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, std::pair<AnfNodePtr, bool>>>
641     internal_outputs_to_front_map_;
642   mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, tensor::TensorPtr>> internal_outputs_tensor_map_;
643   uint32_t current_epoch_;
644   mindspore::HashMap<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
645   AnfNodePtrList input_nodes_;
646   std::vector<tensor::TensorPtr> input_tensors_;
647   KernelMapTensor output_node_to_tensor_;
648   std::map<session::KernelWithIndex, session::KernelWithIndex, session::KernelWithIndexCmp> nop_node_output_map_;
649   mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_;
650   mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_;
651 
652   // key:parallel op ptr, value:vector of <send op receive op > pairs
653   mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_;
654   // key:parallel op ptr, value:vector of <send op receive op > pairs
655   mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_;
656 
657   std::atomic<size_t> pre_graph_finished_count_{0};
658   std::atomic<size_t> post_graph_finished_count_{0};
659   bool first_step_{true};
660   bool has_optimizer_{false};
661   bool is_dynamic_shape_{false};
662 
663   // Indicate the graphs has recursion or multi-call or not as the root graph.
664   bool has_recursive_call_{false};
665   bool has_subgraph_multicall_{false};
666 
667   // Number of labels. This is also the 'batch_num' for DavinciModel,
668   // It should be 1 if no labels used for control flow.
669   uint32_t label_num_ = 1;
670 
671   // Indicate whether the kernels in the graphs acquire Python GIL.
672   bool is_need_gil_{false};
673 
674   // Memory is managed by GE
675   bool memory_managed_by_ge_{false};
676 
677   // Indicate whether the kernel graph is constructed from single op in function graph
678   bool is_from_single_op_{false};
679   // Indicate whether the kernel graph has an any type input.
680   bool is_any_type_input_{false};
681   // Indicate whether the kernel graph sink will run on graph executor or kernel executor
682   device::RunMode run_mode_{device::RunMode::kUnknown};
683 
684   // Indicate whether the kernel graph loop sink to the device executing.
685   bool is_loop_count_sink_{false};
686   // save the communication sub-graph id for comm op reuse
687   std::set<uint32_t> comm_sub_graph_ids_{};
688   // graph info for single op
689   std::string graph_info_;
690   bool from_pynative_{false};
691 
692   bool enable_multi_stream_{false};
693   // Whether this graph contains kernel which need user data.
694   bool has_kernel_need_user_data_{false};
695   std::vector<AnfNodePtr> root_weights_;
696 };
697 }  // namespace session
698 using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
699 }  // namespace mindspore
700 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H
701