1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 18 19 #include <vector> 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 #include <memory> 24 #include <map> 25 #include <set> 26 #include "backend/session/session_context.h" 27 #include "backend/session/kernel_graph.h" 28 #include "backend/session/anf_runtime_algorithm.h" 29 #include "ir/anf.h" 30 #include "ir/tensor.h" 31 #include "utils/any.h" 32 #include "utils/contract.h" 33 #include "runtime/device/kernel_info.h" 34 #include "utils/ms_context.h" 35 #include "runtime/device/bucket.h" 36 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 37 #include "debug/debugger/debugger.h" 38 #endif 39 #include "runtime/hardware/device_context.h" 40 #include "backend/session/pynative_task_manager.h" 41 42 namespace mindspore { 43 namespace runtime { 44 class GraphCompiler; 45 } // namespace runtime 46 } // namespace mindspore 47 48 namespace mindspore { 49 using GraphId = uint32_t; 50 using GraphInfo = std::string; 51 const char kSessionBasic[] = "SessionBasic"; 52 53 namespace session { 54 using CallBackFunc = uint32_t (*)(uint32_t graph_id, 55 const std::map<std::string, mindspore::tensor::TensorPtr> ¶ms_list); 56 using AnyList = std::vector<Any>; 57 using AnyListPtr = std::shared_ptr<AnyList>; 58 59 struct OpRunInfo { 60 std::string op_name; 61 PrimitivePtr primitive; 62 AbstractBasePtr abstract; 63 bool is_dynamic_shape = false; 64 bool is_auto_mixed_precision = false; 65 bool lazy_build = false; 66 std::string next_op_name = ""; 67 #if defined(__APPLE__) 68 int next_input_index = 0; 69 #else 70 size_t next_input_index = 0; 71 #endif 72 }; 73 74 struct InputTensorInfo { 75 std::vector<tensor::TensorPtr> input_tensors; 76 std::vector<int64_t> input_tensors_mask; 77 std::set<KernelWithIndex> input_kernel; 78 }; 79 80 struct OutputTensorInfo { 81 tensor::TensorPtr output_stub_tensor; 82 bool is_weight; 83 }; 84 85 struct GraphOutputInfo { 86 VectorRef *graph_outputs; 87 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes; 88 std::vector<tensor::TensorPtr> graph_output_tensors; 89 }; 90 91 class Executor; 92 93 class SessionBasic : public std::enable_shared_from_this<SessionBasic> { 94 public: SessionBasic()95 SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { 96 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 97 debugger_ = nullptr; 98 #endif 99 } 100 Init(uint32_t device_id)101 virtual void Init(uint32_t device_id) { device_id_ = device_id; } 102 void InitExecutor(const std::string &device_name, uint32_t device_id); SyncStream()103 virtual void SyncStream() const {} ~SessionBasic()104 virtual ~SessionBasic() { summary_callback_ = nullptr; } 105 106 GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); 107 GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); 108 void BuildGraph(GraphId graphId); 109 void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 110 void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 111 void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 112 const std::vector<int64_t> &tensors_mask); 113 void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 114 115 #ifndef ENABLE_SECURITY 116 virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); 117 #endif 118 119 bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph); 120 121 std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs, 122 bool common_opt = true); 123 std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph, 124 std::vector<KernelGraphPtr> *all_out_graph); 125 126 CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, 127 std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); 128 CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); 129 130 // get graph id in child graphs by ME front anf node pointer 131 virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; GetFinalRunGraph()132 virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } 133 void AssignParamKey(const KernelGraphPtr &kernel_graph); 134 void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); 135 bool IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name); CheckModelInputs(uint32_t graph_id,const std::vector<tensor::TensorPtr> & inputs,std::string * error_msg)136 virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, 137 std::string *error_msg) const { 138 return true; 139 } 140 void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs, 141 std::vector<std::string> *inputs_name) const; 142 void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs, 143 std::vector<std::string> *outputs_name) const; 144 std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id, 145 const std::vector<tensor::TensorPtr> &inputs); 146 // Get graph by graph id, if not exist return null ptr 147 KernelGraphPtr GetGraph(GraphId graph_id) const; 148 void ClearGraph(); 149 // create a single run op graph 150 std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info, 151 const std::vector<tensor::TensorPtr> &input_tensors, 152 const std::vector<int64_t> &tensors_mask, bool is_ascend = false); 153 void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, 154 std::vector<tensor::TensorPtr> *input_tensors) const; 155 void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const; 156 static void RunOpHideNopNode(const KernelGraphPtr &kernel_graph); ReportWarningMessage()157 virtual void ReportWarningMessage() {} ReportErrorMessage()158 virtual void ReportErrorMessage() {} SetThreadContext()159 virtual void SetThreadContext() {} 160 #ifdef ENABLE_DEBUGGER 161 // set debugger SetDebugger()162 void SetDebugger() { 163 debugger_ = Debugger::GetInstance(); 164 auto ms_context = MsContext::GetInstance(); 165 MS_EXCEPTION_IF_NULL(ms_context); 166 MS_EXCEPTION_IF_NULL(debugger_); 167 debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET)); 168 } 169 #endif 170 171 private: 172 CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph); 173 std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); 174 std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); 175 void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); 176 std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph); 177 void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const; 178 void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs, 179 std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); 180 std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); 181 void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs); 182 void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, 183 const FuncGraphManagerPtr &front_func_graph_manager, 184 const std::shared_ptr<KernelGraph> &backend_graph); 185 std::string AddPartialParametersMap(const AnfNodePtr &partial_node); 186 void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, 187 std::map<AnfNodePtr, size_t> *parameter_index); 188 void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, 189 VectorRef *const outputs, 190 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes); 191 void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count); 192 void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count, 193 std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map); 194 195 void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, 196 const std::map<KernelWithIndex, size_t> &ref_count, 197 std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, 198 GraphOutputInfo *const graph_output_info); 199 200 protected: 201 friend class Executor; 202 friend class CompileNodesTask; 203 friend class CompileGraphTask; 204 friend class BuildGraphTask; 205 friend class RunGraphTask; 206 friend class RunOpTask; 207 friend class RunOpsInGraphTask; 208 friend class mindspore::runtime::GraphCompiler; IsSupportSummary()209 virtual bool IsSupportSummary() { return true; } 210 virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, 211 VectorRef *outputs, 212 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node); 213 // When the device address of the node is used as the output of the graph, the device address will be passed 214 // to the output tensor, and the output node will recreate a new device address. This third parameter records 215 // the relationship between the new and old device address. 216 virtual void UpdateOutputTensors(const VectorRef *outputs, 217 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 218 std::map<DeviceAddressPtr, DeviceAddressPtr> *); 219 virtual void UnifyMindIR(const KernelGraphPtr &graph); 220 virtual void FinalOptimize(const KernelGraphPtr &graph) const; CompileGraphImpl(const AnfNodePtrList & lst,const AnfNodePtrList & outputs)221 virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } CompileGraphImpl(NotNull<FuncGraphPtr> func_graph)222 virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } BuildGraphImpl(GraphId)223 virtual void BuildGraphImpl(GraphId) {} PreExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)224 virtual void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, 225 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {} PostExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)226 virtual void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, 227 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {} ExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph)228 virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {} 229 void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); BuildOpImpl(const OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask)230 virtual KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, 231 const std::vector<tensor::TensorPtr> &input_tensors, 232 const std::vector<int64_t> &tensors_mask) { 233 return nullptr; 234 } RunOpImpl(const GraphInfo & graph_info,OpRunInfo * op_run_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)235 virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, 236 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 237 const std::vector<int64_t> &tensors_mask) {} RunOpImplOrigin(const GraphInfo & graph_info,OpRunInfo * op_run_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)238 virtual void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info, 239 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 240 const std::vector<int64_t> &tensors_mask) {} 241 void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); BuildOpsInGraph(const GraphId & graph_id,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,const std::map<KernelWithIndex,size_t> & cnode_refcount)242 virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index, 243 const std::vector<tensor::TensorPtr> &graph_inputs, 244 const std::map<KernelWithIndex, size_t> &cnode_refcount) {} 245 #ifndef ENABLE_SECURITY 246 virtual void SetSummaryNodes(KernelGraph *graph); 247 #endif 248 LoadInputs(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs_const)249 void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) { 250 auto kernel_graph = GetGraph(graph_id); 251 MS_EXCEPTION_IF_NULL(kernel_graph); 252 if (!kernel_graph->executable()) { 253 return; 254 } 255 MS_LOG(INFO) << "Load inputs"; 256 LoadInputData(kernel_graph, inputs_const); 257 } 258 ExecuteAllTaskInQueue()259 virtual void ExecuteAllTaskInQueue() {} 260 LoadInputData(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)261 virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, 262 const std::vector<tensor::TensorPtr> &inputs_const) const {} 263 void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, 264 const std::vector<tensor::TensorPtr> &input_tensors, 265 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const; 266 void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const; 267 #ifndef ENABLE_SECURITY 268 void Summary(KernelGraph *graph); 269 #endif 270 // create graph output for RunOp 271 void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph); 272 CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); 273 // Generate graph info for a single op graph 274 GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors); 275 void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info); 276 tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index); 277 tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node, 278 const std::map<AnfNodePtr, size_t> ¶meter_index, 279 const std::vector<tensor::TensorPtr> &graph_inputs); 280 tensor::TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index, 281 const std::map<KernelWithIndex, tensor::TensorPtr> &op_output); 282 void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output, 283 const std::map<AnfNodePtr, size_t> ¶meter_index, 284 const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info); 285 tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode, 286 const std::map<KernelWithIndex, tensor::TensorPtr> &op_output, 287 const std::map<AnfNodePtr, size_t> ¶meter_index, 288 const std::vector<tensor::TensorPtr> &graph_inputs, 289 InputTensorInfo *const input_tensor_info, size_t input_index); 290 291 // create a new kernel graph and update the graph sum 292 KernelGraphPtr NewKernelGraph(); 293 AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); 294 virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph); 295 ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); 296 ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); 297 AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph); 298 void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph); 299 void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); 300 AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); 301 void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph); 302 void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); CreateBucket(uint32_t bucket_id,uint32_t bucket_size)303 virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; } 304 void InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context = nullptr); 305 void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor); 306 void ClearAllBucket(const GraphId &graph_id); 307 std::vector<uint32_t> GetAllReduceSplitIndex(); GetCommWorldGroup()308 virtual std::string GetCommWorldGroup() { return std::string(); } 309 void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph); 310 #if ((defined ENABLE_CPU) && (!defined _WIN32)) 311 void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; 312 void GetBatchElements(const AnfNodePtr &kernel_node) const; 313 void InitPsWorker(const KernelGraphPtr &kernel_graph); 314 #endif 315 316 std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_; 317 std::map<uint32_t, uint32_t> free_bucket_id_map_; 318 std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; 319 std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; 320 std::unordered_map<FuncGraph *, KernelGraphPtr> front_backend_graph_map_; 321 std::unordered_map<AnfNodePtr, AnfNodePtr> partial_parameters_map_; 322 std::unordered_map<AnfNodePtr, std::string> partial_target_map_; 323 std::shared_ptr<Context> context_; 324 CallBackFunc summary_callback_; 325 static GraphId graph_sum_; 326 uint32_t device_id_; 327 // rank id of physical device 328 uint32_t rank_id_{0}; 329 std::shared_ptr<Executor> executor_; 330 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 331 std::shared_ptr<Debugger> debugger_; 332 #endif 333 }; 334 335 using SessionPtr = std::shared_ptr<session::SessionBasic>; 336 using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>; 337 } // namespace session 338 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir, 339 const std::vector<CNodePtr> &execution_order); 340 uint32_t GetRankId(); 341 } // namespace mindspore 342 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 343