1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H 18 19 #include <vector> 20 #include <memory> 21 #include <utility> 22 #include <string> 23 #include <queue> 24 #include <map> 25 #include <set> 26 #include <stack> 27 #include <atomic> 28 #include "utils/hash_map.h" 29 #include "utils/hash_set.h" 30 #include "ir/func_graph.h" 31 #include "ir/anf.h" 32 #include "ir/graph_utils.h" 33 #include "include/common/utils/contract.h" 34 #include "include/backend/device_type.h" 35 #include "include/backend/kernel_info.h" 36 #include "include/backend/device_address.h" 37 #include "include/backend/visible.h" 38 39 namespace mindspore { 40 namespace session { 41 using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; 42 using KernelWithIndex = std::pair<AnfNodePtr, size_t>; 43 struct KernelWithIndexCmp { operatorKernelWithIndexCmp44 bool operator()(const KernelWithIndex &key1, const KernelWithIndex &key2) const { 45 if (key1.first != key2.first) { 46 return key1.first < key2.first; 47 } 48 if (key1.second != key2.second) { 49 return key1.second < key2.second; 50 } 51 return false; 52 } 53 }; 54 55 struct SomasInfo { 56 // whole_block_size_ is 0 indicating that somas did not allocate memory for this graph. 57 size_t whole_block_size_{0}; 58 // offset -> aligned_size_ 59 std::map<size_t, size_t> merged_blocks_map_; 60 61 // Alloc the base address of graph during execution, which is variable. 62 void *base_address_{nullptr}; 63 // Block offset -> address. 64 std::map<size_t, void *> merged_base_addresses_; 65 66 // Used to keep the graph output address when somas block memory free. InsertGraphOutputInfoSomasInfo67 void InsertGraphOutputInfo(device::DeviceAddress *graph_output_device_address, size_t graph_output_address_offset, 68 size_t graph_output_address_size) { 69 // Not insert the repeat size. 70 if (graph_output_address_offsets_set_.count(graph_output_address_offset) > 0) { 71 MS_LOG(INFO) << "The graph:" << graph_id_ 72 << " output somas device is same for offset: " << graph_output_address_offset; 73 return; 74 } 75 (void)graph_output_device_addresses_.emplace_back(graph_output_device_address); 76 (void)graph_output_address_sizes_.emplace_back(graph_output_address_size); 77 (void)graph_output_address_offsets_set_.insert(graph_output_address_offset); 78 } 79 std::vector<device::DeviceAddress *> graph_output_device_addresses_; 80 std::vector<size_t> graph_output_address_sizes_; 81 std::set<size_t> graph_output_address_offsets_set_; 82 83 // The owner graph id. 84 uint32_t graph_id_{0}; 85 }; 86 87 using DeviceType = device::DeviceType; 88 using KernelMapTensor = std::map<session::KernelWithIndex, BaseRef, session::KernelWithIndexCmp>; 89 90 class BACKEND_EXPORT KernelGraph : public FuncGraph { 91 public: KernelGraph()92 KernelGraph() 93 : inputs_(std::make_shared<AnfNodePtrList>()), 94 somas_info_(std::make_shared<SomasInfo>()), 95 graph_id_(0), 96 stream_distinction_label_(kInvalidDistincLabel), 97 device_target_(DeviceType::kUnknown), 98 executable_(true), 99 summary_node_exist_(false), 100 need_inline_(false), 101 start_label_(nullptr), 102 end_goto_(nullptr), 103 current_epoch_(0), 104 is_dynamic_shape_(false) {} 105 KernelGraph(const KernelGraph & graph)106 KernelGraph(const KernelGraph &graph) : FuncGraph(graph) { 107 inputs_ = graph.inputs_; 108 somas_info_ = graph.somas_info_; 109 child_graph_result_ = graph.child_graph_result_; 110 execution_order_ = graph.execution_order_; 111 mem_reuse_exec_order_ = graph.mem_reuse_exec_order_; 112 graph_id_ = graph.graph_id_; 113 device_target_ = graph.device_target_; 114 stream_distinction_label_ = graph.stream_distinction_label_; 115 front_backend_anf_map_ = graph.front_backend_anf_map_; 116 backend_front_anf_map_ = graph.backend_front_anf_map_; 117 tensor_to_value_node_map_ = graph.tensor_to_value_node_map_; 118 graph_value_nodes_ = graph.graph_value_nodes_; 119 ref_out_in_map_ = graph.ref_out_in_map_; 120 node_output_edges_ = graph.node_output_edges_; 121 summary_nodes_ = graph.summary_nodes_; 122 updated_parameters_ = graph.updated_parameters_; 123 executable_ = graph.executable_; 124 summary_node_exist_ = graph.summary_node_exist_; 125 need_inline_ = graph.need_inline_; 126 valid_inputs_ = graph.valid_inputs_; 127 child_graph_order_ = graph.child_graph_order_; 128 device_loop_ctrl_tensors_ = graph.device_loop_ctrl_tensors_; 129 device_loop_ctrl_params_ = graph.device_loop_ctrl_params_; 130 parent_graph_ = graph.parent_graph_; 131 start_label_ = graph.start_label_; 132 end_goto_ = graph.end_goto_; 133 internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_; 134 graph_output_to_front_node_map_ = graph.graph_output_to_front_node_map_; 135 front_node_to_graph_output_map_ = graph.front_node_to_graph_output_map_; 136 front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_; 137 internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_; 138 internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_; 139 current_epoch_ = graph.current_epoch_; 140 tuple_parameter_to_make_tuple_map_ = graph.tuple_parameter_to_make_tuple_map_; 141 input_nodes_ = graph.input_nodes_; 142 pre_graphs_ = graph.pre_graphs_; 143 post_graphs_ = graph.post_graphs_; 144 send_recv_pairs_for_parallel_op_inputs_ = graph.send_recv_pairs_for_parallel_op_inputs_; 145 send_recv_pairs_for_parallel_op_outputs_ = graph.send_recv_pairs_for_parallel_op_outputs_; 146 size_t pre_graph_finished_count = graph.pre_graph_finished_count_; 147 pre_graph_finished_count_ = pre_graph_finished_count; 148 size_t post_graph_finished_count = graph.post_graph_finished_count_; 149 post_graph_finished_count_ = post_graph_finished_count; 150 first_step_ = graph.first_step_; 151 has_optimizer_ = graph.has_optimizer_; 152 is_dynamic_shape_ = graph.is_dynamic_shape_; 153 front_outputs_ = graph.front_outputs_; 154 has_kernel_need_user_data_ = graph.has_kernel_need_user_data_; 155 } 156 157 ~KernelGraph() override; 158 159 MS_DECLARE_PARENT(KernelGraph, FuncGraph); 160 161 const AnfNodePtrList &inputs() const; MutableInputs()162 AnfNodePtrList *MutableInputs() const { return inputs_.get(); } SetGraphInputs(const AnfNodePtrList & inputs)163 void SetGraphInputs(const AnfNodePtrList &inputs) { inputs_ = std::make_shared<AnfNodePtrList>(inputs); } 164 void ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter); 165 AnfNodePtrList outputs() const; 166 CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) override; 167 CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs = AnfNodeWeakPtrList()) override; 168 // NewCNodeWeak is recommended. 169 CNodePtr NewCNode(AnfNodePtrList &&inputs) override; 170 // NewCNodeWeak is recommended. 171 CNodePtr NewCNode(const AnfNodePtrList &inputs = AnfNodePtrList()) override; 172 CNodePtr NewCNodeWithInfos(const AnfNodePtrList &inputs, const CNodePtr &ori_cnode = nullptr); 173 void CreateKernelInfoFromNewParameter(const CNodePtr &cnode) const; 174 CNodePtr NewCNode(const CNodePtr &cnode); 175 void ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const; 176 ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); 177 ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); 178 ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr) const; 179 ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); 180 ValueNodePtr NewValueNode(const tensor::TensorPtr &input_tensor); 181 ValueNodePtr NewValueNode(const ValuePtr &input_value); 182 // trans tuple output to maketuple + no_tuple out 183 AnfNodePtr TransTupleToMakeTuple(const AnfNodePtr &node); set_execution_order(const std::vector<CNodePtr> & order)184 void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; } set_execution_order(std::vector<CNodePtr> && order)185 void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); } execution_order()186 const std::vector<CNodePtr> &execution_order() const { return execution_order_; } 187 // Set new exec_order for mem_reuse set_mem_reuse_exec_order(const std::vector<CNodePtr> & order)188 void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; } mem_reuse_exec_order()189 const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; } 190 void SetExecOrderByDefault(); 191 void SetNodeOutputEdges(); graph_id()192 uint32_t graph_id() const { return graph_id_; } set_graph_id(uint32_t graph_id)193 void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } root_graph_id()194 uint32_t root_graph_id() const { return root_graph_id_; } set_root_graph_id(uint32_t root_graph_id)195 void set_root_graph_id(uint32_t root_graph_id) { root_graph_id_ = root_graph_id; } device_target()196 DeviceType device_target() const { return device_target_; } set_device_target(DeviceType target)197 void set_device_target(DeviceType target) { device_target_ = target; } 198 199 // and a new front to backend anf relation to maop 200 void FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf); 201 // replace old backend anf with new backend anf 202 void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); 203 // get backend anf by front anf 204 AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); 205 // get front anf by backend anf 206 AnfNodePtr GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) const; backend_front_anf_map()207 const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &backend_front_anf_map() const { return backend_front_anf_map_; } 208 // check backend node whether exist in map 209 bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); 210 // get value node by tensor 211 ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor); 212 // add value node tensor relation map 213 void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node); 214 // get all value nodes of graph 215 mindspore::HashSet<ValueNodePtr> graph_value_nodes() const; 216 // add value node to graph 217 void AddValueNodeToGraph(const ValueNodePtr &value_node); 218 // remove value node form graph 219 bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); ClearAllValueNode()220 void ClearAllValueNode() { graph_value_nodes_.clear(); } 221 // ref output is in map 222 bool IsInRefOutputMap(const AnfWithOutIndex &pair) const; 223 // Whether the value corresponds to ref output. 224 bool IsRefOutputMapValue(const AnfWithOutIndex &pair) const; 225 // get ref correspond pairs 226 AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const; 227 // Support the cascade ref node and get the first ref output recursive. 228 AnfWithOutIndex GetRefNodeRecursive(const AnfWithOutIndex &out_pair) const; 229 // add ref correspond pairs 230 void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); 231 // Replace ref pair 232 void ReplaceRefPair(const AnfWithOutIndex &old_pair, const AnfWithOutIndex &new_pair); 233 // get map GetRefMap()234 const std::map<AnfWithOutIndex, AnfWithOutIndex> &GetRefMap() const { return ref_out_in_map_; } 235 // update ref map set_ref_out_in_map(const std::map<AnfWithOutIndex,AnfWithOutIndex> & ref_out_in_map)236 void set_ref_out_in_map(const std::map<AnfWithOutIndex, AnfWithOutIndex> &ref_out_in_map) { 237 ref_out_in_map_ = ref_out_in_map; 238 } 239 // check whether graph is executable executable()240 bool executable() const { return executable_; } 241 // set executable of graph set_executable(bool executable)242 void set_executable(bool executable) { executable_ = executable; } 243 #ifndef ENABLE_SECURITY 244 // set summary_node of graph set_summary_node_exist(bool summary_node_exist)245 void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } 246 #endif 247 // check whether exist summary node in graph summary_node_exist()248 bool summary_node_exist() const { return summary_node_exist_; } 249 // set need inline set_need_inline(bool need_inline)250 void set_need_inline(bool need_inline) { need_inline_ = need_inline; } 251 // check whether need inline need_inline()252 bool need_inline() const { return need_inline_; } 253 // set invalid inputs for control sink MutableValidInputs()254 std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } valid_inputs()255 std::vector<bool> valid_inputs() const { return valid_inputs_; } 256 // replace node in graph 257 void ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node); 258 // set stream label of graph set_stream_distinction_label(uint32_t stream_label)259 void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } 260 // get stream label of graph stream_distinction_label()261 uint32_t stream_distinction_label() const { return stream_distinction_label_; } 262 // refresh execute kernel stream label 263 void UpdateExecuteKernelStreamLabel(); 264 // calculate the leaf graph order of root graph 265 std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder(); 266 // the child graph of current graph child_graph_order()267 const std::vector<std::weak_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; } set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> & order)268 void set_child_graph_order(const std::vector<std::weak_ptr<KernelGraph>> &order) { child_graph_order_ = order; } 269 // checkout whether current graph is leaf graph 270 bool IsLeafGraph() const; 271 set_device_loop_ctrl_tensors(const std::map<std::string,tensor::TensorPtr> & device_loop_ctrl_tensors)272 void set_device_loop_ctrl_tensors(const std::map<std::string, tensor::TensorPtr> &device_loop_ctrl_tensors) { 273 device_loop_ctrl_tensors_ = device_loop_ctrl_tensors; 274 } device_loop_control_tensors()275 const std::map<std::string, tensor::TensorPtr> &device_loop_control_tensors() const { 276 return device_loop_ctrl_tensors_; 277 } 278 set_device_loop_ctrl_params(const std::map<std::string,mindspore::ParameterPtr> & device_loop_ctrl_params)279 void set_device_loop_ctrl_params(const std::map<std::string, mindspore::ParameterPtr> &device_loop_ctrl_params) { 280 device_loop_ctrl_params_ = device_loop_ctrl_params; 281 } device_loop_control_params()282 const std::map<std::string, mindspore::ParameterPtr> &device_loop_control_params() const { 283 return device_loop_ctrl_params_; 284 } 285 286 // get parent kernel graph parent_graph()287 std::weak_ptr<KernelGraph> parent_graph() const { return parent_graph_; } 288 // set parent kernel graph set_parent_graph(const std::weak_ptr<KernelGraph> & parent_graph)289 void set_parent_graph(const std::weak_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } 290 // find anf node in graph 291 std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; 292 std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const; 293 // used to dump ir 294 std::string ToString() const override; 295 296 bool FrontendNodeExistInFrontBackendMap(const AnfNodePtr &frontend_anf); 297 set_start_label(const CNodePtr & start_label)298 void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } get_start_label()299 CNodePtr get_start_label() { return start_label_; } set_end_goto(const CNodePtr & end_goto)300 void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } get_end_goto()301 CNodePtr get_end_goto() { return end_goto_; } 302 void PrintGraphExecuteOrder() const; summary_nodes()303 const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; } set_summary_nodes(const std::map<std::string,std::pair<AnfNodePtr,int>> & nodes)304 void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; } 305 void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx, bool unique_target); 306 void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx, 307 size_t dst_output_idx); 308 void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); 309 AnfWithOutIndex GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; 310 bool IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const; 311 bool IsInternalOutput(const AnfNodePtr &node) const; 312 bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const; 313 void AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor); 314 tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx); 315 AnfWithOutIndex GetGraphOutputByFrontNode(const AnfWithOutIndex &front_node) const; 316 317 // Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_. 318 void CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter, const AnfWithOutIndex &front_node_with_index); 319 void UpdateInternalParameter(); 320 // This function gets the real node that skip the monad control node. 321 AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const; 322 // This function gets the origin node used to connect monad controls between subgraphs. 323 AnfWithOutIndex GetOriginFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const; 324 325 // Get the funcgraph to which the kernel graph belongs. 326 FuncGraphPtr GetFuncGraph(); 327 // Cache the backend graph output nodes and corresponding to front nodes with output index into 328 // graph_output_to_front_node_map_. 329 void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtrList &backend_outputs, const AnfNodePtrList &front_outputs); 330 AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const; 331 332 void SetKernelObjectTypesForUnrealNodes() const; 333 current_epoch()334 uint32_t current_epoch() const { return current_epoch_; } set_current_epoch(uint32_t epoch)335 void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } 336 void UpdateChildGraphOrder(); child_graph_result()337 const AnfNodePtrList &child_graph_result() const { return child_graph_result_; } AddChildGraphResult(const AnfNodePtr & parameter)338 void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); } 339 bool IsChildGraphResult(const AnfNodePtr &node); set_child_graph_result(const AnfNodePtrList & child_graph_result)340 void set_child_graph_result(const AnfNodePtrList &child_graph_result) { child_graph_result_ = child_graph_result; } 341 InsertTupleParameterToMakeTupleMap(const AnfNodePtr & param,const AnfNodePtr & make_tuple)342 void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) { 343 if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { 344 return; 345 } 346 tuple_parameter_to_make_tuple_map_[param] = make_tuple; 347 } FindTupleParameterToMakeTupleMap(const AnfNodePtr & param)348 AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) const { 349 if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { 350 return tuple_parameter_to_make_tuple_map_.at(param); 351 } else { 352 return nullptr; 353 } 354 } 355 void RemoveNodeFromGraph(const AnfNodePtr &node); 356 void EnableRuntimeCache() const; 357 void DisableRuntimeCache() const; 358 void UpdateGraphDynamicAttr(); SetGraphDynamicAttr(bool is_dynamic_shape)359 void SetGraphDynamicAttr(bool is_dynamic_shape) { is_dynamic_shape_ = is_dynamic_shape; } is_dynamic_shape()360 bool is_dynamic_shape() const { return is_dynamic_shape_; } 361 void UpdateGraphAquireGilAttr(); 362 void SetOptimizerFlag(); 363 void SetInputNodes(); input_nodes()364 const AnfNodePtrList &input_nodes() const { return input_nodes_; } SetInputTensors(const std::vector<tensor::TensorPtr> & input_tensors)365 void SetInputTensors(const std::vector<tensor::TensorPtr> &input_tensors) { input_tensors_ = input_tensors; } input_tensors()366 const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; } 367 368 void SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor); 369 GetNodeOutputTensor(const session::KernelWithIndex & output_index)370 tensor::TensorPtr GetNodeOutputTensor(const session::KernelWithIndex &output_index) const { 371 auto iter = output_node_to_tensor_.find(output_index); 372 if (iter != output_node_to_tensor_.end()) { 373 return utils::cast<tensor::TensorPtr>(iter->second); 374 } 375 auto nop_node_output_iter = nop_node_output_map_.find(output_index); 376 if (nop_node_output_iter != nop_node_output_map_.end()) { 377 iter = output_node_to_tensor_.find(nop_node_output_iter->second); 378 if (iter != output_node_to_tensor_.end()) { 379 return utils::cast<tensor::TensorPtr>(iter->second); 380 } 381 } 382 return nullptr; 383 } 384 has_optimizer()385 bool has_optimizer() const { return has_optimizer_; } IsUpdatedParameter(const ParameterPtr & param)386 bool IsUpdatedParameter(const ParameterPtr ¶m) const { 387 return updated_parameters_.find(param) != updated_parameters_.end(); 388 } 389 // handle graph dependency AddPreGraph(const std::shared_ptr<session::KernelGraph> & graph)390 void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) { 391 if (graph != nullptr) { 392 pre_graphs_[graph->graph_id()] = graph; 393 } 394 } 395 get_pre_graphs()396 const mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &get_pre_graphs() const { 397 return pre_graphs_; 398 } AddPostGraph(const std::shared_ptr<session::KernelGraph> & graph)399 void AddPostGraph(const std::shared_ptr<session::KernelGraph> &graph) { 400 if (graph != nullptr) { 401 post_graphs_[graph->graph_id()] = graph; 402 } 403 } GetPostGraphs()404 const mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> &GetPostGraphs() const { 405 return post_graphs_; 406 } 407 IsPreGraphFinished()408 bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; } IsPostGraphFinished()409 bool IsPostGraphFinished() const { 410 if (first_step_) { 411 return true; 412 } 413 return post_graphs_.size() == post_graph_finished_count_; 414 } 415 HasPostGraph()416 bool HasPostGraph() const { return !post_graphs_.empty(); } 417 IncPreGraphFinishedCount()418 void IncPreGraphFinishedCount() { ++pre_graph_finished_count_; } IncPostGraphFinishedCount()419 void IncPostGraphFinishedCount() { ++post_graph_finished_count_; } ResetGraphRunningStatus()420 void ResetGraphRunningStatus() { 421 first_step_ = false; 422 post_graph_finished_count_ = 0; 423 pre_graph_finished_count_ = 0; 424 } OnRunGraphFinished()425 void OnRunGraphFinished() const { 426 for (const auto &post_graph : post_graphs_) { 427 auto post_graph_ptr = post_graph.second.lock(); 428 if (post_graph_ptr != nullptr) { 429 post_graph_ptr->IncPreGraphFinishedCount(); 430 } 431 } 432 for (const auto &pre_graph : pre_graphs_) { 433 auto pre_graph_ptr = pre_graph.second.lock(); 434 if (pre_graph_ptr != nullptr) { 435 pre_graph_ptr->IncPostGraphFinishedCount(); 436 } 437 } 438 } 439 // end of handle graph dependency 440 441 // The interface of parallel op send/recv pairs map. InsertSendRecvPairForParallelOpInputs(const CNodePtr & parallel_op,const std::pair<CNodePtr,CNodePtr> & send_recv_pair)442 void InsertSendRecvPairForParallelOpInputs(const CNodePtr ¶llel_op, 443 const std::pair<CNodePtr, CNodePtr> &send_recv_pair) { 444 auto iter = send_recv_pairs_for_parallel_op_inputs_.find(parallel_op); 445 if (iter == send_recv_pairs_for_parallel_op_inputs_.end()) { 446 send_recv_pairs_for_parallel_op_inputs_[parallel_op] = {send_recv_pair}; 447 } else { 448 iter->second.emplace_back(send_recv_pair); 449 } 450 } 451 InsertSendRecvPairForParallelOpOutputs(const CNodePtr & parallel_op,const std::pair<CNodePtr,CNodePtr> & send_recv_pair)452 void InsertSendRecvPairForParallelOpOutputs(const CNodePtr ¶llel_op, 453 const std::pair<CNodePtr, CNodePtr> &send_recv_pair) { 454 auto iter = send_recv_pairs_for_parallel_op_outputs_.find(parallel_op); 455 if (iter == send_recv_pairs_for_parallel_op_outputs_.end()) { 456 send_recv_pairs_for_parallel_op_outputs_[parallel_op] = {send_recv_pair}; 457 } else { 458 iter->second.emplace_back(send_recv_pair); 459 } 460 } 461 462 const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs()463 &send_recv_pairs_for_parallel_op_inputs() const { 464 return send_recv_pairs_for_parallel_op_inputs_; 465 } 466 const mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs()467 &send_recv_pairs_for_parallel_op_outputs() const { 468 return send_recv_pairs_for_parallel_op_outputs_; 469 } 470 label_num()471 uint32_t label_num() const { return label_num_; } set_label_num(uint32_t num)472 void set_label_num(uint32_t num) { label_num_ = num; } 473 // The graphs has recursion. recursive_call()474 bool recursive_call() const { return has_recursive_call_; } 475 // The graphs has subgraph multi-call. subgraph_multi_call()476 bool subgraph_multi_call() const { return has_subgraph_multicall_; } 477 // set flag to indicate whether has recursion. set_recursive_call(bool flag)478 void set_recursive_call(bool flag) { has_recursive_call_ = flag; } 479 // set flag to indicate whether has multi-call. set_subgraph_multi_call(bool flag)480 void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; } 481 graph_output_map()482 const std::map<AnfWithOutIndex, AnfWithOutIndex> &graph_output_map() const { return graph_output_to_front_node_map_; } front_node_to_graph_output_map()483 const std::map<AnfWithOutIndex, AnfWithOutIndex> &front_node_to_graph_output_map() const { 484 return front_node_to_graph_output_map_; 485 } 486 487 // The interface to set/get the graph GIL flag. set_is_need_gil(bool flag)488 void set_is_need_gil(bool flag) { is_need_gil_ = flag; } is_need_gil()489 bool is_need_gil() const { return is_need_gil_; } 490 491 bool IsDatasetGraph() const; 492 set_is_from_single_op(bool is_from_single_op)493 void set_is_from_single_op(bool is_from_single_op) { is_from_single_op_ = is_from_single_op; } is_from_single_op()494 bool is_from_single_op() const { return is_from_single_op_; } set_is_any_type_input(bool is_any_type_input)495 void set_is_any_type_input(bool is_any_type_input) { is_any_type_input_ = is_any_type_input; } is_any_type_input()496 bool is_any_type_input() const { return is_any_type_input_; } set_run_mode(device::RunMode run_mode)497 void set_run_mode(device::RunMode run_mode) { run_mode_ = run_mode; } RunMode()498 device::RunMode RunMode() const { return run_mode_; } is_graph_run_mode()499 bool is_graph_run_mode() const { return run_mode_ == device::RunMode::kGraphMode; } is_loop_count_sink()500 bool is_loop_count_sink() const { return is_loop_count_sink_; } set_memory_managed_by_ge(bool memory_managed_by_ge)501 void set_memory_managed_by_ge(bool memory_managed_by_ge) { 502 if (IsEnableRefMode()) { 503 memory_managed_by_ge_ = memory_managed_by_ge; 504 } 505 } memory_managed_by_ge()506 bool memory_managed_by_ge() const { return memory_managed_by_ge_; } set_is_loop_count_sink(bool is_loop_count_sink)507 void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; } front_backend_anf_map()508 const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() const { return front_backend_anf_map_; } 509 GetElementInTupleBackendFrontIndexMap(const AnfNodePtr & back_node)510 AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) const { 511 auto iter = tuple_backend_front_anf_index_map_.find(back_node); 512 if (iter == tuple_backend_front_anf_index_map_.end()) { 513 return AnfWithOutIndex{nullptr, 0}; 514 } 515 return iter->second; 516 } InternalParameterToFrontNodeMap()517 const HashMap<AnfNodePtr, AnfWithOutIndex> &InternalParameterToFrontNodeMap() const { 518 return internal_parameter_to_front_node_map_; 519 } SetInternalParameterToFrontNodeMap(const HashMap<AnfNodePtr,AnfWithOutIndex> & ipf_map)520 void SetInternalParameterToFrontNodeMap(const HashMap<AnfNodePtr, AnfWithOutIndex> &ipf_map) { 521 internal_parameter_to_front_node_map_ = ipf_map; 522 } 523 front_outputs()524 const AnfNodePtrList &front_outputs() const { return front_outputs_; } set_front_outputs(const AnfNodePtrList & outputs)525 void set_front_outputs(const AnfNodePtrList &outputs) { front_outputs_ = outputs; } IsCommSubGraph(uint32_t id)526 bool IsCommSubGraph(uint32_t id) const { return comm_sub_graph_ids_.find(id) != comm_sub_graph_ids_.end(); } RecordNewCommSubGraphId(uint32_t id)527 void RecordNewCommSubGraphId(uint32_t id) { comm_sub_graph_ids_.insert(id); } CommSubGraphIds()528 const std::set<uint32_t> &CommSubGraphIds() const { return comm_sub_graph_ids_; } 529 530 // somas total memory size MutableSomasInfo()531 SomasInfo *MutableSomasInfo() const { return somas_info_.get(); } somas_whole_block_size()532 size_t somas_whole_block_size() const { return somas_info_->whole_block_size_; } somas_merged_blocks_map()533 const std::map<size_t, size_t> &somas_merged_blocks_map() const { return somas_info_->merged_blocks_map_; } 534 set_graph_info(const std::string & graph_info)535 void set_graph_info(const std::string &graph_info) { graph_info_ = graph_info; } 536 // Infer cnode abstract by parameter. 537 void InferType(); 538 void PostNewCNode(const CNodePtr &cnode) const; 539 void SetKernelInfoForNode(const AnfNodePtr &node) const; AddInlineSubgraphKernel(const AnfNodePtr & node,const std::string & graph_name)540 void AddInlineSubgraphKernel(const AnfNodePtr &node, const std::string &graph_name) { 541 inline_sub_graph_kernels_[node] = graph_name; 542 } inline_sub_graph_kernels()543 const mindspore::HashMap<AnfNodePtr, std::string> &inline_sub_graph_kernels() const { 544 return inline_sub_graph_kernels_; 545 } condition_gather_to_switch()546 mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch() const { return condition_gather_to_switch_; } AddConditionGatherSwitchPair(const AnfNodePtr & condition_gather,const AnfNodePtr & condition_switch)547 void AddConditionGatherSwitchPair(const AnfNodePtr &condition_gather, const AnfNodePtr &condition_switch) { 548 condition_gather_to_switch_[condition_gather] = condition_switch; 549 } RemoveConditionGatherSwitchPair(const AnfNodePtr & condition_gather)550 void RemoveConditionGatherSwitchPair(const AnfNodePtr &condition_gather) { 551 condition_gather_to_switch_.erase(condition_gather); 552 } 553 set_is_from_pynative(const bool & from_pynative)554 void set_is_from_pynative(const bool &from_pynative) { from_pynative_ = from_pynative; } is_from_pynative()555 bool is_from_pynative() const { return from_pynative_; } 556 enable_multi_stream()557 bool enable_multi_stream() const { return enable_multi_stream_; } set_enable_multi_stream(bool enable_multi_stream)558 void set_enable_multi_stream(bool enable_multi_stream) { enable_multi_stream_ = enable_multi_stream; } 559 has_kernel_need_user_data()560 bool has_kernel_need_user_data() const { return has_kernel_need_user_data_; } set_has_kernel_need_user_data(bool has_kernel_need_user_data)561 void set_has_kernel_need_user_data(bool has_kernel_need_user_data) { 562 has_kernel_need_user_data_ = has_kernel_need_user_data; 563 } 564 void CacheRootWeight(const std::vector<AnfNodePtr> &weights); GetRootWeights()565 const std::vector<AnfNodePtr> &GetRootWeights() const { return root_weights_; } 566 567 private: 568 AnfNodePtr MakeValueNode(const AnfNodePtr &node) const; 569 570 AnfNodePtr TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value); 571 AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); 572 AnfNodePtr TransCNodeTuple(const CNodePtr &node); 573 AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx); 574 std::vector<CNodePtr> SortStartLabelAndEndGoto(); 575 576 // members 577 std::shared_ptr<AnfNodePtrList> inputs_; 578 std::shared_ptr<SomasInfo> somas_info_; 579 AnfNodePtrList child_graph_result_; 580 std::vector<CNodePtr> execution_order_; 581 std::vector<CNodePtr> mem_reuse_exec_order_; 582 uint32_t graph_id_; 583 uint32_t stream_distinction_label_; 584 DeviceType device_target_; 585 uint32_t root_graph_id_{0}; 586 587 // record map between front anf and backend anf,use two map implement bidirectional map 588 mindspore::HashMap<AnfNodePtr, AnfNodePtr> front_backend_anf_map_; 589 mindspore::HashMap<AnfNodePtr, AnfNodePtr> backend_front_anf_map_; 590 mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> tuple_backend_front_anf_index_map_; 591 // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record 592 mindspore::HashMap<tensor::TensorPtr, ValueNodePtr> tensor_to_value_node_map_; 593 // include all value nodes, this second size_t represents the number of times value_node is used in the graph. 594 mindspore::HashMap<ValueNodePtr, size_t> graph_value_nodes_; 595 // record map between ref final output anf with index and ref origin input with index 596 std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_; 597 mindspore::HashMap<AnfNodePtr, AnfNodePtrList> node_output_edges_; 598 std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_; 599 // parameters that will be updated when graph is executed 600 mindspore::HashSet<ParameterPtr> updated_parameters_; 601 // Kernel in inline subgraph for switch node. 602 mindspore::HashMap<AnfNodePtr, std::string> inline_sub_graph_kernels_; 603 // Record the relationship between condition gather and condition switch. 604 mindspore::HashMap<AnfNodePtr, AnfNodePtr> condition_gather_to_switch_; 605 606 // graph needn't execute 607 bool executable_{false}; 608 // exist summary node in graph 609 bool summary_node_exist_{false}; 610 // valid inputs 611 std::vector<bool> valid_inputs_; 612 // need inline 613 bool need_inline_; 614 615 // child graph execute order in parent graph 616 std::vector<std::weak_ptr<KernelGraph>> child_graph_order_; 617 618 // device loop control frontend tensors 619 std::map<std::string, tensor::TensorPtr> device_loop_ctrl_tensors_; 620 // device loop control backend nodes 621 std::map<std::string, mindspore::ParameterPtr> device_loop_ctrl_params_; 622 623 // parameter graph 624 std::weak_ptr<KernelGraph> parent_graph_; 625 626 CNodePtr start_label_; 627 CNodePtr end_goto_; 628 629 AnfNodePtrList front_outputs_; 630 // Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is 631 // related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second 632 // of unordered map is front node corresponding to the output of previous kernel graph. 633 mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_; 634 // The first of map is the backend graph output of this kernel graph, the second of map is front node corresponding to 635 // the backend node with index. 636 std::map<AnfWithOutIndex, AnfWithOutIndex> graph_output_to_front_node_map_; 637 std::map<AnfWithOutIndex, AnfWithOutIndex> front_node_to_graph_output_map_; 638 639 mindspore::HashMap<AnfNodePtr, AnfWithOutIndex> front_to_internal_outputs_map_; 640 mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, std::pair<AnfNodePtr, bool>>> 641 internal_outputs_to_front_map_; 642 mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, tensor::TensorPtr>> internal_outputs_tensor_map_; 643 uint32_t current_epoch_; 644 mindspore::HashMap<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; 645 AnfNodePtrList input_nodes_; 646 std::vector<tensor::TensorPtr> input_tensors_; 647 KernelMapTensor output_node_to_tensor_; 648 std::map<session::KernelWithIndex, session::KernelWithIndex, session::KernelWithIndexCmp> nop_node_output_map_; 649 mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> pre_graphs_; 650 mindspore::HashMap<uint32_t, std::weak_ptr<session::KernelGraph>> post_graphs_; 651 652 // key:parallel op ptr, value:vector of <send op receive op > pairs 653 mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_inputs_; 654 // key:parallel op ptr, value:vector of <send op receive op > pairs 655 mindspore::HashMap<CNodePtr, std::vector<std::pair<CNodePtr, CNodePtr>>> send_recv_pairs_for_parallel_op_outputs_; 656 657 std::atomic<size_t> pre_graph_finished_count_{0}; 658 std::atomic<size_t> post_graph_finished_count_{0}; 659 bool first_step_{true}; 660 bool has_optimizer_{false}; 661 bool is_dynamic_shape_{false}; 662 663 // Indicate the graphs has recursion or multi-call or not as the root graph. 664 bool has_recursive_call_{false}; 665 bool has_subgraph_multicall_{false}; 666 667 // Number of labels. This is also the 'batch_num' for DavinciModel, 668 // It should be 1 if no labels used for control flow. 669 uint32_t label_num_ = 1; 670 671 // Indicate whether the kernels in the graphs acquire Python GIL. 672 bool is_need_gil_{false}; 673 674 // Memory is managed by GE 675 bool memory_managed_by_ge_{false}; 676 677 // Indicate whether the kernel graph is constructed from single op in function graph 678 bool is_from_single_op_{false}; 679 // Indicate whether the kernel graph has an any type input. 680 bool is_any_type_input_{false}; 681 // Indicate whether the kernel graph sink will run on graph executor or kernel executor 682 device::RunMode run_mode_{device::RunMode::kUnknown}; 683 684 // Indicate whether the kernel graph loop sink to the device executing. 685 bool is_loop_count_sink_{false}; 686 // save the communication sub-graph id for comm op reuse 687 std::set<uint32_t> comm_sub_graph_ids_{}; 688 // graph info for single op 689 std::string graph_info_; 690 bool from_pynative_{false}; 691 692 bool enable_multi_stream_{false}; 693 // Whether this graph contains kernel which need user data. 694 bool has_kernel_need_user_data_{false}; 695 std::vector<AnfNodePtr> root_weights_; 696 }; 697 } // namespace session 698 using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; 699 } // namespace mindspore 700 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_GRAPH_H 701