1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <mutex> 25 #include <stack> 26 #include <set> 27 #include <map> 28 #include "include/common/utils/convert_utils.h" 29 #include "include/common/profiler.h" 30 #include "utils/hash_map.h" 31 #include "utils/hash_set.h" 32 #include "pybind11/numpy.h" 33 #include "pybind11/pytypes.h" 34 #include "pybind_api/ir/base_ref_py.h" 35 #include "ir/anf.h" 36 #include "pipeline/pynative/grad/auto_grad.h" 37 #include "frontend/operator/composite/composite.h" 38 #include "pipeline/jit/ps/resource.h" 39 #include "pipeline/pynative/base.h" 40 #include "pipeline/pynative/grad/ir/bprop_tensor_replace.h" 41 #include "utils/ms_context.h" 42 43 namespace mindspore { 44 namespace pynative { 45 namespace py = pybind11; 46 class GradExecutor; 47 using CellIdWithBackwardHookOp = mindspore::HashMap<std::string, AnfNodePtrList>; 48 49 struct PyNGraphInfo { 50 OrderedMap<std::string, ParameterPtr> input_params; // Hold input parameters 51 OrderedMap<std::string, ParameterPtr> weight_params; // Hold weights parameters 52 // Hold op op output or combination of output 53 mindspore::HashMap<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map; 54 }; 55 using GraphInfoPtr = std::shared_ptr<PyNGraphInfo>; 56 57 using MetaGradInfoMap = mindspore::OrderedMap<tensor::BaseTensorPtr, AutoGradMetaDataPtr>; 58 59 class TopCellInfo { 60 public: 61 TopCellInfo() = default; 62 ~TopCellInfo() = default; TopCellInfo(bool is_high_order_top_cell,size_t grad_order,std::string obj_id_with_grad_order,std::string cellid,std::string already_run_cell_id,pipeline::ResourcePtr r,FuncGraphPtr fg,size_t reserve_size)63 TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid, 64 std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg, size_t reserve_size) 65 : is_high_order_top_cell_(is_high_order_top_cell), 66 grad_order_(grad_order), 67 obj_id_with_grad_order_(std::move(obj_id_with_grad_order)), 68 cell_id_(std::move(cellid)), 69 already_run_cell_id_(std::move(already_run_cell_id)), 70 resource_(std::move(r)), 71 fg_(std::move(fg)) { 72 meta_grad_info_.reserve(reserve_size); 73 } 74 is_init_kpynative()75 inline bool is_init_kpynative() const { return is_init_kpynative_; } set_init_kpynative(bool init)76 inline void set_init_kpynative(bool init) { is_init_kpynative_ = init; } grad_order()77 inline size_t grad_order() const { return grad_order_; } set_hook_changed(bool hook_changed)78 inline void set_hook_changed(bool hook_changed) { hook_changed_ = hook_changed; } set_sub_cell_hook_changed(const std::string & sub_cell)79 inline void set_sub_cell_hook_changed(const std::string &sub_cell) { (void)sub_cell_hook_changed_.emplace(sub_cell); } cell_backward_hook_op()80 inline const CellIdWithBackwardHookOp &cell_backward_hook_op() const { return cell_backward_hook_op_; } 81 void RecordCellBackwardHookOp(const std::string &cell_order, const AnfNodePtr &hook_op); 82 void GetOpInfo(const FrontendOpRunInfoPtr &op_run_info, bool is_jit_graph) const; ClearCellHookOp()83 inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); } forward_already_run()84 inline bool forward_already_run() const { return forward_already_run_; } set_forward_already_run(bool set_forward_already_run)85 inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; } need_compile_graph()86 inline bool need_compile_graph() const { return need_compile_graph_; } set_need_compile_graph(bool need_compile_graph)87 inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; } vm_compile()88 inline bool vm_compile() const { return vm_compile_; } set_force_top_cell_compile(bool force_top_cell_compile)89 inline void set_force_top_cell_compile(bool force_top_cell_compile) { 90 force_top_cell_compile_ = force_top_cell_compile; 91 } force_top_cell_compile()92 inline bool force_top_cell_compile() const { return force_top_cell_compile_; } is_high_order_top_cell()93 inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; } set_need_do_final_opt(bool need_do_final_opt)94 inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; } need_do_final_opt()95 inline bool need_do_final_opt() const { return need_do_final_opt_; } set_need_save_dynamic_detect_nodes(bool is_need_save_dynamic_detect_nodes)96 inline void set_need_save_dynamic_detect_nodes(bool is_need_save_dynamic_detect_nodes) { 97 is_need_save_dynamic_detect_nodes_ = is_need_save_dynamic_detect_nodes; 98 } is_need_save_dynamic_detect_nodes()99 inline bool is_need_save_dynamic_detect_nodes() const { return is_need_save_dynamic_detect_nodes_; } resource()100 inline pipeline::ResourcePtr resource() const { return resource_; } fg()101 inline FuncGraphPtr fg() const { 102 MS_EXCEPTION_IF_NULL(fg_); 103 return fg_; 104 } has_call_graph()105 inline const bool &has_call_graph() const { return has_call_graph_; } set_has_call_graph(bool has_call_graph)106 inline void set_has_call_graph(bool has_call_graph) { has_call_graph_ = has_call_graph; } has_control_flow()107 inline bool has_control_flow() const { return has_control_flow_; } set_has_control_flow(bool has_control_flow)108 inline void set_has_control_flow(bool has_control_flow) { has_control_flow_ = has_control_flow; } jit_out_has_dict()109 inline bool jit_out_has_dict() const { return jit_out_has_dict_; } set_jit_out_has_dict(bool jit_out_has_dict)110 inline void set_jit_out_has_dict(bool jit_out_has_dict) { jit_out_has_dict_ = jit_out_has_dict; } is_unknown_shape()111 inline bool is_unknown_shape() const { return is_unknown_shape_; } set_is_unknown_shape(bool is_unknown_shape)112 inline void set_is_unknown_shape(bool is_unknown_shape) { is_unknown_shape_ = is_unknown_shape; } cell_id()113 inline const std::string &cell_id() const { return cell_id_; } obj_id_with_grad_order()114 inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; } already_run_cell_id()115 inline const std::string &already_run_cell_id() const { return already_run_cell_id_; } set_input_args_id(const std::string & input_args_id)116 inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; } input_args_id()117 inline const std::string &input_args_id() const { return input_args_id_; } grad_operation()118 const std::string &grad_operation() const { return grad_operation_; } set_grad_operation(const std::string & grad_operation)119 void set_grad_operation(const std::string &grad_operation) { grad_operation_ = grad_operation; } CheckSubCellHookChanged()120 inline void CheckSubCellHookChanged() { sub_cell_hook_changed_.clear(); } SetGraphInfoMap(const FuncGraphPtr & fg,const GraphInfoPtr & graph_info)121 inline void SetGraphInfoMap(const FuncGraphPtr &fg, const GraphInfoPtr &graph_info) { 122 graph_info_map_[fg] = graph_info; 123 } graph_info_map()124 inline const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() const { return graph_info_map_; } auto_grad_cell_ptr()125 inline autograd::AutoGradPtr auto_grad_cell_ptr() const { 126 MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr_); 127 return auto_grad_cell_ptr_; 128 } set_auto_grad_cell_ptr(autograd::AutoGradPtr && auto_grad_cell_ptr)129 void set_auto_grad_cell_ptr(autograd::AutoGradPtr &&auto_grad_cell_ptr) { 130 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, 131 runtime::ProfilerEvent::kPyNativeGradClearAutoGradCell, 132 runtime::ProfilerRecorder::kNoName, true); 133 auto_grad_cell_ptr_ = std::move(auto_grad_cell_ptr); 134 } op_index()135 inline size_t op_index() const { return op_index_; } IncreaseOpIndex()136 inline void IncreaseOpIndex() { ++op_index_; } initial_graph_param_size()137 inline size_t initial_graph_param_size() const { return initial_graph_param_size_; } replace_info()138 TensorReplaceInfo &replace_info() { return replace_info_; } input_args_info()139 inline InputArgsInfoPtr input_args_info() { return input_args_info_; } set_input_args_info(const InputArgsInfoPtr & input_args_info)140 inline void set_input_args_info(const InputArgsInfoPtr &input_args_info) { input_args_info_ = input_args_info; } 141 void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) const; 142 void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const; 143 void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1, 144 bool need_save_sub_id = true) const; 145 void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile); 146 void ClearDeviceMemory() const; 147 void Clear(); 148 void AddMetaGradInfo(const tensor::BaseTensorPtr &tensor, const AutoGradMetaDataPtr &auto_grad_meta_data); 149 void BackUpValueMetaGradInfo(const ValuePtr &value); 150 void ClearValueMetaGradInfo(const ValuePtr &value); 151 void ClearMetaGradInfo(); 152 void ResetMetaGradInfo(); 153 void ResumeMetaGradInfo(); param_grad_info()154 const MetaGradInfoMap ¶m_grad_info() const { return meta_grad_info_; } use_dynamic_shape_process()155 inline bool use_dynamic_shape_process() const { return use_dynamic_shape_process_; } set_use_dynamic_shape_process(bool use_dynamic_shape_process)156 inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) { 157 use_dynamic_shape_process_ = use_dynamic_shape_process; 158 } has_bprop_cut_op()159 inline bool has_bprop_cut_op() const { return has_bprop_cut_op_; } set_has_bprop_cut_op(bool has_bprop_cut_op)160 inline void set_has_bprop_cut_op(bool has_bprop_cut_op) { has_bprop_cut_op_ = has_bprop_cut_op; } set_resume_flag(bool resume_flag)161 inline void set_resume_flag(bool resume_flag) { need_resume_meta_grad_ = resume_flag; } resume_flag()162 bool resume_flag() const { return need_resume_meta_grad_; } set_is_ir_grad(bool is_ir_grad)163 inline void set_is_ir_grad(bool is_ir_grad) { is_ir_grad_ = is_ir_grad; } is_ir_grad()164 bool is_ir_grad() const { return is_ir_grad_; } set_grad_is_running(bool grad_is_running)165 inline void set_grad_is_running(bool grad_is_running) { grad_is_running_ = grad_is_running; } grad_is_running()166 bool grad_is_running() const { return grad_is_running_; } set_grad_first(bool grad_first)167 inline void set_grad_first(bool grad_first) { grad_first_ = grad_first; } grad_first()168 bool grad_first() const { return grad_first_; } set_is_bprop_need_get_forward_graph(bool is_bprop_need_get_forward_graph)169 inline void set_is_bprop_need_get_forward_graph(bool is_bprop_need_get_forward_graph) { 170 is_bprop_need_get_forward_graph_ = is_bprop_need_get_forward_graph; 171 } is_bprop_need_get_forward_graph()172 bool is_bprop_need_get_forward_graph() const { return is_bprop_need_get_forward_graph_; } set_is_finish_backward(bool is_finish_backward)173 inline void set_is_finish_backward(bool is_finish_backward) { is_finish_backward_ = is_finish_backward; } is_finish_backward()174 bool is_finish_backward() const { return is_finish_backward_; } is_pipeline_top_cell()175 inline bool is_pipeline_top_cell() const { return is_pipeline_top_cell_; } set_is_pipeline_top_cell(bool is_pipeline_top_cell)176 inline void set_is_pipeline_top_cell(bool is_pipeline_top_cell) { is_pipeline_top_cell_ = is_pipeline_top_cell; } shadow_top_cell()177 inline TopCellInfo *shadow_top_cell() const { return shadow_top_cell_; } set_shadow_top_cell(TopCellInfo * shadow_top_cell)178 inline void set_shadow_top_cell(TopCellInfo *shadow_top_cell) { shadow_top_cell_ = shadow_top_cell; } SaveTensorIdWithOpInfo(const std::string & op_info,const ValuePtr & v)179 void SaveTensorIdWithOpInfo(const std::string &op_info, const ValuePtr &v) { 180 SetIdWithOpInfo(v, op_info, kIndex0, &(replace_info_.id_with_op_info)); 181 } 182 void SaveForwardOutputTensorInfoInBpropGraph(const FuncGraphPtr &func_graph); 183 void SetLastOutputValueForwardOutputFlag(const ValuePtr &v); 184 void ChangeTopCellInfo(const std::vector<BaseShapePtr> &args_new_shape); output_ids()185 const std::vector<std::string> &output_ids() const { return output_ids_; } set_outputs_ids(std::vector<std::string> output_ids)186 void set_outputs_ids(std::vector<std::string> output_ids) { output_ids_ = std::move(output_ids); } 187 // Check whether the tensor is top cell output. 188 bool IsOutputTensor(const tensor::BaseTensorPtr &tensor) const; 189 190 private: 191 void SetMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node) const; 192 void SetNestedMultipleOutputToGraphInfoMap(const string &id, const AnfNodePtr &node, 193 const std::vector<int64_t> &index_sequence) const; 194 void SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node, 195 const std::vector<int64_t> &index) const; 196 bool hook_changed_{false}; 197 bool is_init_kpynative_{false}; 198 bool forward_already_run_{false}; 199 bool need_compile_graph_{false}; 200 bool vm_compile_{false}; 201 bool force_top_cell_compile_{false}; 202 bool is_high_order_top_cell_{false}; 203 bool need_do_final_opt_{false}; 204 bool is_need_save_dynamic_detect_nodes_{false}; 205 bool has_call_graph_{false}; 206 bool has_control_flow_{false}; 207 bool jit_out_has_dict_{false}; 208 bool is_unknown_shape_{false}; 209 bool use_dynamic_shape_process_{false}; 210 bool has_bprop_cut_op_{false}; 211 212 // Top cell is running backward 213 bool grad_is_running_{false}; 214 // if call grad not set_grad first, grad first is true 215 bool grad_first_{false}; 216 217 // Topcell used for get forward graph 218 bool is_bprop_need_get_forward_graph_{false}; 219 220 // Judge whether need to resume param grad info. 221 // If net just has run forward by set_grad, which does not do gradient calculation, weight auto grad meta should be 222 // save 223 bool need_resume_meta_grad_{false}; 224 std::map<tensor::BaseTensorPtr, AutoGradMetaDataPtr> param_grad_info_; 225 226 // Running by actor or by func grad 227 bool is_ir_grad_{false}; 228 229 // Whether gradient calculation has been completed 230 bool is_finish_backward_{false}; 231 bool is_pipeline_top_cell_{false}; 232 // When the top cell is no need compile, and it uses ir top cell(actor) for running, this record who is real top cell 233 // is running 234 TopCellInfo *shadow_top_cell_{}; 235 236 size_t grad_order_{0}; 237 size_t op_index_{0}; 238 239 // If the bprop graph has control flow, bprop graph parameters size may be change(to large size) 240 size_t initial_graph_param_size_{0}; 241 242 // id without cell shape and type, add grad order 243 std::string obj_id_with_grad_order_; 244 245 // id with cell shape and type 246 std::string cell_id_; 247 248 // cell_id_ add grad_operation_ and grad_order_ 249 std::string already_run_cell_id_; 250 251 // cell inputs args id 252 std::string input_args_id_; 253 254 // GradOperation(get_all_, or get_by_list_, or get_all) and grad->sens_param and weights(All) id 255 std::string grad_operation_; 256 257 // Forward output tensors id, used for tensor free 258 std::vector<std::string> output_ids_; 259 260 pipeline::ResourcePtr resource_{nullptr}; 261 FuncGraphPtr fg_{nullptr}; 262 263 // Automatic differentiation 264 autograd::AutoGradPtr auto_grad_cell_ptr_{nullptr}; 265 266 OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_; 267 268 // Record `register hook` or `remove hook` function has been called by sub cell 269 // The record range between the begin and end of top cell. 270 mindspore::HashSet<std::string> sub_cell_hook_changed_; 271 // Record backward hook ops for each cell object. 272 // Each cell object has two backward hook ops. 273 CellIdWithBackwardHookOp cell_backward_hook_op_; 274 275 // For forward output replace 276 TensorReplaceInfo replace_info_; 277 MetaGradInfoMap meta_grad_info_; 278 InputArgsInfoPtr input_args_info_{nullptr}; 279 }; 280 using TopCellInfoPtr = std::shared_ptr<TopCellInfo>; 281 } // namespace pynative 282 } // namespace mindspore 283 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_TOP_CELL_H_ 284