• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <mutex>
27 #include <stack>
28 #include <set>
29 #include <map>
30 
31 #include "pybind11/pybind11.h"
32 #include "pybind11/numpy.h"
33 #include "pybind_api/ir/base_ref_py.h"
34 #include "ir/anf.h"
35 #include "frontend/optimizer/ad/kpynative.h"
36 #include "frontend/operator/composite/composite.h"
37 #include "pipeline/jit/resource.h"
38 #include "pipeline/pynative/base.h"
39 #include "pipeline/pynative/pynative_cache.h"
40 #include "utils/ms_context.h"
41 
42 namespace mindspore::pynative {
43 namespace py = pybind11;
44 using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
45 using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
46 using OpInfoWithMsFuncForwardTensors = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
47 
48 py::object RealRunOp(const py::args &args);
49 
50 struct GraphInfo {
51   std::string cell_id;
52   AnfNodePtr output;
53   OrderedMap<std::string, ParameterPtr> params;  // hold input parameters and cell weights
54   std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
55   GraphInfo() = default;
GraphInfoGraphInfo56   explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
57 };
58 using GraphInfoPtr = std::shared_ptr<GraphInfo>;
59 
60 class TopCellInfo {
61  public:
62   TopCellInfo() = default;
63   ~TopCellInfo() = default;
TopCellInfo(bool topest,size_t grad_order,pipeline::ResourcePtr r,FuncGraphPtr df,std::string cellid,std::string alread_run_cell_id)64   TopCellInfo(bool topest, size_t grad_order, pipeline::ResourcePtr r, FuncGraphPtr df, std::string cellid,
65               std::string alread_run_cell_id)
66       : is_topest_(topest),
67         grad_order_(grad_order),
68         resource_(std::move(r)),
69         df_builder_(std::move(df)),
70         cell_id_(std::move(cellid)),
71         alread_run_cell_id_(std::move(alread_run_cell_id)) {}
72 
is_init_kpynative()73   bool is_init_kpynative() const { return is_init_kpynative_; }
set_init_kpynative(bool init)74   void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
is_topest()75   bool is_topest() const { return is_topest_; }
grad_order()76   size_t grad_order() const { return grad_order_; }
set_grad_order(size_t grad_order)77   void set_grad_order(size_t grad_order) { grad_order_ = grad_order; }
is_dynamic()78   bool is_dynamic() const { return is_dynamic_; }
set_is_dynamic(bool is_dynamic)79   void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
vm_compiled()80   bool vm_compiled() const { return vm_compiled_; }
set_vm_compiled(bool vm_compiled)81   void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
ms_function_flag()82   bool ms_function_flag() const { return ms_function_flag_; }
set_ms_function_flag(bool ms_function_flag)83   void set_ms_function_flag(bool ms_function_flag) { ms_function_flag_ = ms_function_flag; }
need_compile_graph()84   bool need_compile_graph() const { return need_compile_graph_; }
set_need_compile_graph(bool need_compile_graph)85   void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
forward_already_run()86   bool forward_already_run() const { return forward_already_run_; }
set_forward_already_run(bool set_forward_already_run)87   void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
resource()88   pipeline::ResourcePtr resource() const { return resource_; }
df_builder()89   FuncGraphPtr df_builder() const { return df_builder_; }
op_num()90   size_t op_num() const { return op_num_; }
set_op_num(size_t op_num)91   void set_op_num(size_t op_num) { op_num_ = op_num; }
cell_id()92   const std::string &cell_id() const { return cell_id_; }
already_run_cell_id()93   const std::string &already_run_cell_id() const { return alread_run_cell_id_; }
input_args_id()94   const std::string &input_args_id() const { return input_args_id_; }
set_input_args_id(const std::string & input_args_id)95   void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
all_op_info()96   std::string &all_op_info() { return all_op_info_; }
grad_operation()97   const std::string &grad_operation() const { return grad_operation_; }
set_grad_operation(const std::string & grad_operation)98   void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
sub_cell_list()99   std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; }
100   bool IsSubCell(const std::string &cell_id) const;
graph_info_map()101   OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
op_info_with_tensor_id()102   OpInfoWithTensorId &op_info_with_tensor_id() { return op_info_with_tensor_id_; }
tensor_id_with_tensor_object()103   TensorIdWithTensorObject &tensor_id_with_tensor_object() { return tensor_id_with_tensor_object_; }
k_pynative_cell_ptr()104   ad::KPynativeCellPtr k_pynative_cell_ptr() const { return k_pynative_cell_ptr_; }
set_k_pynative_cell_ptr(const ad::KPynativeCellPtr & k_pynative_cell_ptr)105   void set_k_pynative_cell_ptr(const ad::KPynativeCellPtr &k_pynative_cell_ptr) {
106     k_pynative_cell_ptr_ = k_pynative_cell_ptr;
107   }
op_info_with_ms_func_forward_tensors()108   const OpInfoWithMsFuncForwardTensors &op_info_with_ms_func_forward_tensors() const {
109     return op_info_with_ms_func_forward_tensors_;
110   }
set_op_info_with_ms_func_forward_tensors(const std::string & op_info,const std::vector<tensor::TensorPtr> & forward_tensors)111   void set_op_info_with_ms_func_forward_tensors(const std::string &op_info,
112                                                 const std::vector<tensor::TensorPtr> &forward_tensors) {
113     op_info_with_ms_func_forward_tensors_[op_info] = forward_tensors;
114   }
115   void ClearDeviceMemory();
116   void Clear();
117 
118  private:
119   bool is_topest_{false};
120   bool is_dynamic_{false};
121   bool vm_compiled_{false};
122   bool ms_function_flag_{false};
123   bool is_init_kpynative_{false};
124   bool forward_already_run_{false};
125   bool need_compile_graph_{false};
126   size_t op_num_{0};
127   size_t grad_order_{0};
128   pipeline::ResourcePtr resource_{nullptr};
129   FuncGraphPtr df_builder_{nullptr};
130   ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
131   std::string cell_id_;
132   std::string alread_run_cell_id_;
133   std::string input_args_id_;
134   std::string all_op_info_;
135   std::string grad_operation_;
136   OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
137   std::unordered_set<std::string> sub_cell_list_;
138   OpInfoWithTensorId op_info_with_tensor_id_;
139   TensorIdWithTensorObject tensor_id_with_tensor_object_;
140   OpInfoWithMsFuncForwardTensors op_info_with_ms_func_forward_tensors_;
141 };
142 using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
143 
144 class ForwardExecutor;
145 using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
146 using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
147 
148 class GradExecutor;
149 using GradExecutorPtr = std::shared_ptr<GradExecutor>;
150 using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
151 
152 class GradExecutor {
153  public:
154   GradExecutor() = default;
155   ~GradExecutor() = default;
156   explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
forward_executor_(ForwardExecutorWeakPtr (forward_executor))157       : forward_executor_(ForwardExecutorWeakPtr(forward_executor)) {}
158 
159   std::function<void(py::object *, const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2,
160                                                                                              auto &&PH3) {
161     NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
162   };
163   std::function<void(py::object *, const py::object &, const py::object &, const py::args &)> LinkGraph =
164     [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) {
165       EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
166                     std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4));
167     };
168   std::function<void(py::object *, const prim::GradOperationPtr &, const py::object &, const py::object &,
169                      const py::args &)>
170     GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) {
171       GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
172                    std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5));
173     };
174   std::function<void(py::object *, const py::object &, const py::tuple &)> RunGraph = [this](auto &&PH1, auto &&PH2,
175                                                                                              auto &&PH3) {
176     RunGradGraph(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
177   };
178 
179   FuncGraphPtr curr_g() const;
180   TopCellInfoPtr top_cell() const;
181   void CheckNeedCompileGraph();
182   void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell);
GetHighOrderStackSize()183   size_t GetHighOrderStackSize() const { return high_order_stack_.size(); }
184   TopCellInfoPtr GetTopCell(const string &already_run_cell_id);
185   void EnableOpGraphCache(bool is_enable);
need_renormalize()186   bool need_renormalize() const { return need_renormalize_; }
enable_op_cache()187   bool enable_op_cache() const { return enable_op_cache_; }
set_top_cell(TopCellInfoPtr top_cell)188   void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
grad_flag()189   bool grad_flag() const { return grad_flag_; }
set_grad_flag(bool flag)190   void set_grad_flag(bool flag) { grad_flag_ = flag; }
set_graph_phase(const std::string & graph_phase)191   void set_graph_phase(const std::string &graph_phase) { graph_phase_ = graph_phase; }
in_cell_with_custom_bprop_()192   bool in_cell_with_custom_bprop_() const { return custom_bprop_cell_count_ > 0; }
193   AnfNodePtr GetInput(const py::object &obj, bool op_mask);
194   std::string GetCellId(const py::object &obj, const py::args &args);
195   void RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
need_construct_graph()196   bool need_construct_graph() const { return !cell_stack_.empty() && grad_flag_; }
197   // Construct grad graph for ms_function
eliminate_forward()198   bool eliminate_forward() const { return eliminate_forward_; }
set_eliminate_forward(bool eliminate_forward)199   void set_eliminate_forward(bool eliminate_forward) { eliminate_forward_ = eliminate_forward; }
200   py::object GradMsFunction(const py::object &out, const py::args &args);
201   void GradMsFunctionInner(const std::string &phase, const py::object &out, const py::args &args,
202                            const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph);
203   void UpdateMsFunctionForwardTensors(const OpExecInfoPtr &op_exec_info, const ValuePtr &new_forward_value);
204   void MakeAdjointForMsFunction(const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph,
205                                 const py::object &actual_out, const py::args &args, const ValuePtr &actual_out_v);
206   void MakeCNodeForMsFunction(const FuncGraphPtr &ms_func_graph, const py::args &args, ValuePtrList *input_values,
207                               CNodePtr *ms_function_cnode);
208   void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const CNodePtr &cnode);
209   void DoOpGrad(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const ValuePtr &op_out);
210   // Update forward tensors info
211   void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const ValuePtr &op_out);
212   void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
213   py::object CheckGraph(const py::object &cell, const py::args &args);
214   void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
215   py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
216   void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);
217   void ClearGrad(const py::object &cell, const py::args &args);
218   void ClearRes();
219   void ClearCellRes(const std::string &cell_id = "");
220 
221  private:
222   ForwardExecutorPtr forward() const;
223   // Higher derivative
224   inline bool IsNestedGrad() const;
225   void SwitchTopcell();
226   void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const py::tuple &forward_args,
227                           std::vector<AnfNodePtr> *inputs, ValuePtrList *weights_args);
228   void MakeNestedCnode(const py::object &cell, const py::tuple &forward_args, const pipeline::ResourcePtr &resource,
229                        const py::object &out);
230   void PushCellStack(const std::string &cell_id);
231   void PopCellStack();
232   TopCellInfoPtr PopHighOrderGraphStack();
233   void HandleInputArgsForTopCell(const py::args &args, bool is_bprop_top);
234   void InitResourceAndDfBuilder(const std::string &cell_id, const py::args &args);
235   void MakeNewTopGraph(const string &cell_id, const py::args &args, bool is_topest);
236   void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compiled);
237   // Manage resource when run grad process.
238   bool IsBpropGraph(const std::string &cell_id);
239   bool IsCellObjIdEq(const std::string &l_cell_id, const std::string &r_cell_id) const;
240   void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
241   void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
242   void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
243   void DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args);
244   std::string GetAlreadyRunCellId(const std::string &cell_id);
245   std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args);
246   void GradNetInner(py::object *ret, const prim::GradOperationPtr &grad, const py::object &cell,
247                     const py::object &weights, const py::args &args);
248   FuncGraphPtr GetBpropGraph(const prim::GradOperationPtr &grad, const py::object &cell,
249                              const std::vector<AnfNodePtr> &weights, size_t arg_size, const py::args &args);
250   std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
251   abstract::AbstractBasePtrList GetArgsSpec(const py::list &args, const FuncGraphPtr &bprop_graph);
252   // Manage resource for construct forward graph.
graph_phase()253   const std::string &graph_phase() const { return graph_phase_; }
254   AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
255   AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
256   void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
257                                       const std::vector<int64_t> &index_sequence, bool is_param = false);
258   void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
259                                   bool is_param = false);
SetParamNodeMapInGraphInfoMap(const FuncGraphPtr & g,const std::string & id,const ParameterPtr & param)260   void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) const {
261     auto &graph_info = top_cell()->graph_info_map()[g];
262     MS_EXCEPTION_IF_NULL(graph_info);
263     graph_info->params[id] = param;
264   }
265   void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
266                                 int64_t index = -1) const {
267     auto &graph_info = top_cell()->graph_info_map()[g];
268     MS_EXCEPTION_IF_NULL(graph_info);
269     graph_info->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
270   }
SetNodeMapInGraphInfoMap(const FuncGraphPtr & g,const std::string & id,const AnfNodePtr & node,const std::vector<int64_t> & index)271   void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
272                                 const std::vector<int64_t> &index) const {
273     auto &graph_info = top_cell()->graph_info_map()[g];
274     MS_EXCEPTION_IF_NULL(graph_info);
275     graph_info->node_map[id] = std::make_pair(node, index);
276   }
277   void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id);
278 
279  private:
280   bool grad_flag_{false};
281   bool enable_op_cache_{true};
282   bool grad_is_running_{false};
283   bool need_renormalize_{false};
284   bool eliminate_forward_{true};
285   int custom_bprop_cell_count_{0};
286   size_t grad_order_{0};
287   size_t top_cell_switch_counts_{0};
288 
289   // The graph phase is used to obtain backend graph that is complied by ms_function
290   std::string graph_phase_;
291   // The cell run check graph which will be top cell
292   std::string check_graph_cell_id_;
293   std::string grad_operation_;
294   // Only set in high grad
295   FuncGraphPtr curr_g_{nullptr};
296   // For clear pre top res
297   TopCellInfoPtr top_cell_{nullptr};
298   // Records forwrad cell, the bottom is top cell
299   std::stack<std::string> cell_stack_;
300   // For high grad of bprop
301   std::stack<std::pair<std::string, bool>> bprop_grad_stack_;
302   std::vector<std::string> bprop_cell_list_;
303   // For high grad order
304   std::stack<std::pair<FuncGraphPtr, TopCellInfoPtr>> high_order_stack_;
305   // Use vector for keep order
306   std::vector<TopCellInfoPtr> top_cell_list_;
307   // Record all top cell which has been ran
308   std::unordered_map<std::string, TopCellInfoPtr> already_run_top_cell_;
309   // Use vector for keep order
310   ForwardExecutorWeakPtr forward_executor_;
311 };
312 
313 class ForwardExecutor {
314  public:
315   ForwardExecutor() = default;
316   ~ForwardExecutor() = default;
317 
318   std::function<void(py::object *, const OpExecInfoPtr &)> RunOpS = [this](auto &&PH1, auto &&PH2) {
319     RunOpInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
320   };
321 
322   void RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info);
323   OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
set_grad_executor(const GradExecutorPtr & grad_executor)324   void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
node_abs_map()325   std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
326   void ClearRes();
327   CNodePtr ConstructForwardGraph(const OpExecInfoPtr &op_exec_info);
set_lazy_build(bool lazy_build)328   void set_lazy_build(bool lazy_build) { lazy_build_ = lazy_build; }
329 
330  private:
331   GradExecutorPtr grad() const;
332   MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
333   py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
334   void RunMixedPrecisionCastOp(const OpExecInfoPtr &op_exec_info, py::object *ret);
335   py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
336   py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
337   py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
338                                     PynativeStatusCode *status);
339   void SetNonCostantValueAbs(const AbstractBasePtr &abs, size_t i, const std::string &id);
340   void GetInputsArgsSpec(const OpExecInfoPtr &op_exec_info, abstract::AbstractBasePtrList *args_spec_list);
341   void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
342                            bool *prim_cache_hit);
343   void GetOpOutput(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
344                    const CNodePtr &cnode, bool prim_cache_hit, py::object *ret);
345   // Mix precision and Implicit transform
346   void SetCastForInputs(const OpExecInfoPtr &op_exec_info);
347   void SetTensorMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
348   void SetImplicitCast(const OpExecInfoPtr &op_exec_info);
349   py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
350   py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name,
351                                           size_t index);
352   py::object DoAutoCastTuple(const py::tuple &tuple, const TypeId &type_id, const std::string &op_name, size_t index);
353   py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
354   void DoSignatrueCast(const PrimitivePyPtr &prim, const std::unordered_map<SignatureEnumDType, TypeId> &dst_type,
355                        const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
356 
357  private:
358   GradExecutorWeakPtr grad_executor_;
359   PrimAbsCache prim_abs_list_;
360   ImplicitCastCache implicit_cast_map_;
361   std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
362   bool lazy_build_{false};
363 };
364 
365 class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
366  public:
GetInstance()367   static std::shared_ptr<PynativeExecutor> GetInstance() {
368     std::lock_guard<std::mutex> i_lock(instance_lock_);
369     if (executor_ == nullptr) {
370       executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
371       forward_executor_ = std::make_shared<ForwardExecutor>();
372       grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
373       forward_executor_->set_grad_executor(grad_executor_);
374     }
375     return executor_;
376   }
377   ~PynativeExecutor() = default;
378   PynativeExecutor(const PynativeExecutor &) = delete;
379   PynativeExecutor &operator=(const PynativeExecutor &) = delete;
380   GradExecutorPtr grad_executor() const;
381   ForwardExecutorPtr forward_executor() const;
382 
383   bool grad_flag() const;
384   void set_grad_flag(bool flag);
385   void set_graph_phase(const std::string &graph_phase);
386   void set_py_exe_path(const py::object &py_exe_path);
387   void set_kernel_build_server_dir(const py::object &kernel_build_server_dir);
388   void NewGraph(const py::object &cell, const py::args &args);
389   void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
390   void GradNet(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
391                const py::args &args);
392   py::object GradMsFunction(const py::object &out, const py::args &args);
393   py::object CheckGraph(const py::object &cell, const py::args &args);
394   py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, const py::args &args);
395   py::object Run(const py::object &cell, const py::tuple &args);
396 
397   // Used by graph clean
398   // Cell destruct will call
399   void ClearCell(const std::string &cell_id);
400   void ClearGrad(const py::object &cell, const py::args &args);
401   // Abnormal existed
402   void ClearRes();
403   // Sync stream
404   void Sync();
405   void SetLazyBuild(bool enable);
406   void ExecuteAllTask();
407   void EnterCell();
408   void ExitCell();
409   bool IsTopCell() const;
410 
411  private:
412   PynativeExecutor() = default;
413 
414   static std::shared_ptr<PynativeExecutor> executor_;
415   static std::mutex instance_lock_;
416   static ForwardExecutorPtr forward_executor_;
417   static GradExecutorPtr grad_executor_;
418   uint32_t cell_depth_{0};
419 };
420 
421 using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
422 }  // namespace mindspore::pynative
423 
424 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
425