• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
18 
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <memory>
23 #include <map>
24 #include <set>
25 #include "utils/hash_map.h"
26 #include "backend/common/session/kernel_graph_mgr.h"
27 #include "backend/common/session/session_context.h"
28 #include "include/backend/kernel_graph.h"
29 #include "include/backend/anf_runtime_algorithm.h"
30 #include "include/common/utils/anfalgo.h"
31 #include "include/common/utils/tensor_future.h"
32 #include "ir/anf.h"
33 #include "ir/tensor.h"
34 #include "utils/any.h"
35 #include "include/common/utils/contract.h"
36 #include "include/backend/kernel_info.h"
37 #include "utils/ms_context.h"
38 #include "pipeline/pynative/base.h"
39 
40 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
41 #include "include/backend/debug/debugger/debugger.h"
42 #endif
43 #include "mindspore/ccsrc/debug/summary/summary.h"
44 #include "runtime/hardware/device_context.h"
45 #include "include/backend/visible.h"
46 
47 namespace mindspore {
48 namespace runtime {
49 class GraphCompiler;
50 }  // namespace runtime
51 }  // namespace mindspore
52 
53 namespace mindspore {
54 const char kSessionBasic[] = "SessionBasic";
55 
56 namespace session {
57 using mindspore::debug::CallBackFunc;
58 #ifndef ENABLE_SECURITY
59 using mindspore::debug::Summary;
60 #endif
61 
62 using AnyList = std::vector<Any>;
63 using AnyListPtr = std::shared_ptr<AnyList>;
64 
65 struct BackendOpRunInfo {
66   ~BackendOpRunInfo() = default;
BackendOpRunInfoBackendOpRunInfo67   BackendOpRunInfo(pynative::BaseOpRunInfo base_op_run_info, PrimitivePtr prim, bool is_infer, bool is_gradient_out)
68       : base_op_run_info(std::move(base_op_run_info)),
69         op_prim(std::move(prim)),
70         is_infer(is_infer),
71         is_gradient_out(is_gradient_out) {}
72 
73   pynative::BaseOpRunInfo base_op_run_info;
74   PrimitivePtr op_prim;
75   bool is_infer = false;
76   bool is_gradient_out = false;
77 };
78 using BackendOpRunInfoPtr = std::shared_ptr<BackendOpRunInfo>;
79 
80 struct InputInfo {
81   std::vector<ValuePtr> input_values;
82   std::vector<InputType> input_types;
83   std::set<KernelWithIndex> input_kernel;
84   abstract::AbstractBasePtrList input_abs;
85 };
86 
87 struct OutputTensorInfo {
88   tensor::TensorPtr output_stub_tensor;
89   bool is_weight;
90 };
91 
92 struct GraphOutputInfo {
93   VectorRef *graph_outputs;
94   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
95   std::vector<tensor::BaseTensorPtr> graph_output_tensors;
96 };
97 
98 class Executor;
99 
100 class BACKEND_EXPORT SessionBasic : public KernelGraphMgr, public std::enable_shared_from_this<SessionBasic> {
101  public:
102   using KernelGraphMgr::ConstructKernelGraph;
SessionBasic()103   SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
104 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
105     debugger_ = nullptr;
106 #endif
107   }
108 
Init(uint32_t device_id)109   virtual void Init(uint32_t device_id) { device_id_ = device_id; }
110   void InitExecutor(const std::string &device_name, uint32_t device_id);
SyncStream()111   virtual void SyncStream() const {}
~SessionBasic()112   virtual ~SessionBasic() { summary_callback_ = nullptr; }
113 
114   GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
115   GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph);
116   void BuildGraph(GraphId graphId);
117   void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
118   void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
119 
120 #ifndef ENABLE_SECURITY
121   virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
122 #endif
GetFinalRunGraph()123   virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
124   bool IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) const;
CheckModelInputs(uint32_t graph_id,const std::vector<tensor::TensorPtr> & inputs,std::string * error_msg)125   virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs,
126                                 std::string *error_msg) const {
127     return true;
128   }
129   void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
130                           std::vector<std::string> *inputs_name) const;
131   void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
132                            std::vector<std::string> *output_names) const;
133   // create a single run op graph
134   std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const BackendOpRunInfoPtr &op_run_info,
135                                                       const std::vector<ValuePtr> &input_values,
136                                                       const std::vector<InputType> &input_type);
137   void EraseValueNodeTensor(const std::vector<InputType> &input_types,
138                             std::vector<tensor::TensorPtr> *input_tensors) const;
139   void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
140   static void RunOpHideNopNode(const KernelGraphPtr &kernel_graph);
ReportWarningMessage()141   virtual void ReportWarningMessage() {}
ReportErrorMessage()142   virtual void ReportErrorMessage() {}
SetThreadContext()143   virtual void SetThreadContext() {}
144 #ifdef ENABLE_DEBUGGER
145   // set debugger
SetDebugger()146   void SetDebugger() {
147     debugger_ = Debugger::GetInstance();
148     auto ms_context = MsContext::GetInstance();
149     MS_EXCEPTION_IF_NULL(ms_context);
150     MS_EXCEPTION_IF_NULL(debugger_);
151     debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
152   }
153 #endif
154   static BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
155                                          const std::vector<tensor::TensorPtr> &input_tensors,
156                                          std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
157                                          KernelMapTensor *node_to_tensor);
158 
159  private:
160   void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
161                          std::map<AnfNodePtr, size_t> *parameter_index) const;
162   void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
163                                VectorRef *const outputs,
164                                std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) const;
165   void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) const;
166   // Cut op not flatten, so we need calculate maketuple input ref count.
167   void CalculateRefCount(const AnfNodePtr &node, std::map<KernelWithIndex, size_t> *ref_count) const;
168   void GetForwardOpOutputRefCount(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
169                                   std::map<std::string, size_t> *forward_op_output_tensor_id,
170                                   const std::map<AnfNodePtr, size_t> &parameter_index) const;
171   void ReleaseForwardOpOutput(const std::vector<ValuePtr> &input_tensors,
172                               std::map<std::string, size_t> *forward_op_output_tensor_id) const;
173   void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
174                       std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const;
175 
176   void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
177                        const std::map<KernelWithIndex, size_t> &ref_count,
178                        std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map,
179                        GraphOutputInfo *const graph_output_info) const;
180 
181  protected:
182   friend class Executor;
183   friend class CompileNodesTask;
184   friend class CompileGraphTask;
185   friend class BuildGraphTask;
186   friend class RunGraphTask;
187   friend class mindspore::runtime::GraphCompiler;
IsSupportSummary()188   virtual bool IsSupportSummary() { return true; }
189   virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
190                                    VectorRef *outputs,
191                                    std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
192                                    KernelMapTensor *node_to_tensor);
193   // When the device address of the node is used as the output of the graph, the device address will be passed
194   // to the output tensor, and the output node will recreate a new device address. This third parameter records
195   // the relationship between the new and old device address.
196   virtual void UpdateOutputTensors(const VectorRef *outputs,
197                                    const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
198                                    std::map<DeviceAddressPtr, DeviceAddressPtr> *);
199   virtual void FinalOptimize(const KernelGraphPtr &graph) const;
CompileGraphImpl(const AnfNodePtrList & lst,const AnfNodePtrList & outputs)200   virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; }
CompileGraphImpl(NotNull<FuncGraphPtr>)201   virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr>) { return kInvalidGraphId; }
BuildGraphImpl(GraphId)202   virtual void BuildGraphImpl(GraphId) {}
PreExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)203   virtual void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
204                                const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
205     MS_EXCEPTION_IF_NULL(kernel_graph);
206     MS_EXCEPTION_IF_NULL(outputs);
207     MS_LOG(INFO) << "Call default PreExecuteGraph with input size: " << inputs.size();
208   }
209 
PostExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)210   virtual void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph,
211                                 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {
212     MS_EXCEPTION_IF_NULL(kernel_graph);
213     MS_EXCEPTION_IF_NULL(outputs);
214     MS_LOG(INFO) << "Call default PostExecuteGraph with input size: " << inputs.size();
215   }
216 
ExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph)217   virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); }
218 
219   void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
220 
221   void ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
222                                            const std::vector<tensor::TensorPtr> &input_tensors) const;
223 #ifndef ENABLE_SECURITY
224   virtual void SetSummaryNodes(KernelGraph *graph);
225   void RecurseSetSummaryNodesForAllGraphs(KernelGraph *graph);
226 #endif
227 
LoadInputs(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs_const)228   void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) const {
229     MS_LOG(INFO) << "Status record: start load input. graph id: " << graph_id;
230     auto kernel_graph = GetGraph(graph_id);
231     MS_EXCEPTION_IF_NULL(kernel_graph);
232     if (!kernel_graph->executable()) {
233       return;
234     }
235     LoadInputData(kernel_graph, inputs_const);
236     MS_LOG(INFO) << "Status record: end load input. graph id: " << graph_id;
237   }
238 
LoadInputData(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)239   virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
240                              const std::vector<tensor::TensorPtr> &inputs_const) const {
241     MS_EXCEPTION_IF_NULL(kernel_graph);
242     MS_LOG(INFO) << "Call default LoadInputData with input size: " << inputs_const.size();
243   }
244 
245   void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
246                      const std::vector<tensor::TensorPtr> &input_tensors,
247                      std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const;
248 #ifndef ENABLE_SECURITY
249   void Summary(KernelGraph *graph);
250 #endif
251   // create graph output for RunOp
252   void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) const;
253 
254   BackendOpRunInfoPtr GetSingleOpRunInfo(const CNodePtr &cnode, const InputInfo &input_info,
255                                          const GraphOutputInfo *const graph_output_info) const;
256   ValuePtr GetValueNodeOutput(const AnfNodePtr &node, size_t output_index) const;
257   tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
258                                              const std::map<AnfNodePtr, size_t> &parameter_index,
259                                              const std::vector<tensor::TensorPtr> &graph_inputs) const;
260   tensor::BaseTensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
261                                              const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output) const;
262   void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
263                          const std::map<AnfNodePtr, size_t> &parameter_index,
264                          const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const;
265   void GetOpInputTensorsFromCNode(const CNodePtr &cnode,
266                                   const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
267                                   const std::map<AnfNodePtr, size_t> &parameter_index,
268                                   const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const;
269   tensor::BaseTensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode,
270                                                 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
271                                                 const std::map<AnfNodePtr, size_t> &parameter_index,
272                                                 const std::vector<tensor::TensorPtr> &graph_inputs,
273                                                 InputInfo *input_info, size_t input_index) const;
274 
275   AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) const;
276   std::vector<uint32_t> GetAllReduceSplitIndex();
GetCommWorldGroup()277   virtual std::string GetCommWorldGroup() { return std::string(); }
278   void DumpGraphs(const std::vector<KernelGraphPtr> &graphs) const;
279   void GetConstValueDepend(const CNodePtr &cnode, std::set<int64_t> *const_input_attr_index) const;
280   mindspore::HashMap<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
281   std::shared_ptr<Context> context_;
282   CallBackFunc summary_callback_;
283   uint32_t device_id_;
284   // rank id of physical device
285   uint32_t rank_id_{0};
286   std::shared_ptr<Executor> executor_;
287 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
288   std::shared_ptr<Debugger> debugger_;
289 #endif
290 };
291 
292 using SessionPtr = std::shared_ptr<session::SessionBasic>;
293 using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>;
294 }  // namespace session
295 BACKEND_EXPORT void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
296                                       const std::vector<CNodePtr> &execution_order);
297 BACKEND_EXPORT uint32_t GetRankId();
298 }  // namespace mindspore
299 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
300