• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/grad.h"
18 #include <algorithm>
19 #include "ops/conv_pool_op_name.h"
20 #include "ops/nn_op_name.h"
21 #include "ops/math_op_name.h"
22 #include "ops/sequence_ops.h"
23 #include "ops/framework_ops.h"
24 #include "pipeline/pynative/grad/top_cell.h"
25 #include "pipeline/pynative/grad/function/func_grad.h"
26 #include "pipeline/pynative/grad/ir/ir_grad.h"
27 #include "pipeline/pynative/pynative_utils.h"
28 #include "pipeline/jit/ps/pipeline.h"
29 #include "ir/cell.h"
30 #include "ir/func_graph_cloner.h"
31 #include "pipeline/jit/ps/parse/data_converter.h"
32 #include "pipeline/jit/ps/debug/trace.h"
33 #include "include/backend/optimizer/helper.h"
34 #include "include/common/utils/convert_utils_py.h"
35 #include "frontend/optimizer/ad/grad.h"
36 #include "frontend/optimizer/environ_conversion.h"
37 #include "pipeline/jit/ps/pass.h"
38 #include "pybind_api/gil_scoped_long_running.h"
39 #include "frontend/optimizer/fallback_rewriter.h"
40 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
41 #include "runtime/pynative/op_executor.h"
42 
43 namespace mindspore {
44 namespace pynative {
45 namespace {
46 const mindspore::HashSet<std::string> kHookOp = {"HookBackward", "CellBackwardHook"};
47 constexpr char kGrad[] = "grad";
48 constexpr auto kNeedRecompute = "is_cell_recompute";
49 constexpr auto kInternalParams = "internal_params";
50 constexpr auto kUsedBpropInputs = "used_bprop_inputs";
51 constexpr size_t kContainerRatio = 2;
52 
ParsePyArgsToInputArgsInfo(const InputArgsInfoPtr & input_args_info,const py::object & obj,const py::args & args,bool is_bprop_need_get_forward_graph)53 void ParsePyArgsToInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::args &args,
54                                 bool is_bprop_need_get_forward_graph) {
55   MS_EXCEPTION_IF_NULL(input_args_info);
56   input_args_info->has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
57   MS_LOG(DEBUG) << "Cell has custom bprop " << input_args_info->has_custom_bprop;
58   bool is_top_cell = input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell;
59   if (is_top_cell) {
60     pipeline::CheckArgsValid(obj, args);
61   }
62   // Only the top cell or custom bprop cell requires value conversion
63   if (is_top_cell || input_args_info->has_custom_bprop || is_bprop_need_get_forward_graph) {
64     input_args_info->obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
65     input_args_info->input_size = args.size();
66     for (size_t i = 0; i < input_args_info->input_size; ++i) {
67       const auto &id = PyNativeAlgo::PyParser::GetIdByPyObj(args[i]);
68       (void)input_args_info->input_arg_id_vec.emplace_back(id);
69     }
70     const auto &forward = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor();
71     for (size_t i = 0; i < input_args_info->input_size; ++i) {
72       input_args_info->input_args_id += input_args_info->input_arg_id_vec[i] + "_";
73       // Get arg value
74       if (py::isinstance<py::list>(args[i])) {
75         (void)input_args_info->input_arg_value_vec.emplace_back(
76           PyNativeAlgo::DataConvert::PyObjToValue(py::cast<py::tuple>(args[i])));
77       } else {
78         (void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
79       }
80 
81       // Get arg abstract
82       auto abs = forward->GetNodeAbsById(input_args_info->input_arg_id_vec[i]);
83       if (abs == nullptr) {
84         abs = input_args_info->input_arg_value_vec.back()->ToAbstract();
85       }
86       (void)input_args_info->input_arg_base_shape_vec.emplace_back(abs->BuildShape());
87     }
88     input_args_info->cell_id = PyNativeAlgo::Common::GetCellId(
89       input_args_info->obj_id, input_args_info->input_arg_id_vec, input_args_info->input_arg_value_vec);
90     MS_LOG(DEBUG) << "Cell_id is " << input_args_info->cell_id << ", is grad topest cell "
91                   << input_args_info->is_grad_topest_cell << ", is high order top cell "
92                   << input_args_info->is_high_order_top_cell << ", is bprop need get forward graph "
93                   << is_bprop_need_get_forward_graph;
94   }
95 }
96 
GetNonTensorInput(const ValuePtr & v,const std::string & obj_id)97 AnfNodePtr GetNonTensorInput(const ValuePtr &v, const std::string &obj_id) {
98   MS_EXCEPTION_IF_NULL(v);
99   bool is_value_seq = v->isa<ValueSequence>();
100   bool is_single_non_tensor = !is_value_seq && !PyNativeAlgo::Common::IsTensor(v);
101   bool mixed_tensor = true;
102   if (is_value_seq) {
103     const auto &v_seq = v->cast<ValueSequencePtr>();
104     mixed_tensor = std::any_of(v_seq->value().begin(), v_seq->value().end(),
105                                [](const ValuePtr &e) { return PyNativeAlgo::Common::IsTensor(e, true); });
106   }
107   if (is_single_non_tensor || !mixed_tensor) {
108     auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v);
109     MS_LOG(DEBUG) << "Get input value node " << v_node->ToString() << ", id " << obj_id;
110     return v_node;
111   }
112   return nullptr;
113 }
114 
ConvertOutputValueToTensor(const ValuePtr & v,bool dict_convert_to_tuple)115 ValuePtr ConvertOutputValueToTensor(const ValuePtr &v, bool dict_convert_to_tuple) {
116   MS_EXCEPTION_IF_NULL(v);
117   if (PyNativeAlgo::Common::IsTensor(v, true)) {
118     return v;
119   }
120   if (v->isa<ValueSequence>()) {
121     auto v_seq = v->cast<ValueSequencePtr>();
122     if (v_seq->size() == 0) {
123       MS_LOG(EXCEPTION) << "Get empty value seq";
124     }
125     // All value are tensor
126     if (std::all_of(v_seq->value().begin(), v_seq->value().end(),
127                     [](const ValuePtr &e) { return PyNativeAlgo::Common::IsTensor(e, true); })) {
128       MS_LOG(DEBUG) << "All output value is tensor";
129       return v;
130     }
131     MS_LOG(DEBUG) << "Output is value sequence, but have tensor and other type mixed. Its value is " << v->ToString();
132     return PyNativeAlgo::Common::FilterSensValues(v, dict_convert_to_tuple);
133   }
134   if (v->isa<FloatImm>()) {
135     double input_value = v->cast<FP32ImmPtr>()->value();
136     return std::make_shared<tensor::Tensor>(input_value, kFloat32);
137   }
138   if (v->isa<BoolImm>()) {
139     return std::make_shared<tensor::Tensor>(v->cast<BoolImmPtr>()->value(), kBool);
140   }
141   if (v->isa<IntegerImm>()) {
142     int64_t input = v->cast<Int64ImmPtr>()->value();
143     return std::make_shared<tensor::Tensor>(input, kInt64);
144   }
145   if (v->isa<ValueDictionary>() && dict_convert_to_tuple) {
146     MS_LOG(DEBUG) << "Get dict value";
147     return PyNativeAlgo::DataConvert::ConvertValueDictToValueTuple(v);
148   }
149   MS_LOG(DEBUG) << "Output is " << v->ToString() << ", abstract "
150                 << PyNativeAlgo::Common::SetAbstractValueToAnyValue(v->ToAbstract());
151   return v;
152 }
153 
BpropGraphFinalOpt(const FuncGraphPtr & bprop_graph,bool has_control_flow)154 FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph, bool has_control_flow) {
155   MS_LOG(DEBUG) << "Do bprop graph final opt";
156   MS_EXCEPTION_IF_NULL(bprop_graph);
157   auto resource = std::make_shared<pipeline::Resource>();
158   resource->set_func_graph(bprop_graph);
159   auto manager = resource->manager();
160   MS_EXCEPTION_IF_NULL(manager);
161   manager->AddFuncGraph(bprop_graph);
162   FuncGraphPtr after_opt_bg = nullptr;
163   after_opt_bg = pipeline::FinalBpropGraphPass(resource, has_control_flow);
164   PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", after_opt_bg);
165   return after_opt_bg;
166 }
167 
SetGraphInputArgs(const std::vector<ValuePtr> & input_vec,const pipeline::ResourcePtr & res,size_t graph_param_size,SensType sens_type,VectorRef * const arg_list)168 void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::ResourcePtr &res,
169                        size_t graph_param_size, SensType sens_type, VectorRef *const arg_list) {
170   MS_EXCEPTION_IF_NULL(arg_list);
171   MS_EXCEPTION_IF_NULL(res);
172   auto graph = res->func_graph();
173   MS_EXCEPTION_IF_NULL(graph);
174   const auto &graph_params = graph->parameters();
175   if (graph_params.size() < graph_param_size) {
176     MS_LOG(EXCEPTION) << "Get initial bprop graph param size " << graph_param_size << " less than current param size "
177                       << graph_params.size() << ". Graph parameters maybe update by kernel graph compile stage";
178   }
179   std::vector<ValuePtr> input_arg_list;
180   if (sens_type == SensType::kNormal) {
181     input_arg_list = input_vec;
182   } else if (sens_type == SensType::kTuple) {
183     PyNativeAlgo::DataConvert::FlattenArgs(input_vec, &input_arg_list, true);
184   } else {
185     input_arg_list.assign(input_vec.begin(), input_vec.end() - kIndex1);
186     const auto &v_sens = input_vec.back();
187     MS_EXCEPTION_IF_NULL(v_sens);
188     if (!v_sens->isa<ValueDictionary>()) {
189       MS_LOG(EXCEPTION) << "Get sens not dict " << v_sens->ToString();
190     }
191     const auto &v_dict = v_sens->cast<ValueDictionaryPtr>();
192     ValuePtrList key_inputs;
193     ValuePtrList value_inputs;
194     for (const auto &elem : v_dict->value()) {
195       (void)key_inputs.emplace_back(elem.first);
196       (void)value_inputs.emplace_back(elem.second);
197     }
198     auto key = std::make_shared<ValueTuple>(key_inputs);
199     auto value = std::make_shared<ValueTuple>(value_inputs);
200     (void)input_arg_list.emplace_back(key);
201     (void)input_arg_list.emplace_back(value);
202   }
203   (void)std::transform(input_arg_list.begin(), input_arg_list.end(), std::back_inserter(*arg_list),
204                        [](const ValuePtr &v) { return v; });
205   size_t arg_size = arg_list->size();
206   if (arg_size != graph_param_size) {
207     // Maybe have some default parameter for input
208     MS_LOG(DEBUG) << "Get args size " << arg_size << ", graph param size " << graph_param_size;
209     for (std::size_t i = arg_size; i < graph_param_size; ++i) {
210       MS_EXCEPTION_IF_NULL(graph_params[i]);
211       auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
212       MS_EXCEPTION_IF_NULL(param_ptr);
213       if (!param_ptr->has_default()) {
214         MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param, " << param_ptr->DebugString();
215       }
216       if (!param_ptr->default_param()->isa<tensor::BaseTensor>()) {
217         MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->DebugString()
218                           << "] is not initialized, need to call `.init_data()`";
219       }
220       arg_list->push_back(param_ptr->default_param());
221     }
222   }
223 }
224 
RestoreBpropGraphParameter(const FuncGraphPtr & graph,size_t graph_param_size)225 void RestoreBpropGraphParameter(const FuncGraphPtr &graph, size_t graph_param_size) {
226   auto parameters = graph->parameters();
227   // Is ascend, kernel graph maybe adjust and insert some control parameters
228   if (parameters.size() > graph_param_size) {
229     (void)parameters.erase(parameters.begin() + graph_param_size, parameters.end());
230     graph->set_parameters(std::move(parameters));
231   }
232 }
233 
SetSensValue(const prim::GradOperationPtr & grad,const InputArgsInfoPtr & input_args_info,const py::args & args,bool dict_convert_to_tuple)234 void SetSensValue(const prim::GradOperationPtr &grad, const InputArgsInfoPtr &input_args_info, const py::args &args,
235                   bool dict_convert_to_tuple) {
236   MS_EXCEPTION_IF_NULL(grad);
237   if (!grad->sens_param()) {
238     return;
239   }
240   size_t forward_args_size = args.size() - 1;
241   auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
242   MS_LOG(DEBUG) << "Get sens param " << sens_v->ToString();
243   const auto &sens_tensor = ConvertOutputValueToTensor(sens_v, dict_convert_to_tuple);
244   if (sens_tensor == nullptr) {
245     MS_LOG(EXCEPTION) << "sens convert tensor is nullptr";
246   }
247   // Sens have already existed, which may be need update
248   MS_EXCEPTION_IF_NULL(input_args_info);
249   if (input_args_info->input_arg_value_vec.size() == args.size()) {
250     input_args_info->input_arg_value_vec.pop_back();
251   }
252   (void)input_args_info->input_arg_value_vec.emplace_back(sens_tensor);
253   if (sens_tensor->isa<ValueSequence>()) {
254     input_args_info->sens_type = SensType::kTuple;
255   } else if (!dict_convert_to_tuple) {
256     input_args_info->sens_type = SensType::kDict;
257   }
258 }
259 
GetWeightsObjIdsByWeights(const py::object & weights)260 std::string GetWeightsObjIdsByWeights(const py::object &weights) {
261   auto is_require_grad = [](const ValuePtr &value) {
262     MS_EXCEPTION_IF_NULL(value);
263     if (!value->isa<tensor::BaseTensor>()) {
264       return false;
265     }
266     auto t = value->cast<tensor::BaseTensorPtr>();
267     MS_EXCEPTION_IF_NULL(t);
268     if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) {
269       return true;
270     }
271     return false;
272   };
273 
274   std::string weights_obj_id;
275   auto append_weights_info = [&weights_obj_id, is_require_grad](const py::object &obj) {
276     const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(obj);
277     if (is_require_grad(v)) {
278       (void)weights_obj_id.append("_").append(PyNativeAlgo::Common::GetIdByValue(v));
279     }
280   };
281 
282   if (py::isinstance<py::tuple>(weights)) {
283     const auto &weights_tuple = weights.cast<py::tuple>();
284     for (size_t i = 0; i < weights_tuple.size(); ++i) {
285       append_weights_info(weights_tuple[i]);
286     }
287   } else if (py::isinstance<py::list>(weights)) {
288     const auto &weights_list = weights.cast<py::list>();
289     for (size_t i = 0; i < weights_list.size(); ++i) {
290       append_weights_info(weights_list[i]);
291     }
292   } else if (!py::isinstance<py::none>(weights)) {
293     append_weights_info(weights);
294   }
295 
296   return weights_obj_id;
297 }
298 
FreeSpecialOpValue(const std::string & op_name,const FrontendOpRunInfoPtr & op_run_info,ValuePtr * const output)299 void FreeSpecialOpValue(const std::string &op_name, const FrontendOpRunInfoPtr &op_run_info, ValuePtr *const output) {
300   // Special cases, manually free more inputs.
301   static mindspore::HashSet<std::string> kMulOp{
302     kMulOpName,
303     kMatMulOpName,
304     kConv2DOpName,
305   };
306   static mindspore::HashSet<std::string> kDivOp{
307     kDivOpName,
308     kRealDivOpName,
309   };
310   if (op_name == kBatchNormOpName) {
311     // 1. BatchNorm is a multi-output node, it's out[0] and out[1] are not used.
312     auto seq_v = (*output)->cast<ValueSequencePtr>();
313     MS_EXCEPTION_IF_NULL(seq_v);
314     ValuePtrList new_v_list{seq_v->value()};
315     new_v_list[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex0]);
316     new_v_list[kIndex1] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex1]);
317     *output = std::make_shared<ValueTuple>(new_v_list);
318     MS_LOG(DEBUG) << "Clear device address for output[0, 1] of " << op_name;
319   } else if (op_name == kLayerNormOpName) {
320     // 2. LayerNorm is a multi-output node, it's out[0] and out[1] are not used.
321     auto seq_v = (*output)->cast<ValueSequencePtr>();
322     MS_EXCEPTION_IF_NULL(seq_v);
323     ValuePtrList new_v_list{seq_v->value()};
324     new_v_list[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex0]);
325     *output = std::make_shared<ValueTuple>(new_v_list);
326     MS_LOG(DEBUG) << "Clear device address for output[0] of " << op_name;
327   } else if (kMulOp.find(op_name) != kMulOp.end()) {
328     // 3. For operators like Mul, the dx ONLY rely on y, and dy ONLY rely on x.
329     //    so if y is a valuenode, the dy is useless, we can free x in ahead.
330     bool x_is_const_value = PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex0]);
331     bool y_is_const_value = PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex1]);
332     if (x_is_const_value && op_run_info->base_op_run_info.expanded_input_values[kIndex1]->isa<tensor::BaseTensor>()) {
333       op_run_info->op_grad_info->input_value[kIndex1] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(
334         op_run_info->base_op_run_info.expanded_input_values[kIndex1]);
335       MS_LOG(DEBUG) << "Clear device address for inputs[1] of " << op_name;
336     }
337     if (y_is_const_value && op_run_info->base_op_run_info.expanded_input_values[kIndex0]->isa<tensor::BaseTensor>()) {
338       op_run_info->op_grad_info->input_value[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(
339         op_run_info->base_op_run_info.expanded_input_values[kIndex0]);
340       MS_LOG(DEBUG) << "Clear device address for inputs[0] of " << op_name;
341     }
342   } else if (kDivOp.find(op_name) != kDivOp.end()) {
343     // 3. For operators like Div, the dy does not rely on output node, so if y is a valuenode, we can free output.
344     if (PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex1])) {
345       *output = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(*output);
346       MS_LOG(DEBUG) << "Clear device address for the output of " << op_name;
347     }
348   }
349 }
350 
FreeUselessValue(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)351 void FreeUselessValue(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
352   MS_EXCEPTION_IF_NULL(op_run_info);
353   MS_EXCEPTION_IF_NULL(top_cell);
354   if (top_cell->is_high_order_top_cell()) {
355     return;
356   }
357 
358   const auto &unused_inputs = BpropExpander::GetUnusedInputs(op_run_info->op_grad_info->op_prim->name());
359   for (const auto i : unused_inputs) {
360     if (i < op_run_info->input_size) {
361       // Free bprop not used input
362       op_run_info->op_grad_info->input_value[i] =
363         PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(op_run_info->op_grad_info->input_value[i]);
364     } else if (i == op_run_info->input_size) {
365       // Process output, free bprop not used output
366       op_run_info->op_grad_info->out_value =
367         PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(op_run_info->op_grad_info->out_value);
368     }
369   }
370 
371   // Free special op memory
372   FreeSpecialOpValue(op_run_info->op_grad_info->op_prim->name(), op_run_info, &op_run_info->op_grad_info->out_value);
373 }
374 
CreateOpGradParam(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)375 GradParamPtr CreateOpGradParam(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
376   MS_EXCEPTION_IF_NULL(op_run_info);
377   bool out_used_in_bporp_graph = true;
378   op_run_info->op_grad_info->out_value = op_run_info->real_out;
379   FreeUselessValue(op_run_info, top_cell);
380 
381   op_run_info->op_grad_info->out_abs = op_run_info->base_op_run_info.abstract;
382   auto grad_param = std::make_shared<GradParam>(op_run_info->op_grad_info, top_cell->use_dynamic_shape_process());
383   grad_param->out_used_in_bporp_graph = out_used_in_bporp_graph;
384   return grad_param;
385 }
386 
CheckBpropCutNode(const TopCellInfoPtr & top_cell,const PrimitivePtr & op_prim)387 void CheckBpropCutNode(const TopCellInfoPtr &top_cell, const PrimitivePtr &op_prim) {
388   MS_EXCEPTION_IF_NULL(op_prim);
389   if (top_cell->has_bprop_cut_op()) {
390     return;
391   }
392   if (op_prim->name() == kHookBackwardName || op_prim->name() == kCellBackwardHookName) {
393     top_cell->set_has_bprop_cut_op(true);
394   }
395 }
396 
CloneParameter(const AnfNodePtr & node,const KernelGraphPtr & new_graph)397 void CloneParameter(const AnfNodePtr &node, const KernelGraphPtr &new_graph) {
398   MS_EXCEPTION_IF_NULL(node);
399   MS_EXCEPTION_IF_NULL(new_graph);
400   auto old_param = node->cast<ParameterPtr>();
401   MS_EXCEPTION_IF_NULL(old_param);
402   auto new_param = new_graph->add_parameter();
403   new_param->set_name(old_param->name());
404   if (auto t = PyNativeAlgo::Common::GetTensorFromParam(old_param); t != nullptr) {
405     const auto &param_info = t->param_info();
406     if (param_info != nullptr) {
407       const auto &param_name = param_info->name();
408       new_param->set_name(param_name);
409       new_param->debug_info()->set_name(param_name);
410     }
411     new_param->set_default_param(t);
412   }
413   new_param->set_abstract(old_param->abstract());
414   new_param->set_scope(old_param->scope());
415 }
416 
CloneKernelGraph(const FuncGraphPtr & func_graph)417 KernelGraphPtr CloneKernelGraph(const FuncGraphPtr &func_graph) {
418   MS_EXCEPTION_IF_NULL(func_graph);
419   MS_LOG(DEBUG) << "Begin clone kernel graph";
420   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
421   MS_EXCEPTION_IF_NULL(kernel_graph);
422   auto new_graph = std::make_shared<session::KernelGraph>();
423   const auto &params = kernel_graph->parameters();
424   for (auto &param : params) {
425     CloneParameter(param, new_graph);
426   }
427   auto out = InlineClone(kernel_graph, new_graph, new_graph->parameters());
428   new_graph->set_output(out);
429   PyNativeAlgo::Common::FreeFuncGraphForwardNodes(func_graph);
430   return new_graph;
431 }
432 
GetInputArgsId(const py::args & args)433 std::string GetInputArgsId(const py::args &args) {
434   std::string input_args_id;
435   for (size_t i = 0; i < args.size(); ++i) {
436     input_args_id += PyNativeAlgo::PyParser::GetIdByPyObj(args[i]) + "_";
437   }
438   return input_args_id;
439 }
440 
SetCustomBpropInputs(const py::object & obj,const InputArgsInfoPtr & input_args_info)441 void SetCustomBpropInputs(const py::object &obj, const InputArgsInfoPtr &input_args_info) {
442   if (py::hasattr(obj, kUsedBpropInputs)) {
443     py::object object = py::getattr(obj, kUsedBpropInputs);
444     if (!py::isinstance<py::tuple>(object) && !py::isinstance<py::list>(object)) {
445       MS_LOG(EXCEPTION) << "For cell bprop, used bprop inputs sholud be tuple or list";
446     }
447     auto used_bprop_inputs = py::cast<py::tuple>(object);
448     std::unordered_set<int64_t> used_inputs;
449     for (size_t i = 0; i < used_bprop_inputs.size(); ++i) {
450       const auto value = PyNativeAlgo::DataConvert::PyObjToValue(used_bprop_inputs[i]);
451       MS_EXCEPTION_IF_NULL(value);
452       int used_index = GetValue<int64_t>(value);
453       (void)used_inputs.insert(used_index);
454     }
455     const size_t input_size = input_args_info->input_arg_value_vec.size();
456     for (size_t i = 0; i < input_size; ++i) {
457       const auto &input_value = input_args_info->input_arg_value_vec[i];
458       if (used_inputs.find(i) == used_inputs.end()) {
459         auto fake_value = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(input_value);
460         input_args_info->input_arg_value_vec[i] = fake_value;
461       }
462     }
463     if (used_inputs.find(input_size) == used_inputs.end()) {
464       auto fake_value = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(input_args_info->out_value);
465       input_args_info->out_value = fake_value;
466     }
467   }
468 
469   if (py::hasattr(obj, kInternalParams)) {
470     py::object weights = py::getattr(obj, kInternalParams);
471     if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
472       auto weights_tuple = py::cast<py::tuple>(weights);
473       for (size_t i = 0; i < weights_tuple.size(); ++i) {
474         const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
475         auto tensor = value->cast<tensor::TensorPtr>();
476         MS_EXCEPTION_IF_NULL(tensor);
477         (void)input_args_info->input_arg_value_vec.emplace_back(tensor);
478         (void)input_args_info->input_arg_id_vec.emplace_back(tensor->id());
479       }
480     }
481   }
482 }
483 }  // namespace
484 
forward() const485 ForwardExecutorPtr GradExecutor::forward() const {
486   auto forward_executor = forward_executor_.lock();
487   MS_EXCEPTION_IF_NULL(forward_executor);
488   return forward_executor;
489 }
490 
Init()491 void GradExecutor::Init() {
492   if (init_) {
493     return;
494   }
495 #ifdef _MSC_VER
496   static WinBpropRegister reg;
497   reg.DoNothing();
498   MS_LOG(DEBUG) << "Do windows bprop expander register";
499 #endif
500   init_ = true;
501 }
502 
PopTopCellStack()503 TopCellInfoPtr GradExecutor::PopTopCellStack() {
504   if (top_cell_stack_.empty()) {
505     MS_LOG(EXCEPTION) << "Stack top cell stack is empty";
506   }
507   MS_LOG(DEBUG) << "Pop top cell " << top_cell_stack_.top() << " on top cell stack";
508   top_cell_stack_.pop();
509   TopCellInfoPtr top_cell = nullptr;
510   if (!top_cell_stack_.empty()) {
511     top_cell = top_cell_stack_.top();
512   }
513   top_cell == nullptr ? MS_LOG(DEBUG) << "Top cell stack has no top cell"
514                       : MS_LOG(DEBUG) << "Top cell stack size " << top_cell_stack_.size();
515   return top_cell;
516 }
517 
PushInputArgsInfoStack(const InputArgsInfoPtr & input_args_info)518 void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
519   input_args_info_stack_.push(input_args_info);
520 }
521 
PopInputArgsInfoStack()522 void GradExecutor::PopInputArgsInfoStack() {
523   if (input_args_info_stack_.empty()) {
524     MS_LOG(EXCEPTION) << "Stack input_args_info_stack_ is empty";
525   }
526   input_args_info_stack_.pop();
527 }
528 
HandleInputArgsForTopCell(const InputArgsInfoPtr & input_args_info)529 void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info) {
530   MS_EXCEPTION_IF_NULL(input_args_info);
531   // Convert input args to parameters for top cell graph in construct.
532   std::vector<ValuePtr> input_param_values;
533   const auto &input_value = input_args_info->input_arg_value_vec;
534   if (input_args_info->input_size != 0 && input_value.empty()) {
535     MS_LOG(EXCEPTION) << "Input value is empty";
536   }
537 
538   AbstractBasePtrList abs_list;
539   for (size_t i = 0; i < input_args_info->input_size; ++i) {
540     const auto &v = input_value[i];
541     auto param_i_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(v->ToAbstract());
542     if (!top_cell()->is_bprop_need_get_forward_graph()) {
543       (void)PyNativeAlgo::Common::SetValueGradInfo(v, top_cell(), InputType::kInput);
544       (void)input_param_values.emplace_back(v);
545       (void)abs_list.emplace_back(param_i_abs);
546     }
547     RecordForwardGraphForInput(v, input_args_info->input_arg_id_vec[i], param_i_abs);
548   }
549   if (top_cell_->is_bprop_need_get_forward_graph()) {
550     MS_LOG(DEBUG) << "Run bprop function, no need do prepare for grad";
551     return;
552   }
553   // If New cellid come up, bprop graph use cnode for reusing
554   if (IsCreateIrGrad()) {
555     top_cell_->set_is_ir_grad(true);
556   }
557   if (top_cell_->is_ir_grad()) {
558     top_cell_->set_auto_grad_cell_ptr(
559       std::make_shared<autograd::IrGrad>(input_param_values, abs_list, op_num_in_bprop_graph_ * kContainerRatio,
560                                          assist_queue_, !top_cell_->is_high_order_top_cell(), is_run_recompute_));
561   } else {
562     top_cell_->set_auto_grad_cell_ptr(
563       std::make_shared<autograd::FuncGrad>(input_param_values, op_num_in_bprop_graph_ * kContainerRatio,
564                                            !top_cell_->is_high_order_top_cell(), is_run_recompute_));
565   }
566 }
567 
IsCreateIrGrad()568 bool GradExecutor::IsCreateIrGrad() {
569   if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
570     // If the already run cell id is pipeline top cell map, no need store in already_run_top_cell_ again when run
571     // CheckNeedCompileGraph
572     if (pipeline_top_cell_map_.find(top_cell_->already_run_cell_id()) == pipeline_top_cell_map_.end()) {
573       top_cell_->set_need_compile_graph(true);
574       // If top cell can not find in both already_run_top_cell_ and pipeline_top_cell_map_ can be create new ir
575       if (!top_cell_->use_dynamic_shape_process()) {
576         return true;
577       }
578     }
579     return false;
580   }
581   return false;
582 }
583 
InitResourceAndDfBuilder(const InputArgsInfoPtr & input_args_info,bool is_bprop_need_get_forward_graph)584 void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_info,
585                                             bool is_bprop_need_get_forward_graph) {
586   MS_LOG(DEBUG) << "InitResourceAndDfBuilder";
587   MS_EXCEPTION_IF_NULL(input_args_info);
588   forward()->WaitForwardTask();
589   // We need wait construct bprop task of outer top cell finish, if the main thread runs quickly when it executes
590   // gradnet and clear bprop_queue queue, bprop task of outer top cell may not finish, it will cause not found cnode
591   // error.
592   WaitBpropTask();
593   if (input_args_info->is_grad_topest_cell) {
594     MS_LOG(DEBUG) << "Make new topest graph";
595     ResetMetaGradInfoForNewTopCell(input_args_info);
596     MakeNewTopCell(input_args_info);
597   } else if (input_args_info->is_high_order_top_cell) {
598     MS_LOG(DEBUG) << "Nested grad graph existed in construct";
599 
600     // High-order inputs are uplevel top cell ops output, so need back up meta grad info too.
601     for (auto &item : input_args_info->input_arg_value_vec) {
602       top_cell_->BackUpValueMetaGradInfo(item);
603     }
604     ResetMetaGradInfoForNewTopCell(input_args_info);
605     MakeNewTopCell(input_args_info);
606     // High-order must use ir grad
607     top_cell_->set_is_ir_grad(true);
608   } else if (is_bprop_need_get_forward_graph) {
609     MS_LOG(DEBUG) << "Run custom bprop function and make forward graph";
610     // Make top cell just for get forward graph, but no need do anything about grad
611     MakeNewTopCell(input_args_info);
612     curr_g()->debug_info()->set_name("bprop_forward_graph");
613     top_cell_->set_is_bprop_need_get_forward_graph(is_bprop_need_get_forward_graph);
614   }
615 
616   // Init kPynativeCellPtr with input parameters of top cell
617   if (!top_cell_->is_init_kpynative()) {
618     auto graph_info_cg = std::make_shared<PyNGraphInfo>();
619     top_cell_->SetGraphInfoMap(curr_g(), graph_info_cg);
620     HandleInputArgsForTopCell(input_args_info);
621     top_cell_->set_init_kpynative(true);
622   }
623 }
624 
ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr & input_args_info) const625 void GradExecutor::ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr &input_args_info) const {
626   // To fix the scene that user calls twice forward network with grad flag, and then call grad() interface.
627   // We need to clear last top cell's parameters grad info to avoid influencing construct bprop graph of current top
628   // cell.
629   if (top_cell_ != nullptr) {
630     MS_LOG(DEBUG) << "Reset meta grad info for top cell " << top_cell_;
631     top_cell_->ResetMetaGradInfo();
632   }
633 
634   // To fix the scene like 1. net(x1) 2. x2 = deepcopy(x1), 3. net(x2) 3. grad_net(x2). 4. grad_net(x1)
635   // x1's auto_grad_meta_data will be copy to x2, x2 grad will use the same auto_grad_meta_data and clear x1's variable
636   // and set x2's variable.
637   // When execute grad_net(x1), x1's variable will not found, so we need clear input's auto_grad_meta_data before
638   // execute.
639   for (auto &item : input_args_info->input_arg_value_vec) {
640     top_cell_->ClearValueMetaGradInfo(item);
641   }
642 }
643 
NewGraphInner(const py::object & obj,const py::args & args)644 void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
645   // Run custom bprop function, and bprop function is under high-order
646   // If bprop forward graph has been made, new top cell creates severing for it, and current top_cell_ it is.
647   bool running_bprop_function = top_cell_ != nullptr && top_cell_->grad_is_running();
648   bool is_bprop_need_get_forward_graph = running_bprop_function && top_cell_->is_high_order_top_cell();
649 
650   const auto input_args_info = GetInputArgsInfo(obj, args, is_bprop_need_get_forward_graph);
651   PushInputArgsInfoStack(input_args_info);
652   MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << ", cell_id " << PyNativeAlgo::PyParser::GetIdByPyObj(obj)
653                 << ", input args info ptr " << input_args_info.get();
654 
655   // Make top graph and init resource
656   if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell ||
657       is_bprop_need_get_forward_graph) {
658     InitResourceAndDfBuilder(input_args_info, is_bprop_need_get_forward_graph);
659   }
660 }
661 
GetInputArgsInfo(const py::object & obj,const py::args & args,bool is_bprop_need_get_forward_graph)662 InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args,
663                                                 bool is_bprop_need_get_forward_graph) {
664   const auto &input_args_info = std::make_shared<InputArgsInfo>(input_args_info_stack_.empty(), IsHighOrderTopCell());
665   ParsePyArgsToInputArgsInfo(input_args_info, obj, args, is_bprop_need_get_forward_graph);
666 
667   if (input_args_info->has_custom_bprop) {
668     custom_bprop_cell_count_ += 1;
669   }
670   // CheckAlready run first, grad_order_ will increase 1(highorder scenario)
671   // If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
672   if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell) {
673     if (grad_order_ == 0) {
674       IncreaseGradOrder();
675     }
676     input_args_info->already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id);
677     MS_LOG(DEBUG) << "Get already run top cell id " << input_args_info->already_run_cell_id;
678     // top_input_args_info_ indicate current running cell info
679     top_input_args_info_ = input_args_info;
680   }
681   return input_args_info;
682 }
683 
GetTopCellDynamicFlag(const InputArgsInfoPtr & input_args_info,const std::string & obj_id_with_grad_order)684 bool GradExecutor::GetTopCellDynamicFlag(const InputArgsInfoPtr &input_args_info,
685                                          const std::string &obj_id_with_grad_order) {
686   MS_EXCEPTION_IF_NULL(input_args_info);
687   // Just has a forward process, and forward is dynamic(by set_inputs)
688   if (forward_use_dynamic_shape_process_) {
689     MS_LOG(DEBUG) << "Get forward dynamic";
690     return true;
691   }
692 
693   // Set by set_inputs
694   if (dynamic_inputs_cells_.find(input_args_info->obj_id) != dynamic_inputs_cells_.end()) {
695     MS_LOG(DEBUG) << "Get dynamic from set inputs";
696     return true;
697   }
698 
699   // Dynamic structure
700   auto pre_top_cell = GetAlreadyRunTopCell(input_args_info->already_run_cell_id);
701   if (pre_top_cell != nullptr && pre_top_cell->use_dynamic_shape_process()) {
702     MS_LOG(DEBUG) << "Get dynamic shape from already run top cell";
703     return true;
704   }
705 
706   // Dynamic structure for pipeline top cell
707   pre_top_cell = GetPipelineRunTopCell(input_args_info->already_run_cell_id);
708   if (pre_top_cell != nullptr && pre_top_cell->use_dynamic_shape_process()) {
709     MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell";
710     return true;
711   }
712 
713   // Dynamic shape
714   if (std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
715                   [&obj_id_with_grad_order](const auto &item) {
716                     if (item.second != nullptr && item.second->obj_id_with_grad_order() == obj_id_with_grad_order) {
717                       return item.second->use_dynamic_shape_process();
718                     }
719                     return false;
720                   })) {
721     MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell with obj_id_with_grad_order " << obj_id_with_grad_order;
722     return true;
723   }
724 
725   // Dynamic shape for pipeline top cell
726   return std::any_of(
727     pipeline_top_cell_map_.begin(), pipeline_top_cell_map_.end(), [&obj_id_with_grad_order](const auto &item) {
728       const auto &pipe_top_cell_list = item.second;
729       if (std::any_of(pipe_top_cell_list.begin(), pipe_top_cell_list.end(),
730                       [&obj_id_with_grad_order](const auto &pipe_item) {
731                         if (pipe_item != nullptr && pipe_item->obj_id_with_grad_order() == obj_id_with_grad_order) {
732                           return pipe_item->use_dynamic_shape_process();
733                         }
734                         return false;
735                       })) {
736         MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell with obj_id_with_grad_order "
737                       << obj_id_with_grad_order;
738         return true;
739       }
740       return false;
741     });
742 }
743 
MakeNewTopCell(const InputArgsInfoPtr & input_args_info)744 void GradExecutor::MakeNewTopCell(const InputArgsInfoPtr &input_args_info) {
745   MS_EXCEPTION_IF_NULL(input_args_info);
746 
747   auto fg = std::make_shared<FuncGraph>();
748   fg->debug_info()->set_name("pynative_forward_graph");
749   auto resource = std::make_shared<pipeline::Resource>();
750 
751   finded_top_cell_ = nullptr;
752   bool new_top_cell_is_pipeline_top_cell = NewTopCellIsPipelineTopCell(input_args_info);
753 
754   bool new_top_cell_is_pipeline_high_order =
755     input_args_info->is_high_order_top_cell && new_top_cell_is_pipeline_top_cell;
756   // If outer layer top cell is also pipeline top cell, top cell stack maybe empty. Here, need push it to top cell stack
757   // too when running MakeNestedCnode or running bprop function(and brpop function has anoher grad). Because it is need
758   // to known who is outer layer top cell when inner run finished.
759   if (top_cell_ != nullptr && top_cell_->is_pipeline_top_cell() &&
760       (new_top_cell_is_pipeline_high_order || top_cell_->grad_is_running())) {
761     PushTopCellStack(top_cell_);
762   }
763 
764   const auto &obj_id_with_grad_order = GetAlreadyRunCellId(input_args_info->obj_id);
765   MS_LOG(DEBUG) << "Get obj id with grad order " << obj_id_with_grad_order;
766   top_cell_ = std::make_shared<TopCellInfo>(
767     input_args_info->is_high_order_top_cell, grad_order_, obj_id_with_grad_order, input_args_info->cell_id,
768     input_args_info->already_run_cell_id, resource, fg, op_num_in_bprop_graph_ * kContainerRatio);
769   top_cell_->set_forward_already_run(true);
770   top_cell_->set_input_args_id(input_args_info->input_args_id);
771   auto use_dynamic_shape_process = GetTopCellDynamicFlag(input_args_info, obj_id_with_grad_order);
772   top_cell_->set_use_dynamic_shape_process(use_dynamic_shape_process);
773   top_cell_->set_need_save_dynamic_detect_nodes(
774     dynamic_shape()->IsNeedSaveDynamicDetectNodes(top_cell_, use_dynamic_shape_process));
775   top_cell_->set_input_args_info(top_input_args_info_);
776   if (dynamic_shape()->enable_unknown_shape()) {
777     dynamic_shape()->TryChangeTopCellToUnknownShape(top_input_args_info_->obj_id,
778                                                     top_input_args_info_->input_arg_base_shape_vec, true);
779   }
780   top_cell_->set_has_bprop_cut_op(input_args_info->has_custom_bprop);
781   top_cell_->set_grad_first(grad_first_);
782   grad_first_ = false;
783   MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << ", top cell ptr " << top_cell_.get()
784                 << " with input args id " << top_cell_->input_args_id();
785 
786   if (new_top_cell_is_pipeline_top_cell) {
787     pipeline_top_cell_map_[input_args_info->already_run_cell_id].emplace_back(top_cell_);
788     top_cell_->set_is_pipeline_top_cell(true);
789     // If pipeline top cell is high-order, it need to be manage by stack when run MakeNestedCnode, so push it to stack.
790     if (top_cell_->is_high_order_top_cell()) {
791       PushTopCellStack(top_cell_);
792     }
793     MS_LOG(DEBUG) << "Create pipeline top cell, input args id " << top_cell_->input_args_id()
794                   << ". The pipeline map size now "
795                   << pipeline_top_cell_map_[input_args_info->already_run_cell_id].size();
796   } else {
797     // Common top cell
798     PushTopCellStack(top_cell_);
799   }
800 }
801 
NewTopCellIsPipelineTopCell(const InputArgsInfoPtr & input_args_info)802 bool GradExecutor::NewTopCellIsPipelineTopCell(const InputArgsInfoPtr &input_args_info) {
803   // net.set_grad.
804   // pipeline, net(input1), grad(net)(input1), net(input2), grad(net)(input2),...
805   const auto it = pipeline_top_cell_map_.find(input_args_info->already_run_cell_id);
806   if (it != pipeline_top_cell_map_.end()) {
807     // First pipeline top cell
808     MS_EXCEPTION_IF_CHECK_FAIL(!it->second.empty(), "Pipeline top cel map is empty");
809 
810     // net.set_grad
811     // grad(net)(input1) -> this will generate a element in already_run_top_cell_ and do a complete grad operation;
812     // Then, run another net(input1) -> this will think it do upgrade op info because a complete grad operation have
813     // done before; But then run another net(input1) -> this will get pipeline top cell and which should be have 2
814     // elements, and they are need compile ir graph because this is the first step for running pipeline top cell.
815     // Then, run grad(net)(input1) -> this will find the matched top cell in already_run_top_cell_ because
816     // already_run_cell_id is matched, but this is not correct because current process is in pipeline top cell now. So,
817     // this will meet a error of auto grad meta.
818     // Erase top cell info from already_run_top_cell_ is need.
819     auto iter = already_run_top_cell_.find(input_args_info->already_run_cell_id);
820     if (iter != already_run_top_cell_.end()) {
821       MS_LOG(DEBUG) << "Erase top cell from already run top cell";
822       // Need use ir top cell, current top cell in pipeline_top_cell_map_ is func grad. So, need exchange.
823       it->second.front() = iter->second;
824       it->second.front()->set_need_compile_graph(true);
825       already_run_top_cell_.erase(iter);
826     }
827     it->second.front()->set_is_pipeline_top_cell(true);
828     return true;
829   }
830   // net.set_grad.
831   // 1. grad(net)(input), top cell id will include grad_operation_;
832   // 2. net(input1), grad(net)(input1), net(input2), grad(net)(input2), ..., top cell id not include grad_operation_.
833   // In second step, grad(net)(input) should be pipeline cell too.
834   auto iter = std::find_if(pipeline_top_cell_map_.begin(), pipeline_top_cell_map_.end(),
835                            [&input_args_info](const auto &iter_pipe) {
836                              return input_args_info->already_run_cell_id.find(iter_pipe.first) != std::string::npos;
837                            });
838   if (iter != pipeline_top_cell_map_.end()) {
839     input_args_info->already_run_cell_id = iter->first;
840     return true;
841   }
842   return false;
843 }
844 
SetForwardLastNodeInfo(const ValuePtr & v) const845 void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v) const {
846   MS_EXCEPTION_IF_NULL(v);
847   auto value = v;
848   if (v->isa<tensor::CSRTensor>()) {
849     auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
850     value = csr_tensorptr->GetValues();
851   } else if (v->isa<tensor::COOTensor>()) {
852     auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
853     value = coo_tensorptr->GetValues();
854   }
855   (void)PyNativeAlgo::Common::SetValueGradInfo(value, top_cell_, InputType::kConstant);
856   // Set last output abstract and will be used for sens
857   auto fake_v = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(value);
858   top_cell()->SetLastOutputValueForwardOutputFlag(fake_v);
859   if (forward()->enable_async()) {
860     auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
861     auto task = [auto_grad_cell_ptr, fake_v]() { auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(fake_v); };
862     DispatchGradQueueTask(std::move(task));
863   } else {
864     top_cell()->auto_grad_cell_ptr()->UpdateOutputNodeOfTopCell(fake_v);
865   }
866 }
867 
EndGraphInner(const py::object & obj,const py::object & out,const py::args & args)868 void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
869   if (input_args_info_stack_.empty()) {
870     return;
871   }
872   const auto input_args_info = input_args_info_stack_.top();
873   MS_EXCEPTION_IF_NULL(input_args_info);
874   MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << ", cell_id " << PyNativeAlgo::PyParser::GetIdByPyObj(obj)
875                 << ", input args info ptr " << input_args_info.get();
876   if (input_args_info->is_grad_topest_cell) {
877     grad_flag_ = false;
878   }
879 
880   // If there is a custom bprop in the forward running process of the cell, need to do DoGradForCustomBprop;
881   // If the top cell is only for obtaining the forward graph, there is no need to do DoGradForCustomBprop.
882   bool need_do_custom_bprop_grad = false;
883   if (input_args_info->has_custom_bprop && custom_bprop_cell_count_ != 0) {
884     --custom_bprop_cell_count_;
885     need_do_custom_bprop_grad = custom_bprop_cell_count_ == 0;
886   }
887   if (!top_cell()->is_bprop_need_get_forward_graph() && need_do_custom_bprop_grad) {
888     GetCustomBpropPrim(obj, args, input_args_info);
889     runtime::OpExecutor::GetInstance().WaitAll();
890     input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out, false);
891     // Recompute need to regardless of non tensor inputs, maybe it is a middle cell and not call EndGraphImpl
892     if (input_args_info->is_need_recompute) {
893       input_args_info->out_value =
894         ConvertOutputValueToTensor(input_args_info->out_value, !top_cell()->jit_out_has_dict());
895     }
896     const auto &out_id = PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value);
897     SetCustomBpropInputs(obj, input_args_info);
898     DoGradForCustomBprop(input_args_info, out_id);
899   }
900 
901   // Get top cell endgraph
902   if (input_args_info->cell_id == top_cell()->cell_id()) {
903     runtime::OpExecutor::GetInstance().WaitAll();
904     if (input_args_info->out_value == nullptr) {
905       input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out, false);
906     }
907     MS_LOG(DEBUG) << "Get cell output value " << input_args_info->out_value->ToString();
908     EndGraphImpl(input_args_info);
909   }
910   PopInputArgsInfoStack();
911 }
912 
EndGraphImpl(const InputArgsInfoPtr & input_args_info)913 void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
914   auto out_tensor = ConvertOutputValueToTensor(input_args_info->out_value, !top_cell()->jit_out_has_dict());
915   std::vector<std::string> output_tensors_id;
916   PyNativeAlgo::DataConvert::ConvertValueTensorId(out_tensor, &output_tensors_id);
917   top_cell()->set_outputs_ids(std::move(output_tensors_id));
918   if (out_tensor != nullptr) {
919     input_args_info->out_value = out_tensor;
920   }
921 
922   // If network runs twice, and one of the runs is an empty network, the following judgment will take effect
923   ValuePtrList inputs{input_args_info->out_value};
924   AbstractBasePtrList abs{PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args_info->out_value->ToAbstract())};
925   auto node_info = std::make_shared<DynamicDetectNodeInfo>(nullptr, abs, nullptr);
926   (void)dynamic_shape()->CheckNodeDynamic(top_cell(), inputs, node_info);
927 
928   // Just only dump the last forward graph or bprop forward graph
929   if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
930     auto output_node =
931       GetInput(input_args_info->out_value, PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value));
932     curr_g()->set_output(output_node);
933     PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
934     MS_LOG(DEBUG) << "Save forward graph";
935   }
936   if (top_cell_->is_bprop_need_get_forward_graph()) {
937     MS_LOG(DEBUG) << "Run bprop no need do grad";
938     return;
939   }
940 
941   // Set sens value for grad
942   SetForwardLastNodeInfo(input_args_info->out_value);
943 
944   if (input_args_info->is_grad_topest_cell) {
945     MS_LOG(DEBUG) << "Cur top last cell " << input_args_info->cell_id;
946     top_cell()->ClearCellHookOp();
947   }
948 
949   top_cell()->CheckSubCellHookChanged();
950   // Checkout whether you need to compile graph when each top cell has run finished
951   CheckNeedCompileGraph(input_args_info);
952   if (!top_cell_->grad_first()) {
953     DecreaseGradOrder();
954   }
955   top_input_args_info_ = input_args_info;
956 }
957 
DoGradForCustomBprop(const InputArgsInfoPtr & input_args_info,const std::string & out_id) const958 void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) const {
959   MS_EXCEPTION_IF_NULL(input_args_info);
960   MS_EXCEPTION_IF_NULL(input_args_info->custom_bprop_prim);
961   auto op_run_info = std::make_shared<FrontendOpRunInfo>();
962   op_run_info->requires_grad = true;
963   op_run_info->base_op_run_info.op_name = input_args_info->custom_bprop_prim->name();
964   op_run_info->op_grad_info->op_prim = input_args_info->custom_bprop_prim;
965   op_run_info->op_grad_info->input_value = input_args_info->input_arg_value_vec;
966   op_run_info->op_grad_info->is_need_recompute = input_args_info->is_need_recompute;
967   op_run_info->input_size = input_args_info->input_arg_value_vec.size();
968   op_run_info->input_value_id = input_args_info->input_arg_id_vec;
969   op_run_info->real_out = input_args_info->out_value;
970   op_run_info->out_value_id = out_id;
971   op_run_info->base_op_run_info.abstract =
972     PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args_info->out_value->ToAbstract());
973   op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
974   for (size_t i = 0; i < op_run_info->input_size; ++i) {
975     const auto &value = input_args_info->input_arg_value_vec[i];
976     (void)op_run_info->op_grad_info->input_abs.emplace_back(
977       PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract()));
978     op_run_info->op_grad_info->input_value_grad_type[i] =
979       PyNativeAlgo::Common::SetValueGradInfo(value, top_cell(), InputType::kConstant);
980   }
981   op_run_info->op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_run_info->real_out);
982   (void)PyNativeAlgo::Common::SetValueGradInfo(op_run_info->real_out, nullptr, InputType::kOpOutput);
983   DoOpGrad(op_run_info);
984   top_cell()->GetOpInfo(op_run_info, false);
985   UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
986                                              op_run_info->base_op_run_info.stream_id);
987   auto node_info = std::make_shared<DynamicDetectNodeInfo>(
988     op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_abs, op_run_info->base_op_run_info.abstract);
989   (void)dynamic_shape()->CheckNodeDynamic(top_cell(), op_run_info->op_grad_info->input_value, node_info);
990   RecordForwardGraph(op_run_info);
991 }
992 
GetCustomBpropPrim(const py::object & obj,const py::args & args,const InputArgsInfoPtr & input_args_info)993 void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args,
994                                       const InputArgsInfoPtr &input_args_info) {
995   MS_EXCEPTION_IF_NULL(input_args_info);
996   MS_LOG(DEBUG) << "Do grad for custom bprop";
997   py::function bprop_func = py::getattr(obj, parse::CUSTOM_BPROP_NAME);
998   py::object code_obj = py::getattr(bprop_func, "__code__");
999   py::object co_name = py::getattr(code_obj, "co_name");
1000   if (std::string(py::str(co_name)) == "staging_specialize") {
1001     MS_LOG(EXCEPTION) << "Decorating bprop with '@jit' is not supported.";
1002   }
1003 
1004   auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
1005   MS_EXCEPTION_IF_NULL(input_args_info);
1006   if (py::isinstance<Cell>(obj)) {
1007     const auto &cell_ptr = obj.cast<CellPtr>();
1008     input_args_info->is_need_recompute = cell_ptr->HasAttr(kNeedRecompute);
1009     fake_prim->set_bprop_cls_name(cell_ptr->name());
1010   }
1011   if (input_args_info->input_arg_value_vec.empty()) {
1012     for (size_t i = 0; i < args.size(); ++i) {
1013       (void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
1014     }
1015   }
1016   fake_prim->AddBackwardHookFn(0, bprop_func);
1017 
1018   (void)fake_prim->AddAttr("cell_id", MakeValue(input_args_info->cell_id));
1019   (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
1020 
1021   input_args_info->custom_bprop_prim = fake_prim;
1022 }
1023 
ClearPreTopCell(const TopCellInfoPtr & new_top_cell,bool is_need_clear_device_mem)1024 void GradExecutor::ClearPreTopCell(const TopCellInfoPtr &new_top_cell, bool is_need_clear_device_mem) {
1025   MS_EXCEPTION_IF_NULL(new_top_cell);
1026   // Clear already run top cell and device mem
1027   for (auto iter = already_run_top_cell_.begin(); iter != already_run_top_cell_.end();) {
1028     MS_EXCEPTION_IF_NULL(iter->second);
1029     if (iter->second->obj_id_with_grad_order() == new_top_cell->obj_id_with_grad_order()) {
1030       if (is_need_clear_device_mem) {
1031         iter->second->ClearDeviceMemory();
1032         (void)need_gc_top_cell_list_.emplace_back(iter->second);
1033       }
1034       iter = already_run_top_cell_.erase(iter);
1035     } else {
1036       (void)iter++;
1037     }
1038   }
1039 }
1040 
CheckNeedCompileGraph(const InputArgsInfoPtr & input_args_info)1041 void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
1042   const auto &already_top_cell_id = top_cell()->already_run_cell_id();
1043   bool is_new_cell_id = false;
1044   // Get new cell id for common grad, even for dynamic shapes, first step will come in too.
1045   if (top_cell_->need_compile_graph()) {
1046     MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
1047     // net.set_grad.
1048     // 1. grad(net)(input), top cell id will include grad_operation_;
1049     // 2. net(input), grad(net)(input), top cell id not include grad_operation_. But, just only keep one for cccurate
1050     // find when call GetTopCell
1051     auto it = std::find_if(
1052       already_run_top_cell_.begin(), already_run_top_cell_.end(),
1053       [&already_top_cell_id](const auto &item) { return item.first.find(already_top_cell_id) != std::string::npos; });
1054     if (it != already_run_top_cell_.end()) {
1055       already_run_top_cell_.erase(it);
1056     }
1057     already_run_top_cell_[already_top_cell_id] = top_cell_;
1058     is_new_cell_id = true;
1059   }
1060   // First step and first top cell prepare for pipeline if it is
1061   const auto it = pipeline_top_cell_map_.find(top_cell_->already_run_cell_id());
1062   if (it == pipeline_top_cell_map_.end()) {
1063     MS_LOG(DEBUG) << "Prepare the first top cell to be pipeline top cell";
1064     top_cell_->set_need_compile_graph(true);
1065     pipeline_top_cell_map_[already_top_cell_id].emplace_back(top_cell_);
1066     // If the top cell is the first top cell, pipeline top cell backup one and return;
1067     // But, if top cell is not pipeline(run by already run top cell map), in the second step, can not return here, which
1068     // should go down to judge compile status.
1069     if (is_new_cell_id) {
1070       return;
1071     }
1072   }
1073   // Get pipeline top cell in first step, and judge by first top cell have completed a backward run
1074   if (top_cell_->is_pipeline_top_cell()) {
1075     MS_EXCEPTION_IF_CHECK_FAIL(!it->second.empty(), "Pipeline top cel map is empty");
1076     if (!it->second.front()->is_finish_backward()) {
1077       top_cell_->set_need_compile_graph(true);
1078       // Get dynamic structure
1079       if (top_cell_->use_dynamic_shape_process()) {
1080         it->second.front()->set_use_dynamic_shape_process(true);
1081       }
1082       MS_LOG(DEBUG) << "Get pipeline top cell has never been ran, input args " << top_cell_->input_args_id();
1083       return;
1084     }
1085   }
1086 
1087   // Older top cell id or dynamic shape
1088   MS_EXCEPTION_IF_NULL(input_args_info);
1089   // In high-order situations, the internal top cell has changed, but the outer top cell remains unchanged. Then outer
1090   // bprop graph needs to compile again
1091   if (top_cell_->use_dynamic_shape_process() || top_cell_->force_top_cell_compile()) {
1092     // Function need compiler every time.
1093     top_cell_->use_dynamic_shape_process() ? MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"
1094                                            : MS_LOG(DEBUG) << "Force outer graph compile graph";
1095     if (!top_cell_->is_pipeline_top_cell()) {
1096       auto has_higher_order = std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
1097                                           [](const auto &elem) { return elem.second->is_high_order_top_cell(); });
1098       ClearPreTopCell(top_cell_, input_args_info->is_grad_topest_cell && !has_higher_order);
1099       already_run_top_cell_[already_top_cell_id] = top_cell_;
1100     } else {
1101       MS_LOG(DEBUG) << "Get pipeline top cell, input args " << top_cell_->input_args_id();
1102     }
1103     top_cell_->set_need_compile_graph(true);
1104     top_cell_->set_force_top_cell_compile(false);
1105   } else {
1106     MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " no need to compile graph again";
1107     if (!top_cell_->is_pipeline_top_cell()) {
1108       top_cell_->set_need_compile_graph(false);
1109       auto pre_top_cell = GetAlreadyRunTopCell(already_top_cell_id);
1110       MS_EXCEPTION_IF_NULL(pre_top_cell);
1111       pre_top_cell->set_input_args_id(top_cell_->input_args_id());
1112       // In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
1113       // the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
1114       if (!input_args_info->is_grad_topest_cell) {
1115         pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), top_cell_->graph_info_map().at(top_cell_->fg()));
1116       }
1117       pre_top_cell->set_forward_already_run(true);
1118       pre_top_cell->set_input_args_info(input_args_info);
1119       top_cell_stack_.top() = pre_top_cell;
1120     } else {
1121       MS_LOG(DEBUG) << "Get pipeline top cell, input args " << top_cell_->input_args_id();
1122     }
1123   }
1124 }
1125 
GetAlreadyRunTopCell(const std::string & already_run_cell_id) const1126 TopCellInfoPtr GradExecutor::GetAlreadyRunTopCell(const std::string &already_run_cell_id) const {
1127   const auto it = already_run_top_cell_.find(already_run_cell_id);
1128   if (it != already_run_top_cell_.end()) {
1129     return it->second;
1130   }
1131   return nullptr;
1132 }
1133 
GetPipelineRunTopCell(const std::string & already_run_cell_id) const1134 TopCellInfoPtr GradExecutor::GetPipelineRunTopCell(const std::string &already_run_cell_id) const {
1135   const auto it = pipeline_top_cell_map_.find(already_run_cell_id);
1136   if (it != pipeline_top_cell_map_.end()) {
1137     return it->second.front();
1138   }
1139   return nullptr;
1140 }
1141 
GetPipelineTopCell(const std::string & already_run_cell_id,const std::string & input_args_id,bool is_reverse_match) const1142 TopCellInfoPtr GradExecutor::GetPipelineTopCell(const std::string &already_run_cell_id,
1143                                                 const std::string &input_args_id, bool is_reverse_match) const {
1144   for (const auto &t : pipeline_top_cell_map_) {
1145     bool is_find = is_reverse_match ? t.first.find(already_run_cell_id) != std::string::npos
1146                                     : already_run_cell_id.find(t.first) != std::string::npos;
1147     if (is_find) {
1148       // If finish backward, skip the first ir top cell
1149       auto begin =
1150         !t.second.empty() && t.second.front()->is_finish_backward() ? t.second.begin() + 1 : t.second.begin();
1151       auto input_args_id_with_top_cell =
1152         std::find_if(begin, t.second.end(), [input_args_id](const TopCellInfoPtr &pipe_top_cell) {
1153           return input_args_id == pipe_top_cell->input_args_id();
1154         });
1155       if (input_args_id_with_top_cell == t.second.end()) {
1156         MS_LOG(DEBUG) << "Can not find top cell with input args id " << input_args_id;
1157         continue;
1158       }
1159       MS_LOG(DEBUG) << "Find pipeline top cell with input args id " << input_args_id;
1160       return *input_args_id_with_top_cell;
1161     }
1162   }
1163   MS_LOG(DEBUG) << "Can not find cell id " << already_run_cell_id << " in pipeline top cell map";
1164   return nullptr;
1165 }
1166 
ErasePipelineTopCell(const std::string & already_run_cell_id,const std::string & input_args_id,bool is_pipeline_ir_top_cell)1167 void GradExecutor::ErasePipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id,
1168                                         bool is_pipeline_ir_top_cell) {
1169   for (auto &t : pipeline_top_cell_map_) {
1170     if (already_run_cell_id.find(t.first) == std::string::npos) {
1171       continue;
1172     }
1173 
1174     // If top cell is pipeline ir top cell and finish backward, skip the first ir top cell
1175     auto begin = !is_pipeline_ir_top_cell && !t.second.empty() && t.second.front()->is_finish_backward()
1176                    ? t.second.begin() + 1
1177                    : t.second.begin();
1178     auto input_args_id_with_top_cell = std::find_if(
1179       begin, t.second.end(),
1180       [input_args_id](const TopCellInfoPtr &pipe_top_cell) { return input_args_id == pipe_top_cell->input_args_id(); });
1181     if (input_args_id_with_top_cell == t.second.end()) {
1182       MS_LOG(DEBUG) << "Can not find top cell with input args id " << input_args_id;
1183       continue;
1184     }
1185     MS_LOG(DEBUG) << "Erase pipeline top cell " << input_args_id_with_top_cell->get() << " with input args id "
1186                   << input_args_id << ". The pipeline map size now " << t.second.size() - 1;
1187     t.second.erase(input_args_id_with_top_cell);
1188     if (t.second.empty()) {
1189       MS_LOG(DEBUG) << "Pipeline top cell map with already run cell id " << already_run_cell_id
1190                     << " is empty, erase it from the pipeline map";
1191       pipeline_top_cell_map_.erase(t.first);
1192     }
1193     return;
1194   }
1195 }
1196 
RunGrad(const prim::GradOperationPtr & grad,const py::object & obj,const py::object & weights,const py::object & grad_position,const py::args & args)1197 py::object GradExecutor::RunGrad(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
1198                                  const py::object &grad_position, const py::args &args) {
1199   // Wait forward task finish.
1200   runtime::OpExecutor::GetInstance().WaitAll();
1201 
1202   GetTopCellWithInputArgsRespectTo(grad, obj, args);
1203   MS_EXCEPTION_IF_NULL(top_cell_);
1204   MS_LOG(DEBUG) << "Run top cell " << top_cell_;
1205 
1206   // Inputs args info must be update to current even no need compile graph again
1207   top_input_args_info_ = top_cell_->input_args_info();
1208   MS_EXCEPTION_IF_NULL(top_input_args_info_);
1209   // Set sens
1210   SetSensValue(grad, top_input_args_info_, args, !top_cell_->jit_out_has_dict());
1211 
1212   MS_LOG(DEBUG) << "RunGrad start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
1213                 << ", input args info ptr " << top_input_args_info_.get();
1214 
1215   if (!top_cell_->need_compile_graph()) {
1216     MS_LOG(DEBUG) << "No need compile graph, graph is ir_grad " << top_cell_->is_ir_grad();
1217     // If no need compile, we can clear construct bprop queue.
1218     (void)need_gc_top_cell_list_.emplace_back(top_cell_);
1219     ClearBpropTask();
1220     top_cell_->ClearMetaGradInfo();
1221 
1222     // If top cell is pipeline top cell, finded_top_cell_ will be itself;
1223     // Otherwise, it ir top cell in already_run_top_cell_;
1224     if (!ReplacePipelineTopCellForwardOutput()) {
1225       finded_top_cell_->set_shadow_top_cell(top_cell_.get());
1226       top_cell_ = finded_top_cell_;
1227       finded_top_cell_ = nullptr;
1228     }
1229     // Top cell clean must after pipeline forward output replace, because replace_info can not be clear
1230     AsyncClearTopCell();
1231     top_cell_->UpdateTopCellInfo(false, false, false);
1232     return RunGradGraph();
1233   }
1234 
1235   MS_LOG(DEBUG) << "Need compile graph, graph is ir_grad " << top_cell_->is_ir_grad();
1236   WaitBpropTask();
1237   AsyncClearTopCell();
1238   top_cell_ = finded_top_cell_;
1239   finded_top_cell_ = nullptr;
1240   op_num_in_bprop_graph_ = top_cell_->op_index();
1241   top_cell_->set_grad_operation(grad_operation_);
1242   top_cell_->UpdateTopCellInfo(false, false, true);
1243   top_cell_->ResumeMetaGradInfo();
1244   SetBpropGraphJitLevel(obj);
1245   bool weight_param_is_tuple = true;
1246   auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple);
1247   auto p_args = GetGradPositionArgs(grad_position, grad->get_by_position_);
1248   autograd::GradAttr grad_attr(grad->get_all_, grad->get_by_list_, grad->sens_param_, grad->get_by_position_,
1249                                weight_param_is_tuple);
1250   if (top_cell_->is_ir_grad()) {
1251     GetGradGraph(grad_attr, w_args, p_args);
1252     return RunGradGraph();
1253   }
1254   return RunGradFunc(grad_attr, w_args, p_args);
1255 }
1256 
GetAlreadyRunCellId(const std::string & obj_id) const1257 std::string GradExecutor::GetAlreadyRunCellId(const std::string &obj_id) const {
1258   std::string already_run_cell_id(obj_id);
1259   already_run_cell_id += "_" + std::to_string(grad_order_ == 0 ? 1 : grad_order_);
1260   already_run_cell_id += "_" + grad_operation_;
1261   return already_run_cell_id;
1262 }
1263 
GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr & grad,const py::object & obj,const py::args & args)1264 void GradExecutor::GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr &grad, const py::object &obj,
1265                                                     const py::args &args) {
1266   auto reset_flag = [this]() {
1267     if (finded_top_cell_->is_pipeline_top_cell()) {
1268       top_cell_ = finded_top_cell_;
1269     } else if (top_cell_ != nullptr &&
1270                finded_top_cell_->already_run_cell_id().find(top_cell_->already_run_cell_id()) == std::string::npos) {
1271       // NetA.set_grad, NetB.set_grad
1272       // then, run NetA(input), NetB(input) for get loss, and then run grad(NetA)(input), grad(NetB)(input).
1273       // But, when run grad(NetA)(input), finded_top_cell_ is grad of NetA, but top cell is grad(NetB)(input), which is
1274       // not matched, so need to do exchange.
1275       // Need do meta grad info reset for NetB because NetB run after NetA and NetB not do this operation in
1276       // MakeNewTopCell. If have same inputs or weight parameters, auto grad meta maybe meet nullptr.
1277       top_cell_->ResetMetaGradInfo();
1278       top_cell_ = finded_top_cell_;
1279     }
1280   };
1281 
1282   if (finded_top_cell_ != nullptr) {
1283     reset_flag();
1284     return;
1285   }
1286   MS_EXCEPTION_IF_NULL(grad);
1287   py::args args_without_sens;
1288   if (grad->sens_param_) {
1289     // If there is a sense, it will not hit the already run cache
1290     auto tuple_args_size = args.size() - 1;
1291     if (tuple_args_size < 0) {
1292       MS_LOG(EXCEPTION) << "args.size:" << args.size() << " tuple_args_size:" << tuple_args_size << " is invalid.";
1293     }
1294     py::tuple tuple_args(tuple_args_size);
1295     for (size_t i = 0; i < tuple_args_size; ++i) {
1296       tuple_args[i] = args[i];
1297     }
1298     args_without_sens = tuple_args;
1299   } else {
1300     args_without_sens = args;
1301   }
1302   const auto &id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args_without_sens);
1303   const auto &cell_id =
1304     PyNativeAlgo::Common::GetCellId(PyNativeAlgo::PyParser::GetIdByPyObj(obj), id_v.first, id_v.second);
1305   const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
1306   const auto &input_args_id = GetInputArgsId(args_without_sens);
1307   MS_LOG(DEBUG) << "Get input cell id " << cell_id << " and already run cell id " << already_run_cell_id
1308                 << ", input args id " << input_args_id;
1309   finded_top_cell_ = GetTopCell(already_run_cell_id, input_args_id);
1310   MS_EXCEPTION_IF_NULL(finded_top_cell_);
1311   reset_flag();
1312 }
1313 
ReplacePipelineTopCellForwardOutput()1314 bool GradExecutor::ReplacePipelineTopCellForwardOutput() {
1315   // If top cell is pipeline top cell, need to get its ir top cell
1316   if (!top_cell_->is_pipeline_top_cell()) {
1317     return false;
1318   }
1319   auto pipeline_ir_top_cell = GetPipelineRunTopCell(top_cell_->already_run_cell_id());
1320   if (pipeline_ir_top_cell == nullptr) {
1321     MS_LOG(EXCEPTION) << "Can not find pipeline ir top cell " << top_cell_->already_run_cell_id()
1322                       << " in pipeline top cell map";
1323   }
1324   UpdatePipelineTopCellFowardTensor(pipeline_ir_top_cell->replace_info(), top_cell_->replace_info());
1325   pipeline_ir_top_cell->set_shadow_top_cell(top_cell_.get());
1326   top_cell_ = pipeline_ir_top_cell;
1327   MS_LOG(DEBUG) << "Run no need compile ir top cell " << top_cell_;
1328   return true;
1329 }
1330 
GetGradGraph(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1331 void GradExecutor::GetGradGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args,
1332                                 const std::vector<size_t> &p_args) {
1333   // Get bprop graph of top cell
1334   auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
1335   auto resource = top_cell()->resource();
1336   MS_EXCEPTION_IF_NULL(resource);
1337   resource->set_func_graph(bprop_graph);
1338   auto manager = resource->manager();
1339   MS_EXCEPTION_IF_NULL(manager);
1340   manager->AddFuncGraph(bprop_graph, true);
1341   bprop_graph->ResetOwnNodes();
1342   // If clear autogradcell before resetownnode, it may corrupt.
1343   AsyncClearAutoGradCell(top_cell());
1344   if (top_cell()->has_control_flow()) {
1345     (void)opt::EnvironConversion(resource);
1346   }
1347   if (top_input_args_info_->sens_type == SensType::kDict) {
1348     PyNativeAlgo::Common::ProcessDictParam(bprop_graph, top_input_args_info_->input_size);
1349   } else if (top_input_args_info_->sens_type == SensType::kTuple) {
1350     PyNativeAlgo::Common::ProcessTupleParam(bprop_graph, top_input_args_info_->input_size);
1351   }
1352   if (top_cell()->jit_out_has_dict()) {
1353     MS_LOG(DEBUG) << "Jit out is dict, need convert make dict to pyexecute";
1354     (void)mindspore::opt::RewriterAfterOptA(resource->func_graph(), resource);
1355   }
1356   top_cell()->SaveForwardOutputTensorInfoInBpropGraph(resource->func_graph());
1357   PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
1358   resource->SetBackendAsync([]() { return compile::CreateBackend(); });
1359   MS_LOG(DEBUG) << "Start task emit action";
1360   (void)TaskEmitAction(resource);
1361   MS_LOG(DEBUG) << "Start execute action";
1362   (void)ExecuteAction(resource);
1363   top_cell()->UpdateTopCellInfo(false, false, true);
1364   resource->Clean();
1365 }
1366 
GetWeightsArgs(const py::object & weights,bool * weight_param_is_tuple) const1367 std::vector<tensor::BaseTensorPtr> GradExecutor::GetWeightsArgs(const py::object &weights,
1368                                                                 bool *weight_param_is_tuple) const {
1369   std::vector<tensor::BaseTensorPtr> w_args;
1370   if (py::hasattr(weights, "__parameter_tuple__")) {
1371     const auto &weights_tuple = weights.cast<py::tuple>();
1372     MS_LOG(DEBUG) << "Get weights tuple size " << weights_tuple.size();
1373     for (size_t i = 0; i < weights_tuple.size(); ++i) {
1374       const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
1375       auto tensor = value->cast<tensor::BaseTensorPtr>();
1376       MS_EXCEPTION_IF_NULL(tensor);
1377       (void)w_args.emplace_back(tensor);
1378     }
1379   } else {
1380     MS_LOG(DEBUG) << "No parameter tuple get, try get weights params by input weight";
1381     if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
1382       auto weights_tuple = py::cast<py::tuple>(weights);
1383       for (size_t i = 0; i < weights_tuple.size(); ++i) {
1384         const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
1385         auto tensor = value->cast<tensor::BaseTensorPtr>();
1386         MS_EXCEPTION_IF_NULL(tensor);
1387         (void)w_args.emplace_back(tensor);
1388       }
1389     } else if (!py::isinstance<py::none>(weights)) {
1390       // Single input
1391       const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights);
1392       auto tensor = value->cast<tensor::BaseTensorPtr>();
1393       (void)w_args.emplace_back(tensor);
1394       MS_EXCEPTION_IF_NULL(tensor);
1395       *weight_param_is_tuple = false;
1396     } else {
1397       return GetDefaultWeights();
1398     }
1399   }
1400   return w_args;
1401 }
1402 
GetDefaultWeights() const1403 std::vector<tensor::BaseTensorPtr> GradExecutor::GetDefaultWeights() const {
1404   std::vector<tensor::BaseTensorPtr> w_args;
1405   for (const auto &params : top_cell()->param_grad_info()) {
1406     const auto &tensor = params.first;
1407     if (tensor->is_parameter()) {
1408       (void)w_args.emplace_back(tensor);
1409     }
1410   }
1411   return w_args;
1412 }
1413 
GetGradPositionArgs(const py::object & grad_position,bool get_by_position) const1414 std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const {
1415   std::vector<size_t> pos_args;
1416   if (!get_by_position) {
1417     return pos_args;
1418   }
1419   if (py::isinstance<py::tuple>(grad_position)) {
1420     const auto &tuple = grad_position.cast<py::tuple>();
1421     (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
1422                          [](const py::handle &elem) { return elem.cast<int64_t>(); });
1423     if (pos_args.empty()) {
1424       MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
1425     }
1426     return pos_args;
1427   }
1428   MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
1429 }
1430 
CheckParamShapeAndType(const ParameterPtr & param_node,const abstract::AbstractBasePtr & ir_abs,const abstract::AbstractBasePtr & input_abs) const1431 void GradExecutor::CheckParamShapeAndType(const ParameterPtr &param_node, const abstract::AbstractBasePtr &ir_abs,
1432                                           const abstract::AbstractBasePtr &input_abs) const {
1433   MS_EXCEPTION_IF_NULL(param_node);
1434   MS_EXCEPTION_IF_NULL(ir_abs);
1435   MS_EXCEPTION_IF_NULL(input_abs);
1436   const auto &ir_shape = ir_abs->BuildShape()->ToString();
1437   const auto &input_shape = input_abs->BuildShape()->ToString();
1438   if (input_shape != "()" && ir_shape != "()") {
1439     if (input_shape != ir_shape) {
1440       // Sens shape in ir graph is determined by graph output, so it can be dynamic shape; But input shape is
1441       // determined by user input, which could not be dynamic shape.
1442       if (param_node->debug_info()->name() != "sens" || !ir_abs->BuildShape()->IsDynamic()) {
1443         MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
1444                                  << param_node->DebugString() << ", ir_abs " << ir_abs->ToString() << ", input_abs "
1445                                  << input_abs->ToString();
1446       }
1447     }
1448     const auto &ir_dtype = ir_abs->BuildType()->ToString();
1449     const auto &input_dtype = input_abs->BuildType()->ToString();
1450     if (input_dtype != ir_dtype) {
1451       MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
1452                               << param_node->DebugString();
1453     }
1454   }
1455 }
1456 
UpdateParamAbsByArgs(const std::vector<ValuePtr> & input_args,const FuncGraphPtr & bprop_graph) const1457 void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
1458                                         const FuncGraphPtr &bprop_graph) const {
1459   MS_EXCEPTION_IF_NULL(bprop_graph);
1460   const auto &bprop_params = bprop_graph->parameters();
1461   // bprop_params include inputs, parameters and sens, should be more than inputs size
1462   if (bprop_params.size() < input_args.size()) {
1463     MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << input_args.size();
1464   }
1465   size_t index = 0;
1466   for (const auto &param : bprop_params) {
1467     auto param_node = param->cast<ParameterPtr>();
1468     if (param_node->has_default()) {
1469       MS_EXCEPTION_IF_NULL(param_node->abstract());
1470     } else {
1471       const auto &input_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args[index]->ToAbstract());
1472       if (param_node->abstract() != nullptr) {
1473         CheckParamShapeAndType(param_node, param_node->abstract(), input_abs);
1474       } else {
1475         param_node->set_abstract(input_abs);
1476       }
1477       ++index;
1478     }
1479   }
1480 }
1481 
GetBpropGraph(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1482 FuncGraphPtr GradExecutor::GetBpropGraph(const autograd::GradAttr &grad_attr,
1483                                          const std::vector<tensor::BaseTensorPtr> &w_args,
1484                                          const std::vector<size_t> &p_args) {
1485   MS_EXCEPTION_IF_NULL(top_input_args_info_);
1486   const auto &auto_grad_cell = std::dynamic_pointer_cast<autograd::IrGrad>(top_cell()->auto_grad_cell_ptr());
1487   MS_EXCEPTION_IF_NULL(auto_grad_cell);
1488   // Update bprop_graph_run_by_single_op for bprop graph, if it is true, pass like ConvertMakeTupleInputToDynamicInput
1489   // will not take effect
1490   auto_grad_cell->set_bprop_graph_run_by_single_op(top_cell()->use_dynamic_shape_process());
1491   FuncGraphPtr bprop_graph = auto_grad_cell->Finish(w_args, p_args, grad_attr);
1492   MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
1493   UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph);
1494   if (top_cell()->need_do_final_opt()) {
1495     bprop_graph = BpropGraphFinalOpt(bprop_graph, top_cell()->has_control_flow());
1496   }
1497   if (top_input_args_info_->is_high_order_top_cell) {
1498     MS_LOG(DEBUG) << "Get high grad";
1499     top_cell()->resource()->set_optimize_graph(bprop_graph);
1500     bool has_bprop_cut = bprop_graph->has_flag(kFlagPyNativeBpropGraphWithBpropCut);
1501     if (bprop_graph->isa<session::KernelGraph>()) {
1502       bprop_graph = CloneKernelGraph(bprop_graph);
1503     } else {
1504       bprop_graph = BasicClone(bprop_graph);
1505     }
1506     if (has_bprop_cut) {
1507       bprop_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
1508     }
1509     PyNativeAlgo::Common::ReplaceCNodeWithValueNode(bprop_graph);
1510   } else {
1511     top_cell()->resource()->set_optimize_graph(bprop_graph);
1512   }
1513   if (bprop_graph->has_flag(kFlagIsControlFlow)) {
1514     top_cell()->set_has_control_flow(true);
1515   }
1516   if (top_cell()->has_control_flow()) {
1517     bprop_graph = LiftingClone(bprop_graph);
1518   }
1519   bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1520   bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
1521   bprop_graph->set_flag(kFlagPyNativeBpropGraphIsDynamic, top_cell()->use_dynamic_shape_process());
1522 
1523   // Update bprop cut flag. Has two scenario:
1524   // 1. kHookBackwardName or kCellBackwardHookName
1525   // 2. Custom op bprop(set in auto_grad.cc by kFlagPyNativeBpropGraphWithBpropCut)
1526   bprop_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut,
1527                         bprop_graph->has_flag(kFlagPyNativeBpropGraphWithBpropCut) || top_cell()->has_bprop_cut_op());
1528 
1529   // Update run graph by single op flag. Has two scenario:
1530   // 1. Dynamic shape(or structure) or Dynamic structure
1531   // 2. Has bprop cut op
1532   // If set_inputs, but has constrol flow, we need run by actor.
1533   bprop_graph->set_flag(kFlagEnableRunGraphBySingleOp,
1534                         auto_grad_cell->bprop_graph_run_by_single_op() && !bprop_graph->has_flag(kFlagIsControlFlow));
1535   top_cell()->set_use_dynamic_shape_process(bprop_graph->has_flag(kFlagEnableRunGraphBySingleOp));
1536   if (top_cell()->has_call_graph()) {
1537     bprop_graph->set_flag(kFlagPyNativeWithJitCallGraph, true);
1538   }
1539   bool has_control_flow = top_cell()->has_control_flow();
1540   bprop_graph->set_flag(kFlagIsPyNativeBpropKernelGraph, !has_control_flow);
1541   // Control graph will generate kernel graph in compile graphs again. Graph id is conflict with default id 0
1542   if (has_control_flow) {
1543     auto kernel_graph = bprop_graph->cast<KernelGraphPtr>();
1544     MS_EXCEPTION_IF_NULL(kernel_graph);
1545     kernel_graph->set_graph_id(kernel_graph_id_for_control_flow());
1546   }
1547   return bprop_graph;
1548 }
1549 
NeedIncreaseGradOrder(const std::string & obj_id)1550 bool GradExecutor::NeedIncreaseGradOrder(const std::string &obj_id) {
1551   // top_cell_ == nullptr means call by grad first
1552   // top_cell_->obj_id_with_grad_order() include obj_id and grad_order
1553   // If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos, means current cell is not top cell,
1554   // another cell or function needs to get grad, so high-order comes up
1555   if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id + "_") == std::string::npos) {
1556     IncreaseGradOrder();
1557     return true;
1558   }
1559   return false;
1560 }
1561 
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & obj,const py::object & weights,const py::object & grad_hash_id,const py::args & args)1562 py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
1563                                          const py::object &weights, const py::object &grad_hash_id,
1564                                          const py::args &args) {
1565   const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
1566 
1567   // The rule of grad order is:
1568   // scenarios 1. net.set_grad, net(input) calls first, increase 1 before MakeNewTopCell, and decrease 1 when running to
1569   // EndGraphImpl, indicating that a complete bprop graph construction is completed; Then call grad(net)(input) is won't
1570   // be affected and get forward_run is true.
1571   // scenarios 2. If grad(net)(input) calls first, then increase 1 before MakeNewTopCell and decrease 1 in Rungrad. The
1572   // reason for this design is that if grad(net)(input) calls first and decrease 1 in EndGraphImpl, it will cause
1573   // matching problems during RunGrad due to the presence of Gradopration information in already_run_cell_id is not the
1574   // same. Gradopration information include grad order for distinguish high-order.
1575   // Use flag: grad_first_ for distinguish this two scenarios. If scenarios 1 is taked, grad_first_ will not take
1576   // effect, otherwise, it works.
1577   bool neee_increase_grad_order = NeedIncreaseGradOrder(obj_id);
1578   // Include weight param size and required grad flag
1579   std::string grad_hash_id_str;
1580   if (!py::isinstance<py::none>(grad_hash_id)) {
1581     grad_hash_id_str = std::string(py::str(grad_hash_id));
1582   }
1583 
1584   std::string weights_obj_id = GetWeightsObjIdsByWeights(weights);
1585   grad_operation_ = std::to_string(static_cast<int>(grad->get_all_)) +
1586                     std::to_string(static_cast<int>(grad->get_by_list_)) +
1587                     std::to_string(static_cast<int>(grad->sens_param_)) + grad_hash_id_str + weights_obj_id;
1588 
1589   auto input_args_id = GetInputArgsId(args);
1590   // Under the condition that the stack is empty (forward process completed or no forward process),
1591   // check whether need to run forward process
1592   bool forward_run = false;
1593   if (input_args_info_stack_.empty()) {
1594     const auto &id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args);
1595     auto cell_id = PyNativeAlgo::Common::GetCellId(obj_id, id_v.first, id_v.second);
1596     const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
1597     MS_LOG(DEBUG) << "Get check already run top cell id " << check_already_run_cell_id;
1598     auto find_top_cell = GetTopCell(check_already_run_cell_id, input_args_id);
1599     if (find_top_cell != nullptr) {
1600       MS_LOG(DEBUG) << "Find already run top cell " << find_top_cell;
1601       forward_run = find_top_cell->forward_already_run();
1602       bool input_args_changed =
1603         !find_top_cell->input_args_id().empty() && find_top_cell->input_args_id() != input_args_id;
1604       if (forward_run && input_args_changed) {
1605         MS_LOG(DEBUG) << "The input info " << input_args_id << " is not the same with pre input info "
1606                       << find_top_cell->input_args_id() << ", forward process will run again";
1607         forward_run = false;
1608       }
1609       // The pipeline top cell finish forward, but grad is the previous pipeline top cell. Need reset auto meta grad
1610       // info
1611       if (top_cell_ != nullptr && top_cell_->is_pipeline_top_cell() && top_cell_->input_args_id() != input_args_id) {
1612         WaitBpropTask();
1613         top_cell_->ResetMetaGradInfo();
1614       }
1615       if (forward_run) {
1616         // If neee_increase_grad_order is true means grad order increased and prepare to do grad;
1617         // But forward run is true now, means no need do forward again, so grad order need be decrease.
1618         if (neee_increase_grad_order) {
1619           DecreaseGradOrder();
1620         }
1621         finded_top_cell_ = find_top_cell;
1622       }
1623     }
1624   }
1625   if (!forward_run) {
1626     grad_first_ = true;
1627   }
1628   forward_run ? MS_LOG(DEBUG) << "Top cell have already ran with input args id " << input_args_id
1629               : MS_LOG(DEBUG) << "Top cell no run before with input args id " << input_args_id;
1630   return BaseRefToPyData(forward_run);
1631 }
1632 
RunGradFunc(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1633 py::object GradExecutor::RunGradFunc(const autograd::GradAttr &grad_attr,
1634                                      const std::vector<tensor::BaseTensorPtr> &w_args,
1635                                      const std::vector<size_t> &p_args) {
1636   MS_EXCEPTION_IF_NULL(top_input_args_info_);
1637   ValuePtr sens = nullptr;
1638   if (grad_attr.has_sens) {
1639     sens = top_input_args_info_->input_arg_value_vec.back();
1640   }
1641 
1642   MS_LOG(DEBUG) << "Eval run begin";
1643   MS_EXCEPTION_IF_NULL(top_cell_);
1644   const auto &auto_grad_cell = std::dynamic_pointer_cast<autograd::FuncGrad>(top_cell_->auto_grad_cell_ptr());
1645   MS_EXCEPTION_IF_NULL(auto_grad_cell);
1646   top_cell_->set_grad_is_running(true);
1647   auto grads = auto_grad_cell->Finish(w_args, p_args, grad_attr, sens);
1648   MS_EXCEPTION_IF_NULL(grads);
1649   MS_EXCEPTION_IF_NULL(top_cell_);
1650   top_cell_->set_grad_is_running(false);
1651   top_input_args_info_ = top_cell_->input_args_info();
1652   MS_LOG(DEBUG) << "Eval run end";
1653 
1654   // Clear top cell resource
1655   top_cell_->ClearMetaGradInfo();
1656   // Func grad need to use auto grad meta in finish, so clear it after finish.
1657   AsyncClearAutoGradCell(top_cell_);
1658   ClearGradRes();
1659 
1660   // For custom nested grad, we need to resume grad info when finish custom grad.
1661   if (top_cell_ != nullptr) {
1662     top_cell_->ResumeMetaGradInfo();
1663   }
1664   return BaseRefToPyData(grads);
1665 }
1666 
RunGradGraph()1667 py::object GradExecutor::RunGradGraph() {
1668   MS_EXCEPTION_IF_NULL(top_input_args_info_);
1669   MS_EXCEPTION_IF_NULL(top_cell_);
1670   const auto &resource = top_cell_->resource();
1671   MS_EXCEPTION_IF_NULL(resource);
1672   MS_LOG(DEBUG) << "Run top cell " << top_cell_ << " and its shadow top cell " << top_cell_->shadow_top_cell();
1673   VectorRef arg_list;
1674   SetGraphInputArgs(top_input_args_info_->input_arg_value_vec, resource, top_cell_->initial_graph_param_size(),
1675                     top_input_args_info_->sens_type, &arg_list);
1676   MS_LOG(DEBUG) << "Convert args size " << top_input_args_info_->input_arg_value_vec.size() << ", graph param size "
1677                 << arg_list.size();
1678 
1679   auto context = MsContext::GetInstance();
1680   MS_EXCEPTION_IF_NULL(context);
1681   context->SetJitLevel(kAttrJitLevelO0);
1682 
1683   compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
1684   MS_EXCEPTION_IF_NULL(run);
1685 
1686   MS_LOG(DEBUG) << "Eval run " << MsContext::GetInstance()->backend_policy();
1687   top_cell_->set_grad_is_running(true);
1688   BaseRef out_value = (*run)(arg_list);
1689   MS_EXCEPTION_IF_NULL(top_cell_);
1690   top_cell_->set_grad_is_running(false);
1691   top_input_args_info_ = top_cell_->input_args_info();
1692   MS_LOG(DEBUG) << "Eval run end";
1693 
1694   // Do high-order grad
1695   MakeNestedCnode(top_input_args_info_->has_custom_bprop, top_input_args_info_->input_arg_value_vec,
1696                   resource->optimize_graph(), out_value);
1697 
1698   // For custom nested grad, we need to resume grad info when finish custom grad.
1699   if (top_cell_ != nullptr) {
1700     top_cell_->ResumeMetaGradInfo();
1701   }
1702   return BaseRefToPyData(out_value);
1703 }
1704 
MakeNestedCnode(bool has_custom_bprop,const std::vector<ValuePtr> & forward_args,const FuncGraphPtr & cur_run_bprop_graph,const BaseRef & out)1705 void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<ValuePtr> &forward_args,
1706                                    const FuncGraphPtr &cur_run_bprop_graph, const BaseRef &out) {
1707   MS_EXCEPTION_IF_NULL(top_input_args_info_);
1708   if (top_input_args_info_->is_grad_topest_cell) {
1709     MS_LOG(DEBUG) << "No nested grad find";
1710     MS_EXCEPTION_IF_NULL(top_cell_);
1711     top_cell_->ClearMetaGradInfo();
1712     ClearGradRes();
1713     return;
1714   }
1715   MS_LOG(DEBUG) << "Do high grad";
1716   // first_grad_fg maybe modified in auto grad, and first_grad_fg can be used multiple times
1717   auto first_grad_fg = cur_run_bprop_graph;
1718   MS_LOG(DEBUG) << "Current top cell ptr " << top_cell().get() << " and its shadow top cell "
1719                 << top_cell_->shadow_top_cell();
1720   top_cell_->set_is_finish_backward(true);
1721   if (has_custom_bprop) {
1722     first_grad_fg = curr_g();
1723     // Bprop top cell just used for getting forward graph
1724     top_cell_ = PopTopCellStack();
1725     MS_LOG(DEBUG) << "Bprop nested, after get bprop forward graph, current top cell ptr " << top_cell().get();
1726   } else {
1727     RestoreBpropGraphParameter(cur_run_bprop_graph, top_cell()->initial_graph_param_size());
1728   }
1729 
1730   MS_EXCEPTION_IF_NULL(first_grad_fg);
1731   PyNativeAlgo::Common::DumpGraphIR("first_grad_fg.ir", first_grad_fg);
1732   ValuePtrList weights_args;
1733   const std::string cur_top_cell_id = top_cell()->obj_id_with_grad_order();
1734   bool use_dynamic_shape_process = top_cell()->use_dynamic_shape_process() || top_cell()->vm_compile();
1735   bool has_call_graph = top_cell()->has_call_graph();
1736   auto inner_graph_info = top_cell()->graph_info_map().at(curr_g());
1737   SwitchTopCell();
1738   auto op_run_info = std::make_shared<FrontendOpRunInfo>();
1739   op_run_info->requires_grad = true;
1740   op_run_info->op_grad_info->input_value = forward_args;
1741   op_run_info->input_size = forward_args.size();
1742   auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out, true, true);
1743   // Get output values
1744   if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
1745     std::vector<ValuePtr> out_v{out_value};
1746     out_value = std::make_shared<ValueTuple>(out_v);
1747   }
1748   MS_EXCEPTION_IF_NULL(out_value);
1749   RecordNestedGraph(first_grad_fg, inner_graph_info, forward_args, out_value);
1750 
1751   // Get input values
1752   PyNativeAlgo::Common::SetGraphInputAndWeightsInfo(op_run_info, first_grad_fg, top_cell());
1753   auto grad_fg = first_grad_fg;
1754   if (has_call_graph) {
1755     auto r = std::make_shared<pipeline::Resource>();
1756     jit()->set_eliminate_forward(false);
1757     (void)first_grad_fg->transforms().erase(kGrad);
1758     auto opt = opt::Optimizer::MakeEmptyOptimizer(r);
1759     opt->set_is_first_order_j(false);
1760     grad_fg = ad::Grad(first_grad_fg, opt);
1761     jit()->set_eliminate_forward(true);
1762   }
1763   auto op_grad_info = std::make_shared<OpGradInfo>();
1764   op_grad_info->input_value = op_run_info->op_grad_info->input_value;
1765   op_grad_info->input_abs = op_run_info->op_grad_info->input_abs;
1766   op_grad_info->out_value = out_value;
1767   op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_grad_info->out_value);
1768   op_grad_info->out_abs = first_grad_fg->output()->abstract();
1769   op_grad_info->input_value_grad_type = op_run_info->op_grad_info->input_value_grad_type;
1770   auto grad_param = std::make_shared<GradParam>(op_grad_info, use_dynamic_shape_process);
1771   grad_param->fg = grad_fg;
1772   grad_param->source_fg = first_grad_fg;
1773   grad_param->is_control_flow = has_call_graph;
1774   // If fun grad and ir grad use the same ad grad graph(hit cache), dout will occur wrong by different type(tuple or
1775   // plant tuple)
1776   grad_param->graph_cache_key = cur_top_cell_id + std::to_string(top_cell()->is_ir_grad());
1777   if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
1778     MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph ";
1779   }
1780   top_cell()->set_need_do_final_opt(true);
1781 }
1782 
DoParameterReplace(const FuncGraphPtr & first_grad_fg,const GraphInfoPtr & inner_graph_info,const std::vector<ValuePtr> & forward_args,AnfNodePtrList * inputs)1783 void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
1784                                       const std::vector<ValuePtr> &forward_args, AnfNodePtrList *inputs) {
1785   MS_EXCEPTION_IF_NULL(inner_graph_info);
1786   auto outer_graph_info = top_cell()->graph_info_map().at(curr_g());
1787   MS_EXCEPTION_IF_NULL(outer_graph_info);
1788   for (const auto &forward_arg : forward_args) {
1789     const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_arg);
1790     const auto it = outer_graph_info->input_params.find(id);
1791     if (it != outer_graph_info->input_params.end()) {
1792       // Can find in outer graph
1793       MS_LOG(DEBUG) << "Replace input param id " << id;
1794       // Replace inner graph param by outer graph param
1795       (void)inputs->emplace_back(it->second);
1796     } else {
1797       MS_LOG(DEBUG) << "Can't find input param id " << id;
1798       // Inner graph input param not find in outer graph, need add to outer graph
1799       (void)inputs->emplace_back(GetInput(forward_arg, id));
1800     }
1801   }
1802   mindspore::HashSet<std::string> inner_graph_used_weights_set;
1803   // Weight in inner graph
1804   const auto &fir_graph_parameters = first_grad_fg->parameters();
1805   for (const auto &param : fir_graph_parameters) {
1806     auto weight_tensor = PyNativeAlgo::Common::GetTensorFromParam(param);
1807     if (weight_tensor != nullptr) {
1808       (void)inner_graph_used_weights_set.emplace(weight_tensor->id());
1809     }
1810   }
1811   for (const auto &weight : inner_graph_info->weight_params) {
1812     // If weight used in graph, but not need get grad by gradnet, it will be a valuenode, no need replace
1813     if (inner_graph_used_weights_set.find(weight.first) == inner_graph_used_weights_set.end()) {
1814       continue;
1815     }
1816     const auto it = outer_graph_info->weight_params.find(weight.first);
1817     if (it != outer_graph_info->weight_params.end()) {
1818       // Can find in outer graph
1819       MS_LOG(DEBUG) << "Replace weight param name " << weight.second->name() << ", id " << weight.first;
1820       (void)inputs->emplace_back(it->second);
1821     } else {
1822       MS_LOG(DEBUG) << "Can't find weight param name " << weight.second->name() << ", id " << weight.first;
1823       top_cell()->SetParamNodeMapInGraphInfoMap(weight.first, weight.second, true);
1824       (void)inputs->emplace_back(weight.second);
1825     }
1826   }
1827 }
1828 
SwitchTopCell()1829 void GradExecutor::SwitchTopCell() {
1830   ClearPipelineTopCellRes();
1831   // Get outer top cell
1832   auto outer_top_cell = PopTopCellStack();
1833   MS_EXCEPTION_IF_NULL(outer_top_cell);
1834   MS_LOG(DEBUG) << "Get outer top cell ptr " << outer_top_cell.get();
1835   // If inner graph compile graph, outer must be compile
1836   if (top_cell()->vm_compile()) {
1837     outer_top_cell->set_force_top_cell_compile(true);
1838     outer_top_cell->set_use_dynamic_shape_process(outer_top_cell->use_dynamic_shape_process() ||
1839                                                   top_cell()->use_dynamic_shape_process());
1840   }
1841   outer_top_cell->ResumeMetaGradInfo();
1842   set_top_cell(outer_top_cell);
1843 }
1844 
ClearGlobalRes() const1845 void GradExecutor::ClearGlobalRes() const {
1846   abstract::AnalysisContext::ClearContext();
1847   parse::data_converter::ClearObjectCache();
1848   parse::Parser::CleanParserResource();
1849   trace::ClearTraceStack();
1850   ad::CleanRes();
1851   pipeline::ReclaimOptimizer();
1852 }
1853 
ClearGradRes()1854 void GradExecutor::ClearGradRes() {
1855   MS_LOG(DEBUG) << "Top cell run finish " << top_cell_ << " and its shadow top cell " << top_cell_->shadow_top_cell();
1856   // Pop current top cell on stack
1857   if (!top_cell_->is_pipeline_top_cell()) {
1858     (void)PopTopCellStack();
1859   }
1860 
1861   if (!top_cell_stack_.empty() && top_cell_->is_pipeline_top_cell()) {
1862     MS_LOG(DEBUG) << "Top cell stack real running top cell " << top_cell_stack_.top();
1863     if (top_cell_stack_.top() == top_cell_) {
1864       MS_LOG(DEBUG) << "Pop pipeline top cell " << top_cell_stack_.top() << " from stack with input args id "
1865                     << top_cell_stack_.top()->input_args_id();
1866       (void)PopTopCellStack();
1867     }
1868   }
1869   auto has_higher_order = std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
1870                                       [](const auto &elem) { return elem.second->is_high_order_top_cell(); });
1871   // High order must not clean
1872   if (!has_higher_order) {
1873     top_cell_->ClearDeviceMemory();
1874   }
1875 
1876   top_cell_->input_args_info()->Reset();
1877   top_cell_->set_is_finish_backward(true);
1878   ClearPipelineTopCellRes();
1879   top_input_args_info_ = nullptr;
1880   ClearGlobalRes();
1881   MS_LOG(DEBUG) << "Current top cell stack size " << top_cell_stack_.size() << ", pipeline top cell map size "
1882                 << pipeline_top_cell_map_.size() << ", pipeline top cell map with already run cell id "
1883                 << top_cell_->already_run_cell_id() << " size "
1884                 << (pipeline_top_cell_map_.find(top_cell_->already_run_cell_id()) == pipeline_top_cell_map_.end()
1885                       ? 0
1886                       : pipeline_top_cell_map_[top_cell_->already_run_cell_id()].size());
1887   top_cell_ = nullptr;
1888   // Nested grad, get outer top cell if exist
1889   // Run top cell with bprop, and bprop has grad, after running inner grad, top cell should be restore
1890   if (!top_cell_stack_.empty()) {
1891     top_cell_ = top_cell_stack_.top();
1892     MS_LOG(DEBUG) << "Get outer top cell " << top_cell_ << " as the currently running top cell";
1893   }
1894 }
1895 
ClearPipelineTopCellRes()1896 void GradExecutor::ClearPipelineTopCellRes() {
1897   // Remove pipipe top cell from pipeline top cell map exclude the first one
1898   if (top_cell_->is_pipeline_top_cell()) {
1899     // Run second step and following step
1900     if (top_cell_->shadow_top_cell() != nullptr) {
1901       ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->shadow_top_cell()->input_args_id(), false);
1902       top_cell_->set_shadow_top_cell(nullptr);
1903     } else if (!top_cell_->is_ir_grad()) {
1904       // Pipeline top cell exclude the first top cell
1905       ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->input_args_id(), false);
1906     }
1907   } else {
1908     // If top cell is not pipeline, because it is stored in pipeline top cell map in first step, here need to do delete
1909     // from the map.
1910     ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->input_args_id(), true);
1911   }
1912   if (top_cell_->grad_first()) {
1913     DecreaseGradOrder();
1914   }
1915   grad_operation_.clear();
1916 }
1917 
ClearRes()1918 void GradExecutor::ClearRes() {
1919   MS_LOG(DEBUG) << "Clear grad res";
1920   WaitBpropTask();
1921   init_ = false;
1922   grad_flag_ = false;
1923   enable_grad_ = true;
1924   is_run_recompute_ = false;
1925   save_graphs_ = false;
1926   forward_use_dynamic_shape_process_ = false;
1927 
1928   kernel_graph_id_for_control_flow_ = UINT32_MAX;
1929   custom_bprop_cell_count_ = 0;
1930   grad_order_ = 0;
1931   op_num_in_bprop_graph_ = kDefaultContainerSize;
1932   grad_operation_.clear();
1933 
1934   top_cell_ = nullptr;
1935   top_input_args_info_ = nullptr;
1936   std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
1937   std::stack<TopCellInfoPtr>().swap(top_cell_stack_);
1938   already_run_top_cell_.clear();
1939   pipeline_top_cell_map_.clear();
1940   dynamic_inputs_cells_.clear();
1941   need_gc_top_cell_list_.clear();
1942   dynamic_shape()->Clear();
1943   jit()->Clear();
1944 }
1945 
AsyncClearTopCell()1946 void GradExecutor::AsyncClearTopCell() {
1947   for (const auto &need_gc_top_cell : need_gc_top_cell_list_) {
1948     if (forward()->enable_async()) {
1949       auto task = [need_gc_top_cell]() { need_gc_top_cell->Clear(); };
1950       DispatchGradQueueTask(std::move(task));
1951     } else {
1952       need_gc_top_cell->Clear();
1953     }
1954   }
1955   need_gc_top_cell_list_.clear();
1956 }
1957 
AsyncClearAutoGradCell(const TopCellInfoPtr & top_cell)1958 void GradExecutor::AsyncClearAutoGradCell(const TopCellInfoPtr &top_cell) {
1959   if (forward()->enable_async()) {
1960     auto task = [top_cell] { top_cell->set_auto_grad_cell_ptr(nullptr); };
1961     DispatchGradQueueTask(std::move(task));
1962   } else {
1963     top_cell->set_auto_grad_cell_ptr(nullptr);
1964   }
1965 }
1966 
WorkerJoin()1967 void GradExecutor::WorkerJoin() {
1968   bprop_queue_->WorkerJoin();
1969   assist_queue_->WorkerJoin();
1970 }
1971 
GetInput(const ValuePtr & v,const string & obj_id) const1972 AnfNodePtr GradExecutor::GetInput(const ValuePtr &v, const string &obj_id) const {
1973   // Is not a tensor
1974   AnfNodePtr node = GetNonTensorInput(v, obj_id);
1975   if (node != nullptr) {
1976     return node;
1977   }
1978   // Get param input
1979   node = GetParamInput(v, obj_id);
1980   if (node != nullptr) {
1981     return node;
1982   }
1983   // Get op output
1984   node = GetOutputNodeAsInput(obj_id);
1985   if (node != nullptr) {
1986     return node;
1987   }
1988   // A tuple returns in this case: x = op1, y = op2, return (x, y)
1989   // or a scalar or (scalar, tensor)
1990   node = GetValueSequenceInput(v);
1991   if (node != nullptr) {
1992     return node;
1993   }
1994   auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v);
1995   MS_LOG(DEBUG) << "Get input value node " << v_node->ToString() << ", id " << obj_id;
1996   return v_node;
1997 }
1998 
GetParamInput(const ValuePtr & v,const std::string & id) const1999 AnfNodePtr GradExecutor::GetParamInput(const ValuePtr &v, const std::string &id) const {
2000   const auto &graph_info = top_cell()->graph_info_map().at(curr_g());
2001   MS_EXCEPTION_IF_NULL(graph_info);
2002   // Get input param input
2003   const auto it = graph_info->input_params.find(id);
2004   if (it != graph_info->input_params.end()) {
2005     MS_LOG(DEBUG) << "Get input param " << id;
2006     return it->second;
2007   }
2008 
2009   // Get weight param input
2010   MS_EXCEPTION_IF_NULL(v);
2011   if (v->isa<tensor::BaseTensor>() && v->cast<tensor::BaseTensorPtr>()->is_parameter()) {
2012     const auto item_by_id = graph_info->weight_params.find(id);
2013     if (item_by_id != graph_info->weight_params.end()) {
2014       MS_LOG(DEBUG) << "Get weight param " << id;
2015       return item_by_id->second;
2016     }
2017     MS_LOG(DEBUG) << "Add new weight param " << id;
2018     auto tensor = v->cast<tensor::BaseTensorPtr>();
2019     const auto &param_info = tensor->param_info();
2020     MS_EXCEPTION_IF_NULL(param_info);
2021     const auto &param_name = param_info->name();
2022     // Add new weight param to graph info
2023     auto weight_param = curr_g()->add_parameter();
2024     weight_param->set_name(param_name);
2025     weight_param->debug_info()->set_name(param_name);
2026     weight_param->set_default_param(tensor);
2027     weight_param->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
2028     top_cell()->SetParamNodeMapInGraphInfoMap(id, weight_param, true);
2029     return weight_param;
2030   }
2031   return nullptr;
2032 }
2033 
GetOutputNodeAsInput(const std::string & obj_id) const2034 AnfNodePtr GradExecutor::GetOutputNodeAsInput(const std::string &obj_id) const {
2035   const auto &graph_info = top_cell()->graph_info_map().at(curr_g());
2036   MS_EXCEPTION_IF_NULL(graph_info);
2037   const auto it = graph_info->node_map.find(obj_id);
2038   if (it == graph_info->node_map.end()) {
2039     return nullptr;
2040   }
2041   // Single output CNode
2042   if (it->second.second.size() == 1 && it->second.second[0] == -1) {
2043     MS_LOG(DEBUG) << "Get input node " << it->second.first->ToString() << ", id " << obj_id;
2044     return it->second.first;
2045   }
2046   // Create tuple get item node for multiple output CNode
2047   return CreateTupleGetItemNode(obj_id, it->second);
2048 }
2049 
GetValueSequenceInput(const ValuePtr & v) const2050 AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v) const {
2051   MS_EXCEPTION_IF_NULL(v);
2052   if (!v->isa<ValueSequence>()) {
2053     return nullptr;
2054   }
2055   ValuePtrList input_args;
2056   abstract::AbstractBasePtrList abs_list;
2057   AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
2058   const auto &obj_tuple = v->cast<ValueSequencePtr>();
2059   const auto &v_list = obj_tuple->value();
2060   for (size_t i = 0; i < obj_tuple->size(); ++i) {
2061     const auto &v_arg = v_list[i];
2062     // Graph have no define for grad
2063     if (v_arg->isa<FuncGraph>()) {
2064       continue;
2065     }
2066     (void)input_args.emplace_back(v_arg);
2067     const std::string &id = PyNativeAlgo::Common::GetIdByValue(v_arg);
2068     (void)inputs.emplace_back(GetInput(v_arg, id));
2069     (void)abs_list.emplace_back(PyNativeAlgo::Common::SetAbstractValueToAnyValue(v_arg->ToAbstract()));
2070     (void)GetValueSequenceInput(v_arg);
2071   }
2072   // Create make tuple node and record to graph info map.
2073   auto cnode = curr_g()->NewCNode(inputs);
2074   cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
2075   MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
2076   return cnode;
2077 }
2078 
CreateTupleGetItemNode(const std::string & obj_id,const std::pair<AnfNodePtr,std::vector<int64_t>> & out) const2079 AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
2080                                                 const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const {
2081   AnfNodePtr c_node = out.first->cast<CNodePtr>();
2082   bool param_is_sequence = false;
2083   if (c_node == nullptr) {
2084     // Input param is tuple or list
2085     if (GetParamInput(MakeValue(true), obj_id) != nullptr) {
2086       MS_LOG(EXCEPTION) << "Get wrong input node " << out.first->DebugString();
2087     }
2088     param_is_sequence = true;
2089     c_node = out.first;
2090   }
2091   MS_LOG(DEBUG) << "Sequence input node " << c_node->DebugString() << ", id " << obj_id << ", out second "
2092                 << out.second;
2093   // Create tuple get item node
2094   auto abs = c_node->abstract();
2095   MS_EXCEPTION_IF_NULL(abs);
2096   for (auto idx : out.second) {
2097     AnfNodePtrList tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), c_node, NewValueNode(idx)};
2098     c_node = curr_g()->NewCNode(tuple_get_item_inputs);
2099     if (!abs->isa<abstract::AbstractSequence>()) {
2100       MS_LOG(EXCEPTION) << "Input node abs is not sequence " << abs->ToString();
2101     }
2102     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2103     if (static_cast<size_t>(idx) >= abs_seq->size()) {
2104       MS_LOG(EXCEPTION) << "Index exceeds the size of elements. Index " << idx << ", element size " << abs_seq->size();
2105     }
2106     abs = abs_seq->elements()[static_cast<size_t>(idx)];
2107     MS_EXCEPTION_IF_NULL(abs);
2108     c_node->set_abstract(abs);
2109     if (param_is_sequence) {
2110       c_node->set_user_data(kParamterIsSequence, MakeValue(param_is_sequence));
2111     }
2112   }
2113   MS_LOG(DEBUG) << "Create tuple getitem node " << c_node->DebugString() << ", abs " << c_node->abstract()->ToString();
2114   return c_node;
2115 }
2116 
GetTopCell(const std::string & already_run_cell_id,const std::string & input_args_id)2117 TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id, const std::string &input_args_id) {
2118   TopCellInfoPtr find_top_cell = nullptr;
2119   for (const auto &[cell_id, top_cell] : already_run_top_cell_) {
2120     MS_EXCEPTION_IF_NULL(top_cell);
2121     MS_LOG(DEBUG) << "Top cell " << top_cell << " with already run cell id " << cell_id << ", input args id "
2122                   << top_cell->input_args_id();
2123     // Complete match, means run grad operation first
2124     if (top_cell->already_run_cell_id() == already_run_cell_id) {
2125       find_top_cell = top_cell;
2126       break;
2127     }
2128     // Partial match, means run forward first without grad_operation in already run cell id
2129     if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
2130         top_cell->already_run_cell_id().back() == '_') {
2131       find_top_cell = top_cell;
2132       break;
2133     }
2134     // Partial match, means run grad first, but follow a other net grad
2135     if (top_cell->already_run_cell_id().find(already_run_cell_id) != std::string::npos &&
2136         already_run_cell_id.back() == '_') {
2137       find_top_cell = top_cell;
2138       break;
2139     }
2140   }
2141 
2142   // Get pipeline top cell
2143   if (find_top_cell == nullptr) {
2144     MS_LOG(DEBUG) << "Not find in already run top cell map, try find in pipeline top cell map";
2145     find_top_cell = GetPipelineTopCell(already_run_cell_id, input_args_id, already_run_cell_id.back() == '_');
2146   } else if (find_top_cell->is_pipeline_top_cell()) {
2147     // Delete first pipeline top from already run top cell map
2148     (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
2149     if (find_top_cell->input_args_id() != input_args_id) {
2150       MS_LOG(DEBUG) << "Find top cell input args id " << find_top_cell->input_args_id()
2151                     << " not match current input args id " << input_args_id << ", try find in pipeline top cell map";
2152       find_top_cell = GetPipelineTopCell(already_run_cell_id, input_args_id, already_run_cell_id.back() == '_');
2153     }
2154   }
2155 
2156   // Same topcell info, but grad operation is not the same, construct backward graph again
2157   if (find_top_cell != nullptr) {
2158     if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
2159       MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
2160                     << grad_operation_;
2161       (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
2162       return nullptr;
2163     }
2164     return find_top_cell;
2165   }
2166   return nullptr;
2167 }
2168 
SetHookChanged(const py::object & cell) const2169 void GradExecutor::SetHookChanged(const py::object &cell) const {
2170   if (top_cell_ == nullptr) {
2171     return;
2172   }
2173   const auto &cell_id = PyNativeAlgo::PyParser::GetIdByPyObj(cell);
2174   if (top_cell_->cell_id().find(cell_id) != std::string::npos) {
2175     top_cell_->set_hook_changed(true);
2176   }
2177   if (RequiresGrad()) {
2178     top_cell_->set_sub_cell_hook_changed(cell_id);
2179   }
2180 }
2181 
ProcessOpGradInfo(const FrontendOpRunInfoPtr & op_run_info) const2182 void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
2183   MS_EXCEPTION_IF_NULL(op_run_info);
2184   RecordForwardGraph(op_run_info);
2185   if (top_cell_->is_bprop_need_get_forward_graph()) {
2186     MS_LOG(DEBUG) << "Just need forward graph";
2187     return;
2188   }
2189   DoOpGrad(op_run_info);
2190   top_cell()->GetOpInfo(op_run_info, false);
2191   UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
2192                                              op_run_info->base_op_run_info.stream_id);
2193   DynamicDetectNodeInfoPtr node_info;
2194   if (op_run_info->op_grad_info->output_value_simple_info != nullptr) {
2195     node_info = std::make_shared<DynamicDetectNodeInfo>(op_run_info->op_grad_info->op_prim);
2196   } else {
2197     node_info = std::make_shared<DynamicDetectNodeInfo>(
2198       op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_abs, op_run_info->op_grad_info->out_abs);
2199   }
2200   CheckBpropCutNode(top_cell(), op_run_info->op_grad_info->op_prim);
2201   (void)dynamic_shape()->CheckNodeDynamic(top_cell(), op_run_info->op_grad_info->input_value, node_info);
2202 }
2203 
SaveOutputNodeMap(const std::string & obj_id,const FrontendOpRunInfoPtr & op_run_info,const CNodePtr & cnode) const2204 void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
2205                                      const CNodePtr &cnode) const {
2206   MS_EXCEPTION_IF_NULL(cnode);
2207   MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << ", out value id " << obj_id;
2208   // In hook compute, output is a copy of input; If hook input is a input param, follow op use hook output as input,
2209   // which GetInput will always find input param, so need delete from input param map
2210   MS_EXCEPTION_IF_NULL(op_run_info);
2211   if (op_run_info->run_in_vm && kHookOp.find(op_run_info->base_op_run_info.op_name) != kHookOp.end()) {
2212     for (size_t i = 0; i < op_run_info->input_size; ++i) {
2213       top_cell()->DeleteParamNodeInfo(curr_g(), op_run_info->input_value_id[i]);
2214     }
2215   }
2216   top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode);
2217 }
2218 
2219 // Run ad grad for curr op and connect grad graph with previous op
DoOpGrad(const FrontendOpRunInfoPtr & op_run_info) const2220 void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info) const {
2221   MS_EXCEPTION_IF_NULL(op_run_info);
2222   auto &&grad_param = CreateOpGradParam(op_run_info, top_cell());
2223   if (forward()->enable_async()) {
2224     auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
2225     auto task = [auto_grad_cell_ptr, grad_param]() { (void)auto_grad_cell_ptr->KPynativeOp(grad_param); };
2226     DispatchGradQueueTask(std::move(task));
2227   } else {
2228     (void)top_cell()->auto_grad_cell_ptr()->KPynativeOp(grad_param);
2229   }
2230 }
2231 
UpdateTopCellForwardTensorInfoInBpropGraph(const std::string & op_info,const ValuePtr & v,const size_t & stream_id) const2232 void GradExecutor::UpdateTopCellForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v,
2233                                                               const size_t &stream_id) const {
2234   auto pre_top_cell = GetAlreadyRunTopCell(top_cell()->already_run_cell_id());
2235   // The shape of the last two steps is the same, and the pre_top_cell is not empty.
2236   // But if a dynamic shape is enabled at this point, you still need to execute SaveTensorIdWithOpInfo.
2237   if (pre_top_cell == nullptr || use_dynamic_shape_process()) {
2238     // First run top cell, save op output info for replacement
2239     pre_top_cell = GetPipelineRunTopCell(top_cell_->already_run_cell_id());
2240     if (pre_top_cell == nullptr) {
2241       top_cell_->SaveTensorIdWithOpInfo(op_info, v);
2242       use_dynamic_shape_process() ? MS_LOG(DEBUG) << "Current top cell is in dynamic process"
2243                                   : MS_LOG(DEBUG)
2244                                       << "Top cell " << top_cell_->already_run_cell_id() << " run firstly, op info "
2245                                       << op_info << ", output id " << PyNativeAlgo::Common::GetIdByValue(v);
2246       return;
2247     }
2248   }
2249 
2250   // In dynamic process, no need replaces
2251   if (top_cell_->use_dynamic_shape_process()) {
2252     return;
2253   }
2254 
2255   // top cell is pipeline top cell
2256   if (top_cell_->is_pipeline_top_cell()) {
2257     if (pre_top_cell->is_finish_backward()) {
2258       // Not first run top cell, do update; Save op_info -> tensor
2259       if (!pre_top_cell->is_ir_grad()) {
2260         MS_LOG(EXCEPTION) << "Forward repalce top cell must be ir grad";
2261       }
2262       MS_LOG(DEBUG) << "Store pipeline top cell " << top_cell_->already_run_cell_id() << " with input args id "
2263                     << top_cell_->input_args_id() << ", op info " << op_info << ", output id "
2264                     << PyNativeAlgo::Common::GetIdByValue(v);
2265       StoreForwardOutputWithOpInfo(pre_top_cell->replace_info().op_info_with_tensor_object, op_info, v,
2266                                    &top_cell_->replace_info());
2267     } else {
2268       // The first ir grad is not run before, indicate this is a first step, top cell will run func grad independently
2269       MS_LOG(DEBUG) << "Current top cell is pipeline top cell and run firstly, op info " << op_info << ", output id "
2270                     << PyNativeAlgo::Common::GetIdByValue(v);
2271     }
2272   } else {
2273     // Not first run top cell, do update
2274     MS_LOG(DEBUG) << "Update top cell forward output tensor info " << op_info << ", output id "
2275                   << PyNativeAlgo::Common::GetIdByValue(v);
2276     UpdateForwardOutputTensorInfo(op_info, v, pre_top_cell->replace_info());
2277   }
2278 }
2279 
GetRealInputNodeBySkipHook(const AnfNodePtr & input_node) const2280 AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const {
2281   if (input_node == nullptr) {
2282     MS_LOG(DEBUG) << "The input node is nullptr.";
2283     return input_node;
2284   }
2285   const auto &cell_backward_hook_op = top_cell()->cell_backward_hook_op();
2286   for (const auto &elem : cell_backward_hook_op) {
2287     constexpr size_t cell_backward_hook_num = 2;
2288     if (elem.second.size() < cell_backward_hook_num) {  // In cell own scope, no need to skip backward hook op.
2289       continue;
2290     }
2291     // The input node is the first backward hook op of another cell, skip the backward hook op.
2292     if (IsPrimitiveCNode(input_node, prim::kPrimCellBackwardHook) && input_node == elem.second[0]) {
2293       // Single input.
2294       auto backward_hook_op = input_node->cast<CNodePtr>();
2295       MS_EXCEPTION_IF_NULL(backward_hook_op);
2296       return backward_hook_op->input(1);
2297     } else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
2298       // Multi inputs.
2299       auto tuple_get_item = input_node->cast<CNodePtr>();
2300       MS_EXCEPTION_IF_NULL(tuple_get_item);
2301       auto inp_in_tuple = tuple_get_item->input(1);
2302       MS_EXCEPTION_IF_NULL(inp_in_tuple);
2303       if (IsPrimitiveCNode(inp_in_tuple, prim::kPrimCellBackwardHook) && inp_in_tuple == elem.second[0]) {
2304         constexpr size_t idx = 2;
2305         auto idx_node = tuple_get_item->input(idx);
2306         MS_EXCEPTION_IF_NULL(idx_node);
2307         auto value_node = idx_node->cast<ValueNodePtr>();
2308         MS_EXCEPTION_IF_NULL(value_node);
2309         auto out_idx = GetValue<int64_t>(value_node->value());
2310         auto backward_hook_op = inp_in_tuple->cast<CNodePtr>();
2311         MS_EXCEPTION_IF_NULL(backward_hook_op);
2312         return backward_hook_op->input(1 + LongToSize(out_idx));
2313       }
2314     }
2315   }
2316   return input_node;
2317 }
2318 
ConstructForwardGraph(const FrontendOpRunInfoPtr & op_run_info) const2319 CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const {
2320   MS_EXCEPTION_IF_NULL(op_run_info);
2321   AnfNodePtrList inputs;
2322   (void)inputs.emplace_back(NewValueNode(op_run_info->op_grad_info->op_prim));
2323   for (size_t i = 0; i < op_run_info->input_size; i++) {
2324     AnfNodePtr input_node = nullptr;
2325     const auto node = GetInput(op_run_info->op_grad_info->input_value[i], op_run_info->input_value_id[i]);
2326     input_node = GetRealInputNodeBySkipHook(node);
2327     // update abstract
2328     if (input_node != nullptr) {
2329       (void)inputs.emplace_back(input_node);
2330     }
2331   }
2332   const auto &cnode = curr_g()->NewCNodeInOrder(inputs);
2333   if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
2334     top_cell()->RecordCellBackwardHookOp(op_run_info->cell_obj_id, cnode);
2335   }
2336   MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
2337                 << cnode->DebugString();
2338   return cnode;
2339 }
2340 
RecordForwardGraph(const FrontendOpRunInfoPtr & op_run_info) const2341 void GradExecutor::RecordForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const {
2342   if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
2343     MS_EXCEPTION_IF_NULL(op_run_info);
2344     if (op_run_info->input_value_id.empty()) {
2345       (void)std::transform(op_run_info->op_grad_info->input_value.begin(), op_run_info->op_grad_info->input_value.end(),
2346                            std::back_inserter(op_run_info->input_value_id),
2347                            [](const ValuePtr &value) { return PyNativeAlgo::Common::GetIdByValue(value); });
2348     }
2349     if (op_run_info->out_value_id.empty()) {
2350       op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
2351     }
2352     const auto &cnode = ConstructForwardGraph(op_run_info);
2353     MS_EXCEPTION_IF_NULL(cnode);
2354     // By simple infer, abstract is nullptr
2355     if (op_run_info->base_op_run_info.abstract == nullptr) {
2356       cnode->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->real_out->ToAbstract()));
2357     } else {
2358       cnode->set_abstract(op_run_info->base_op_run_info.abstract);
2359     }
2360     SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
2361   }
2362 }
2363 
RecordForwardGraphForInput(const ValuePtr & value,const string & input_id,const abstract::AbstractBasePtr & param_abs)2364 void GradExecutor::RecordForwardGraphForInput(const ValuePtr &value, const string &input_id,
2365                                               const abstract::AbstractBasePtr &param_abs) {
2366   save_graphs_ = MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
2367   if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
2368     auto new_param = curr_g()->add_parameter();
2369     new_param->set_abstract(param_abs);
2370     if (value->isa<ValueSequence>()) {
2371       top_cell()->SetNodeMapInGraphInfoMap(input_id, new_param, true);
2372     }
2373     top_cell()->SetParamNodeMapInGraphInfoMap(input_id, new_param);
2374   }
2375 }
2376 
RecordNestedGraph(const FuncGraphPtr & first_grad_fg,const GraphInfoPtr & inner_graph_info,const std::vector<ValuePtr> & forward_args,const ValuePtr & out)2377 void GradExecutor::RecordNestedGraph(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
2378                                      const std::vector<ValuePtr> &forward_args, const ValuePtr &out) {
2379   if (save_graphs_) {
2380     AnfNodePtrList inputs{NewValueNode(first_grad_fg)};
2381     DoParameterReplace(first_grad_fg, inner_graph_info, forward_args, &inputs);
2382     auto cnode = curr_g()->NewCNode(inputs);
2383     auto out_id = PyNativeAlgo::Common::GetIdByValue(out);
2384     top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
2385     cnode->set_abstract(first_grad_fg->output()->abstract());
2386     MS_LOG(DEBUG) << "Nested make cnode is: " << cnode->DebugString() << ", out id " << out_id;
2387   }
2388 }
2389 
SetBpropGraphJitLevel(const py::object & obj) const2390 void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
2391   if (!py::hasattr(obj, kAttrCellJitConfigDict)) {
2392     return;
2393   }
2394 
2395   auto jit_config = py::getattr(obj, kAttrCellJitConfigDict);
2396   if (!py::isinstance<py::dict>(jit_config)) {
2397     MS_LOG(EXCEPTION) << "JitConfig only support dict!";
2398   }
2399   auto jit_config_dict = jit_config.cast<py::dict>();
2400   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
2401   MS_EXCEPTION_IF_NULL(graph_executor);
2402   graph_executor->SetJitConfig(jit_config_dict);
2403 }
2404 
SaveDynamicInputsCells(const py::object & obj,const py::args & args)2405 void GradExecutor::SaveDynamicInputsCells(const py::object &obj, const py::args &args) {
2406   const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
2407   MS_LOG(INFO) << "SaveDynamicInputsCells: "
2408                << (py::isinstance<Cell>(obj) ? obj_id + " " + obj.cast<CellPtr>()->ToString()
2409                                              : py::getattr(obj, "__name__").cast<std::string>());
2410   (void)dynamic_inputs_cells_.insert(obj_id);
2411 }
2412 
SetTopCellDynamicAttr(const py::object & cell)2413 void GradExecutor::SetTopCellDynamicAttr(const py::object &cell) {
2414   if (top_cell_ == nullptr) {
2415     return;
2416   }
2417 
2418   if (top_cell()->use_dynamic_shape_process()) {
2419     // Top cell is already dynamic, no need to set again.
2420     return;
2421   }
2422   top_cell()->set_use_dynamic_shape_process(dynamic_inputs_cells_.count(PyNativeAlgo::PyParser::GetIdByPyObj(cell)));
2423 }
2424 
DispatchGradQueueTask(std::function<void (void)> && task) const2425 void GradExecutor::DispatchGradQueueTask(std::function<void(void)> &&task) const {
2426   if (!bprop_queue_->Push(new (std::nothrow) BpropTask(std::move(task)))) {
2427     bprop_queue_->CheckException();
2428   }
2429 }
2430 
ClearBpropTask() const2431 void GradExecutor::ClearBpropTask() const {
2432   if (bprop_queue_ != nullptr) {
2433     GilReleaseWithCheck gil_release;
2434     bprop_queue_->Clear();
2435     assist_queue_->Clear();
2436     bprop_queue_->CheckException();
2437   }
2438 }
2439 
WaitBpropTask() const2440 void GradExecutor::WaitBpropTask() const {
2441   if (bprop_queue_ != nullptr) {
2442     GilReleaseWithCheck gil_release;
2443     bprop_queue_->Wait();
2444     assist_queue_->Wait();
2445     bprop_queue_->CheckException();
2446   }
2447 }
2448 
DispatchAssistQueueTask(std::function<void (void)> task) const2449 void GradExecutor::DispatchAssistQueueTask(std::function<void(void)> task) const {
2450   bool success = assist_queue_->Push(new (std::nothrow) BpropTask(std::move(task)));
2451   if (!success) {
2452     assist_queue_->CheckException();
2453   }
2454 }
2455 
ChildAfterFork()2456 void GradExecutor::ChildAfterFork() {
2457   MS_LOG(DEBUG) << "GradExecutor reinitialize after fork.";
2458   if (bprop_queue_ != nullptr) {
2459     MS_LOG(DEBUG) << "Reinitialize bprop_queue_.";
2460     bprop_queue_->ChildAfterFork();
2461   }
2462   if (assist_queue_ != nullptr) {
2463     MS_LOG(DEBUG) << "Reinitialize assist_queue_.";
2464     assist_queue_->ChildAfterFork();
2465   }
2466   runtime::PyBoostOpExecute::GetInstance().ClearBackend();
2467   MS_LOG(DEBUG) << "GradExecutor reinitialize after fork done.";
2468 }
2469 }  // namespace pynative
2470 }  // namespace mindspore
2471