• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <stack>
24 #include <set>
25 #include <vector>
26 #include <map>
27 #include "pipeline/pynative/base.h"
28 #include "pipeline/pynative/grad/top_cell.h"
29 #include "pipeline/pynative/grad/jit/jit_grad.h"
30 #include "runtime/pipeline/async_hqueue.h"
31 #include "pipeline/pynative/grad/bprop_task.h"
32 #include "pipeline/pynative/grad/ir/dynamic_shape.h"
33 #include "pipeline/pynative/grad/variable.h"
34 #include "pipeline/jit/ps/resource.h"
35 namespace mindspore {
36 namespace pynative {
37 namespace py = pybind11;
38 class ForwardExecutor;
39 using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
40 using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
41 
42 class GradExecutor {
43   // key: already run cell id, value: all already run top cell
44   using TopCellIdWithTopCell = std::map<std::string, TopCellInfoPtr>;
45   // key: already run cell id, value: pipeline top cell
46   using PipelineTopCellMap = std::map<std::string, std::vector<TopCellInfoPtr>>;
47 
48  public:
49   GradExecutor() = default;
50   ~GradExecutor() = default;
51   explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
forward_executor_(ForwardExecutorWeakPtr (forward_executor))52       : forward_executor_(ForwardExecutorWeakPtr(forward_executor)),
53         jit_(std::make_shared<Jit>()),
54         dynamic_shape_(std::make_shared<DynamicShape>()),
55         bprop_queue_(std::make_shared<runtime::AsyncHqueue>("bprop_queue")),
56         assist_queue_(std::make_shared<runtime::AsyncHqueue>("assist_queue")) {}
57 
58   void Init();
59   std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) {
60     NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
61   };
62   std::function<void(const py::object &, const py::object &, const py::args &)> LinkGraph = [this](auto &&PH1,
63                                                                                                    auto &&PH2,
64                                                                                                    auto &&PH3) {
65     EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
66   };
67   std::function<py::object(const prim::GradOperationPtr &, const py::object &, const py::object &, const py::object &,
68                            const py::args &)>
69     Run = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) {
70       return RunGrad(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
71                      std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4),
72                      std::forward<decltype(PH5)>(PH5));
73     };
top_cell()74   inline TopCellInfoPtr top_cell() const {
75     MS_EXCEPTION_IF_NULL(top_cell_);
76     return top_cell_;
77   }
dynamic_shape()78   inline DynamicShapePtr dynamic_shape() const {
79     MS_EXCEPTION_IF_NULL(dynamic_shape_);
80     return dynamic_shape_;
81   }
jit()82   inline JitPtr jit() const {
83     MS_EXCEPTION_IF_NULL(jit_);
84     return jit_;
85   }
86 
TopCellHasNotBeenCreate()87   inline bool TopCellHasNotBeenCreate() const { return top_cell_ == nullptr; }
set_top_cell(TopCellInfoPtr top_cell)88   inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
grad_flag()89   inline bool grad_flag() const { return grad_flag_; }
set_grad_flag(bool flag)90   inline void set_grad_flag(bool flag) { grad_flag_ = flag; }
enable_grad()91   inline bool enable_grad() const { return enable_grad_; }
set_enable_grad(bool enable_grad)92   inline void set_enable_grad(bool enable_grad) { enable_grad_ = enable_grad; }
RequiresGrad()93   inline bool RequiresGrad() const { return enable_grad() && grad_flag(); }
set_is_run_recompute(bool is_run_recompute)94   inline void set_is_run_recompute(bool is_run_recompute) { is_run_recompute_ = is_run_recompute; }
95   // Construct grad graph for jit
custom_bprop_cell_count()96   inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; }
bprop_queue()97   inline runtime::AsyncHqueuePtr bprop_queue() const { return bprop_queue_; }
already_run_top_cell()98   TopCellIdWithTopCell &already_run_top_cell() { return already_run_top_cell_; }
99   void SetHookChanged(const py::object &cell) const;
100   py::object RunGrad(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
101                      const py::object &grad_position, const py::args &args);
102   py::object RunGradFunc(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args,
103                          const std::vector<size_t> &p_args);
104   py::object RunGradGraph();
105   CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
106   void RecordForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const;
107   void RecordForwardGraphForInput(const ValuePtr &value, const string &input_id,
108                                   const abstract::AbstractBasePtr &param_abs);
109   void RecordNestedGraph(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
110                          const std::vector<ValuePtr> &forward_args, const ValuePtr &out);
111   py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
112                              const py::object &grad_hash_id, const py::args &args);
113   TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const;
114   TopCellInfoPtr GetPipelineRunTopCell(const std::string &already_run_cell_id) const;
115   TopCellInfoPtr GetPipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id,
116                                     bool is_reverse_match) const;
117   void ErasePipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id,
118                             bool is_pipeline_ir_top_cell);
119   void GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr &grad, const py::object &obj,
120                                         const py::args &args);
121   bool ReplacePipelineTopCellForwardOutput();
122   void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
123   AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
124   AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
125   void UpdateTopCellForwardTensorInfoInBpropGraph(const string &op_info, const ValuePtr &v,
126                                                   const size_t &stream_id) const;
127   void ClearRes();
128   void AsyncClearTopCell();
129   void AsyncClearAutoGradCell(const TopCellInfoPtr &top_cell);
130   void WorkerJoin();
131   void WaitBpropTask() const;
132   void SaveDynamicInputsCells(const py::object &obj, const py::args &args);
133   void SetTopCellDynamicAttr(const py::object &cell);
use_dynamic_shape_process()134   bool use_dynamic_shape_process() const {
135     if (top_cell_ == nullptr) {
136       return false;
137     }
138     return top_cell()->use_dynamic_shape_process();
139   }
140 
set_use_dynamic_shape_process(bool use_dynamic_shape_process)141   void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
142     if (top_cell_ == nullptr) {
143       return;
144     }
145     return top_cell()->set_use_dynamic_shape_process(use_dynamic_shape_process);
146   }
147 
forward_use_dynamic_shape_process()148   inline bool forward_use_dynamic_shape_process() const { return forward_use_dynamic_shape_process_; }
set_forward_use_dynamic_shape_process(bool forward_use_dynamic_shape_process)149   inline void set_forward_use_dynamic_shape_process(bool forward_use_dynamic_shape_process) {
150     forward_use_dynamic_shape_process_ = forward_use_dynamic_shape_process;
151   }
is_cell_has_dynamic_inputs(const std::string & obj_id)152   inline bool is_cell_has_dynamic_inputs(const std::string &obj_id) const {
153     return dynamic_inputs_cells_.count(obj_id) > 0;
154   }
155   std::string GetAlreadyRunCellId(const std::string &obj_id) const;
156 
is_high_order_top_cell()157   inline bool is_high_order_top_cell() const { return top_cell_ != nullptr && top_cell_->is_high_order_top_cell(); }
158   void ChildAfterFork();
159 
160  private:
161   ForwardExecutorPtr forward() const;
curr_g()162   inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
PushTopCellStack(const TopCellInfoPtr & top_cell)163   inline void PushTopCellStack(const TopCellInfoPtr &top_cell) {
164     top_cell_stack_.push(top_cell);
165     MS_LOG(DEBUG) << "Push top cell " << top_cell << " on top cell stack";
166   }
167   bool NeedIncreaseGradOrder(const std::string &obj_id);
168   void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
169                          const CNodePtr &cnode) const;
170   void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info) const;
171   AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const;
172   void SetBpropGraphJitLevel(const py::object &obj) const;
173   void ClearGlobalRes() const;
174   void ClearGradRes();
175   void ClearPipelineTopCellRes();
176 
177   // Higher derivative
IsNestedGrad()178   inline bool IsNestedGrad() const { return grad_order_ > 1; }
IncreaseGradOrder()179   inline void IncreaseGradOrder() {
180     ++grad_order_;
181     MS_LOG(DEBUG) << "Increase grad order, current grad_order is " << grad_order_;
182   }
DecreaseGradOrder()183   inline void DecreaseGradOrder() {
184     if (grad_order_ > 0) {
185       --grad_order_;
186     }
187     MS_LOG(DEBUG) << "Decrease grad order, current grad_order is " << grad_order_;
188   }
IsHighOrderTopCell()189   inline bool IsHighOrderTopCell() const {
190     return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_;
191   }
kernel_graph_id_for_control_flow()192   uint32_t kernel_graph_id_for_control_flow() { return --kernel_graph_id_for_control_flow_; }
193   void ClearPreTopCell(const TopCellInfoPtr &new_top_cell, bool is_need_clear_device_mem);
194   bool GetTopCellDynamicFlag(const InputArgsInfoPtr &input_args_info, const std::string &obj_id_with_grad_order);
195   void SwitchTopCell();
196   TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id, const std::string &input_args_id);
197   void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
198                           const std::vector<ValuePtr> &forward_args, AnfNodePtrList *inputs);
199   void MakeNestedCnode(bool has_custom_bprop, const std::vector<ValuePtr> &forward_args,
200                        const FuncGraphPtr &cur_run_bprop_graph, const BaseRef &out);
201   TopCellInfoPtr PopTopCellStack();
202   void PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info);
203   void PopInputArgsInfoStack();
204   void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info);
205   bool IsNewCellId();
206   void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info, bool is_bprop_need_get_forward_graph);
207   bool IsCreateIrGrad();
208   void MakeNewTopCell(const InputArgsInfoPtr &input_args_info);
209   bool NewTopCellIsPipelineTopCell(const InputArgsInfoPtr &input_args_info);
210 
211   // Manage resource when run grad process.
212   void NewGraphInner(const py::object &obj, const py::args &args);
213   InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, bool is_bprop_need_get_forward_graph);
214   void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
215   void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
216   void SetForwardLastNodeInfo(const ValuePtr &v) const;
217   void GetCustomBpropPrim(const py::object &obj, const py::args &args, const InputArgsInfoPtr &input_args_info);
218   void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) const;
219   void CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info);
220   void GetGradGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args,
221                     const std::vector<size_t> &p_args);
222   FuncGraphPtr GetBpropGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args,
223                              const std::vector<size_t> &p_args);
224   std::vector<tensor::BaseTensorPtr> GetWeightsArgs(const py::object &weights, bool *weight_param_is_tuple) const;
225   std::vector<tensor::BaseTensorPtr> GetDefaultWeights() const;
226   void CheckParamShapeAndType(const ParameterPtr &param_node, const abstract::AbstractBasePtr &input_abs,
227                               const abstract::AbstractBasePtr &ir_abs) const;
228   void UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph) const;
229   std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const;
230   // Manage resource for construct forward graph.
231   AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const;
232   AnfNodePtr GetValueSequenceInput(const ValuePtr &v) const;
233   AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
234                                     const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
235   void DispatchGradQueueTask(std::function<void(void)> &&task) const;
236   void DispatchAssistQueueTask(std::function<void(void)> task) const;
237   void ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr &input_args_info) const;
238   void ClearBpropTask() const;
239 
240   bool init_{false};
241   bool grad_flag_{false};
242   bool enable_grad_{true};
243   bool is_run_recompute_{false};
244   bool save_graphs_{false};
245   bool forward_use_dynamic_shape_process_{false};
246 
247   uint32_t kernel_graph_id_for_control_flow_{UINT32_MAX};
248   size_t custom_bprop_cell_count_{0};
249 
250   // If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
251   size_t grad_order_{0};
252   // if call grad not set_grad first, grad first is true.
253   bool grad_first_{false};
254 
255   // Used for auto grad map reserve
256   size_t op_num_in_bprop_graph_{kDefaultContainerSize};
257   std::string grad_operation_;
258 
259   TopCellInfoPtr top_cell_{nullptr};
260   InputArgsInfoPtr top_input_args_info_{nullptr};
261 
262   // Records every cell info for share, regardless of whether you need construct grad graph
263   std::stack<InputArgsInfoPtr> input_args_info_stack_;
264 
265   // For top cell nested top cell, import for high-order grad
266   std::stack<TopCellInfoPtr> top_cell_stack_;
267 
268   // Used for set grad scenario. If top cell set in CheckAlreadyRun, no need find again in RunGrad;
269   TopCellInfoPtr finded_top_cell_;
270   // Record all top cells that have been run
271   TopCellIdWithTopCell already_run_top_cell_;
272   // Record pipeline top cells.
273   PipelineTopCellMap pipeline_top_cell_map_;
274 
275   std::set<std::string> dynamic_inputs_cells_;
276   std::vector<TopCellInfoPtr> need_gc_top_cell_list_;
277 
278   ForwardExecutorWeakPtr forward_executor_;
279   JitPtr jit_;
280   DynamicShapePtr dynamic_shape_{nullptr};
281   runtime::AsyncHqueuePtr bprop_queue_;
282   runtime::AsyncHqueuePtr assist_queue_;
283 };
284 }  // namespace pynative
285 }  // namespace mindspore
286 
287 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_
288