1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_ 19 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <stack> 24 #include <set> 25 #include <vector> 26 #include <map> 27 #include "pipeline/pynative/base.h" 28 #include "pipeline/pynative/grad/top_cell.h" 29 #include "pipeline/pynative/grad/jit/jit_grad.h" 30 #include "runtime/pipeline/async_hqueue.h" 31 #include "pipeline/pynative/grad/bprop_task.h" 32 #include "pipeline/pynative/grad/ir/dynamic_shape.h" 33 #include "pipeline/pynative/grad/variable.h" 34 #include "pipeline/jit/ps/resource.h" 35 namespace mindspore { 36 namespace pynative { 37 namespace py = pybind11; 38 class ForwardExecutor; 39 using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>; 40 using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>; 41 42 class GradExecutor { 43 // key: already run cell id, value: all already run top cell 44 using TopCellIdWithTopCell = std::map<std::string, TopCellInfoPtr>; 45 // key: already run cell id, value: pipeline top cell 46 using PipelineTopCellMap = std::map<std::string, std::vector<TopCellInfoPtr>>; 47 48 public: 49 GradExecutor() = default; 50 ~GradExecutor() = default; 51 explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr) forward_executor_(ForwardExecutorWeakPtr (forward_executor))52 : forward_executor_(ForwardExecutorWeakPtr(forward_executor)), 53 jit_(std::make_shared<Jit>()), 54 dynamic_shape_(std::make_shared<DynamicShape>()), 55 bprop_queue_(std::make_shared<runtime::AsyncHqueue>("bprop_queue")), 56 assist_queue_(std::make_shared<runtime::AsyncHqueue>("assist_queue")) {} 57 58 void Init(); 59 std::function<void(const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2) { 60 NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); 61 }; 62 std::function<void(const py::object &, const py::object &, const py::args &)> LinkGraph = [this](auto &&PH1, 63 auto &&PH2, 64 auto &&PH3) { 65 EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3)); 66 }; 67 std::function<py::object(const prim::GradOperationPtr &, const py::object &, const py::object &, const py::object &, 68 const py::args &)> 69 Run = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) { 70 return RunGrad(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), 71 std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4), 72 std::forward<decltype(PH5)>(PH5)); 73 }; top_cell()74 inline TopCellInfoPtr top_cell() const { 75 MS_EXCEPTION_IF_NULL(top_cell_); 76 return top_cell_; 77 } dynamic_shape()78 inline DynamicShapePtr dynamic_shape() const { 79 MS_EXCEPTION_IF_NULL(dynamic_shape_); 80 return dynamic_shape_; 81 } jit()82 inline JitPtr jit() const { 83 MS_EXCEPTION_IF_NULL(jit_); 84 return jit_; 85 } 86 TopCellHasNotBeenCreate()87 inline bool TopCellHasNotBeenCreate() const { return top_cell_ == nullptr; } set_top_cell(TopCellInfoPtr top_cell)88 inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); } grad_flag()89 inline bool grad_flag() const { return grad_flag_; } set_grad_flag(bool flag)90 inline void set_grad_flag(bool flag) { grad_flag_ = flag; } enable_grad()91 inline bool enable_grad() const { return enable_grad_; } set_enable_grad(bool enable_grad)92 inline void set_enable_grad(bool enable_grad) { enable_grad_ = enable_grad; } RequiresGrad()93 inline bool RequiresGrad() const { return enable_grad() && grad_flag(); } set_is_run_recompute(bool is_run_recompute)94 inline void set_is_run_recompute(bool is_run_recompute) { is_run_recompute_ = is_run_recompute; } 95 // Construct grad graph for jit custom_bprop_cell_count()96 inline size_t custom_bprop_cell_count() const { return custom_bprop_cell_count_; } bprop_queue()97 inline runtime::AsyncHqueuePtr bprop_queue() const { return bprop_queue_; } already_run_top_cell()98 TopCellIdWithTopCell &already_run_top_cell() { return already_run_top_cell_; } 99 void SetHookChanged(const py::object &cell) const; 100 py::object RunGrad(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights, 101 const py::object &grad_position, const py::args &args); 102 py::object RunGradFunc(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args, 103 const std::vector<size_t> &p_args); 104 py::object RunGradGraph(); 105 CNodePtr ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const; 106 void RecordForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const; 107 void RecordForwardGraphForInput(const ValuePtr &value, const string &input_id, 108 const abstract::AbstractBasePtr ¶m_abs); 109 void RecordNestedGraph(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info, 110 const std::vector<ValuePtr> &forward_args, const ValuePtr &out); 111 py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights, 112 const py::object &grad_hash_id, const py::args &args); 113 TopCellInfoPtr GetAlreadyRunTopCell(const std::string &already_run_cell_id) const; 114 TopCellInfoPtr GetPipelineRunTopCell(const std::string &already_run_cell_id) const; 115 TopCellInfoPtr GetPipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id, 116 bool is_reverse_match) const; 117 void ErasePipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id, 118 bool is_pipeline_ir_top_cell); 119 void GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr &grad, const py::object &obj, 120 const py::args &args); 121 bool ReplacePipelineTopCellForwardOutput(); 122 void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const; 123 AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const; 124 AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const; 125 void UpdateTopCellForwardTensorInfoInBpropGraph(const string &op_info, const ValuePtr &v, 126 const size_t &stream_id) const; 127 void ClearRes(); 128 void AsyncClearTopCell(); 129 void AsyncClearAutoGradCell(const TopCellInfoPtr &top_cell); 130 void WorkerJoin(); 131 void WaitBpropTask() const; 132 void SaveDynamicInputsCells(const py::object &obj, const py::args &args); 133 void SetTopCellDynamicAttr(const py::object &cell); use_dynamic_shape_process()134 bool use_dynamic_shape_process() const { 135 if (top_cell_ == nullptr) { 136 return false; 137 } 138 return top_cell()->use_dynamic_shape_process(); 139 } 140 set_use_dynamic_shape_process(bool use_dynamic_shape_process)141 void set_use_dynamic_shape_process(bool use_dynamic_shape_process) { 142 if (top_cell_ == nullptr) { 143 return; 144 } 145 return top_cell()->set_use_dynamic_shape_process(use_dynamic_shape_process); 146 } 147 forward_use_dynamic_shape_process()148 inline bool forward_use_dynamic_shape_process() const { return forward_use_dynamic_shape_process_; } set_forward_use_dynamic_shape_process(bool forward_use_dynamic_shape_process)149 inline void set_forward_use_dynamic_shape_process(bool forward_use_dynamic_shape_process) { 150 forward_use_dynamic_shape_process_ = forward_use_dynamic_shape_process; 151 } is_cell_has_dynamic_inputs(const std::string & obj_id)152 inline bool is_cell_has_dynamic_inputs(const std::string &obj_id) const { 153 return dynamic_inputs_cells_.count(obj_id) > 0; 154 } 155 std::string GetAlreadyRunCellId(const std::string &obj_id) const; 156 is_high_order_top_cell()157 inline bool is_high_order_top_cell() const { return top_cell_ != nullptr && top_cell_->is_high_order_top_cell(); } 158 void ChildAfterFork(); 159 160 private: 161 ForwardExecutorPtr forward() const; curr_g()162 inline FuncGraphPtr curr_g() const { return top_cell()->fg(); } PushTopCellStack(const TopCellInfoPtr & top_cell)163 inline void PushTopCellStack(const TopCellInfoPtr &top_cell) { 164 top_cell_stack_.push(top_cell); 165 MS_LOG(DEBUG) << "Push top cell " << top_cell << " on top cell stack"; 166 } 167 bool NeedIncreaseGradOrder(const std::string &obj_id); 168 void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info, 169 const CNodePtr &cnode) const; 170 void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info) const; 171 AnfNodePtr GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const; 172 void SetBpropGraphJitLevel(const py::object &obj) const; 173 void ClearGlobalRes() const; 174 void ClearGradRes(); 175 void ClearPipelineTopCellRes(); 176 177 // Higher derivative IsNestedGrad()178 inline bool IsNestedGrad() const { return grad_order_ > 1; } IncreaseGradOrder()179 inline void IncreaseGradOrder() { 180 ++grad_order_; 181 MS_LOG(DEBUG) << "Increase grad order, current grad_order is " << grad_order_; 182 } DecreaseGradOrder()183 inline void DecreaseGradOrder() { 184 if (grad_order_ > 0) { 185 --grad_order_; 186 } 187 MS_LOG(DEBUG) << "Decrease grad order, current grad_order is " << grad_order_; 188 } IsHighOrderTopCell()189 inline bool IsHighOrderTopCell() const { 190 return !input_args_info_stack_.empty() && IsNestedGrad() && top_cell()->grad_order() != grad_order_; 191 } kernel_graph_id_for_control_flow()192 uint32_t kernel_graph_id_for_control_flow() { return --kernel_graph_id_for_control_flow_; } 193 void ClearPreTopCell(const TopCellInfoPtr &new_top_cell, bool is_need_clear_device_mem); 194 bool GetTopCellDynamicFlag(const InputArgsInfoPtr &input_args_info, const std::string &obj_id_with_grad_order); 195 void SwitchTopCell(); 196 TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id, const std::string &input_args_id); 197 void DoParameterReplace(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info, 198 const std::vector<ValuePtr> &forward_args, AnfNodePtrList *inputs); 199 void MakeNestedCnode(bool has_custom_bprop, const std::vector<ValuePtr> &forward_args, 200 const FuncGraphPtr &cur_run_bprop_graph, const BaseRef &out); 201 TopCellInfoPtr PopTopCellStack(); 202 void PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info); 203 void PopInputArgsInfoStack(); 204 void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info); 205 bool IsNewCellId(); 206 void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info, bool is_bprop_need_get_forward_graph); 207 bool IsCreateIrGrad(); 208 void MakeNewTopCell(const InputArgsInfoPtr &input_args_info); 209 bool NewTopCellIsPipelineTopCell(const InputArgsInfoPtr &input_args_info); 210 211 // Manage resource when run grad process. 212 void NewGraphInner(const py::object &obj, const py::args &args); 213 InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args, bool is_bprop_need_get_forward_graph); 214 void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args); 215 void EndGraphImpl(const InputArgsInfoPtr &input_args_info); 216 void SetForwardLastNodeInfo(const ValuePtr &v) const; 217 void GetCustomBpropPrim(const py::object &obj, const py::args &args, const InputArgsInfoPtr &input_args_info); 218 void DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) const; 219 void CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info); 220 void GetGradGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args, 221 const std::vector<size_t> &p_args); 222 FuncGraphPtr GetBpropGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args, 223 const std::vector<size_t> &p_args); 224 std::vector<tensor::BaseTensorPtr> GetWeightsArgs(const py::object &weights, bool *weight_param_is_tuple) const; 225 std::vector<tensor::BaseTensorPtr> GetDefaultWeights() const; 226 void CheckParamShapeAndType(const ParameterPtr ¶m_node, const abstract::AbstractBasePtr &input_abs, 227 const abstract::AbstractBasePtr &ir_abs) const; 228 void UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args, const FuncGraphPtr &bprop_graph) const; 229 std::vector<size_t> GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const; 230 // Manage resource for construct forward graph. 231 AnfNodePtr GetOutputNodeAsInput(const std::string &obj_id) const; 232 AnfNodePtr GetValueSequenceInput(const ValuePtr &v) const; 233 AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id, 234 const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const; 235 void DispatchGradQueueTask(std::function<void(void)> &&task) const; 236 void DispatchAssistQueueTask(std::function<void(void)> task) const; 237 void ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr &input_args_info) const; 238 void ClearBpropTask() const; 239 240 bool init_{false}; 241 bool grad_flag_{false}; 242 bool enable_grad_{true}; 243 bool is_run_recompute_{false}; 244 bool save_graphs_{false}; 245 bool forward_use_dynamic_shape_process_{false}; 246 247 uint32_t kernel_graph_id_for_control_flow_{UINT32_MAX}; 248 size_t custom_bprop_cell_count_{0}; 249 250 // If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ... 251 size_t grad_order_{0}; 252 // if call grad not set_grad first, grad first is true. 253 bool grad_first_{false}; 254 255 // Used for auto grad map reserve 256 size_t op_num_in_bprop_graph_{kDefaultContainerSize}; 257 std::string grad_operation_; 258 259 TopCellInfoPtr top_cell_{nullptr}; 260 InputArgsInfoPtr top_input_args_info_{nullptr}; 261 262 // Records every cell info for share, regardless of whether you need construct grad graph 263 std::stack<InputArgsInfoPtr> input_args_info_stack_; 264 265 // For top cell nested top cell, import for high-order grad 266 std::stack<TopCellInfoPtr> top_cell_stack_; 267 268 // Used for set grad scenario. If top cell set in CheckAlreadyRun, no need find again in RunGrad; 269 TopCellInfoPtr finded_top_cell_; 270 // Record all top cells that have been run 271 TopCellIdWithTopCell already_run_top_cell_; 272 // Record pipeline top cells. 273 PipelineTopCellMap pipeline_top_cell_map_; 274 275 std::set<std::string> dynamic_inputs_cells_; 276 std::vector<TopCellInfoPtr> need_gc_top_cell_list_; 277 278 ForwardExecutorWeakPtr forward_executor_; 279 JitPtr jit_; 280 DynamicShapePtr dynamic_shape_{nullptr}; 281 runtime::AsyncHqueuePtr bprop_queue_; 282 runtime::AsyncHqueuePtr assist_queue_; 283 }; 284 } // namespace pynative 285 } // namespace mindspore 286 287 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_GRAD_H_ 288