• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "pipeline/pynative/forward/forward.h"
24 #include "pipeline/pynative/grad/grad.h"
25 
26 #include "pybind11/pybind11.h"
27 #include "frontend/operator/composite/composite.h"
28 #include "ir/anf.h"
29 #include "mindrt/include/fork_utils.h"
30 
31 namespace mindspore::pynative {
32 namespace py = pybind11;
33 
34 class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> {
35  public:
GetInstance()36   static const std::shared_ptr<PyNativeExecutor> &GetInstance() {
37     std::lock_guard<std::mutex> i_lock(instance_lock_);
38     if (executor_ == nullptr) {
39       executor_ = std::shared_ptr<PyNativeExecutor>(new (std::nothrow) PyNativeExecutor());
40       Init();
41     }
42     return executor_;
43   }
44   ~PyNativeExecutor() = default;
45   static void Init();
46   PyNativeExecutor(const PyNativeExecutor &) = delete;
47   PyNativeExecutor &operator=(const PyNativeExecutor &) = delete;
grad_executor()48   static inline const GradExecutorPtr &grad_executor() {
49     MS_EXCEPTION_IF_NULL(grad_executor_);
50     return grad_executor_;
51   }
forward_executor()52   static inline const ForwardExecutorPtr &forward_executor() {
53     MS_EXCEPTION_IF_NULL(forward_executor_);
54     return forward_executor_;
55   }
56 
57   void StoreAsyncStatus(const FrontendOpRunInfoPtr &op_run_info) const;
58   // Generate stub tensor and dispatch async task.
59   py::object RunOpStub(const py::args &args) const;
60   py::object RealRunOp(const py::args &args) const;
61   void SetAsyncForGraph(bool flag) const;
62   py::object CallConstantFolding(const py::args &args) const;
63   bool grad_flag() const;
64   void set_grad_flag(bool flag) const;
65   bool enable_grad() const;
66   void set_enable_grad(bool enable_grad) const;
67   void set_py_exe_path(const py::object &py_exe_path) const;
68   void set_kernel_build_server_dir(const py::object &kernel_build_server_dir) const;
69   void SetHookChanged(const py::object &cell) const;
70   void NewGraph(const py::object &obj, const py::args &args) const;
71   void EndGraph(const py::object &obj, const py::object &out, const py::args &args) const;
72   py::object RunGrad(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights,
73                      const py::object &grad_position, const py::args &args) const;
74   py::object GradJit(const py::object &out, const py::args &args) const;
75   void SetDynamicInput(const py::object &obj, const py::args &args) const;
76   py::object GetDynamicInput(const py::object &actual_input) const;
77 
78   py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
79                              const py::object &grad_hash_id, const py::args &args) const;
80   void ClearRes() const;
81   // Sync stream
82   void Sync() const;
83   bool IsFirstCell() const;
84   void WorkerJoin();
85   void SetJitCompileStatus(bool is_compiling, const std::string &phase) const;
86   void SetIsRunRecompute(bool is_runing_recompute) const;
87   void ParentBeforeFork();
88   void ChildAfterFork();
89   py::object RunSliceOpStub(const std::vector<ValuePtr> &input_v,
90                             const std::vector<SliceOpInfoPtr> &slice_op_infos) const;
91 
92  private:
93   PyNativeExecutor() = default;
94   static std::shared_ptr<PyNativeExecutor> executor_;
95   static std::mutex instance_lock_;
96   static ForwardExecutorPtr forward_executor_;
97   static GradExecutorPtr grad_executor_;
98 };
99 }  // namespace mindspore::pynative
100 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
101