1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H 18 #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H 19 20 #include <unordered_map> 21 #include <string> 22 #include <memory> 23 #include <vector> 24 #include <utility> 25 #include <stack> 26 #include <map> 27 #include <tuple> 28 #include <set> 29 #include "backend/session/session_basic.h" 30 #include "backend/session/kernel_graph.h" 31 #include "backend/kernel_compiler/kernel.h" 32 #include "backend/session/session_factory.h" 33 #include "backend/session/pynative_task_manager.h" 34 35 namespace mindspore { 36 namespace session { 37 enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; 38 39 class AscendSession : public SessionBasic { 40 public: AscendSession()41 AscendSession() { final_graph_id_ = kInvalidGraphId; } 42 ~AscendSession() = default; 43 void Init(uint32_t device_id) override; 44 // get graph id of final graph GetFinalRunGraph()45 GraphId GetFinalRunGraph() const override { return final_graph_id_; } 46 void SyncStream() const override; 47 48 static void BatchBuildKernel(const std::vector<std::shared_ptr<SessionTask>> &build_tasks); 49 50 protected: 51 void UnifyMindIR(const KernelGraphPtr &graph) override; 52 GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; 53 GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; 54 bool IsSupportSummary() override; 55 void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, 56 const std::vector<tensor::TensorPtr> &inputs_const) const override; 57 void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs, 58 VectorRef *const outputs) override; 59 void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs, 60 VectorRef *const outputs) override; 61 void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override; 62 void BuildGraphImpl(GraphId) override; 63 64 KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, 65 const std::vector<tensor::TensorPtr> &input_tensors, 66 const std::vector<int64_t> &tensors_mask) override; 67 68 void BindAddressToTensor(const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) const; 69 void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info, 70 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 71 const std::vector<int64_t> &tensors_mask) override; 72 73 void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, 74 VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; 75 void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index, 76 const std::vector<tensor::TensorPtr> &graph_inputs, 77 const std::map<KernelWithIndex, size_t> &cnode_refcount) override; GetCommWorldGroup()78 std::string GetCommWorldGroup() override { return kHcclWorldGroup; } 79 void ReportWarningMessage() override; 80 void ReportErrorMessage() override; 81 void SetThreadContext() override; 82 void ExecuteAllTaskInQueue() override; 83 void UpdateOutputTensors(const VectorRef *outputs, 84 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 85 std::map<DeviceAddressPtr, DeviceAddressPtr> *) override; 86 DeviceAddressPtr AssignExtraMemForGraphOutput(const tensor::TensorPtr &tensor, const AnfNodePtr &node, 87 size_t index) const; 88 89 private: 90 // compile child graph when session have multiple child graphs 91 void CompileChildGraph(const KernelGraphPtr &child_graph); 92 #ifndef ENABLE_SECURITY 93 void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary); 94 void SetSummaryNodes(KernelGraph *graph) override; 95 #endif 96 void InitRuntimeResource(); 97 void SelectKernel(const KernelGraph &kernel_graph) const; 98 void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; 99 void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const; 100 void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 101 void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 102 void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; 103 void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 104 static void BuildKernel(const std::vector<CNodePtr> &kernels); 105 void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 106 void MemoryAlloc(KernelGraph *kernel_graph) const; 107 void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; 108 void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors, 109 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 110 const KernelGraph &kernel_graph) const; 111 void RunOpMemoryClear(const KernelGraph *kernel_graph) const; 112 void RunOpGenKernelEvent(const KernelGraph *graph) const; 113 void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; 114 void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const; 115 #ifndef ENABLE_SECURITY 116 void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; 117 void DumpSetup(const std::shared_ptr<KernelGraph> &kernel_graph) const; 118 #endif 119 void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs); 120 void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; 121 // below functions are used for run op 122 void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const; 123 124 void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph, const std::vector<KernelGraphPtr> &all_graphs); 125 // merge execution order list of child graphs 126 void MergeGraphExecOrder(); 127 // get graph order vector by graph id 128 const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const; 129 // get graph order type vector by graph id 130 const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const; 131 // sync initial tensors' data to device 132 void SyncInitialTenosrToDevice(); 133 #ifndef ENABLE_SECURITY 134 void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph); 135 #endif 136 // create parameter to receive data from multiple branch output 137 void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); 138 void SelectKernel(NotNull<KernelGraphPtr> root_graph); 139 void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo, 140 size_t *const raise_precision_count, size_t *const reduce_precision_count) const; 141 void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); 142 void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; 143 #ifdef ENABLE_DEBUGGER 144 void LoadGraphsToDbg(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; 145 #endif 146 void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; 147 void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; 148 void CacheCNodeOutputInfo(const KernelGraph &graph) const; 149 KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const std::vector<tensor::TensorPtr> &input_tensors, 150 const std::vector<int64_t> &tensors_mask); 151 void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> ¶meter_index, 152 const std::vector<tensor::TensorPtr> &graph_inputs, 153 const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info, 154 InputTensorInfo *input_tensor_info); 155 void PrepareForOutputTensor(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors, 156 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, 157 VectorRef *outputs) const; 158 std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; 159 160 void LaunchFunc(const KernelGraphPtr &graph, 161 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, bool is_dynamic_shape, 162 const std::vector<tensor::TensorPtr> &input_tensors); 163 KernelGraphPtr CreateKernelGraph(const GraphInfo &graph_info, OpRunInfo *op_run_info, 164 std::vector<tensor::TensorPtr> *input_tensors, 165 const std::vector<int64_t> &tensors_mask, bool cache_miss); 166 static bool DisableLazyBuild(const OpRunInfo &op_run_info); 167 // key is final_graph_id,value is child graph execute order of final graph 168 std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_; 169 // key is final_graph_id,value is the graph types of child graphs 170 std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_; 171 // initial tensors, these tensor will sync data to device before run graph 172 std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_; 173 // final_graph_id is used in every root graph has it's own session situation 174 GraphId final_graph_id_; 175 // record graph ids of bp graphs that has been built in PyNative mode 176 std::set<GraphId> built_graph_id_; 177 // tensor with new device addr map 178 std::map<tensor::TensorPtr, DeviceAddressPtr> tensor_device_addr_map_; 179 }; 180 MS_REG_SESSION(kAscendDevice, AscendSession); 181 } // namespace session 182 } // namespace mindspore 183 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H 184