1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "pipeline/pynative/forward/forward.h" 24 #include "pipeline/pynative/grad/grad.h" 25 26 #include "pybind11/pybind11.h" 27 #include "frontend/operator/composite/composite.h" 28 #include "ir/anf.h" 29 #include "mindrt/include/fork_utils.h" 30 31 namespace mindspore::pynative { 32 namespace py = pybind11; 33 34 class PyNativeExecutor : public std::enable_shared_from_this<PyNativeExecutor> { 35 public: GetInstance()36 static const std::shared_ptr<PyNativeExecutor> &GetInstance() { 37 std::lock_guard<std::mutex> i_lock(instance_lock_); 38 if (executor_ == nullptr) { 39 executor_ = std::shared_ptr<PyNativeExecutor>(new (std::nothrow) PyNativeExecutor()); 40 Init(); 41 } 42 return executor_; 43 } 44 ~PyNativeExecutor() = default; 45 static void Init(); 46 PyNativeExecutor(const PyNativeExecutor &) = delete; 47 PyNativeExecutor &operator=(const PyNativeExecutor &) = delete; grad_executor()48 static inline const GradExecutorPtr &grad_executor() { 49 MS_EXCEPTION_IF_NULL(grad_executor_); 50 return grad_executor_; 51 } forward_executor()52 static inline const ForwardExecutorPtr &forward_executor() { 53 MS_EXCEPTION_IF_NULL(forward_executor_); 54 return forward_executor_; 55 } 56 57 void StoreAsyncStatus(const FrontendOpRunInfoPtr &op_run_info) const; 58 // Generate stub tensor and dispatch async task. 59 py::object RunOpStub(const py::args &args) const; 60 py::object RealRunOp(const py::args &args) const; 61 void SetAsyncForGraph(bool flag) const; 62 py::object CallConstantFolding(const py::args &args) const; 63 bool grad_flag() const; 64 void set_grad_flag(bool flag) const; 65 bool enable_grad() const; 66 void set_enable_grad(bool enable_grad) const; 67 void set_py_exe_path(const py::object &py_exe_path) const; 68 void set_kernel_build_server_dir(const py::object &kernel_build_server_dir) const; 69 void SetHookChanged(const py::object &cell) const; 70 void NewGraph(const py::object &obj, const py::args &args) const; 71 void EndGraph(const py::object &obj, const py::object &out, const py::args &args) const; 72 py::object RunGrad(const prim::GradOperationPtr &grad, const py::object &cell, const py::object &weights, 73 const py::object &grad_position, const py::args &args) const; 74 py::object GradJit(const py::object &out, const py::args &args) const; 75 void SetDynamicInput(const py::object &obj, const py::args &args) const; 76 py::object GetDynamicInput(const py::object &actual_input) const; 77 78 py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights, 79 const py::object &grad_hash_id, const py::args &args) const; 80 void ClearRes() const; 81 // Sync stream 82 void Sync() const; 83 bool IsFirstCell() const; 84 void WorkerJoin(); 85 void SetJitCompileStatus(bool is_compiling, const std::string &phase) const; 86 void SetIsRunRecompute(bool is_runing_recompute) const; 87 void ParentBeforeFork(); 88 void ChildAfterFork(); 89 py::object RunSliceOpStub(const std::vector<ValuePtr> &input_v, 90 const std::vector<SliceOpInfoPtr> &slice_op_infos) const; 91 92 private: 93 PyNativeExecutor() = default; 94 static std::shared_ptr<PyNativeExecutor> executor_; 95 static std::mutex instance_lock_; 96 static ForwardExecutorPtr forward_executor_; 97 static GradExecutorPtr grad_executor_; 98 }; 99 } // namespace mindspore::pynative 100 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_ 101