• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
18 
19 #include <vector>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <memory>
24 #include <map>
25 #include <set>
26 #include "backend/session/session_context.h"
27 #include "backend/session/kernel_graph.h"
28 #include "backend/session/anf_runtime_algorithm.h"
29 #include "ir/anf.h"
30 #include "ir/tensor.h"
31 #include "utils/any.h"
32 #include "utils/contract.h"
33 #include "runtime/device/kernel_info.h"
34 #include "utils/ms_context.h"
35 #include "runtime/device/bucket.h"
36 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
37 #include "debug/debugger/debugger.h"
38 #endif
39 #include "runtime/hardware/device_context.h"
40 #include "backend/session/pynative_task_manager.h"
41 
42 namespace mindspore {
43 namespace runtime {
44 class GraphCompiler;
45 }  // namespace runtime
46 }  // namespace mindspore
47 
48 namespace mindspore {
49 using GraphId = uint32_t;
50 using GraphInfo = std::string;
51 const char kSessionBasic[] = "SessionBasic";
52 
53 namespace session {
54 using CallBackFunc = uint32_t (*)(uint32_t graph_id,
55                                   const std::map<std::string, mindspore::tensor::TensorPtr> &params_list);
56 using AnyList = std::vector<Any>;
57 using AnyListPtr = std::shared_ptr<AnyList>;
58 
59 struct OpRunInfo {
60   std::string op_name;
61   PrimitivePtr primitive;
62   AbstractBasePtr abstract;
63   bool is_dynamic_shape = false;
64   bool is_auto_mixed_precision = false;
65   bool lazy_build = false;
66   std::string next_op_name = "";
67 #if defined(__APPLE__)
68   int next_input_index = 0;
69 #else
70   size_t next_input_index = 0;
71 #endif
72 };
73 
74 struct InputTensorInfo {
75   std::vector<tensor::TensorPtr> input_tensors;
76   std::vector<int64_t> input_tensors_mask;
77   std::set<KernelWithIndex> input_kernel;
78 };
79 
80 struct OutputTensorInfo {
81   tensor::TensorPtr output_stub_tensor;
82   bool is_weight;
83 };
84 
85 struct GraphOutputInfo {
86   VectorRef *graph_outputs;
87   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
88   std::vector<tensor::TensorPtr> graph_output_tensors;
89 };
90 
91 class Executor;
92 
93 class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
94  public:
SessionBasic()95   SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
96 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
97     debugger_ = nullptr;
98 #endif
99   }
100 
Init(uint32_t device_id)101   virtual void Init(uint32_t device_id) { device_id_ = device_id; }
102   void InitExecutor(const std::string &device_name, uint32_t device_id);
SyncStream()103   virtual void SyncStream() const {}
~SessionBasic()104   virtual ~SessionBasic() { summary_callback_ = nullptr; }
105 
106   GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
107   GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
108   void BuildGraph(GraphId graphId);
109   void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
110   void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
111   void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
112              const std::vector<int64_t> &tensors_mask);
113   void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
114 
115 #ifndef ENABLE_SECURITY
116   virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
117 #endif
118 
119   bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
120 
121   std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
122                                                     bool common_opt = true);
123   std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
124                                                     std::vector<KernelGraphPtr> *all_out_graph);
125 
126   CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
127                           std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
128   CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
129 
130   // get graph id in child graphs by ME front anf node pointer
131   virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const;
GetFinalRunGraph()132   virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
133   void AssignParamKey(const KernelGraphPtr &kernel_graph);
134   void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
135   bool IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name);
CheckModelInputs(uint32_t graph_id,const std::vector<tensor::TensorPtr> & inputs,std::string * error_msg)136   virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
137                                 std::string *error_msg) const {
138     return true;
139   }
140   void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
141                           std::vector<std::string> *inputs_name) const;
142   void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
143                            std::vector<std::string> *outputs_name) const;
144   std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id,
145                                                          const std::vector<tensor::TensorPtr> &inputs);
146   // Get graph by graph id, if not exist return null ptr
147   KernelGraphPtr GetGraph(GraphId graph_id) const;
148   void ClearGraph();
149   // create a single run op graph
150   std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
151                                                       const std::vector<tensor::TensorPtr> &input_tensors,
152                                                       const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
153   void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
154                             std::vector<tensor::TensorPtr> *input_tensors) const;
155   void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
156   static void RunOpHideNopNode(const KernelGraphPtr &kernel_graph);
ReportWarningMessage()157   virtual void ReportWarningMessage() {}
ReportErrorMessage()158   virtual void ReportErrorMessage() {}
SetThreadContext()159   virtual void SetThreadContext() {}
160 #ifdef ENABLE_DEBUGGER
161   // set debugger
SetDebugger()162   void SetDebugger() {
163     debugger_ = Debugger::GetInstance();
164     auto ms_context = MsContext::GetInstance();
165     MS_EXCEPTION_IF_NULL(ms_context);
166     MS_EXCEPTION_IF_NULL(debugger_);
167     debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
168   }
169 #endif
170 
171  private:
172   CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph);
173   std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
174   std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
175   void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
176   std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
177   void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const;
178   void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
179                          std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
180   std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
181   void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
182   void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
183                             const FuncGraphManagerPtr &front_func_graph_manager,
184                             const std::shared_ptr<KernelGraph> &backend_graph);
185   std::string AddPartialParametersMap(const AnfNodePtr &partial_node);
186   void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
187                          std::map<AnfNodePtr, size_t> *parameter_index);
188   void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
189                                VectorRef *const outputs,
190                                std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes);
191   void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count);
192   void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
193                       std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map);
194 
195   void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
196                        const std::map<KernelWithIndex, size_t> &ref_count,
197                        std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
198                        GraphOutputInfo *const graph_output_info);
199 
200  protected:
201   friend class Executor;
202   friend class CompileNodesTask;
203   friend class CompileGraphTask;
204   friend class BuildGraphTask;
205   friend class RunGraphTask;
206   friend class RunOpTask;
207   friend class RunOpsInGraphTask;
208   friend class mindspore::runtime::GraphCompiler;
IsSupportSummary()209   virtual bool IsSupportSummary() { return true; }
210   virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
211                                    VectorRef *outputs,
212                                    std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node);
213   // When the device address of the node is used as the output of the graph, the device address will be passed
214   // to the output tensor, and the output node will recreate a new device address. This third parameter records
215   // the relationship between the new and old device address.
216   virtual void UpdateOutputTensors(const VectorRef *outputs,
217                                    const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
218                                    std::map<DeviceAddressPtr, DeviceAddressPtr> *);
219   virtual void UnifyMindIR(const KernelGraphPtr &graph);
220   virtual void FinalOptimize(const KernelGraphPtr &graph) const;
CompileGraphImpl(const AnfNodePtrList & lst,const AnfNodePtrList & outputs)221   virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
CompileGraphImpl(NotNull<FuncGraphPtr> func_graph)222   virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
BuildGraphImpl(GraphId)223   virtual void BuildGraphImpl(GraphId) {}
PreExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)224   virtual void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
225                                const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {}
PostExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)226   virtual void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
227                                 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {}
ExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph)228   virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {}
229   void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
BuildOpImpl(const OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask)230   virtual KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
231                                      const std::vector<tensor::TensorPtr> &input_tensors,
232                                      const std::vector<int64_t> &tensors_mask) {
233     return nullptr;
234   }
RunOpImpl(const GraphInfo & graph_info,OpRunInfo * op_run_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)235   virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
236                          std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
237                          const std::vector<int64_t> &tensors_mask) {}
RunOpImplOrigin(const GraphInfo & graph_info,OpRunInfo * op_run_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)238   virtual void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info,
239                                std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
240                                const std::vector<int64_t> &tensors_mask) {}
241   void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
BuildOpsInGraph(const GraphId & graph_id,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,const std::map<KernelWithIndex,size_t> & cnode_refcount)242   virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
243                                const std::vector<tensor::TensorPtr> &graph_inputs,
244                                const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
245 #ifndef ENABLE_SECURITY
246   virtual void SetSummaryNodes(KernelGraph *graph);
247 #endif
248 
LoadInputs(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs_const)249   void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {
250     auto kernel_graph = GetGraph(graph_id);
251     MS_EXCEPTION_IF_NULL(kernel_graph);
252     if (!kernel_graph->executable()) {
253       return;
254     }
255     MS_LOG(INFO) << "Load inputs";
256     LoadInputData(kernel_graph, inputs_const);
257   }
258 
ExecuteAllTaskInQueue()259   virtual void ExecuteAllTaskInQueue() {}
260 
LoadInputData(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)261   virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
262                              const std::vector<tensor::TensorPtr> &inputs_const) const {}
263   void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
264                      const std::vector<tensor::TensorPtr> &input_tensors,
265                      std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const;
266   void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
267 #ifndef ENABLE_SECURITY
268   void Summary(KernelGraph *graph);
269 #endif
270   // create graph output for RunOp
271   void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
272   CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
273   // Generate graph info for a single op graph
274   GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
275   void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info);
276   tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
277   tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
278                                              const std::map<AnfNodePtr, size_t> &parameter_index,
279                                              const std::vector<tensor::TensorPtr> &graph_inputs);
280   tensor::TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
281                                          const std::map<KernelWithIndex, tensor::TensorPtr> &op_output);
282   void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
283                          const std::map<AnfNodePtr, size_t> &parameter_index,
284                          const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
285   tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode,
286                                             const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
287                                             const std::map<AnfNodePtr, size_t> &parameter_index,
288                                             const std::vector<tensor::TensorPtr> &graph_inputs,
289                                             InputTensorInfo *const input_tensor_info, size_t input_index);
290 
291   // create a new kernel graph and update the graph sum
292   KernelGraphPtr NewKernelGraph();
293   AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
294   virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
295   ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
296   ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
297   AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
298   void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
299   void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
300   AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
301   void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
302   void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
CreateBucket(uint32_t bucket_id,uint32_t bucket_size)303   virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
304   void InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context = nullptr);
305   void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
306   void ClearAllBucket(const GraphId &graph_id);
307   std::vector<uint32_t> GetAllReduceSplitIndex();
GetCommWorldGroup()308   virtual std::string GetCommWorldGroup() { return std::string(); }
309   void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph);
310 #if ((defined ENABLE_CPU) && (!defined _WIN32))
311   void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
312   void GetBatchElements(const AnfNodePtr &kernel_node) const;
313   void InitPsWorker(const KernelGraphPtr &kernel_graph);
314 #endif
315 
316   std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
317   std::map<uint32_t, uint32_t> free_bucket_id_map_;
318   std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
319   std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
320   std::unordered_map<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
321   std::unordered_map<AnfNodePtr, AnfNodePtr> partial_parameters_map_;
322   std::unordered_map<AnfNodePtr, std::string> partial_target_map_;
323   std::shared_ptr<Context> context_;
324   CallBackFunc summary_callback_;
325   static GraphId graph_sum_;
326   uint32_t device_id_;
327   // rank id of physical device
328   uint32_t rank_id_{0};
329   std::shared_ptr<Executor> executor_;
330 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
331   std::shared_ptr<Debugger> debugger_;
332 #endif
333 };
334 
335 using SessionPtr = std::shared_ptr<session::SessionBasic>;
336 using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>;
337 }  // namespace session
338 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
339                        const std::vector<CNodePtr> &execution_order);
340 uint32_t GetRankId();
341 }  // namespace mindspore
342 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
343