• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
18 #define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
19 
20 #include <unordered_map>
21 #include <string>
22 #include <memory>
23 #include <vector>
24 #include <utility>
25 #include <stack>
26 #include <map>
27 #include <tuple>
28 #include <set>
29 #include "backend/session/session_basic.h"
30 #include "backend/session/kernel_graph.h"
31 #include "backend/kernel_compiler/kernel.h"
32 #include "backend/session/session_factory.h"
33 #include "backend/session/pynative_task_manager.h"
34 
35 namespace mindspore {
36 namespace session {
37 enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 };
38 
39 class AscendSession : public SessionBasic {
40  public:
AscendSession()41   AscendSession() { final_graph_id_ = kInvalidGraphId; }
42   ~AscendSession() = default;
43   void Init(uint32_t device_id) override;
44   // get graph id of final graph
GetFinalRunGraph()45   GraphId GetFinalRunGraph() const override { return final_graph_id_; }
46   void SyncStream() const override;
47 
48   static void BatchBuildKernel(const std::vector<std::shared_ptr<SessionTask>> &build_tasks);
49 
50  protected:
51   void UnifyMindIR(const KernelGraphPtr &graph) override;
52   GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
53   GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
54   bool IsSupportSummary() override;
55   void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
56                      const std::vector<tensor::TensorPtr> &inputs_const) const override;
57   void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
58                        VectorRef *const outputs) override;
59   void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
60                         VectorRef *const outputs) override;
61   void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
62   void BuildGraphImpl(GraphId) override;
63 
64   KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
65                              const std::vector<tensor::TensorPtr> &input_tensors,
66                              const std::vector<int64_t> &tensors_mask) override;
67 
68   void BindAddressToTensor(const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) const;
69   void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info,
70                        std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
71                        const std::vector<int64_t> &tensors_mask) override;
72 
73   void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
74                  VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
75   void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
76                        const std::vector<tensor::TensorPtr> &graph_inputs,
77                        const std::map<KernelWithIndex, size_t> &cnode_refcount) override;
GetCommWorldGroup()78   std::string GetCommWorldGroup() override { return kHcclWorldGroup; }
79   void ReportWarningMessage() override;
80   void ReportErrorMessage() override;
81   void SetThreadContext() override;
82   void ExecuteAllTaskInQueue() override;
83   void UpdateOutputTensors(const VectorRef *outputs,
84                            const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
85                            std::map<DeviceAddressPtr, DeviceAddressPtr> *) override;
86   DeviceAddressPtr AssignExtraMemForGraphOutput(const tensor::TensorPtr &tensor, const AnfNodePtr &node,
87                                                 size_t index) const;
88 
89  private:
90   // compile child graph when session have multiple child graphs
91   void CompileChildGraph(const KernelGraphPtr &child_graph);
92 #ifndef ENABLE_SECURITY
93   void RecurseSetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
94   void SetSummaryNodes(KernelGraph *graph) override;
95 #endif
96   void InitRuntimeResource();
97   void SelectKernel(const KernelGraph &kernel_graph) const;
98   void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
99   void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
100   void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
101   void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
102   void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
103   void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
104   static void BuildKernel(const std::vector<CNodePtr> &kernels);
105   void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
106   void MemoryAlloc(KernelGraph *kernel_graph) const;
107   void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
108   void RunOpMemoryAllocNew(const std::vector<tensor::TensorPtr> &input_tensors,
109                            const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
110                            const KernelGraph &kernel_graph) const;
111   void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
112   void RunOpGenKernelEvent(const KernelGraph *graph) const;
113   void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
114   void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
115 #ifndef ENABLE_SECURITY
116   void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
117   void DumpSetup(const std::shared_ptr<KernelGraph> &kernel_graph) const;
118 #endif
119   void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
120   void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
121   // below functions are used for run op
122   void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
123 
124   void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph, const std::vector<KernelGraphPtr> &all_graphs);
125   // merge execution order list of child graphs
126   void MergeGraphExecOrder();
127   // get graph order vector by graph id
128   const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
129   // get graph order type vector by graph id
130   const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
131   // sync initial tensors' data to device
132   void SyncInitialTenosrToDevice();
133 #ifndef ENABLE_SECURITY
134   void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
135 #endif
136   // create parameter to receive data from multiple branch output
137   void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
138   void SelectKernel(NotNull<KernelGraphPtr> root_graph);
139   void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
140                                size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
141   void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
142   void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
143 #ifdef ENABLE_DEBUGGER
144   void LoadGraphsToDbg(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
145 #endif
146   void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
147   void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
148   void CacheCNodeOutputInfo(const KernelGraph &graph) const;
149   KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const std::vector<tensor::TensorPtr> &input_tensors,
150                             const std::vector<int64_t> &tensors_mask);
151   void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
152                              const std::vector<tensor::TensorPtr> &graph_inputs,
153                              const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
154                              InputTensorInfo *input_tensor_info);
155   void PrepareForOutputTensor(const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &input_tensors,
156                               std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
157                               VectorRef *outputs) const;
158   std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
159 
160   void LaunchFunc(const KernelGraphPtr &graph,
161                   const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, bool is_dynamic_shape,
162                   const std::vector<tensor::TensorPtr> &input_tensors);
163   KernelGraphPtr CreateKernelGraph(const GraphInfo &graph_info, OpRunInfo *op_run_info,
164                                    std::vector<tensor::TensorPtr> *input_tensors,
165                                    const std::vector<int64_t> &tensors_mask, bool cache_miss);
166   static bool DisableLazyBuild(const OpRunInfo &op_run_info);
167   // key is final_graph_id,value is child graph execute order of final graph
168   std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
169   // key is final_graph_id,value is the graph types of child graphs
170   std::unordered_map<GraphId, std::vector<GraphType>> graph_order_types_;
171   // initial tensors, these tensor will sync data to device before run graph
172   std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
173   // final_graph_id is used in every root graph has it's own session situation
174   GraphId final_graph_id_;
175   // record graph ids of bp graphs that has been built in PyNative mode
176   std::set<GraphId> built_graph_id_;
177   // tensor with new device addr map
178   std::map<tensor::TensorPtr, DeviceAddressPtr> tensor_device_addr_map_;
179 };
180 MS_REG_SESSION(kAscendDevice, AscendSession);
181 }  // namespace session
182 }  // namespace mindspore
183 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
184