1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <unordered_map> 25 #include <unordered_set> 26 #include <mutex> 27 #include <stack> 28 #include <set> 29 #include <map> 30 31 #include "pybind11/pybind11.h" 32 #include "pybind11/numpy.h" 33 #include "pybind_api/ir/base_ref_py.h" 34 #include "ir/anf.h" 35 #include "frontend/optimizer/ad/kpynative.h" 36 #include "frontend/operator/composite/composite.h" 37 #include "pipeline/jit/resource.h" 38 #include "pipeline/pynative/base.h" 39 #include "pipeline/pynative/pynative_cache.h" 40 #include "utils/ms_context.h" 41 42 namespace mindspore::pynative { 43 namespace py = pybind11; 44 using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>; 45 using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>; 46 using OpInfoWithMsFuncForwardTensors = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>; 47 48 py::object RealRunOp(const py::args &args); 49 50 struct GraphInfo { 51 std::string cell_id; 52 AnfNodePtr output; 53 OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weights 54 std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map; 55 GraphInfo() = default; GraphInfoGraphInfo56 explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} 57 }; 58 using GraphInfoPtr = std::shared_ptr<GraphInfo>; 59 60 class TopCellInfo { 61 public: 62 TopCellInfo() = default; 63 ~TopCellInfo() = default; TopCellInfo(bool topest,size_t grad_order,pipeline::ResourcePtr r,FuncGraphPtr df,std::string cellid,std::string alread_run_cell_id)64 TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr df, std::string cellid, 65 std::string alread_run_cell_id) 66 : is_topest_(topest), 67 grad_order_(grad_order), 68 resource_(std::move(r)), 69 df_builder_(std::move(df)), 70 cell_id_(std::move(cellid)), 71 alread_run_cell_id_(std::move(alread_run_cell_id)) {} 72 is_init_kpynative()73 bool is_init_kpynative() const { return is_init_kpynative_; } set_init_kpynative(bool init)74 void set_init_kpynative(bool init) { is_init_kpynative_ = init; } is_topest()75 bool is_topest() const { return is_topest_; } grad_order()76 size_t grad_order() const { return grad_order_; } set_grad_order(size_t grad_order)77 void set_grad_order(size_t grad_order) { grad_order_ = grad_order; } is_dynamic()78 bool is_dynamic() const { return is_dynamic_; } set_is_dynamic(bool is_dynamic)79 void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } vm_compiled()80 bool vm_compiled() const { return vm_compiled_; } set_vm_compiled(bool vm_compiled)81 void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; } ms_function_flag()82 bool ms_function_flag() const { return ms_function_flag_; } set_ms_function_flag(bool ms_function_flag)83 void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; } need_compile_graph()84 bool need_compile_graph() const { return need_compile_graph_; } set_need_compile_graph(bool need_compile_graph)85 void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } forward_already_run()86 bool forward_already_run() const { return forward_already_run_; } set_forward_already_run(bool set_forward_already_run)87 void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } resource()88 pipeline::ResourcePtr resource() const { return resource_; } df_builder()89 FuncGraphPtr df_builder() const { return df_builder_; } op_num()90 size_t op_num() const { return op_num_; } set_op_num(size_t op_num)91 void set_op_num(size_t op_num) { op_num_ = op_num; } cell_id()92 const std::string &cell_id() const { return cell_id_; } already_run_cell_id()93 const std::string &already_run_cell_id() const { return alread_run_cell_id_; } input_args_id()94 const std::string &input_args_id() const { return input_args_id_; } set_input_args_id(const std::string & input_args_id)95 void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; } all_op_info()96 std::string &all_op_info() { return all_op_info_; } grad_operation()97 const std::string &grad_operation() const { return grad_operation_; } set_grad_operation(const std::string & grad_operation)98 void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; } sub_cell_list()99 std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; } 100 bool IsSubCell(const std::string &cell_id) const; graph_info_map()101 OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; } op_info_with_tensor_id()102 OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; } tensor_id_with_tensor_object()103 TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; } k_pynative_cell_ptr()104 ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; } set_k_pynative_cell_ptr(const ad::KPynativeCellPtr & k_pynative_cell_ptr)105 void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) { 106 k_pynative_cell_ptr_ = k_pynative_cell_ptr; 107 } op_info_with_ms_func_forward_tensors()108 const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const { 109 return op_info_with_ms_func_forward_tensors_; 110 } set_op_info_with_ms_func_forward_tensors(const std::string & op_info,const std::vector<tensor::TensorPtr> & forward_tensors)111 void set_op_info_with_ms_func_forward_tensors(const std::string &op_info, 112 const std::vector<tensor::TensorPtr> &forward_tensors) { 113 op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors; 114 } 115 void ClearDeviceMemory(); 116 void Clear(); 117 118 private: 119 bool is_topest_{false}; 120 bool is_dynamic_{false}; 121 bool vm_compiled_{false}; 122 bool ms_function_flag_{false}; 123 bool is_init_kpynative_{false}; 124 bool forward_already_run_{false}; 125 bool need_compile_graph_{false}; 126 size_t op_num_{0}; 127 size_t grad_order_{0}; 128 pipeline::ResourcePtr resource_{nullptr}; 129 FuncGraphPtr df_builder_{nullptr}; 130 ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr}; 131 std::string cell_id_; 132 std::string alread_run_cell_id_; 133 std::string input_args_id_; 134 std::string all_op_info_; 135 std::string grad_operation_; 136 OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_; 137 std::unordered_set<std::string> sub_cell_list_; 138 OpInfoWithTensorId op_info_with_tensor_id_; 139 TensorIdWithTensorObject tensor_id_with_tensor_object_; 140 OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_; 141 }; 142 using TopCellInfoPtr = std::shared_ptr<TopCellInfo>; 143 144 class ForwardExecutor; 145 using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>; 146 using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>; 147 148 class GradExecutor; 149 using GradExecutorPtr = std::shared_ptr<GradExecutor>; 150 using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>; 151 152 class GradExecutor { 153 public: 154 GradExecutor() = default; 155 ~GradExecutor() = default; 156 explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr) forward_executor_(ForwardExecutorWeakPtr (forward_executor))157 : forward_executor_(ForwardExecutorWeakPtr(forward_executor)) {} 158 159 std::function<void(py::object *, const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2, 160 auto &&PH3) { 161 NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3)); 162 }; 163 std::function<void(py::object *, const py::object &, const py::object &, const py::args &)> LinkGraph = 164 [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) { 165 EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), 166 std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4)); 167 }; 168 std::function<void(py::object *, const prim::GradOperationPtr &, const py::object &, const py::object &, 169 const py::args &)> 170 GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) { 171 GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3), 172 std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5)); 173 }; 174 std::function<void(py::object *, const py::object &, const py::tuple &)> RunGraph = [this](auto &&PH1, auto &&PH2, 175 auto &&PH3) { 176 RunGradGraph(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3)); 177 }; 178 179 FuncGraphPtr curr_g() const; 180 TopCellInfoPtr top_cell() const; 181 void CheckNeedCompileGraph(); 182 void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell); GetHighOrderStackSize()183 size_t GetHighOrderStackSize() const { return high_order_stack_.size(); } 184 TopCellInfoPtr GetTopCell(const string &already_run_cell_id); 185 void EnableOpGraphCache(bool is_enable); need_renormalize()186 bool need_renormalize() const { return need_renormalize_; } enable_op_cache()187 bool enable_op_cache() const { return enable_op_cache_; } set_top_cell(TopCellInfoPtr top_cell)188 void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); } grad_flag()189 bool grad_flag() const { return grad_flag_; } set_grad_flag(bool flag)190 void set_grad_flag(bool flag) { grad_flag_ = flag; } set_graph_phase(const std::string & graph_phase)191 void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; } in_cell_with_custom_bprop_()192 bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; } 193 AnfNodePtr GetInput(const py::object &obj, bool op_mask); 194 std::string GetCellId(const py::object &obj, const py::args &args); 195 void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out); need_construct_graph()196 bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; } 197 // Construct grad graph for ms_function eliminate_forward()198 bool eliminate_forward() const { return eliminate_forward_; } set_eliminate_forward(bool eliminate_forward)199 void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; } 200 py::object GradMsFunction(const py::object &out, const py::args &args); 201 void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args, 202 const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph); 203 void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const ValuePtr &new_forward_value); 204 void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph, 205 const py::object &actual_out, const py::args &args, const ValuePtr &actual_out_v); 206 void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values, 207 CNodePtr *ms_function_cnode); 208 void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode); 209 void DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out); 210 // Update forward tensors info 211 void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out); 212 void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const; 213 py::object CheckGraph(const py::object &cell, const py::args &args); 214 void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args); 215 py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args); 216 void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell); 217 void ClearGrad(const py::object &cell, const py::args &args); 218 void ClearRes(); 219 void ClearCellRes(const std::string &cell_id = ""); 220 221 private: 222 ForwardExecutorPtr forward() const; 223 // Higher derivative 224 inline bool IsNestedGrad() const; 225 void SwitchTopcell(); 226 void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args, 227 std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args); 228 void MakeNestedCnode(const py::object &cell, const py::tuple &forward_args, const pipeline::ResourcePtr &resource, 229 const py::object &out); 230 void PushCellStack(const std::string &cell_id); 231 void PopCellStack(); 232 TopCellInfoPtr PopHighOrderGraphStack(); 233 void HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top); 234 void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args); 235 void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest); 236 void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled); 237 // Manage resource when run grad process. 238 bool IsBpropGraph(const std::string &cell_id); 239 bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const; 240 void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); 241 void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args); 242 void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args); 243 void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args); 244 std::string GetAlreadyRunCellId(const std::string &cell_id); 245 std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args); 246 void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell, 247 const py::object &weights, const py::args &args); 248 FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell, 249 const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args); 250 std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); 251 abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph); 252 // Manage resource for construct forward graph. graph_phase()253 const std::string &graph_phase() const { return graph_phase_; } 254 AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id); 255 AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); 256 void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, 257 const std::vector<int64_t> &index_sequence, bool is_param = false); 258 void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, 259 bool is_param = false); SetParamNodeMapInGraphInfoMap(const FuncGraphPtr & g,const std::string & id,const ParameterPtr & param)260 void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) const { 261 auto &graph_info = top_cell()->graph_info_map()[g]; 262 MS_EXCEPTION_IF_NULL(graph_info); 263 graph_info->params[id] = param; 264 } 265 void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, 266 int64_t index = -1) const { 267 auto &graph_info = top_cell()->graph_info_map()[g]; 268 MS_EXCEPTION_IF_NULL(graph_info); 269 graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index}); 270 } SetNodeMapInGraphInfoMap(const FuncGraphPtr & g,const std::string & id,const AnfNodePtr & node,const std::vector<int64_t> & index)271 void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, 272 const std::vector<int64_t> &index) const { 273 auto &graph_info = top_cell()->graph_info_map()[g]; 274 MS_EXCEPTION_IF_NULL(graph_info); 275 graph_info->node_map[id] = std::make_pair(node, index); 276 } 277 void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id); 278 279 private: 280 bool grad_flag_{false}; 281 bool enable_op_cache_{true}; 282 bool grad_is_running_{false}; 283 bool need_renormalize_{false}; 284 bool eliminate_forward_{true}; 285 int custom_bprop_cell_count_{0}; 286 size_t grad_order_{0}; 287 size_t top_cell_switch_counts_{0}; 288 289 // The graph phase is used to obtain backend graph that is complied by ms_function 290 std::string graph_phase_; 291 // The cell run check graph which will be top cell 292 std::string check_graph_cell_id_; 293 std::string grad_operation_; 294 // Only set in high grad 295 FuncGraphPtr curr_g_{nullptr}; 296 // For clear pre top res 297 TopCellInfoPtr top_cell_{nullptr}; 298 // Records forwrad cell, the bottom is top cell 299 std::stack<std::string> cell_stack_; 300 // For high grad of bprop 301 std::stack<std::pair<std::string, bool>> bprop_grad_stack_; 302 std::vector<std::string> bprop_cell_list_; 303 // For high grad order 304 std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>> high_order_stack_; 305 // Use vector for keep order 306 std::vector<TopCellInfoPtr> top_cell_list_; 307 // Record all top cell which has been ran 308 std::unordered_map<std::string, TopCellInfoPtr> already_run_top_cell_; 309 // Use vector for keep order 310 ForwardExecutorWeakPtr forward_executor_; 311 }; 312 313 class ForwardExecutor { 314 public: 315 ForwardExecutor() = default; 316 ~ForwardExecutor() = default; 317 318 std::function<void(py::object *, const OpExecInfoPtr &)> RunOpS = [this](auto &&PH1, auto &&PH2) { 319 RunOpInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); 320 }; 321 322 void RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info); 323 OpExecInfoPtr GenerateOpExecInfo(const py::args &args); set_grad_executor(const GradExecutorPtr & grad_executor)324 void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); } node_abs_map()325 std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; } 326 void ClearRes(); 327 CNodePtr ConstructForwardGraph(const OpExecInfoPtr &op_exec_info); set_lazy_build(bool lazy_build)328 void set_lazy_build(bool lazy_build) { lazy_build_ = lazy_build; } 329 330 private: 331 GradExecutorPtr grad() const; 332 MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); 333 py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); 334 void RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret); 335 py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); 336 py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); 337 py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, 338 PynativeStatusCode *status); 339 void SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id); 340 void GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info, abstract::AbstractBasePtrList *args_spec_list); 341 void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, 342 bool *prim_cache_hit); 343 void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, 344 const CNodePtr &cnode, bool prim_cache_hit, py::object *ret); 345 // Mix precision and Implicit transform 346 void SetCastForInputs(const OpExecInfoPtr &op_exec_info); 347 void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info); 348 void SetImplicitCast(const OpExecInfoPtr &op_exec_info); 349 py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index); 350 py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name, 351 size_t index); 352 py::object DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name, size_t index); 353 py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index); 354 void DoSignatrueCast(const PrimitivePyPtr &prim, const std::unordered_map<SignatureEnumDType, TypeId> &dst_type, 355 const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info); 356 357 private: 358 GradExecutorWeakPtr grad_executor_; 359 PrimAbsCache prim_abs_list_; 360 ImplicitCastCache implicit_cast_map_; 361 std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; 362 bool lazy_build_{false}; 363 }; 364 365 class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { 366 public: GetInstance()367 static std::shared_ptr<PynativeExecutor> GetInstance() { 368 std::lock_guard<std::mutex> i_lock(instance_lock_); 369 if (executor_ == nullptr) { 370 executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor()); 371 forward_executor_ = std::make_shared<ForwardExecutor>(); 372 grad_executor_ = std::make_shared<GradExecutor>(forward_executor_); 373 forward_executor_->set_grad_executor(grad_executor_); 374 } 375 return executor_; 376 } 377 ~PynativeExecutor() = default; 378 PynativeExecutor(const PynativeExecutor &) = delete; 379 PynativeExecutor &operator=(const PynativeExecutor &) = delete; 380 GradExecutorPtr grad_executor() const; 381 ForwardExecutorPtr forward_executor() const; 382 383 bool grad_flag() const; 384 void set_grad_flag(bool flag); 385 void set_graph_phase(const std::string &graph_phase); 386 void set_py_exe_path(const py::object &py_exe_path); 387 void set_kernel_build_server_dir(const py::object &kernel_build_server_dir); 388 void NewGraph(const py::object &cell, const py::args &args); 389 void EndGraph(const py::object &cell, const py::object &out, const py::args &args); 390 void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights, 391 const py::args &args); 392 py::object GradMsFunction(const py::object &out, const py::args &args); 393 py::object CheckGraph(const py::object &cell, const py::args &args); 394 py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args); 395 py::object Run(const py::object &cell, const py::tuple &args); 396 397 // Used by graph clean 398 // Cell destruct will call 399 void ClearCell(const std::string &cell_id); 400 void ClearGrad(const py::object &cell, const py::args &args); 401 // Abnormal existed 402 void ClearRes(); 403 // Sync stream 404 void Sync(); 405 void SetLazyBuild(bool enable); 406 void ExecuteAllTask(); 407 void EnterCell(); 408 void ExitCell(); 409 bool IsTopCell() const; 410 411 private: 412 PynativeExecutor() = default; 413 414 static std::shared_ptr<PynativeExecutor> executor_; 415 static std::mutex instance_lock_; 416 static ForwardExecutorPtr forward_executor_; 417 static GradExecutorPtr grad_executor_; 418 uint32_t cell_depth_{0}; 419 }; 420 421 using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>; 422 } // namespace mindspore::pynative 423 424 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 425