• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <mutex>
25 #include <stack>
26 #include <set>
27 #include <map>
28 #include "include/common/utils/convert_utils.h"
29 #include "include/common/profiler.h"
30 #include "utils/hash_map.h"
31 #include "utils/hash_set.h"
32 #include "pybind11/numpy.h"
33 #include "pybind11/pytypes.h"
34 #include "pybind_api/ir/base_ref_py.h"
35 #include "ir/anf.h"
36 #include "pipeline/pynative/grad/auto_grad.h"
37 #include "frontend/operator/composite/composite.h"
38 #include "pipeline/jit/ps/resource.h"
39 #include "pipeline/pynative/base.h"
40 #include "pipeline/pynative/grad/ir/bprop_tensor_replace.h"
41 #include "utils/ms_context.h"
42 
43 namespace mindspore {
44 namespace pynative {
45 namespace py = pybind11;
46 class GradExecutor;
47 using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, AnfNodePtrList>;
48 
49 struct PyNGraphInfo {
50   OrderedMap<std::string, ParameterPtr> input_params;   // Hold input parameters
51   OrderedMap<std::string, ParameterPtr> weight_params;  // Hold weights parameters
52   // Hold op op output or combination of output
53   mindspore::HashMap<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
54 };
55 using GraphInfoPtr = std::shared_ptr<PyNGraphInfo>;
56 
57 using MetaGradInfoMap = mindspore::OrderedMap<tensor::BaseTensorPtr, AutoGradMetaDataPtr>;
58 
59 class TopCellInfo {
60  public:
61   TopCellInfo() = default;
62   ~TopCellInfo() = default;
TopCellInfo(bool is_high_order_top_cell,size_t grad_order,std::string obj_id_with_grad_order,std::string cellid,std::string already_run_cell_id,pipeline::ResourcePtr r,FuncGraphPtr fg,size_t reserve_size)63   TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid,
64               std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg, size_t reserve_size)
65       : is_high_order_top_cell_(is_high_order_top_cell),
66         grad_order_(grad_order),
67         obj_id_with_grad_order_(std::move(obj_id_with_grad_order)),
68         cell_id_(std::move(cellid)),
69         already_run_cell_id_(std::move(already_run_cell_id)),
70         resource_(std::move(r)),
71         fg_(std::move(fg)) {
72     meta_grad_info_.reserve(reserve_size);
73   }
74 
is_init_kpynative()75   inline bool is_init_kpynative() const { return is_init_kpynative_; }
set_init_kpynative(bool init)76   inline void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
grad_order()77   inline size_t grad_order() const { return grad_order_; }
set_hook_changed(bool hook_changed)78   inline void set_hook_changed(bool hook_changed) { hook_changed_ = hook_changed; }
set_sub_cell_hook_changed(const std::string & sub_cell)79   inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); }
cell_backward_hook_op()80   inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; }
81   void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op);
82   void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info, bool is_jit_graph) const;
ClearCellHookOp()83   inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
forward_already_run()84   inline bool forward_already_run() const { return forward_already_run_; }
set_forward_already_run(bool set_forward_already_run)85   inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
need_compile_graph()86   inline bool need_compile_graph() const { return need_compile_graph_; }
set_need_compile_graph(bool need_compile_graph)87   inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
vm_compile()88   inline bool vm_compile() const { return vm_compile_; }
set_force_top_cell_compile(bool force_top_cell_compile)89   inline void set_force_top_cell_compile(bool force_top_cell_compile) {
90     force_top_cell_compile_ = force_top_cell_compile;
91   }
force_top_cell_compile()92   inline bool force_top_cell_compile() const { return force_top_cell_compile_; }
is_high_order_top_cell()93   inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
set_need_do_final_opt(bool need_do_final_opt)94   inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
need_do_final_opt()95   inline bool need_do_final_opt() const { return need_do_final_opt_; }
set_need_save_dynamic_detect_nodes(bool is_need_save_dynamic_detect_nodes)96   inline void set_need_save_dynamic_detect_nodes(bool is_need_save_dynamic_detect_nodes) {
97     is_need_save_dynamic_detect_nodes_ = is_need_save_dynamic_detect_nodes;
98   }
is_need_save_dynamic_detect_nodes()99   inline bool is_need_save_dynamic_detect_nodes() const { return is_need_save_dynamic_detect_nodes_; }
resource()100   inline pipeline::ResourcePtr resource() const { return resource_; }
fg()101   inline FuncGraphPtr fg() const {
102     MS_EXCEPTION_IF_NULL(fg_);
103     return fg_;
104   }
has_call_graph()105   inline const bool &has_call_graph() const { return has_call_graph_; }
set_has_call_graph(bool has_call_graph)106   inline void set_has_call_graph(bool has_call_graph) { has_call_graph_ = has_call_graph; }
has_control_flow()107   inline bool has_control_flow() const { return has_control_flow_; }
set_has_control_flow(bool has_control_flow)108   inline void set_has_control_flow(bool has_control_flow) { has_control_flow_ = has_control_flow; }
jit_out_has_dict()109   inline bool jit_out_has_dict() const { return jit_out_has_dict_; }
set_jit_out_has_dict(bool jit_out_has_dict)110   inline void set_jit_out_has_dict(bool jit_out_has_dict) { jit_out_has_dict_ = jit_out_has_dict; }
is_unknown_shape()111   inline bool is_unknown_shape() const { return is_unknown_shape_; }
set_is_unknown_shape(bool is_unknown_shape)112   inline void set_is_unknown_shape(bool is_unknown_shape) { is_unknown_shape_ = is_unknown_shape; }
cell_id()113   inline const std::string &cell_id() const { return cell_id_; }
obj_id_with_grad_order()114   inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; }
already_run_cell_id()115   inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
set_input_args_id(const std::string & input_args_id)116   inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
input_args_id()117   inline const std::string &input_args_id() const { return input_args_id_; }
grad_operation()118   const std::string &grad_operation() const { return grad_operation_; }
set_grad_operation(const std::string & grad_operation)119   void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; }
CheckSubCellHookChanged()120   inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); }
SetGraphInfoMap(const FuncGraphPtr & fg,const GraphInfoPtr & graph_info)121   inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) {
122     graph_info_map_[fg] = graph_info;
123   }
graph_info_map()124   inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; }
auto_grad_cell_ptr()125   inline autograd::AutoGradPtr auto_grad_cell_ptr() const {
126     MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_);
127     return auto_grad_cell_ptr_;
128   }
set_auto_grad_cell_ptr(autograd::AutoGradPtr && auto_grad_cell_ptr)129   void set_auto_grad_cell_ptr(autograd::AutoGradPtr &&auto_grad_cell_ptr) {
130     runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
131                                        runtime::ProfilerEvent::kPyNativeGradClearAutoGradCell,
132                                        runtime::ProfilerRecorder::kNoName, true);
133     auto_grad_cell_ptr_ = std::move(auto_grad_cell_ptr);
134   }
op_index()135   inline size_t op_index() const { return op_index_; }
IncreaseOpIndex()136   inline void IncreaseOpIndex() { ++op_index_; }
initial_graph_param_size()137   inline size_t initial_graph_param_size() const { return initial_graph_param_size_; }
replace_info()138   TensorReplaceInfo &replace_info() { return replace_info_; }
input_args_info()139   inline InputArgsInfoPtr input_args_info() { return input_args_info_; }
set_input_args_info(const InputArgsInfoPtr & input_args_info)140   inline void set_input_args_info(const InputArgsInfoPtr &input_args_info) { input_args_info_ = input_args_info; }
141   void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) const;
142   void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param, bool is_weight = false) const;
143   void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
144                                 bool need_save_sub_id = true) const;
145   void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile);
146   void ClearDeviceMemory() const;
147   void Clear();
148   void AddMetaGradInfo(const tensor::BaseTensorPtr &tensor, const AutoGradMetaDataPtr &auto_grad_meta_data);
149   void BackUpValueMetaGradInfo(const ValuePtr &value);
150   void ClearValueMetaGradInfo(const ValuePtr &value);
151   void ClearMetaGradInfo();
152   void ResetMetaGradInfo();
153   void ResumeMetaGradInfo();
param_grad_info()154   const MetaGradInfoMap &param_grad_info() const { return meta_grad_info_; }
use_dynamic_shape_process()155   inline bool use_dynamic_shape_process() const { return use_dynamic_shape_process_; }
set_use_dynamic_shape_process(bool use_dynamic_shape_process)156   inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
157     use_dynamic_shape_process_ = use_dynamic_shape_process;
158   }
has_bprop_cut_op()159   inline bool has_bprop_cut_op() const { return has_bprop_cut_op_; }
set_has_bprop_cut_op(bool has_bprop_cut_op)160   inline void set_has_bprop_cut_op(bool has_bprop_cut_op) { has_bprop_cut_op_ = has_bprop_cut_op; }
set_resume_flag(bool resume_flag)161   inline void set_resume_flag(bool resume_flag) { need_resume_meta_grad_ = resume_flag; }
resume_flag()162   bool resume_flag() const { return need_resume_meta_grad_; }
set_is_ir_grad(bool is_ir_grad)163   inline void set_is_ir_grad(bool is_ir_grad) { is_ir_grad_ = is_ir_grad; }
is_ir_grad()164   bool is_ir_grad() const { return is_ir_grad_; }
set_grad_is_running(bool grad_is_running)165   inline void set_grad_is_running(bool grad_is_running) { grad_is_running_ = grad_is_running; }
grad_is_running()166   bool grad_is_running() const { return grad_is_running_; }
set_grad_first(bool grad_first)167   inline void set_grad_first(bool grad_first) { grad_first_ = grad_first; }
grad_first()168   bool grad_first() const { return grad_first_; }
set_is_bprop_need_get_forward_graph(bool is_bprop_need_get_forward_graph)169   inline void set_is_bprop_need_get_forward_graph(bool is_bprop_need_get_forward_graph) {
170     is_bprop_need_get_forward_graph_ = is_bprop_need_get_forward_graph;
171   }
is_bprop_need_get_forward_graph()172   bool is_bprop_need_get_forward_graph() const { return is_bprop_need_get_forward_graph_; }
set_is_finish_backward(bool is_finish_backward)173   inline void set_is_finish_backward(bool is_finish_backward) { is_finish_backward_ = is_finish_backward; }
is_finish_backward()174   bool is_finish_backward() const { return is_finish_backward_; }
is_pipeline_top_cell()175   inline bool is_pipeline_top_cell() const { return is_pipeline_top_cell_; }
set_is_pipeline_top_cell(bool is_pipeline_top_cell)176   inline void set_is_pipeline_top_cell(bool is_pipeline_top_cell) { is_pipeline_top_cell_ = is_pipeline_top_cell; }
shadow_top_cell()177   inline TopCellInfo *shadow_top_cell() const { return shadow_top_cell_; }
set_shadow_top_cell(TopCellInfo * shadow_top_cell)178   inline void set_shadow_top_cell(TopCellInfo *shadow_top_cell) { shadow_top_cell_ = shadow_top_cell; }
SaveTensorIdWithOpInfo(const std::string & op_info,const ValuePtr & v)179   void SaveTensorIdWithOpInfo(const std::string &op_info, const ValuePtr &v) {
180     SetIdWithOpInfo(v, op_info, kIndex0, &(replace_info_.id_with_op_info));
181   }
182   void SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph);
183   void SetLastOutputValueForwardOutputFlag(const ValuePtr &v);
184   void ChangeTopCellInfo(const std::vector<BaseShapePtr> &args_new_shape);
output_ids()185   const std::vector<std::string> &output_ids() const { return output_ids_; }
set_outputs_ids(std::vector<std::string> output_ids)186   void set_outputs_ids(std::vector<std::string> output_ids) { output_ids_ = std::move(output_ids); }
187   // Check whether the tensor is top cell output.
188   bool IsOutputTensor(const tensor::BaseTensorPtr &tensor) const;
189 
190  private:
191   void SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const;
192   void SetNestedMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node,
193                                              const std::vector<int64_t> &index_sequence) const;
194   void SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
195                                      const std::vector<int64_t> &index) const;
196   bool hook_changed_{false};
197   bool is_init_kpynative_{false};
198   bool forward_already_run_{false};
199   bool need_compile_graph_{false};
200   bool vm_compile_{false};
201   bool force_top_cell_compile_{false};
202   bool is_high_order_top_cell_{false};
203   bool need_do_final_opt_{false};
204   bool is_need_save_dynamic_detect_nodes_{false};
205   bool has_call_graph_{false};
206   bool has_control_flow_{false};
207   bool jit_out_has_dict_{false};
208   bool is_unknown_shape_{false};
209   bool use_dynamic_shape_process_{false};
210   bool has_bprop_cut_op_{false};
211 
212   // Top cell is running backward
213   bool grad_is_running_{false};
214   // if call grad not set_grad first, grad first is true
215   bool grad_first_{false};
216 
217   // Topcell used for get forward graph
218   bool is_bprop_need_get_forward_graph_{false};
219 
220   // Judge whether need to resume param grad info.
221   // If net just has run forward by set_grad, which does not do gradient calculation, weight auto grad meta should be
222   // save
223   bool need_resume_meta_grad_{false};
224   std::map<tensor::BaseTensorPtr, AutoGradMetaDataPtr> param_grad_info_;
225 
226   // Running by actor or by func grad
227   bool is_ir_grad_{false};
228 
229   // Whether gradient calculation has been completed
230   bool is_finish_backward_{false};
231   bool is_pipeline_top_cell_{false};
232   // When the top cell is no need compile, and it uses ir top cell(actor) for running, this record who is real top cell
233   // is running
234   TopCellInfo *shadow_top_cell_{};
235 
236   size_t grad_order_{0};
237   size_t op_index_{0};
238 
239   // If the bprop graph has control flow, bprop graph parameters size may be change(to large size)
240   size_t initial_graph_param_size_{0};
241 
242   // id without cell shape and type, add grad order
243   std::string obj_id_with_grad_order_;
244 
245   // id with cell shape and type
246   std::string cell_id_;
247 
248   // cell_id_ add grad_operation_ and grad_order_
249   std::string already_run_cell_id_;
250 
251   // cell inputs args id
252   std::string input_args_id_;
253 
254   // GradOperation(get_all_, or get_by_list_, or get_all) and grad->sens_param and weights(All) id
255   std::string grad_operation_;
256 
257   // Forward output tensors id, used for tensor free
258   std::vector<std::string> output_ids_;
259 
260   pipeline::ResourcePtr resource_{nullptr};
261   FuncGraphPtr fg_{nullptr};
262 
263   // Automatic differentiation
264   autograd::AutoGradPtr auto_grad_cell_ptr_{nullptr};
265 
266   OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
267 
268   // Record `register hook` or `remove hook` function has been called by sub cell
269   // The record range between the begin and end of top cell.
270   mindspore::HashSet<std::string> sub_cell_hook_changed_;
271   // Record backward hook ops for each cell object.
272   // Each cell object has two backward hook ops.
273   CellIdWithBackwardHookOp cell_backward_hook_op_;
274 
275   // For forward output replace
276   TensorReplaceInfo replace_info_;
277   MetaGradInfoMap meta_grad_info_;
278   InputArgsInfoPtr input_args_info_{nullptr};
279 };
280 using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
281 }  // namespace pynative
282 }  // namespace mindspore
283 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_
284