• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_
19 
20 #include <memory>
21 #include <string>
22 #include <map>
23 #include <utility>
24 #include <stack>
25 #include <vector>
26 #include "pipeline/pynative/forward/do_cast.h"
27 #include "pipeline/pynative/forward/do_pyboost_cast.h"
28 #include "pipeline/pynative/forward/do_infer.h"
29 #include "backend/graph_compiler/backend.h"
30 #include "ir/cell.h"
31 #include "runtime/pipeline/async_hqueue.h"
32 #include "ops/view/view_strides_calculator.h"
33 #include "runtime/pipeline/async_rqueue.h"
34 
35 namespace mindspore {
36 namespace pynative {
37 class GradExecutor;
38 using GradExecutorPtr = std::shared_ptr<GradExecutor>;
39 using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
40 
41 using MindrtBackendMap = std::map<std::string, std::shared_ptr<compile::MindRTBackend>>;
42 
43 class ForwardExecutor {
44  public:
ForwardExecutor()45   ForwardExecutor()
46       : cast_operation_(std::make_shared<CastOperation>()),
47         pyboost_cast_operation_(std::make_shared<PyBoostCastOperation>()),
48         infer_operation_(std::make_shared<InferOperation>()) {}
49   ~ForwardExecutor() = default;
50 
51   void Init();
52   std::function<void(const FrontendOpRunInfoPtr &)> RunOpS = [this](auto &&PH1) {
53     RunOpFrontend(std::forward<decltype(PH1)>(PH1));
54   };
55 
56   void DispatchFrontendTask(const FrontendOpRunInfoPtr &op_run_info);
57   void RunOpFrontend(const FrontendOpRunInfoPtr &op_run_info);
58   // If sub is true, this function will not convert StubTensor to Tensor.
59   // Used to reduce the overhead of StubTensor WaitValue.
60   FrontendOpRunInfoPtr GenerateOpRunInfo(const py::args &args, bool stub = false);
61   ValuePtr RunSliceOpFrontend(const std::vector<ValuePtr> &input_values,
62                               const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
63                               const stub::StubNodePtr &stub_output);
64   void DispatchSilceOpFrontendTask(const std::vector<ValuePtr> &input_values,
65                                    const std::vector<SliceOpInfoPtr> &slice_op_infos, bool requires_grad,
66                                    const stub::StubNodePtr &stub_output);
set_grad_executor(const GradExecutorPtr & grad_executor)67   void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
68   void RefreshForwardCallback();
69   void ClearNodeAbsMap() const;
70   void SetNodeAbsMapByValue(const FrontendOpRunInfoPtr &op_run_info) const;
71   void SetNodeAbsMapById(const std::string &id, const abstract::AbstractBasePtr &abs) const;
72   AbstractBasePtr GetNodeAbsById(const std::string &id) const;
73   void ClearRes();
74   bool EnablePipeline(const std::string &op_name) const;
75   bool enable_async() const;
device_target()76   inline const std::string &device_target() const { return device_target_; }
mindrt_backend()77   const MindrtBackendMap &mindrt_backend() const { return mindrt_backends_; }
IsFirstCell()78   inline bool IsFirstCell() const { return forward_cell_stack_.empty(); }
PushForwardCell(const CellPtr & cell)79   void PushForwardCell(const CellPtr &cell) { forward_cell_stack_.push(cell); }
PopForwardCell()80   void PopForwardCell() { forward_cell_stack_.pop(); }
81   void ExecuteLazyTask() const;
82   void Sync();
83   void PrintPyObjInfo(const py::object &obj, const std::string &str, bool is_cell) const;
84   void ProcessBeforeNewGraph(const py::object &obj, const py::args &args);
85   void ProcessAfterNewGraph(const py::object &obj) const;
86   void ProcessBeforeEndGraph(const py::object &obj, bool is_cell);
87   void ProcessAfterEndGraph(const py::object &obj, bool is_cell) const;
88   bool CellNotSetMixedPrecision(const FrontendOpRunInfoPtr &op_run_info);
infer_operation()89   inline InferOperationPtr infer_operation() const {
90     MS_EXCEPTION_IF_NULL(infer_operation_);
91     return infer_operation_;
92   }
set_is_jit_compiling(bool is_jit_compiling)93   inline void set_is_jit_compiling(bool is_jit_compiling) { is_jit_compiling_ = is_jit_compiling; }
is_jit_compiling()94   bool is_jit_compiling() const { return is_jit_compiling_; }
95 
96   void WaitForwardTask();
97   bool IsVmOp(const std::string &op_name) const;
98   std::string GetCurrentCellObjId() const;
99   std::string GetCurrentDeviceTarget(const PrimitivePtr &op_prim) const;
100   void ReInit();
101   void ForwardOpGradImpl(const FrontendOpRunInfoPtr &op_run_info) const;
102   GradExecutorPtr grad() const;
103   void InitOpRunInfo(const FrontendOpRunInfoPtr &op_run_info);
104   // Mix precision and Implicit transform
105   void SetCastForInputs(const FrontendOpRunInfoPtr &op_run_info) const;
pyboost_cast_operation()106   inline const PyBoostCastOperationPtr &pyboost_cast_operation() const {
107     MS_EXCEPTION_IF_NULL(pyboost_cast_operation_);
108     return pyboost_cast_operation_;
109   }
110   void ChildAfterFork();
111 
112  private:
113   compile::MindRTBackendPtr GetMindRtBackend(const string &cur_device_target);
cast_operation()114   inline CastOperationPtr cast_operation() const {
115     MS_EXCEPTION_IF_NULL(cast_operation_);
116     return cast_operation_;
117   }
118   ValuePtr RunOpInVM(const FrontendOpRunInfoPtr &op_run_info) const;
119   ValuePtr RunOpInMs(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info);
120   ValuePtr RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info);
121   ValuePtr RunOpWithBackendPolicy(const FrontendOpRunInfoPtr &op_run_info,
122                                   const BackendOpRunInfoPtr &backend_op_run_info);
123   void RunOpBackend(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info);
124   void RunOpBackendSync(const FrontendOpRunInfoPtr &op_run_info);
125 
126   VectorRef RunOpBackendInner(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info);
127   // Infer output abstract
128   void InferOutputAbstract(const FrontendOpRunInfoPtr &op_run_info) const;
129   void PrepareOpInputs(const FrontendOpRunInfoPtr &op_run_info);
130   void OpRunInfoUsePrimC(const FrontendOpRunInfoPtr &op_run_info) const;
131   void CreateInputAddressForViewOp(const tensor::BaseTensorPtr &input_tensor, const FrontendOpRunInfoPtr &op_run_info);
132   void DispatchViewKernelTask(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type);
133   void ForwardRunViewKernelTask(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type,
134                                 bool enable_async);
135 
136   bool ProcessViewOp(const FrontendOpRunInfoPtr &op_run_info, const ops::StridesCalcFunc &func_info,
137                      bool is_tuple_output);
138   device::DeviceAddressPtr TensorContiguousCallback(const DeviceSyncPtr &device_address,
139                                                     const TensorStorageInfoPtr &storage_info);
140 
141   void CreateViewOutputTensor(const FrontendOpRunInfoPtr &op_run_info, const tensor::BaseTensorPtr &input_tensor,
142                               const TensorStorageInfoPtr &storage_info, runtime::KernelTaskType task_type);
143 
144   void DispatchAllocateMemTask(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor,
145                                const size_t &input_idx, bool need_wait = false);
146   PrimitivePtr GetSlicePrimFromCache(const std::string &op_name);
147   FrontendOpRunInfoPtr GenerateSliceOpRunInfo(const std::string &op_name, bool requires_grad,
148                                               const stub::StubNodePtr &stub_output);
149   void CreateViewOpOutputs(const FrontendOpRunInfoPtr &op_run_info, const tensor::BaseTensorPtr &view_input_tensor,
150                            runtime::KernelTaskType task_type, const TensorStorageInfoPtrList &storage_infos,
151                            bool is_tuple_output);
152 
153  private:
154   bool init_{false};
155   bool enable_async_{true};
156   bool is_jit_compiling_{false};
157   std::string device_target_;
158   std::string last_target_{"Unknown"};
159   std::stack<CellPtr> forward_cell_stack_;
160   GradExecutorWeakPtr grad_executor_;
161   CastOperationPtr cast_operation_;
162   PyBoostCastOperationPtr pyboost_cast_operation_;
163   InferOperationPtr infer_operation_;
164   MindrtBackendMap mindrt_backends_;
165   mindspore::HashMap<std::string, PrimitivePtr> slice_prim_cache_;
166 };
167 }  // namespace pynative
168 }  // namespace mindspore
169 
170 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_H_
171