1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/grad.h"
18 #include <algorithm>
19 #include "ops/conv_pool_op_name.h"
20 #include "ops/nn_op_name.h"
21 #include "ops/math_op_name.h"
22 #include "ops/sequence_ops.h"
23 #include "ops/framework_ops.h"
24 #include "pipeline/pynative/grad/top_cell.h"
25 #include "pipeline/pynative/grad/function/func_grad.h"
26 #include "pipeline/pynative/grad/ir/ir_grad.h"
27 #include "pipeline/pynative/pynative_utils.h"
28 #include "pipeline/jit/ps/pipeline.h"
29 #include "ir/cell.h"
30 #include "ir/func_graph_cloner.h"
31 #include "pipeline/jit/ps/parse/data_converter.h"
32 #include "pipeline/jit/ps/debug/trace.h"
33 #include "include/backend/optimizer/helper.h"
34 #include "include/common/utils/convert_utils_py.h"
35 #include "frontend/optimizer/ad/grad.h"
36 #include "frontend/optimizer/environ_conversion.h"
37 #include "pipeline/jit/ps/pass.h"
38 #include "pybind_api/gil_scoped_long_running.h"
39 #include "frontend/optimizer/fallback_rewriter.h"
40 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
41 #include "runtime/pynative/op_executor.h"
42
43 namespace mindspore {
44 namespace pynative {
45 namespace {
46 const mindspore::HashSet<std::string> kHookOp = {"HookBackward", "CellBackwardHook"};
47 constexpr char kGrad[] = "grad";
48 constexpr auto kNeedRecompute = "is_cell_recompute";
49 constexpr auto kInternalParams = "internal_params";
50 constexpr auto kUsedBpropInputs = "used_bprop_inputs";
51 constexpr size_t kContainerRatio = 2;
52
ParsePyArgsToInputArgsInfo(const InputArgsInfoPtr & input_args_info,const py::object & obj,const py::args & args,bool is_bprop_need_get_forward_graph)53 void ParsePyArgsToInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::args &args,
54 bool is_bprop_need_get_forward_graph) {
55 MS_EXCEPTION_IF_NULL(input_args_info);
56 input_args_info->has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
57 MS_LOG(DEBUG) << "Cell has custom bprop " << input_args_info->has_custom_bprop;
58 bool is_top_cell = input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell;
59 if (is_top_cell) {
60 pipeline::CheckArgsValid(obj, args);
61 }
62 // Only the top cell or custom bprop cell requires value conversion
63 if (is_top_cell || input_args_info->has_custom_bprop || is_bprop_need_get_forward_graph) {
64 input_args_info->obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
65 input_args_info->input_size = args.size();
66 for (size_t i = 0; i < input_args_info->input_size; ++i) {
67 const auto &id = PyNativeAlgo::PyParser::GetIdByPyObj(args[i]);
68 (void)input_args_info->input_arg_id_vec.emplace_back(id);
69 }
70 const auto &forward = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor();
71 for (size_t i = 0; i < input_args_info->input_size; ++i) {
72 input_args_info->input_args_id += input_args_info->input_arg_id_vec[i] + "_";
73 // Get arg value
74 if (py::isinstance<py::list>(args[i])) {
75 (void)input_args_info->input_arg_value_vec.emplace_back(
76 PyNativeAlgo::DataConvert::PyObjToValue(py::cast<py::tuple>(args[i])));
77 } else {
78 (void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
79 }
80
81 // Get arg abstract
82 auto abs = forward->GetNodeAbsById(input_args_info->input_arg_id_vec[i]);
83 if (abs == nullptr) {
84 abs = input_args_info->input_arg_value_vec.back()->ToAbstract();
85 }
86 (void)input_args_info->input_arg_base_shape_vec.emplace_back(abs->BuildShape());
87 }
88 input_args_info->cell_id = PyNativeAlgo::Common::GetCellId(
89 input_args_info->obj_id, input_args_info->input_arg_id_vec, input_args_info->input_arg_value_vec);
90 MS_LOG(DEBUG) << "Cell_id is " << input_args_info->cell_id << ", is grad topest cell "
91 << input_args_info->is_grad_topest_cell << ", is high order top cell "
92 << input_args_info->is_high_order_top_cell << ", is bprop need get forward graph "
93 << is_bprop_need_get_forward_graph;
94 }
95 }
96
GetNonTensorInput(const ValuePtr & v,const std::string & obj_id)97 AnfNodePtr GetNonTensorInput(const ValuePtr &v, const std::string &obj_id) {
98 MS_EXCEPTION_IF_NULL(v);
99 bool is_value_seq = v->isa<ValueSequence>();
100 bool is_single_non_tensor = !is_value_seq && !PyNativeAlgo::Common::IsTensor(v);
101 bool mixed_tensor = true;
102 if (is_value_seq) {
103 const auto &v_seq = v->cast<ValueSequencePtr>();
104 mixed_tensor = std::any_of(v_seq->value().begin(), v_seq->value().end(),
105 [](const ValuePtr &e) { return PyNativeAlgo::Common::IsTensor(e, true); });
106 }
107 if (is_single_non_tensor || !mixed_tensor) {
108 auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v);
109 MS_LOG(DEBUG) << "Get input value node " << v_node->ToString() << ", id " << obj_id;
110 return v_node;
111 }
112 return nullptr;
113 }
114
ConvertOutputValueToTensor(const ValuePtr & v,bool dict_convert_to_tuple)115 ValuePtr ConvertOutputValueToTensor(const ValuePtr &v, bool dict_convert_to_tuple) {
116 MS_EXCEPTION_IF_NULL(v);
117 if (PyNativeAlgo::Common::IsTensor(v, true)) {
118 return v;
119 }
120 if (v->isa<ValueSequence>()) {
121 auto v_seq = v->cast<ValueSequencePtr>();
122 if (v_seq->size() == 0) {
123 MS_LOG(EXCEPTION) << "Get empty value seq";
124 }
125 // All value are tensor
126 if (std::all_of(v_seq->value().begin(), v_seq->value().end(),
127 [](const ValuePtr &e) { return PyNativeAlgo::Common::IsTensor(e, true); })) {
128 MS_LOG(DEBUG) << "All output value is tensor";
129 return v;
130 }
131 MS_LOG(DEBUG) << "Output is value sequence, but have tensor and other type mixed. Its value is " << v->ToString();
132 return PyNativeAlgo::Common::FilterSensValues(v, dict_convert_to_tuple);
133 }
134 if (v->isa<FloatImm>()) {
135 double input_value = v->cast<FP32ImmPtr>()->value();
136 return std::make_shared<tensor::Tensor>(input_value, kFloat32);
137 }
138 if (v->isa<BoolImm>()) {
139 return std::make_shared<tensor::Tensor>(v->cast<BoolImmPtr>()->value(), kBool);
140 }
141 if (v->isa<IntegerImm>()) {
142 int64_t input = v->cast<Int64ImmPtr>()->value();
143 return std::make_shared<tensor::Tensor>(input, kInt64);
144 }
145 if (v->isa<ValueDictionary>() && dict_convert_to_tuple) {
146 MS_LOG(DEBUG) << "Get dict value";
147 return PyNativeAlgo::DataConvert::ConvertValueDictToValueTuple(v);
148 }
149 MS_LOG(DEBUG) << "Output is " << v->ToString() << ", abstract "
150 << PyNativeAlgo::Common::SetAbstractValueToAnyValue(v->ToAbstract());
151 return v;
152 }
153
BpropGraphFinalOpt(const FuncGraphPtr & bprop_graph,bool has_control_flow)154 FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph, bool has_control_flow) {
155 MS_LOG(DEBUG) << "Do bprop graph final opt";
156 MS_EXCEPTION_IF_NULL(bprop_graph);
157 auto resource = std::make_shared<pipeline::Resource>();
158 resource->set_func_graph(bprop_graph);
159 auto manager = resource->manager();
160 MS_EXCEPTION_IF_NULL(manager);
161 manager->AddFuncGraph(bprop_graph);
162 FuncGraphPtr after_opt_bg = nullptr;
163 after_opt_bg = pipeline::FinalBpropGraphPass(resource, has_control_flow);
164 PyNativeAlgo::Common::DumpGraphIR("after_final_opt.ir", after_opt_bg);
165 return after_opt_bg;
166 }
167
SetGraphInputArgs(const std::vector<ValuePtr> & input_vec,const pipeline::ResourcePtr & res,size_t graph_param_size,SensType sens_type,VectorRef * const arg_list)168 void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::ResourcePtr &res,
169 size_t graph_param_size, SensType sens_type, VectorRef *const arg_list) {
170 MS_EXCEPTION_IF_NULL(arg_list);
171 MS_EXCEPTION_IF_NULL(res);
172 auto graph = res->func_graph();
173 MS_EXCEPTION_IF_NULL(graph);
174 const auto &graph_params = graph->parameters();
175 if (graph_params.size() < graph_param_size) {
176 MS_LOG(EXCEPTION) << "Get initial bprop graph param size " << graph_param_size << " less than current param size "
177 << graph_params.size() << ". Graph parameters maybe update by kernel graph compile stage";
178 }
179 std::vector<ValuePtr> input_arg_list;
180 if (sens_type == SensType::kNormal) {
181 input_arg_list = input_vec;
182 } else if (sens_type == SensType::kTuple) {
183 PyNativeAlgo::DataConvert::FlattenArgs(input_vec, &input_arg_list, true);
184 } else {
185 input_arg_list.assign(input_vec.begin(), input_vec.end() - kIndex1);
186 const auto &v_sens = input_vec.back();
187 MS_EXCEPTION_IF_NULL(v_sens);
188 if (!v_sens->isa<ValueDictionary>()) {
189 MS_LOG(EXCEPTION) << "Get sens not dict " << v_sens->ToString();
190 }
191 const auto &v_dict = v_sens->cast<ValueDictionaryPtr>();
192 ValuePtrList key_inputs;
193 ValuePtrList value_inputs;
194 for (const auto &elem : v_dict->value()) {
195 (void)key_inputs.emplace_back(elem.first);
196 (void)value_inputs.emplace_back(elem.second);
197 }
198 auto key = std::make_shared<ValueTuple>(key_inputs);
199 auto value = std::make_shared<ValueTuple>(value_inputs);
200 (void)input_arg_list.emplace_back(key);
201 (void)input_arg_list.emplace_back(value);
202 }
203 (void)std::transform(input_arg_list.begin(), input_arg_list.end(), std::back_inserter(*arg_list),
204 [](const ValuePtr &v) { return v; });
205 size_t arg_size = arg_list->size();
206 if (arg_size != graph_param_size) {
207 // Maybe have some default parameter for input
208 MS_LOG(DEBUG) << "Get args size " << arg_size << ", graph param size " << graph_param_size;
209 for (std::size_t i = arg_size; i < graph_param_size; ++i) {
210 MS_EXCEPTION_IF_NULL(graph_params[i]);
211 auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
212 MS_EXCEPTION_IF_NULL(param_ptr);
213 if (!param_ptr->has_default()) {
214 MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param, " << param_ptr->DebugString();
215 }
216 if (!param_ptr->default_param()->isa<tensor::BaseTensor>()) {
217 MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->DebugString()
218 << "] is not initialized, need to call `.init_data()`";
219 }
220 arg_list->push_back(param_ptr->default_param());
221 }
222 }
223 }
224
RestoreBpropGraphParameter(const FuncGraphPtr & graph,size_t graph_param_size)225 void RestoreBpropGraphParameter(const FuncGraphPtr &graph, size_t graph_param_size) {
226 auto parameters = graph->parameters();
227 // Is ascend, kernel graph maybe adjust and insert some control parameters
228 if (parameters.size() > graph_param_size) {
229 (void)parameters.erase(parameters.begin() + graph_param_size, parameters.end());
230 graph->set_parameters(std::move(parameters));
231 }
232 }
233
SetSensValue(const prim::GradOperationPtr & grad,const InputArgsInfoPtr & input_args_info,const py::args & args,bool dict_convert_to_tuple)234 void SetSensValue(const prim::GradOperationPtr &grad, const InputArgsInfoPtr &input_args_info, const py::args &args,
235 bool dict_convert_to_tuple) {
236 MS_EXCEPTION_IF_NULL(grad);
237 if (!grad->sens_param()) {
238 return;
239 }
240 size_t forward_args_size = args.size() - 1;
241 auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
242 MS_LOG(DEBUG) << "Get sens param " << sens_v->ToString();
243 const auto &sens_tensor = ConvertOutputValueToTensor(sens_v, dict_convert_to_tuple);
244 if (sens_tensor == nullptr) {
245 MS_LOG(EXCEPTION) << "sens convert tensor is nullptr";
246 }
247 // Sens have already existed, which may be need update
248 MS_EXCEPTION_IF_NULL(input_args_info);
249 if (input_args_info->input_arg_value_vec.size() == args.size()) {
250 input_args_info->input_arg_value_vec.pop_back();
251 }
252 (void)input_args_info->input_arg_value_vec.emplace_back(sens_tensor);
253 if (sens_tensor->isa<ValueSequence>()) {
254 input_args_info->sens_type = SensType::kTuple;
255 } else if (!dict_convert_to_tuple) {
256 input_args_info->sens_type = SensType::kDict;
257 }
258 }
259
GetWeightsObjIdsByWeights(const py::object & weights)260 std::string GetWeightsObjIdsByWeights(const py::object &weights) {
261 auto is_require_grad = [](const ValuePtr &value) {
262 MS_EXCEPTION_IF_NULL(value);
263 if (!value->isa<tensor::BaseTensor>()) {
264 return false;
265 }
266 auto t = value->cast<tensor::BaseTensorPtr>();
267 MS_EXCEPTION_IF_NULL(t);
268 if (t->is_parameter() && t->param_info() != nullptr && t->param_info()->requires_grad()) {
269 return true;
270 }
271 return false;
272 };
273
274 std::string weights_obj_id;
275 auto append_weights_info = [&weights_obj_id, is_require_grad](const py::object &obj) {
276 const auto &v = PyNativeAlgo::DataConvert::PyObjToValue(obj);
277 if (is_require_grad(v)) {
278 (void)weights_obj_id.append("_").append(PyNativeAlgo::Common::GetIdByValue(v));
279 }
280 };
281
282 if (py::isinstance<py::tuple>(weights)) {
283 const auto &weights_tuple = weights.cast<py::tuple>();
284 for (size_t i = 0; i < weights_tuple.size(); ++i) {
285 append_weights_info(weights_tuple[i]);
286 }
287 } else if (py::isinstance<py::list>(weights)) {
288 const auto &weights_list = weights.cast<py::list>();
289 for (size_t i = 0; i < weights_list.size(); ++i) {
290 append_weights_info(weights_list[i]);
291 }
292 } else if (!py::isinstance<py::none>(weights)) {
293 append_weights_info(weights);
294 }
295
296 return weights_obj_id;
297 }
298
FreeSpecialOpValue(const std::string & op_name,const FrontendOpRunInfoPtr & op_run_info,ValuePtr * const output)299 void FreeSpecialOpValue(const std::string &op_name, const FrontendOpRunInfoPtr &op_run_info, ValuePtr *const output) {
300 // Special cases, manually free more inputs.
301 static mindspore::HashSet<std::string> kMulOp{
302 kMulOpName,
303 kMatMulOpName,
304 kConv2DOpName,
305 };
306 static mindspore::HashSet<std::string> kDivOp{
307 kDivOpName,
308 kRealDivOpName,
309 };
310 if (op_name == kBatchNormOpName) {
311 // 1. BatchNorm is a multi-output node, it's out[0] and out[1] are not used.
312 auto seq_v = (*output)->cast<ValueSequencePtr>();
313 MS_EXCEPTION_IF_NULL(seq_v);
314 ValuePtrList new_v_list{seq_v->value()};
315 new_v_list[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex0]);
316 new_v_list[kIndex1] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex1]);
317 *output = std::make_shared<ValueTuple>(new_v_list);
318 MS_LOG(DEBUG) << "Clear device address for output[0, 1] of " << op_name;
319 } else if (op_name == kLayerNormOpName) {
320 // 2. LayerNorm is a multi-output node, it's out[0] and out[1] are not used.
321 auto seq_v = (*output)->cast<ValueSequencePtr>();
322 MS_EXCEPTION_IF_NULL(seq_v);
323 ValuePtrList new_v_list{seq_v->value()};
324 new_v_list[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(new_v_list[kIndex0]);
325 *output = std::make_shared<ValueTuple>(new_v_list);
326 MS_LOG(DEBUG) << "Clear device address for output[0] of " << op_name;
327 } else if (kMulOp.find(op_name) != kMulOp.end()) {
328 // 3. For operators like Mul, the dx ONLY rely on y, and dy ONLY rely on x.
329 // so if y is a valuenode, the dy is useless, we can free x in ahead.
330 bool x_is_const_value = PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex0]);
331 bool y_is_const_value = PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex1]);
332 if (x_is_const_value && op_run_info->base_op_run_info.expanded_input_values[kIndex1]->isa<tensor::BaseTensor>()) {
333 op_run_info->op_grad_info->input_value[kIndex1] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(
334 op_run_info->base_op_run_info.expanded_input_values[kIndex1]);
335 MS_LOG(DEBUG) << "Clear device address for inputs[1] of " << op_name;
336 }
337 if (y_is_const_value && op_run_info->base_op_run_info.expanded_input_values[kIndex0]->isa<tensor::BaseTensor>()) {
338 op_run_info->op_grad_info->input_value[kIndex0] = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(
339 op_run_info->base_op_run_info.expanded_input_values[kIndex0]);
340 MS_LOG(DEBUG) << "Clear device address for inputs[0] of " << op_name;
341 }
342 } else if (kDivOp.find(op_name) != kDivOp.end()) {
343 // 3. For operators like Div, the dy does not rely on output node, so if y is a valuenode, we can free output.
344 if (PyNativeAlgo::Common::IsConstant(op_run_info->op_grad_info->input_value_grad_type[kIndex1])) {
345 *output = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(*output);
346 MS_LOG(DEBUG) << "Clear device address for the output of " << op_name;
347 }
348 }
349 }
350
FreeUselessValue(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)351 void FreeUselessValue(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
352 MS_EXCEPTION_IF_NULL(op_run_info);
353 MS_EXCEPTION_IF_NULL(top_cell);
354 if (top_cell->is_high_order_top_cell()) {
355 return;
356 }
357
358 const auto &unused_inputs = BpropExpander::GetUnusedInputs(op_run_info->op_grad_info->op_prim->name());
359 for (const auto i : unused_inputs) {
360 if (i < op_run_info->input_size) {
361 // Free bprop not used input
362 op_run_info->op_grad_info->input_value[i] =
363 PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(op_run_info->op_grad_info->input_value[i]);
364 } else if (i == op_run_info->input_size) {
365 // Process output, free bprop not used output
366 op_run_info->op_grad_info->out_value =
367 PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(op_run_info->op_grad_info->out_value);
368 }
369 }
370
371 // Free special op memory
372 FreeSpecialOpValue(op_run_info->op_grad_info->op_prim->name(), op_run_info, &op_run_info->op_grad_info->out_value);
373 }
374
CreateOpGradParam(const FrontendOpRunInfoPtr & op_run_info,const TopCellInfoPtr & top_cell)375 GradParamPtr CreateOpGradParam(const FrontendOpRunInfoPtr &op_run_info, const TopCellInfoPtr &top_cell) {
376 MS_EXCEPTION_IF_NULL(op_run_info);
377 bool out_used_in_bporp_graph = true;
378 op_run_info->op_grad_info->out_value = op_run_info->real_out;
379 FreeUselessValue(op_run_info, top_cell);
380
381 op_run_info->op_grad_info->out_abs = op_run_info->base_op_run_info.abstract;
382 auto grad_param = std::make_shared<GradParam>(op_run_info->op_grad_info, top_cell->use_dynamic_shape_process());
383 grad_param->out_used_in_bporp_graph = out_used_in_bporp_graph;
384 return grad_param;
385 }
386
CheckBpropCutNode(const TopCellInfoPtr & top_cell,const PrimitivePtr & op_prim)387 void CheckBpropCutNode(const TopCellInfoPtr &top_cell, const PrimitivePtr &op_prim) {
388 MS_EXCEPTION_IF_NULL(op_prim);
389 if (top_cell->has_bprop_cut_op()) {
390 return;
391 }
392 if (op_prim->name() == kHookBackwardName || op_prim->name() == kCellBackwardHookName) {
393 top_cell->set_has_bprop_cut_op(true);
394 }
395 }
396
CloneParameter(const AnfNodePtr & node,const KernelGraphPtr & new_graph)397 void CloneParameter(const AnfNodePtr &node, const KernelGraphPtr &new_graph) {
398 MS_EXCEPTION_IF_NULL(node);
399 MS_EXCEPTION_IF_NULL(new_graph);
400 auto old_param = node->cast<ParameterPtr>();
401 MS_EXCEPTION_IF_NULL(old_param);
402 auto new_param = new_graph->add_parameter();
403 new_param->set_name(old_param->name());
404 if (auto t = PyNativeAlgo::Common::GetTensorFromParam(old_param); t != nullptr) {
405 const auto ¶m_info = t->param_info();
406 if (param_info != nullptr) {
407 const auto ¶m_name = param_info->name();
408 new_param->set_name(param_name);
409 new_param->debug_info()->set_name(param_name);
410 }
411 new_param->set_default_param(t);
412 }
413 new_param->set_abstract(old_param->abstract());
414 new_param->set_scope(old_param->scope());
415 }
416
CloneKernelGraph(const FuncGraphPtr & func_graph)417 KernelGraphPtr CloneKernelGraph(const FuncGraphPtr &func_graph) {
418 MS_EXCEPTION_IF_NULL(func_graph);
419 MS_LOG(DEBUG) << "Begin clone kernel graph";
420 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
421 MS_EXCEPTION_IF_NULL(kernel_graph);
422 auto new_graph = std::make_shared<session::KernelGraph>();
423 const auto ¶ms = kernel_graph->parameters();
424 for (auto ¶m : params) {
425 CloneParameter(param, new_graph);
426 }
427 auto out = InlineClone(kernel_graph, new_graph, new_graph->parameters());
428 new_graph->set_output(out);
429 PyNativeAlgo::Common::FreeFuncGraphForwardNodes(func_graph);
430 return new_graph;
431 }
432
GetInputArgsId(const py::args & args)433 std::string GetInputArgsId(const py::args &args) {
434 std::string input_args_id;
435 for (size_t i = 0; i < args.size(); ++i) {
436 input_args_id += PyNativeAlgo::PyParser::GetIdByPyObj(args[i]) + "_";
437 }
438 return input_args_id;
439 }
440
SetCustomBpropInputs(const py::object & obj,const InputArgsInfoPtr & input_args_info)441 void SetCustomBpropInputs(const py::object &obj, const InputArgsInfoPtr &input_args_info) {
442 if (py::hasattr(obj, kUsedBpropInputs)) {
443 py::object object = py::getattr(obj, kUsedBpropInputs);
444 if (!py::isinstance<py::tuple>(object) && !py::isinstance<py::list>(object)) {
445 MS_LOG(EXCEPTION) << "For cell bprop, used bprop inputs sholud be tuple or list";
446 }
447 auto used_bprop_inputs = py::cast<py::tuple>(object);
448 std::unordered_set<int64_t> used_inputs;
449 for (size_t i = 0; i < used_bprop_inputs.size(); ++i) {
450 const auto value = PyNativeAlgo::DataConvert::PyObjToValue(used_bprop_inputs[i]);
451 MS_EXCEPTION_IF_NULL(value);
452 int used_index = GetValue<int64_t>(value);
453 (void)used_inputs.insert(used_index);
454 }
455 const size_t input_size = input_args_info->input_arg_value_vec.size();
456 for (size_t i = 0; i < input_size; ++i) {
457 const auto &input_value = input_args_info->input_arg_value_vec[i];
458 if (used_inputs.find(i) == used_inputs.end()) {
459 auto fake_value = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(input_value);
460 input_args_info->input_arg_value_vec[i] = fake_value;
461 }
462 }
463 if (used_inputs.find(input_size) == used_inputs.end()) {
464 auto fake_value = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(input_args_info->out_value);
465 input_args_info->out_value = fake_value;
466 }
467 }
468
469 if (py::hasattr(obj, kInternalParams)) {
470 py::object weights = py::getattr(obj, kInternalParams);
471 if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
472 auto weights_tuple = py::cast<py::tuple>(weights);
473 for (size_t i = 0; i < weights_tuple.size(); ++i) {
474 const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
475 auto tensor = value->cast<tensor::TensorPtr>();
476 MS_EXCEPTION_IF_NULL(tensor);
477 (void)input_args_info->input_arg_value_vec.emplace_back(tensor);
478 (void)input_args_info->input_arg_id_vec.emplace_back(tensor->id());
479 }
480 }
481 }
482 }
483 } // namespace
484
forward() const485 ForwardExecutorPtr GradExecutor::forward() const {
486 auto forward_executor = forward_executor_.lock();
487 MS_EXCEPTION_IF_NULL(forward_executor);
488 return forward_executor;
489 }
490
Init()491 void GradExecutor::Init() {
492 if (init_) {
493 return;
494 }
495 #ifdef _MSC_VER
496 static WinBpropRegister reg;
497 reg.DoNothing();
498 MS_LOG(DEBUG) << "Do windows bprop expander register";
499 #endif
500 init_ = true;
501 }
502
PopTopCellStack()503 TopCellInfoPtr GradExecutor::PopTopCellStack() {
504 if (top_cell_stack_.empty()) {
505 MS_LOG(EXCEPTION) << "Stack top cell stack is empty";
506 }
507 MS_LOG(DEBUG) << "Pop top cell " << top_cell_stack_.top() << " on top cell stack";
508 top_cell_stack_.pop();
509 TopCellInfoPtr top_cell = nullptr;
510 if (!top_cell_stack_.empty()) {
511 top_cell = top_cell_stack_.top();
512 }
513 top_cell == nullptr ? MS_LOG(DEBUG) << "Top cell stack has no top cell"
514 : MS_LOG(DEBUG) << "Top cell stack size " << top_cell_stack_.size();
515 return top_cell;
516 }
517
PushInputArgsInfoStack(const InputArgsInfoPtr & input_args_info)518 void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
519 input_args_info_stack_.push(input_args_info);
520 }
521
PopInputArgsInfoStack()522 void GradExecutor::PopInputArgsInfoStack() {
523 if (input_args_info_stack_.empty()) {
524 MS_LOG(EXCEPTION) << "Stack input_args_info_stack_ is empty";
525 }
526 input_args_info_stack_.pop();
527 }
528
HandleInputArgsForTopCell(const InputArgsInfoPtr & input_args_info)529 void GradExecutor::HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info) {
530 MS_EXCEPTION_IF_NULL(input_args_info);
531 // Convert input args to parameters for top cell graph in construct.
532 std::vector<ValuePtr> input_param_values;
533 const auto &input_value = input_args_info->input_arg_value_vec;
534 if (input_args_info->input_size != 0 && input_value.empty()) {
535 MS_LOG(EXCEPTION) << "Input value is empty";
536 }
537
538 AbstractBasePtrList abs_list;
539 for (size_t i = 0; i < input_args_info->input_size; ++i) {
540 const auto &v = input_value[i];
541 auto param_i_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(v->ToAbstract());
542 if (!top_cell()->is_bprop_need_get_forward_graph()) {
543 (void)PyNativeAlgo::Common::SetValueGradInfo(v, top_cell(), InputType::kInput);
544 (void)input_param_values.emplace_back(v);
545 (void)abs_list.emplace_back(param_i_abs);
546 }
547 RecordForwardGraphForInput(v, input_args_info->input_arg_id_vec[i], param_i_abs);
548 }
549 if (top_cell_->is_bprop_need_get_forward_graph()) {
550 MS_LOG(DEBUG) << "Run bprop function, no need do prepare for grad";
551 return;
552 }
553 // If New cellid come up, bprop graph use cnode for reusing
554 if (IsCreateIrGrad()) {
555 top_cell_->set_is_ir_grad(true);
556 }
557 if (top_cell_->is_ir_grad()) {
558 top_cell_->set_auto_grad_cell_ptr(
559 std::make_shared<autograd::IrGrad>(input_param_values, abs_list, op_num_in_bprop_graph_ * kContainerRatio,
560 assist_queue_, !top_cell_->is_high_order_top_cell(), is_run_recompute_));
561 } else {
562 top_cell_->set_auto_grad_cell_ptr(
563 std::make_shared<autograd::FuncGrad>(input_param_values, op_num_in_bprop_graph_ * kContainerRatio,
564 !top_cell_->is_high_order_top_cell(), is_run_recompute_));
565 }
566 }
567
IsCreateIrGrad()568 bool GradExecutor::IsCreateIrGrad() {
569 if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
570 // If the already run cell id is pipeline top cell map, no need store in already_run_top_cell_ again when run
571 // CheckNeedCompileGraph
572 if (pipeline_top_cell_map_.find(top_cell_->already_run_cell_id()) == pipeline_top_cell_map_.end()) {
573 top_cell_->set_need_compile_graph(true);
574 // If top cell can not find in both already_run_top_cell_ and pipeline_top_cell_map_ can be create new ir
575 if (!top_cell_->use_dynamic_shape_process()) {
576 return true;
577 }
578 }
579 return false;
580 }
581 return false;
582 }
583
InitResourceAndDfBuilder(const InputArgsInfoPtr & input_args_info,bool is_bprop_need_get_forward_graph)584 void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_info,
585 bool is_bprop_need_get_forward_graph) {
586 MS_LOG(DEBUG) << "InitResourceAndDfBuilder";
587 MS_EXCEPTION_IF_NULL(input_args_info);
588 forward()->WaitForwardTask();
589 // We need wait construct bprop task of outer top cell finish, if the main thread runs quickly when it executes
590 // gradnet and clear bprop_queue queue, bprop task of outer top cell may not finish, it will cause not found cnode
591 // error.
592 WaitBpropTask();
593 if (input_args_info->is_grad_topest_cell) {
594 MS_LOG(DEBUG) << "Make new topest graph";
595 ResetMetaGradInfoForNewTopCell(input_args_info);
596 MakeNewTopCell(input_args_info);
597 } else if (input_args_info->is_high_order_top_cell) {
598 MS_LOG(DEBUG) << "Nested grad graph existed in construct";
599
600 // High-order inputs are uplevel top cell ops output, so need back up meta grad info too.
601 for (auto &item : input_args_info->input_arg_value_vec) {
602 top_cell_->BackUpValueMetaGradInfo(item);
603 }
604 ResetMetaGradInfoForNewTopCell(input_args_info);
605 MakeNewTopCell(input_args_info);
606 // High-order must use ir grad
607 top_cell_->set_is_ir_grad(true);
608 } else if (is_bprop_need_get_forward_graph) {
609 MS_LOG(DEBUG) << "Run custom bprop function and make forward graph";
610 // Make top cell just for get forward graph, but no need do anything about grad
611 MakeNewTopCell(input_args_info);
612 curr_g()->debug_info()->set_name("bprop_forward_graph");
613 top_cell_->set_is_bprop_need_get_forward_graph(is_bprop_need_get_forward_graph);
614 }
615
616 // Init kPynativeCellPtr with input parameters of top cell
617 if (!top_cell_->is_init_kpynative()) {
618 auto graph_info_cg = std::make_shared<PyNGraphInfo>();
619 top_cell_->SetGraphInfoMap(curr_g(), graph_info_cg);
620 HandleInputArgsForTopCell(input_args_info);
621 top_cell_->set_init_kpynative(true);
622 }
623 }
624
ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr & input_args_info) const625 void GradExecutor::ResetMetaGradInfoForNewTopCell(const InputArgsInfoPtr &input_args_info) const {
626 // To fix the scene that user calls twice forward network with grad flag, and then call grad() interface.
627 // We need to clear last top cell's parameters grad info to avoid influencing construct bprop graph of current top
628 // cell.
629 if (top_cell_ != nullptr) {
630 MS_LOG(DEBUG) << "Reset meta grad info for top cell " << top_cell_;
631 top_cell_->ResetMetaGradInfo();
632 }
633
634 // To fix the scene like 1. net(x1) 2. x2 = deepcopy(x1), 3. net(x2) 3. grad_net(x2). 4. grad_net(x1)
635 // x1's auto_grad_meta_data will be copy to x2, x2 grad will use the same auto_grad_meta_data and clear x1's variable
636 // and set x2's variable.
637 // When execute grad_net(x1), x1's variable will not found, so we need clear input's auto_grad_meta_data before
638 // execute.
639 for (auto &item : input_args_info->input_arg_value_vec) {
640 top_cell_->ClearValueMetaGradInfo(item);
641 }
642 }
643
NewGraphInner(const py::object & obj,const py::args & args)644 void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
645 // Run custom bprop function, and bprop function is under high-order
646 // If bprop forward graph has been made, new top cell creates severing for it, and current top_cell_ it is.
647 bool running_bprop_function = top_cell_ != nullptr && top_cell_->grad_is_running();
648 bool is_bprop_need_get_forward_graph = running_bprop_function && top_cell_->is_high_order_top_cell();
649
650 const auto input_args_info = GetInputArgsInfo(obj, args, is_bprop_need_get_forward_graph);
651 PushInputArgsInfoStack(input_args_info);
652 MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << ", cell_id " << PyNativeAlgo::PyParser::GetIdByPyObj(obj)
653 << ", input args info ptr " << input_args_info.get();
654
655 // Make top graph and init resource
656 if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell ||
657 is_bprop_need_get_forward_graph) {
658 InitResourceAndDfBuilder(input_args_info, is_bprop_need_get_forward_graph);
659 }
660 }
661
GetInputArgsInfo(const py::object & obj,const py::args & args,bool is_bprop_need_get_forward_graph)662 InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args,
663 bool is_bprop_need_get_forward_graph) {
664 const auto &input_args_info = std::make_shared<InputArgsInfo>(input_args_info_stack_.empty(), IsHighOrderTopCell());
665 ParsePyArgsToInputArgsInfo(input_args_info, obj, args, is_bprop_need_get_forward_graph);
666
667 if (input_args_info->has_custom_bprop) {
668 custom_bprop_cell_count_ += 1;
669 }
670 // CheckAlready run first, grad_order_ will increase 1(highorder scenario)
671 // If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
672 if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell) {
673 if (grad_order_ == 0) {
674 IncreaseGradOrder();
675 }
676 input_args_info->already_run_cell_id = GetAlreadyRunCellId(input_args_info->cell_id);
677 MS_LOG(DEBUG) << "Get already run top cell id " << input_args_info->already_run_cell_id;
678 // top_input_args_info_ indicate current running cell info
679 top_input_args_info_ = input_args_info;
680 }
681 return input_args_info;
682 }
683
GetTopCellDynamicFlag(const InputArgsInfoPtr & input_args_info,const std::string & obj_id_with_grad_order)684 bool GradExecutor::GetTopCellDynamicFlag(const InputArgsInfoPtr &input_args_info,
685 const std::string &obj_id_with_grad_order) {
686 MS_EXCEPTION_IF_NULL(input_args_info);
687 // Just has a forward process, and forward is dynamic(by set_inputs)
688 if (forward_use_dynamic_shape_process_) {
689 MS_LOG(DEBUG) << "Get forward dynamic";
690 return true;
691 }
692
693 // Set by set_inputs
694 if (dynamic_inputs_cells_.find(input_args_info->obj_id) != dynamic_inputs_cells_.end()) {
695 MS_LOG(DEBUG) << "Get dynamic from set inputs";
696 return true;
697 }
698
699 // Dynamic structure
700 auto pre_top_cell = GetAlreadyRunTopCell(input_args_info->already_run_cell_id);
701 if (pre_top_cell != nullptr && pre_top_cell->use_dynamic_shape_process()) {
702 MS_LOG(DEBUG) << "Get dynamic shape from already run top cell";
703 return true;
704 }
705
706 // Dynamic structure for pipeline top cell
707 pre_top_cell = GetPipelineRunTopCell(input_args_info->already_run_cell_id);
708 if (pre_top_cell != nullptr && pre_top_cell->use_dynamic_shape_process()) {
709 MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell";
710 return true;
711 }
712
713 // Dynamic shape
714 if (std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
715 [&obj_id_with_grad_order](const auto &item) {
716 if (item.second != nullptr && item.second->obj_id_with_grad_order() == obj_id_with_grad_order) {
717 return item.second->use_dynamic_shape_process();
718 }
719 return false;
720 })) {
721 MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell with obj_id_with_grad_order " << obj_id_with_grad_order;
722 return true;
723 }
724
725 // Dynamic shape for pipeline top cell
726 return std::any_of(
727 pipeline_top_cell_map_.begin(), pipeline_top_cell_map_.end(), [&obj_id_with_grad_order](const auto &item) {
728 const auto &pipe_top_cell_list = item.second;
729 if (std::any_of(pipe_top_cell_list.begin(), pipe_top_cell_list.end(),
730 [&obj_id_with_grad_order](const auto &pipe_item) {
731 if (pipe_item != nullptr && pipe_item->obj_id_with_grad_order() == obj_id_with_grad_order) {
732 return pipe_item->use_dynamic_shape_process();
733 }
734 return false;
735 })) {
736 MS_LOG(DEBUG) << "Get dynamic shape from pipeline top cell with obj_id_with_grad_order "
737 << obj_id_with_grad_order;
738 return true;
739 }
740 return false;
741 });
742 }
743
MakeNewTopCell(const InputArgsInfoPtr & input_args_info)744 void GradExecutor::MakeNewTopCell(const InputArgsInfoPtr &input_args_info) {
745 MS_EXCEPTION_IF_NULL(input_args_info);
746
747 auto fg = std::make_shared<FuncGraph>();
748 fg->debug_info()->set_name("pynative_forward_graph");
749 auto resource = std::make_shared<pipeline::Resource>();
750
751 finded_top_cell_ = nullptr;
752 bool new_top_cell_is_pipeline_top_cell = NewTopCellIsPipelineTopCell(input_args_info);
753
754 bool new_top_cell_is_pipeline_high_order =
755 input_args_info->is_high_order_top_cell && new_top_cell_is_pipeline_top_cell;
756 // If outer layer top cell is also pipeline top cell, top cell stack maybe empty. Here, need push it to top cell stack
757 // too when running MakeNestedCnode or running bprop function(and brpop function has anoher grad). Because it is need
758 // to known who is outer layer top cell when inner run finished.
759 if (top_cell_ != nullptr && top_cell_->is_pipeline_top_cell() &&
760 (new_top_cell_is_pipeline_high_order || top_cell_->grad_is_running())) {
761 PushTopCellStack(top_cell_);
762 }
763
764 const auto &obj_id_with_grad_order = GetAlreadyRunCellId(input_args_info->obj_id);
765 MS_LOG(DEBUG) << "Get obj id with grad order " << obj_id_with_grad_order;
766 top_cell_ = std::make_shared<TopCellInfo>(
767 input_args_info->is_high_order_top_cell, grad_order_, obj_id_with_grad_order, input_args_info->cell_id,
768 input_args_info->already_run_cell_id, resource, fg, op_num_in_bprop_graph_ * kContainerRatio);
769 top_cell_->set_forward_already_run(true);
770 top_cell_->set_input_args_id(input_args_info->input_args_id);
771 auto use_dynamic_shape_process = GetTopCellDynamicFlag(input_args_info, obj_id_with_grad_order);
772 top_cell_->set_use_dynamic_shape_process(use_dynamic_shape_process);
773 top_cell_->set_need_save_dynamic_detect_nodes(
774 dynamic_shape()->IsNeedSaveDynamicDetectNodes(top_cell_, use_dynamic_shape_process));
775 top_cell_->set_input_args_info(top_input_args_info_);
776 if (dynamic_shape()->enable_unknown_shape()) {
777 dynamic_shape()->TryChangeTopCellToUnknownShape(top_input_args_info_->obj_id,
778 top_input_args_info_->input_arg_base_shape_vec, true);
779 }
780 top_cell_->set_has_bprop_cut_op(input_args_info->has_custom_bprop);
781 top_cell_->set_grad_first(grad_first_);
782 grad_first_ = false;
783 MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << ", top cell ptr " << top_cell_.get()
784 << " with input args id " << top_cell_->input_args_id();
785
786 if (new_top_cell_is_pipeline_top_cell) {
787 pipeline_top_cell_map_[input_args_info->already_run_cell_id].emplace_back(top_cell_);
788 top_cell_->set_is_pipeline_top_cell(true);
789 // If pipeline top cell is high-order, it need to be manage by stack when run MakeNestedCnode, so push it to stack.
790 if (top_cell_->is_high_order_top_cell()) {
791 PushTopCellStack(top_cell_);
792 }
793 MS_LOG(DEBUG) << "Create pipeline top cell, input args id " << top_cell_->input_args_id()
794 << ". The pipeline map size now "
795 << pipeline_top_cell_map_[input_args_info->already_run_cell_id].size();
796 } else {
797 // Common top cell
798 PushTopCellStack(top_cell_);
799 }
800 }
801
NewTopCellIsPipelineTopCell(const InputArgsInfoPtr & input_args_info)802 bool GradExecutor::NewTopCellIsPipelineTopCell(const InputArgsInfoPtr &input_args_info) {
803 // net.set_grad.
804 // pipeline, net(input1), grad(net)(input1), net(input2), grad(net)(input2),...
805 const auto it = pipeline_top_cell_map_.find(input_args_info->already_run_cell_id);
806 if (it != pipeline_top_cell_map_.end()) {
807 // First pipeline top cell
808 MS_EXCEPTION_IF_CHECK_FAIL(!it->second.empty(), "Pipeline top cel map is empty");
809
810 // net.set_grad
811 // grad(net)(input1) -> this will generate a element in already_run_top_cell_ and do a complete grad operation;
812 // Then, run another net(input1) -> this will think it do upgrade op info because a complete grad operation have
813 // done before; But then run another net(input1) -> this will get pipeline top cell and which should be have 2
814 // elements, and they are need compile ir graph because this is the first step for running pipeline top cell.
815 // Then, run grad(net)(input1) -> this will find the matched top cell in already_run_top_cell_ because
816 // already_run_cell_id is matched, but this is not correct because current process is in pipeline top cell now. So,
817 // this will meet a error of auto grad meta.
818 // Erase top cell info from already_run_top_cell_ is need.
819 auto iter = already_run_top_cell_.find(input_args_info->already_run_cell_id);
820 if (iter != already_run_top_cell_.end()) {
821 MS_LOG(DEBUG) << "Erase top cell from already run top cell";
822 // Need use ir top cell, current top cell in pipeline_top_cell_map_ is func grad. So, need exchange.
823 it->second.front() = iter->second;
824 it->second.front()->set_need_compile_graph(true);
825 already_run_top_cell_.erase(iter);
826 }
827 it->second.front()->set_is_pipeline_top_cell(true);
828 return true;
829 }
830 // net.set_grad.
831 // 1. grad(net)(input), top cell id will include grad_operation_;
832 // 2. net(input1), grad(net)(input1), net(input2), grad(net)(input2), ..., top cell id not include grad_operation_.
833 // In second step, grad(net)(input) should be pipeline cell too.
834 auto iter = std::find_if(pipeline_top_cell_map_.begin(), pipeline_top_cell_map_.end(),
835 [&input_args_info](const auto &iter_pipe) {
836 return input_args_info->already_run_cell_id.find(iter_pipe.first) != std::string::npos;
837 });
838 if (iter != pipeline_top_cell_map_.end()) {
839 input_args_info->already_run_cell_id = iter->first;
840 return true;
841 }
842 return false;
843 }
844
SetForwardLastNodeInfo(const ValuePtr & v) const845 void GradExecutor::SetForwardLastNodeInfo(const ValuePtr &v) const {
846 MS_EXCEPTION_IF_NULL(v);
847 auto value = v;
848 if (v->isa<tensor::CSRTensor>()) {
849 auto csr_tensorptr = v->cast<tensor::CSRTensorPtr>();
850 value = csr_tensorptr->GetValues();
851 } else if (v->isa<tensor::COOTensor>()) {
852 auto coo_tensorptr = v->cast<tensor::COOTensorPtr>();
853 value = coo_tensorptr->GetValues();
854 }
855 (void)PyNativeAlgo::Common::SetValueGradInfo(value, top_cell_, InputType::kConstant);
856 // Set last output abstract and will be used for sens
857 auto fake_v = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(value);
858 top_cell()->SetLastOutputValueForwardOutputFlag(fake_v);
859 if (forward()->enable_async()) {
860 auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
861 auto task = [auto_grad_cell_ptr, fake_v]() { auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(fake_v); };
862 DispatchGradQueueTask(std::move(task));
863 } else {
864 top_cell()->auto_grad_cell_ptr()->UpdateOutputNodeOfTopCell(fake_v);
865 }
866 }
867
EndGraphInner(const py::object & obj,const py::object & out,const py::args & args)868 void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, const py::args &args) {
869 if (input_args_info_stack_.empty()) {
870 return;
871 }
872 const auto input_args_info = input_args_info_stack_.top();
873 MS_EXCEPTION_IF_NULL(input_args_info);
874 MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << ", cell_id " << PyNativeAlgo::PyParser::GetIdByPyObj(obj)
875 << ", input args info ptr " << input_args_info.get();
876 if (input_args_info->is_grad_topest_cell) {
877 grad_flag_ = false;
878 }
879
880 // If there is a custom bprop in the forward running process of the cell, need to do DoGradForCustomBprop;
881 // If the top cell is only for obtaining the forward graph, there is no need to do DoGradForCustomBprop.
882 bool need_do_custom_bprop_grad = false;
883 if (input_args_info->has_custom_bprop && custom_bprop_cell_count_ != 0) {
884 --custom_bprop_cell_count_;
885 need_do_custom_bprop_grad = custom_bprop_cell_count_ == 0;
886 }
887 if (!top_cell()->is_bprop_need_get_forward_graph() && need_do_custom_bprop_grad) {
888 GetCustomBpropPrim(obj, args, input_args_info);
889 runtime::OpExecutor::GetInstance().WaitAll();
890 input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out, false);
891 // Recompute need to regardless of non tensor inputs, maybe it is a middle cell and not call EndGraphImpl
892 if (input_args_info->is_need_recompute) {
893 input_args_info->out_value =
894 ConvertOutputValueToTensor(input_args_info->out_value, !top_cell()->jit_out_has_dict());
895 }
896 const auto &out_id = PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value);
897 SetCustomBpropInputs(obj, input_args_info);
898 DoGradForCustomBprop(input_args_info, out_id);
899 }
900
901 // Get top cell endgraph
902 if (input_args_info->cell_id == top_cell()->cell_id()) {
903 runtime::OpExecutor::GetInstance().WaitAll();
904 if (input_args_info->out_value == nullptr) {
905 input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out, false);
906 }
907 MS_LOG(DEBUG) << "Get cell output value " << input_args_info->out_value->ToString();
908 EndGraphImpl(input_args_info);
909 }
910 PopInputArgsInfoStack();
911 }
912
EndGraphImpl(const InputArgsInfoPtr & input_args_info)913 void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
914 auto out_tensor = ConvertOutputValueToTensor(input_args_info->out_value, !top_cell()->jit_out_has_dict());
915 std::vector<std::string> output_tensors_id;
916 PyNativeAlgo::DataConvert::ConvertValueTensorId(out_tensor, &output_tensors_id);
917 top_cell()->set_outputs_ids(std::move(output_tensors_id));
918 if (out_tensor != nullptr) {
919 input_args_info->out_value = out_tensor;
920 }
921
922 // If network runs twice, and one of the runs is an empty network, the following judgment will take effect
923 ValuePtrList inputs{input_args_info->out_value};
924 AbstractBasePtrList abs{PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args_info->out_value->ToAbstract())};
925 auto node_info = std::make_shared<DynamicDetectNodeInfo>(nullptr, abs, nullptr);
926 (void)dynamic_shape()->CheckNodeDynamic(top_cell(), inputs, node_info);
927
928 // Just only dump the last forward graph or bprop forward graph
929 if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
930 auto output_node =
931 GetInput(input_args_info->out_value, PyNativeAlgo::Common::GetIdByValue(input_args_info->out_value));
932 curr_g()->set_output(output_node);
933 PyNativeAlgo::Common::DumpGraphIR("fg.ir", curr_g());
934 MS_LOG(DEBUG) << "Save forward graph";
935 }
936 if (top_cell_->is_bprop_need_get_forward_graph()) {
937 MS_LOG(DEBUG) << "Run bprop no need do grad";
938 return;
939 }
940
941 // Set sens value for grad
942 SetForwardLastNodeInfo(input_args_info->out_value);
943
944 if (input_args_info->is_grad_topest_cell) {
945 MS_LOG(DEBUG) << "Cur top last cell " << input_args_info->cell_id;
946 top_cell()->ClearCellHookOp();
947 }
948
949 top_cell()->CheckSubCellHookChanged();
950 // Checkout whether you need to compile graph when each top cell has run finished
951 CheckNeedCompileGraph(input_args_info);
952 if (!top_cell_->grad_first()) {
953 DecreaseGradOrder();
954 }
955 top_input_args_info_ = input_args_info;
956 }
957
DoGradForCustomBprop(const InputArgsInfoPtr & input_args_info,const std::string & out_id) const958 void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info, const std::string &out_id) const {
959 MS_EXCEPTION_IF_NULL(input_args_info);
960 MS_EXCEPTION_IF_NULL(input_args_info->custom_bprop_prim);
961 auto op_run_info = std::make_shared<FrontendOpRunInfo>();
962 op_run_info->requires_grad = true;
963 op_run_info->base_op_run_info.op_name = input_args_info->custom_bprop_prim->name();
964 op_run_info->op_grad_info->op_prim = input_args_info->custom_bprop_prim;
965 op_run_info->op_grad_info->input_value = input_args_info->input_arg_value_vec;
966 op_run_info->op_grad_info->is_need_recompute = input_args_info->is_need_recompute;
967 op_run_info->input_size = input_args_info->input_arg_value_vec.size();
968 op_run_info->input_value_id = input_args_info->input_arg_id_vec;
969 op_run_info->real_out = input_args_info->out_value;
970 op_run_info->out_value_id = out_id;
971 op_run_info->base_op_run_info.abstract =
972 PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args_info->out_value->ToAbstract());
973 op_run_info->op_grad_info->input_value_grad_type.resize(op_run_info->input_size);
974 for (size_t i = 0; i < op_run_info->input_size; ++i) {
975 const auto &value = input_args_info->input_arg_value_vec[i];
976 (void)op_run_info->op_grad_info->input_abs.emplace_back(
977 PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract()));
978 op_run_info->op_grad_info->input_value_grad_type[i] =
979 PyNativeAlgo::Common::SetValueGradInfo(value, top_cell(), InputType::kConstant);
980 }
981 op_run_info->op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_run_info->real_out);
982 (void)PyNativeAlgo::Common::SetValueGradInfo(op_run_info->real_out, nullptr, InputType::kOpOutput);
983 DoOpGrad(op_run_info);
984 top_cell()->GetOpInfo(op_run_info, false);
985 UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
986 op_run_info->base_op_run_info.stream_id);
987 auto node_info = std::make_shared<DynamicDetectNodeInfo>(
988 op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_abs, op_run_info->base_op_run_info.abstract);
989 (void)dynamic_shape()->CheckNodeDynamic(top_cell(), op_run_info->op_grad_info->input_value, node_info);
990 RecordForwardGraph(op_run_info);
991 }
992
GetCustomBpropPrim(const py::object & obj,const py::args & args,const InputArgsInfoPtr & input_args_info)993 void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args,
994 const InputArgsInfoPtr &input_args_info) {
995 MS_EXCEPTION_IF_NULL(input_args_info);
996 MS_LOG(DEBUG) << "Do grad for custom bprop";
997 py::function bprop_func = py::getattr(obj, parse::CUSTOM_BPROP_NAME);
998 py::object code_obj = py::getattr(bprop_func, "__code__");
999 py::object co_name = py::getattr(code_obj, "co_name");
1000 if (std::string(py::str(co_name)) == "staging_specialize") {
1001 MS_LOG(EXCEPTION) << "Decorating bprop with '@jit' is not supported.";
1002 }
1003
1004 auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name());
1005 MS_EXCEPTION_IF_NULL(input_args_info);
1006 if (py::isinstance<Cell>(obj)) {
1007 const auto &cell_ptr = obj.cast<CellPtr>();
1008 input_args_info->is_need_recompute = cell_ptr->HasAttr(kNeedRecompute);
1009 fake_prim->set_bprop_cls_name(cell_ptr->name());
1010 }
1011 if (input_args_info->input_arg_value_vec.empty()) {
1012 for (size_t i = 0; i < args.size(); ++i) {
1013 (void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
1014 }
1015 }
1016 fake_prim->AddBackwardHookFn(0, bprop_func);
1017
1018 (void)fake_prim->AddAttr("cell_id", MakeValue(input_args_info->cell_id));
1019 (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true));
1020
1021 input_args_info->custom_bprop_prim = fake_prim;
1022 }
1023
ClearPreTopCell(const TopCellInfoPtr & new_top_cell,bool is_need_clear_device_mem)1024 void GradExecutor::ClearPreTopCell(const TopCellInfoPtr &new_top_cell, bool is_need_clear_device_mem) {
1025 MS_EXCEPTION_IF_NULL(new_top_cell);
1026 // Clear already run top cell and device mem
1027 for (auto iter = already_run_top_cell_.begin(); iter != already_run_top_cell_.end();) {
1028 MS_EXCEPTION_IF_NULL(iter->second);
1029 if (iter->second->obj_id_with_grad_order() == new_top_cell->obj_id_with_grad_order()) {
1030 if (is_need_clear_device_mem) {
1031 iter->second->ClearDeviceMemory();
1032 (void)need_gc_top_cell_list_.emplace_back(iter->second);
1033 }
1034 iter = already_run_top_cell_.erase(iter);
1035 } else {
1036 (void)iter++;
1037 }
1038 }
1039 }
1040
CheckNeedCompileGraph(const InputArgsInfoPtr & input_args_info)1041 void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
1042 const auto &already_top_cell_id = top_cell()->already_run_cell_id();
1043 bool is_new_cell_id = false;
1044 // Get new cell id for common grad, even for dynamic shapes, first step will come in too.
1045 if (top_cell_->need_compile_graph()) {
1046 MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
1047 // net.set_grad.
1048 // 1. grad(net)(input), top cell id will include grad_operation_;
1049 // 2. net(input), grad(net)(input), top cell id not include grad_operation_. But, just only keep one for cccurate
1050 // find when call GetTopCell
1051 auto it = std::find_if(
1052 already_run_top_cell_.begin(), already_run_top_cell_.end(),
1053 [&already_top_cell_id](const auto &item) { return item.first.find(already_top_cell_id) != std::string::npos; });
1054 if (it != already_run_top_cell_.end()) {
1055 already_run_top_cell_.erase(it);
1056 }
1057 already_run_top_cell_[already_top_cell_id] = top_cell_;
1058 is_new_cell_id = true;
1059 }
1060 // First step and first top cell prepare for pipeline if it is
1061 const auto it = pipeline_top_cell_map_.find(top_cell_->already_run_cell_id());
1062 if (it == pipeline_top_cell_map_.end()) {
1063 MS_LOG(DEBUG) << "Prepare the first top cell to be pipeline top cell";
1064 top_cell_->set_need_compile_graph(true);
1065 pipeline_top_cell_map_[already_top_cell_id].emplace_back(top_cell_);
1066 // If the top cell is the first top cell, pipeline top cell backup one and return;
1067 // But, if top cell is not pipeline(run by already run top cell map), in the second step, can not return here, which
1068 // should go down to judge compile status.
1069 if (is_new_cell_id) {
1070 return;
1071 }
1072 }
1073 // Get pipeline top cell in first step, and judge by first top cell have completed a backward run
1074 if (top_cell_->is_pipeline_top_cell()) {
1075 MS_EXCEPTION_IF_CHECK_FAIL(!it->second.empty(), "Pipeline top cel map is empty");
1076 if (!it->second.front()->is_finish_backward()) {
1077 top_cell_->set_need_compile_graph(true);
1078 // Get dynamic structure
1079 if (top_cell_->use_dynamic_shape_process()) {
1080 it->second.front()->set_use_dynamic_shape_process(true);
1081 }
1082 MS_LOG(DEBUG) << "Get pipeline top cell has never been ran, input args " << top_cell_->input_args_id();
1083 return;
1084 }
1085 }
1086
1087 // Older top cell id or dynamic shape
1088 MS_EXCEPTION_IF_NULL(input_args_info);
1089 // In high-order situations, the internal top cell has changed, but the outer top cell remains unchanged. Then outer
1090 // bprop graph needs to compile again
1091 if (top_cell_->use_dynamic_shape_process() || top_cell_->force_top_cell_compile()) {
1092 // Function need compiler every time.
1093 top_cell_->use_dynamic_shape_process() ? MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again"
1094 : MS_LOG(DEBUG) << "Force outer graph compile graph";
1095 if (!top_cell_->is_pipeline_top_cell()) {
1096 auto has_higher_order = std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
1097 [](const auto &elem) { return elem.second->is_high_order_top_cell(); });
1098 ClearPreTopCell(top_cell_, input_args_info->is_grad_topest_cell && !has_higher_order);
1099 already_run_top_cell_[already_top_cell_id] = top_cell_;
1100 } else {
1101 MS_LOG(DEBUG) << "Get pipeline top cell, input args " << top_cell_->input_args_id();
1102 }
1103 top_cell_->set_need_compile_graph(true);
1104 top_cell_->set_force_top_cell_compile(false);
1105 } else {
1106 MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " no need to compile graph again";
1107 if (!top_cell_->is_pipeline_top_cell()) {
1108 top_cell_->set_need_compile_graph(false);
1109 auto pre_top_cell = GetAlreadyRunTopCell(already_top_cell_id);
1110 MS_EXCEPTION_IF_NULL(pre_top_cell);
1111 pre_top_cell->set_input_args_id(top_cell_->input_args_id());
1112 // In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
1113 // the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
1114 if (!input_args_info->is_grad_topest_cell) {
1115 pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), top_cell_->graph_info_map().at(top_cell_->fg()));
1116 }
1117 pre_top_cell->set_forward_already_run(true);
1118 pre_top_cell->set_input_args_info(input_args_info);
1119 top_cell_stack_.top() = pre_top_cell;
1120 } else {
1121 MS_LOG(DEBUG) << "Get pipeline top cell, input args " << top_cell_->input_args_id();
1122 }
1123 }
1124 }
1125
GetAlreadyRunTopCell(const std::string & already_run_cell_id) const1126 TopCellInfoPtr GradExecutor::GetAlreadyRunTopCell(const std::string &already_run_cell_id) const {
1127 const auto it = already_run_top_cell_.find(already_run_cell_id);
1128 if (it != already_run_top_cell_.end()) {
1129 return it->second;
1130 }
1131 return nullptr;
1132 }
1133
GetPipelineRunTopCell(const std::string & already_run_cell_id) const1134 TopCellInfoPtr GradExecutor::GetPipelineRunTopCell(const std::string &already_run_cell_id) const {
1135 const auto it = pipeline_top_cell_map_.find(already_run_cell_id);
1136 if (it != pipeline_top_cell_map_.end()) {
1137 return it->second.front();
1138 }
1139 return nullptr;
1140 }
1141
GetPipelineTopCell(const std::string & already_run_cell_id,const std::string & input_args_id,bool is_reverse_match) const1142 TopCellInfoPtr GradExecutor::GetPipelineTopCell(const std::string &already_run_cell_id,
1143 const std::string &input_args_id, bool is_reverse_match) const {
1144 for (const auto &t : pipeline_top_cell_map_) {
1145 bool is_find = is_reverse_match ? t.first.find(already_run_cell_id) != std::string::npos
1146 : already_run_cell_id.find(t.first) != std::string::npos;
1147 if (is_find) {
1148 // If finish backward, skip the first ir top cell
1149 auto begin =
1150 !t.second.empty() && t.second.front()->is_finish_backward() ? t.second.begin() + 1 : t.second.begin();
1151 auto input_args_id_with_top_cell =
1152 std::find_if(begin, t.second.end(), [input_args_id](const TopCellInfoPtr &pipe_top_cell) {
1153 return input_args_id == pipe_top_cell->input_args_id();
1154 });
1155 if (input_args_id_with_top_cell == t.second.end()) {
1156 MS_LOG(DEBUG) << "Can not find top cell with input args id " << input_args_id;
1157 continue;
1158 }
1159 MS_LOG(DEBUG) << "Find pipeline top cell with input args id " << input_args_id;
1160 return *input_args_id_with_top_cell;
1161 }
1162 }
1163 MS_LOG(DEBUG) << "Can not find cell id " << already_run_cell_id << " in pipeline top cell map";
1164 return nullptr;
1165 }
1166
ErasePipelineTopCell(const std::string & already_run_cell_id,const std::string & input_args_id,bool is_pipeline_ir_top_cell)1167 void GradExecutor::ErasePipelineTopCell(const std::string &already_run_cell_id, const std::string &input_args_id,
1168 bool is_pipeline_ir_top_cell) {
1169 for (auto &t : pipeline_top_cell_map_) {
1170 if (already_run_cell_id.find(t.first) == std::string::npos) {
1171 continue;
1172 }
1173
1174 // If top cell is pipeline ir top cell and finish backward, skip the first ir top cell
1175 auto begin = !is_pipeline_ir_top_cell && !t.second.empty() && t.second.front()->is_finish_backward()
1176 ? t.second.begin() + 1
1177 : t.second.begin();
1178 auto input_args_id_with_top_cell = std::find_if(
1179 begin, t.second.end(),
1180 [input_args_id](const TopCellInfoPtr &pipe_top_cell) { return input_args_id == pipe_top_cell->input_args_id(); });
1181 if (input_args_id_with_top_cell == t.second.end()) {
1182 MS_LOG(DEBUG) << "Can not find top cell with input args id " << input_args_id;
1183 continue;
1184 }
1185 MS_LOG(DEBUG) << "Erase pipeline top cell " << input_args_id_with_top_cell->get() << " with input args id "
1186 << input_args_id << ". The pipeline map size now " << t.second.size() - 1;
1187 t.second.erase(input_args_id_with_top_cell);
1188 if (t.second.empty()) {
1189 MS_LOG(DEBUG) << "Pipeline top cell map with already run cell id " << already_run_cell_id
1190 << " is empty, erase it from the pipeline map";
1191 pipeline_top_cell_map_.erase(t.first);
1192 }
1193 return;
1194 }
1195 }
1196
RunGrad(const prim::GradOperationPtr & grad,const py::object & obj,const py::object & weights,const py::object & grad_position,const py::args & args)1197 py::object GradExecutor::RunGrad(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &weights,
1198 const py::object &grad_position, const py::args &args) {
1199 // Wait forward task finish.
1200 runtime::OpExecutor::GetInstance().WaitAll();
1201
1202 GetTopCellWithInputArgsRespectTo(grad, obj, args);
1203 MS_EXCEPTION_IF_NULL(top_cell_);
1204 MS_LOG(DEBUG) << "Run top cell " << top_cell_;
1205
1206 // Inputs args info must be update to current even no need compile graph again
1207 top_input_args_info_ = top_cell_->input_args_info();
1208 MS_EXCEPTION_IF_NULL(top_input_args_info_);
1209 // Set sens
1210 SetSensValue(grad, top_input_args_info_, args, !top_cell_->jit_out_has_dict());
1211
1212 MS_LOG(DEBUG) << "RunGrad start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
1213 << ", input args info ptr " << top_input_args_info_.get();
1214
1215 if (!top_cell_->need_compile_graph()) {
1216 MS_LOG(DEBUG) << "No need compile graph, graph is ir_grad " << top_cell_->is_ir_grad();
1217 // If no need compile, we can clear construct bprop queue.
1218 (void)need_gc_top_cell_list_.emplace_back(top_cell_);
1219 ClearBpropTask();
1220 top_cell_->ClearMetaGradInfo();
1221
1222 // If top cell is pipeline top cell, finded_top_cell_ will be itself;
1223 // Otherwise, it ir top cell in already_run_top_cell_;
1224 if (!ReplacePipelineTopCellForwardOutput()) {
1225 finded_top_cell_->set_shadow_top_cell(top_cell_.get());
1226 top_cell_ = finded_top_cell_;
1227 finded_top_cell_ = nullptr;
1228 }
1229 // Top cell clean must after pipeline forward output replace, because replace_info can not be clear
1230 AsyncClearTopCell();
1231 top_cell_->UpdateTopCellInfo(false, false, false);
1232 return RunGradGraph();
1233 }
1234
1235 MS_LOG(DEBUG) << "Need compile graph, graph is ir_grad " << top_cell_->is_ir_grad();
1236 WaitBpropTask();
1237 AsyncClearTopCell();
1238 top_cell_ = finded_top_cell_;
1239 finded_top_cell_ = nullptr;
1240 op_num_in_bprop_graph_ = top_cell_->op_index();
1241 top_cell_->set_grad_operation(grad_operation_);
1242 top_cell_->UpdateTopCellInfo(false, false, true);
1243 top_cell_->ResumeMetaGradInfo();
1244 SetBpropGraphJitLevel(obj);
1245 bool weight_param_is_tuple = true;
1246 auto w_args = GetWeightsArgs(weights, &weight_param_is_tuple);
1247 auto p_args = GetGradPositionArgs(grad_position, grad->get_by_position_);
1248 autograd::GradAttr grad_attr(grad->get_all_, grad->get_by_list_, grad->sens_param_, grad->get_by_position_,
1249 weight_param_is_tuple);
1250 if (top_cell_->is_ir_grad()) {
1251 GetGradGraph(grad_attr, w_args, p_args);
1252 return RunGradGraph();
1253 }
1254 return RunGradFunc(grad_attr, w_args, p_args);
1255 }
1256
GetAlreadyRunCellId(const std::string & obj_id) const1257 std::string GradExecutor::GetAlreadyRunCellId(const std::string &obj_id) const {
1258 std::string already_run_cell_id(obj_id);
1259 already_run_cell_id += "_" + std::to_string(grad_order_ == 0 ? 1 : grad_order_);
1260 already_run_cell_id += "_" + grad_operation_;
1261 return already_run_cell_id;
1262 }
1263
GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr & grad,const py::object & obj,const py::args & args)1264 void GradExecutor::GetTopCellWithInputArgsRespectTo(const prim::GradOperationPtr &grad, const py::object &obj,
1265 const py::args &args) {
1266 auto reset_flag = [this]() {
1267 if (finded_top_cell_->is_pipeline_top_cell()) {
1268 top_cell_ = finded_top_cell_;
1269 } else if (top_cell_ != nullptr &&
1270 finded_top_cell_->already_run_cell_id().find(top_cell_->already_run_cell_id()) == std::string::npos) {
1271 // NetA.set_grad, NetB.set_grad
1272 // then, run NetA(input), NetB(input) for get loss, and then run grad(NetA)(input), grad(NetB)(input).
1273 // But, when run grad(NetA)(input), finded_top_cell_ is grad of NetA, but top cell is grad(NetB)(input), which is
1274 // not matched, so need to do exchange.
1275 // Need do meta grad info reset for NetB because NetB run after NetA and NetB not do this operation in
1276 // MakeNewTopCell. If have same inputs or weight parameters, auto grad meta maybe meet nullptr.
1277 top_cell_->ResetMetaGradInfo();
1278 top_cell_ = finded_top_cell_;
1279 }
1280 };
1281
1282 if (finded_top_cell_ != nullptr) {
1283 reset_flag();
1284 return;
1285 }
1286 MS_EXCEPTION_IF_NULL(grad);
1287 py::args args_without_sens;
1288 if (grad->sens_param_) {
1289 // If there is a sense, it will not hit the already run cache
1290 auto tuple_args_size = args.size() - 1;
1291 if (tuple_args_size < 0) {
1292 MS_LOG(EXCEPTION) << "args.size:" << args.size() << " tuple_args_size:" << tuple_args_size << " is invalid.";
1293 }
1294 py::tuple tuple_args(tuple_args_size);
1295 for (size_t i = 0; i < tuple_args_size; ++i) {
1296 tuple_args[i] = args[i];
1297 }
1298 args_without_sens = tuple_args;
1299 } else {
1300 args_without_sens = args;
1301 }
1302 const auto &id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args_without_sens);
1303 const auto &cell_id =
1304 PyNativeAlgo::Common::GetCellId(PyNativeAlgo::PyParser::GetIdByPyObj(obj), id_v.first, id_v.second);
1305 const auto &already_run_cell_id = GetAlreadyRunCellId(cell_id);
1306 const auto &input_args_id = GetInputArgsId(args_without_sens);
1307 MS_LOG(DEBUG) << "Get input cell id " << cell_id << " and already run cell id " << already_run_cell_id
1308 << ", input args id " << input_args_id;
1309 finded_top_cell_ = GetTopCell(already_run_cell_id, input_args_id);
1310 MS_EXCEPTION_IF_NULL(finded_top_cell_);
1311 reset_flag();
1312 }
1313
ReplacePipelineTopCellForwardOutput()1314 bool GradExecutor::ReplacePipelineTopCellForwardOutput() {
1315 // If top cell is pipeline top cell, need to get its ir top cell
1316 if (!top_cell_->is_pipeline_top_cell()) {
1317 return false;
1318 }
1319 auto pipeline_ir_top_cell = GetPipelineRunTopCell(top_cell_->already_run_cell_id());
1320 if (pipeline_ir_top_cell == nullptr) {
1321 MS_LOG(EXCEPTION) << "Can not find pipeline ir top cell " << top_cell_->already_run_cell_id()
1322 << " in pipeline top cell map";
1323 }
1324 UpdatePipelineTopCellFowardTensor(pipeline_ir_top_cell->replace_info(), top_cell_->replace_info());
1325 pipeline_ir_top_cell->set_shadow_top_cell(top_cell_.get());
1326 top_cell_ = pipeline_ir_top_cell;
1327 MS_LOG(DEBUG) << "Run no need compile ir top cell " << top_cell_;
1328 return true;
1329 }
1330
GetGradGraph(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1331 void GradExecutor::GetGradGraph(const autograd::GradAttr &grad_attr, const std::vector<tensor::BaseTensorPtr> &w_args,
1332 const std::vector<size_t> &p_args) {
1333 // Get bprop graph of top cell
1334 auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
1335 auto resource = top_cell()->resource();
1336 MS_EXCEPTION_IF_NULL(resource);
1337 resource->set_func_graph(bprop_graph);
1338 auto manager = resource->manager();
1339 MS_EXCEPTION_IF_NULL(manager);
1340 manager->AddFuncGraph(bprop_graph, true);
1341 bprop_graph->ResetOwnNodes();
1342 // If clear autogradcell before resetownnode, it may corrupt.
1343 AsyncClearAutoGradCell(top_cell());
1344 if (top_cell()->has_control_flow()) {
1345 (void)opt::EnvironConversion(resource);
1346 }
1347 if (top_input_args_info_->sens_type == SensType::kDict) {
1348 PyNativeAlgo::Common::ProcessDictParam(bprop_graph, top_input_args_info_->input_size);
1349 } else if (top_input_args_info_->sens_type == SensType::kTuple) {
1350 PyNativeAlgo::Common::ProcessTupleParam(bprop_graph, top_input_args_info_->input_size);
1351 }
1352 if (top_cell()->jit_out_has_dict()) {
1353 MS_LOG(DEBUG) << "Jit out is dict, need convert make dict to pyexecute";
1354 (void)mindspore::opt::RewriterAfterOptA(resource->func_graph(), resource);
1355 }
1356 top_cell()->SaveForwardOutputTensorInfoInBpropGraph(resource->func_graph());
1357 PyNativeAlgo::Common::DumpGraphIR("launch_bprop_graph.ir", bprop_graph);
1358 resource->SetBackendAsync([]() { return compile::CreateBackend(); });
1359 MS_LOG(DEBUG) << "Start task emit action";
1360 (void)TaskEmitAction(resource);
1361 MS_LOG(DEBUG) << "Start execute action";
1362 (void)ExecuteAction(resource);
1363 top_cell()->UpdateTopCellInfo(false, false, true);
1364 resource->Clean();
1365 }
1366
GetWeightsArgs(const py::object & weights,bool * weight_param_is_tuple) const1367 std::vector<tensor::BaseTensorPtr> GradExecutor::GetWeightsArgs(const py::object &weights,
1368 bool *weight_param_is_tuple) const {
1369 std::vector<tensor::BaseTensorPtr> w_args;
1370 if (py::hasattr(weights, "__parameter_tuple__")) {
1371 const auto &weights_tuple = weights.cast<py::tuple>();
1372 MS_LOG(DEBUG) << "Get weights tuple size " << weights_tuple.size();
1373 for (size_t i = 0; i < weights_tuple.size(); ++i) {
1374 const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
1375 auto tensor = value->cast<tensor::BaseTensorPtr>();
1376 MS_EXCEPTION_IF_NULL(tensor);
1377 (void)w_args.emplace_back(tensor);
1378 }
1379 } else {
1380 MS_LOG(DEBUG) << "No parameter tuple get, try get weights params by input weight";
1381 if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
1382 auto weights_tuple = py::cast<py::tuple>(weights);
1383 for (size_t i = 0; i < weights_tuple.size(); ++i) {
1384 const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights_tuple[i]);
1385 auto tensor = value->cast<tensor::BaseTensorPtr>();
1386 MS_EXCEPTION_IF_NULL(tensor);
1387 (void)w_args.emplace_back(tensor);
1388 }
1389 } else if (!py::isinstance<py::none>(weights)) {
1390 // Single input
1391 const auto value = PyNativeAlgo::DataConvert::PyObjToValue(weights);
1392 auto tensor = value->cast<tensor::BaseTensorPtr>();
1393 (void)w_args.emplace_back(tensor);
1394 MS_EXCEPTION_IF_NULL(tensor);
1395 *weight_param_is_tuple = false;
1396 } else {
1397 return GetDefaultWeights();
1398 }
1399 }
1400 return w_args;
1401 }
1402
GetDefaultWeights() const1403 std::vector<tensor::BaseTensorPtr> GradExecutor::GetDefaultWeights() const {
1404 std::vector<tensor::BaseTensorPtr> w_args;
1405 for (const auto ¶ms : top_cell()->param_grad_info()) {
1406 const auto &tensor = params.first;
1407 if (tensor->is_parameter()) {
1408 (void)w_args.emplace_back(tensor);
1409 }
1410 }
1411 return w_args;
1412 }
1413
GetGradPositionArgs(const py::object & grad_position,bool get_by_position) const1414 std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_position, bool get_by_position) const {
1415 std::vector<size_t> pos_args;
1416 if (!get_by_position) {
1417 return pos_args;
1418 }
1419 if (py::isinstance<py::tuple>(grad_position)) {
1420 const auto &tuple = grad_position.cast<py::tuple>();
1421 (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
1422 [](const py::handle &elem) { return elem.cast<int64_t>(); });
1423 if (pos_args.empty()) {
1424 MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
1425 }
1426 return pos_args;
1427 }
1428 MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
1429 }
1430
CheckParamShapeAndType(const ParameterPtr & param_node,const abstract::AbstractBasePtr & ir_abs,const abstract::AbstractBasePtr & input_abs) const1431 void GradExecutor::CheckParamShapeAndType(const ParameterPtr ¶m_node, const abstract::AbstractBasePtr &ir_abs,
1432 const abstract::AbstractBasePtr &input_abs) const {
1433 MS_EXCEPTION_IF_NULL(param_node);
1434 MS_EXCEPTION_IF_NULL(ir_abs);
1435 MS_EXCEPTION_IF_NULL(input_abs);
1436 const auto &ir_shape = ir_abs->BuildShape()->ToString();
1437 const auto &input_shape = input_abs->BuildShape()->ToString();
1438 if (input_shape != "()" && ir_shape != "()") {
1439 if (input_shape != ir_shape) {
1440 // Sens shape in ir graph is determined by graph output, so it can be dynamic shape; But input shape is
1441 // determined by user input, which could not be dynamic shape.
1442 if (param_node->debug_info()->name() != "sens" || !ir_abs->BuildShape()->IsDynamic()) {
1443 MS_EXCEPTION(ValueError) << "The shape should be " << ir_shape << ", but got " << input_shape << ", "
1444 << param_node->DebugString() << ", ir_abs " << ir_abs->ToString() << ", input_abs "
1445 << input_abs->ToString();
1446 }
1447 }
1448 const auto &ir_dtype = ir_abs->BuildType()->ToString();
1449 const auto &input_dtype = input_abs->BuildType()->ToString();
1450 if (input_dtype != ir_dtype) {
1451 MS_EXCEPTION(TypeError) << "The dtype should be " << ir_dtype << ", but got " << input_dtype << ", "
1452 << param_node->DebugString();
1453 }
1454 }
1455 }
1456
UpdateParamAbsByArgs(const std::vector<ValuePtr> & input_args,const FuncGraphPtr & bprop_graph) const1457 void GradExecutor::UpdateParamAbsByArgs(const std::vector<ValuePtr> &input_args,
1458 const FuncGraphPtr &bprop_graph) const {
1459 MS_EXCEPTION_IF_NULL(bprop_graph);
1460 const auto &bprop_params = bprop_graph->parameters();
1461 // bprop_params include inputs, parameters and sens, should be more than inputs size
1462 if (bprop_params.size() < input_args.size()) {
1463 MS_LOG(EXCEPTION) << "Df parameters size " << bprop_params.size() << " less than " << input_args.size();
1464 }
1465 size_t index = 0;
1466 for (const auto ¶m : bprop_params) {
1467 auto param_node = param->cast<ParameterPtr>();
1468 if (param_node->has_default()) {
1469 MS_EXCEPTION_IF_NULL(param_node->abstract());
1470 } else {
1471 const auto &input_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_args[index]->ToAbstract());
1472 if (param_node->abstract() != nullptr) {
1473 CheckParamShapeAndType(param_node, param_node->abstract(), input_abs);
1474 } else {
1475 param_node->set_abstract(input_abs);
1476 }
1477 ++index;
1478 }
1479 }
1480 }
1481
GetBpropGraph(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1482 FuncGraphPtr GradExecutor::GetBpropGraph(const autograd::GradAttr &grad_attr,
1483 const std::vector<tensor::BaseTensorPtr> &w_args,
1484 const std::vector<size_t> &p_args) {
1485 MS_EXCEPTION_IF_NULL(top_input_args_info_);
1486 const auto &auto_grad_cell = std::dynamic_pointer_cast<autograd::IrGrad>(top_cell()->auto_grad_cell_ptr());
1487 MS_EXCEPTION_IF_NULL(auto_grad_cell);
1488 // Update bprop_graph_run_by_single_op for bprop graph, if it is true, pass like ConvertMakeTupleInputToDynamicInput
1489 // will not take effect
1490 auto_grad_cell->set_bprop_graph_run_by_single_op(top_cell()->use_dynamic_shape_process());
1491 FuncGraphPtr bprop_graph = auto_grad_cell->Finish(w_args, p_args, grad_attr);
1492 MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
1493 UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph);
1494 if (top_cell()->need_do_final_opt()) {
1495 bprop_graph = BpropGraphFinalOpt(bprop_graph, top_cell()->has_control_flow());
1496 }
1497 if (top_input_args_info_->is_high_order_top_cell) {
1498 MS_LOG(DEBUG) << "Get high grad";
1499 top_cell()->resource()->set_optimize_graph(bprop_graph);
1500 bool has_bprop_cut = bprop_graph->has_flag(kFlagPyNativeBpropGraphWithBpropCut);
1501 if (bprop_graph->isa<session::KernelGraph>()) {
1502 bprop_graph = CloneKernelGraph(bprop_graph);
1503 } else {
1504 bprop_graph = BasicClone(bprop_graph);
1505 }
1506 if (has_bprop_cut) {
1507 bprop_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
1508 }
1509 PyNativeAlgo::Common::ReplaceCNodeWithValueNode(bprop_graph);
1510 } else {
1511 top_cell()->resource()->set_optimize_graph(bprop_graph);
1512 }
1513 if (bprop_graph->has_flag(kFlagIsControlFlow)) {
1514 top_cell()->set_has_control_flow(true);
1515 }
1516 if (top_cell()->has_control_flow()) {
1517 bprop_graph = LiftingClone(bprop_graph);
1518 }
1519 bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1520 bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
1521 bprop_graph->set_flag(kFlagPyNativeBpropGraphIsDynamic, top_cell()->use_dynamic_shape_process());
1522
1523 // Update bprop cut flag. Has two scenario:
1524 // 1. kHookBackwardName or kCellBackwardHookName
1525 // 2. Custom op bprop(set in auto_grad.cc by kFlagPyNativeBpropGraphWithBpropCut)
1526 bprop_graph->set_flag(kFlagPyNativeBpropGraphWithBpropCut,
1527 bprop_graph->has_flag(kFlagPyNativeBpropGraphWithBpropCut) || top_cell()->has_bprop_cut_op());
1528
1529 // Update run graph by single op flag. Has two scenario:
1530 // 1. Dynamic shape(or structure) or Dynamic structure
1531 // 2. Has bprop cut op
1532 // If set_inputs, but has constrol flow, we need run by actor.
1533 bprop_graph->set_flag(kFlagEnableRunGraphBySingleOp,
1534 auto_grad_cell->bprop_graph_run_by_single_op() && !bprop_graph->has_flag(kFlagIsControlFlow));
1535 top_cell()->set_use_dynamic_shape_process(bprop_graph->has_flag(kFlagEnableRunGraphBySingleOp));
1536 if (top_cell()->has_call_graph()) {
1537 bprop_graph->set_flag(kFlagPyNativeWithJitCallGraph, true);
1538 }
1539 bool has_control_flow = top_cell()->has_control_flow();
1540 bprop_graph->set_flag(kFlagIsPyNativeBpropKernelGraph, !has_control_flow);
1541 // Control graph will generate kernel graph in compile graphs again. Graph id is conflict with default id 0
1542 if (has_control_flow) {
1543 auto kernel_graph = bprop_graph->cast<KernelGraphPtr>();
1544 MS_EXCEPTION_IF_NULL(kernel_graph);
1545 kernel_graph->set_graph_id(kernel_graph_id_for_control_flow());
1546 }
1547 return bprop_graph;
1548 }
1549
NeedIncreaseGradOrder(const std::string & obj_id)1550 bool GradExecutor::NeedIncreaseGradOrder(const std::string &obj_id) {
1551 // top_cell_ == nullptr means call by grad first
1552 // top_cell_->obj_id_with_grad_order() include obj_id and grad_order
1553 // If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos, means current cell is not top cell,
1554 // another cell or function needs to get grad, so high-order comes up
1555 if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id + "_") == std::string::npos) {
1556 IncreaseGradOrder();
1557 return true;
1558 }
1559 return false;
1560 }
1561
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & obj,const py::object & weights,const py::object & grad_hash_id,const py::args & args)1562 py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
1563 const py::object &weights, const py::object &grad_hash_id,
1564 const py::args &args) {
1565 const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
1566
1567 // The rule of grad order is:
1568 // scenarios 1. net.set_grad, net(input) calls first, increase 1 before MakeNewTopCell, and decrease 1 when running to
1569 // EndGraphImpl, indicating that a complete bprop graph construction is completed; Then call grad(net)(input) is won't
1570 // be affected and get forward_run is true.
1571 // scenarios 2. If grad(net)(input) calls first, then increase 1 before MakeNewTopCell and decrease 1 in Rungrad. The
1572 // reason for this design is that if grad(net)(input) calls first and decrease 1 in EndGraphImpl, it will cause
1573 // matching problems during RunGrad due to the presence of Gradopration information in already_run_cell_id is not the
1574 // same. Gradopration information include grad order for distinguish high-order.
1575 // Use flag: grad_first_ for distinguish this two scenarios. If scenarios 1 is taked, grad_first_ will not take
1576 // effect, otherwise, it works.
1577 bool neee_increase_grad_order = NeedIncreaseGradOrder(obj_id);
1578 // Include weight param size and required grad flag
1579 std::string grad_hash_id_str;
1580 if (!py::isinstance<py::none>(grad_hash_id)) {
1581 grad_hash_id_str = std::string(py::str(grad_hash_id));
1582 }
1583
1584 std::string weights_obj_id = GetWeightsObjIdsByWeights(weights);
1585 grad_operation_ = std::to_string(static_cast<int>(grad->get_all_)) +
1586 std::to_string(static_cast<int>(grad->get_by_list_)) +
1587 std::to_string(static_cast<int>(grad->sens_param_)) + grad_hash_id_str + weights_obj_id;
1588
1589 auto input_args_id = GetInputArgsId(args);
1590 // Under the condition that the stack is empty (forward process completed or no forward process),
1591 // check whether need to run forward process
1592 bool forward_run = false;
1593 if (input_args_info_stack_.empty()) {
1594 const auto &id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args);
1595 auto cell_id = PyNativeAlgo::Common::GetCellId(obj_id, id_v.first, id_v.second);
1596 const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
1597 MS_LOG(DEBUG) << "Get check already run top cell id " << check_already_run_cell_id;
1598 auto find_top_cell = GetTopCell(check_already_run_cell_id, input_args_id);
1599 if (find_top_cell != nullptr) {
1600 MS_LOG(DEBUG) << "Find already run top cell " << find_top_cell;
1601 forward_run = find_top_cell->forward_already_run();
1602 bool input_args_changed =
1603 !find_top_cell->input_args_id().empty() && find_top_cell->input_args_id() != input_args_id;
1604 if (forward_run && input_args_changed) {
1605 MS_LOG(DEBUG) << "The input info " << input_args_id << " is not the same with pre input info "
1606 << find_top_cell->input_args_id() << ", forward process will run again";
1607 forward_run = false;
1608 }
1609 // The pipeline top cell finish forward, but grad is the previous pipeline top cell. Need reset auto meta grad
1610 // info
1611 if (top_cell_ != nullptr && top_cell_->is_pipeline_top_cell() && top_cell_->input_args_id() != input_args_id) {
1612 WaitBpropTask();
1613 top_cell_->ResetMetaGradInfo();
1614 }
1615 if (forward_run) {
1616 // If neee_increase_grad_order is true means grad order increased and prepare to do grad;
1617 // But forward run is true now, means no need do forward again, so grad order need be decrease.
1618 if (neee_increase_grad_order) {
1619 DecreaseGradOrder();
1620 }
1621 finded_top_cell_ = find_top_cell;
1622 }
1623 }
1624 }
1625 if (!forward_run) {
1626 grad_first_ = true;
1627 }
1628 forward_run ? MS_LOG(DEBUG) << "Top cell have already ran with input args id " << input_args_id
1629 : MS_LOG(DEBUG) << "Top cell no run before with input args id " << input_args_id;
1630 return BaseRefToPyData(forward_run);
1631 }
1632
RunGradFunc(const autograd::GradAttr & grad_attr,const std::vector<tensor::BaseTensorPtr> & w_args,const std::vector<size_t> & p_args)1633 py::object GradExecutor::RunGradFunc(const autograd::GradAttr &grad_attr,
1634 const std::vector<tensor::BaseTensorPtr> &w_args,
1635 const std::vector<size_t> &p_args) {
1636 MS_EXCEPTION_IF_NULL(top_input_args_info_);
1637 ValuePtr sens = nullptr;
1638 if (grad_attr.has_sens) {
1639 sens = top_input_args_info_->input_arg_value_vec.back();
1640 }
1641
1642 MS_LOG(DEBUG) << "Eval run begin";
1643 MS_EXCEPTION_IF_NULL(top_cell_);
1644 const auto &auto_grad_cell = std::dynamic_pointer_cast<autograd::FuncGrad>(top_cell_->auto_grad_cell_ptr());
1645 MS_EXCEPTION_IF_NULL(auto_grad_cell);
1646 top_cell_->set_grad_is_running(true);
1647 auto grads = auto_grad_cell->Finish(w_args, p_args, grad_attr, sens);
1648 MS_EXCEPTION_IF_NULL(grads);
1649 MS_EXCEPTION_IF_NULL(top_cell_);
1650 top_cell_->set_grad_is_running(false);
1651 top_input_args_info_ = top_cell_->input_args_info();
1652 MS_LOG(DEBUG) << "Eval run end";
1653
1654 // Clear top cell resource
1655 top_cell_->ClearMetaGradInfo();
1656 // Func grad need to use auto grad meta in finish, so clear it after finish.
1657 AsyncClearAutoGradCell(top_cell_);
1658 ClearGradRes();
1659
1660 // For custom nested grad, we need to resume grad info when finish custom grad.
1661 if (top_cell_ != nullptr) {
1662 top_cell_->ResumeMetaGradInfo();
1663 }
1664 return BaseRefToPyData(grads);
1665 }
1666
RunGradGraph()1667 py::object GradExecutor::RunGradGraph() {
1668 MS_EXCEPTION_IF_NULL(top_input_args_info_);
1669 MS_EXCEPTION_IF_NULL(top_cell_);
1670 const auto &resource = top_cell_->resource();
1671 MS_EXCEPTION_IF_NULL(resource);
1672 MS_LOG(DEBUG) << "Run top cell " << top_cell_ << " and its shadow top cell " << top_cell_->shadow_top_cell();
1673 VectorRef arg_list;
1674 SetGraphInputArgs(top_input_args_info_->input_arg_value_vec, resource, top_cell_->initial_graph_param_size(),
1675 top_input_args_info_->sens_type, &arg_list);
1676 MS_LOG(DEBUG) << "Convert args size " << top_input_args_info_->input_arg_value_vec.size() << ", graph param size "
1677 << arg_list.size();
1678
1679 auto context = MsContext::GetInstance();
1680 MS_EXCEPTION_IF_NULL(context);
1681 context->SetJitLevel(kAttrJitLevelO0);
1682
1683 compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>();
1684 MS_EXCEPTION_IF_NULL(run);
1685
1686 MS_LOG(DEBUG) << "Eval run " << MsContext::GetInstance()->backend_policy();
1687 top_cell_->set_grad_is_running(true);
1688 BaseRef out_value = (*run)(arg_list);
1689 MS_EXCEPTION_IF_NULL(top_cell_);
1690 top_cell_->set_grad_is_running(false);
1691 top_input_args_info_ = top_cell_->input_args_info();
1692 MS_LOG(DEBUG) << "Eval run end";
1693
1694 // Do high-order grad
1695 MakeNestedCnode(top_input_args_info_->has_custom_bprop, top_input_args_info_->input_arg_value_vec,
1696 resource->optimize_graph(), out_value);
1697
1698 // For custom nested grad, we need to resume grad info when finish custom grad.
1699 if (top_cell_ != nullptr) {
1700 top_cell_->ResumeMetaGradInfo();
1701 }
1702 return BaseRefToPyData(out_value);
1703 }
1704
MakeNestedCnode(bool has_custom_bprop,const std::vector<ValuePtr> & forward_args,const FuncGraphPtr & cur_run_bprop_graph,const BaseRef & out)1705 void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<ValuePtr> &forward_args,
1706 const FuncGraphPtr &cur_run_bprop_graph, const BaseRef &out) {
1707 MS_EXCEPTION_IF_NULL(top_input_args_info_);
1708 if (top_input_args_info_->is_grad_topest_cell) {
1709 MS_LOG(DEBUG) << "No nested grad find";
1710 MS_EXCEPTION_IF_NULL(top_cell_);
1711 top_cell_->ClearMetaGradInfo();
1712 ClearGradRes();
1713 return;
1714 }
1715 MS_LOG(DEBUG) << "Do high grad";
1716 // first_grad_fg maybe modified in auto grad, and first_grad_fg can be used multiple times
1717 auto first_grad_fg = cur_run_bprop_graph;
1718 MS_LOG(DEBUG) << "Current top cell ptr " << top_cell().get() << " and its shadow top cell "
1719 << top_cell_->shadow_top_cell();
1720 top_cell_->set_is_finish_backward(true);
1721 if (has_custom_bprop) {
1722 first_grad_fg = curr_g();
1723 // Bprop top cell just used for getting forward graph
1724 top_cell_ = PopTopCellStack();
1725 MS_LOG(DEBUG) << "Bprop nested, after get bprop forward graph, current top cell ptr " << top_cell().get();
1726 } else {
1727 RestoreBpropGraphParameter(cur_run_bprop_graph, top_cell()->initial_graph_param_size());
1728 }
1729
1730 MS_EXCEPTION_IF_NULL(first_grad_fg);
1731 PyNativeAlgo::Common::DumpGraphIR("first_grad_fg.ir", first_grad_fg);
1732 ValuePtrList weights_args;
1733 const std::string cur_top_cell_id = top_cell()->obj_id_with_grad_order();
1734 bool use_dynamic_shape_process = top_cell()->use_dynamic_shape_process() || top_cell()->vm_compile();
1735 bool has_call_graph = top_cell()->has_call_graph();
1736 auto inner_graph_info = top_cell()->graph_info_map().at(curr_g());
1737 SwitchTopCell();
1738 auto op_run_info = std::make_shared<FrontendOpRunInfo>();
1739 op_run_info->requires_grad = true;
1740 op_run_info->op_grad_info->input_value = forward_args;
1741 op_run_info->input_size = forward_args.size();
1742 auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out, true, true);
1743 // Get output values
1744 if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
1745 std::vector<ValuePtr> out_v{out_value};
1746 out_value = std::make_shared<ValueTuple>(out_v);
1747 }
1748 MS_EXCEPTION_IF_NULL(out_value);
1749 RecordNestedGraph(first_grad_fg, inner_graph_info, forward_args, out_value);
1750
1751 // Get input values
1752 PyNativeAlgo::Common::SetGraphInputAndWeightsInfo(op_run_info, first_grad_fg, top_cell());
1753 auto grad_fg = first_grad_fg;
1754 if (has_call_graph) {
1755 auto r = std::make_shared<pipeline::Resource>();
1756 jit()->set_eliminate_forward(false);
1757 (void)first_grad_fg->transforms().erase(kGrad);
1758 auto opt = opt::Optimizer::MakeEmptyOptimizer(r);
1759 opt->set_is_first_order_j(false);
1760 grad_fg = ad::Grad(first_grad_fg, opt);
1761 jit()->set_eliminate_forward(true);
1762 }
1763 auto op_grad_info = std::make_shared<OpGradInfo>();
1764 op_grad_info->input_value = op_run_info->op_grad_info->input_value;
1765 op_grad_info->input_abs = op_run_info->op_grad_info->input_abs;
1766 op_grad_info->out_value = out_value;
1767 op_grad_info->output_size = PyNativeAlgo::Common::GetValueSize(op_grad_info->out_value);
1768 op_grad_info->out_abs = first_grad_fg->output()->abstract();
1769 op_grad_info->input_value_grad_type = op_run_info->op_grad_info->input_value_grad_type;
1770 auto grad_param = std::make_shared<GradParam>(op_grad_info, use_dynamic_shape_process);
1771 grad_param->fg = grad_fg;
1772 grad_param->source_fg = first_grad_fg;
1773 grad_param->is_control_flow = has_call_graph;
1774 // If fun grad and ir grad use the same ad grad graph(hit cache), dout will occur wrong by different type(tuple or
1775 // plant tuple)
1776 grad_param->graph_cache_key = cur_top_cell_id + std::to_string(top_cell()->is_ir_grad());
1777 if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
1778 MS_LOG(EXCEPTION) << "Failed to run ad grad for second grad graph ";
1779 }
1780 top_cell()->set_need_do_final_opt(true);
1781 }
1782
DoParameterReplace(const FuncGraphPtr & first_grad_fg,const GraphInfoPtr & inner_graph_info,const std::vector<ValuePtr> & forward_args,AnfNodePtrList * inputs)1783 void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
1784 const std::vector<ValuePtr> &forward_args, AnfNodePtrList *inputs) {
1785 MS_EXCEPTION_IF_NULL(inner_graph_info);
1786 auto outer_graph_info = top_cell()->graph_info_map().at(curr_g());
1787 MS_EXCEPTION_IF_NULL(outer_graph_info);
1788 for (const auto &forward_arg : forward_args) {
1789 const auto &id = PyNativeAlgo::Common::GetIdByValue(forward_arg);
1790 const auto it = outer_graph_info->input_params.find(id);
1791 if (it != outer_graph_info->input_params.end()) {
1792 // Can find in outer graph
1793 MS_LOG(DEBUG) << "Replace input param id " << id;
1794 // Replace inner graph param by outer graph param
1795 (void)inputs->emplace_back(it->second);
1796 } else {
1797 MS_LOG(DEBUG) << "Can't find input param id " << id;
1798 // Inner graph input param not find in outer graph, need add to outer graph
1799 (void)inputs->emplace_back(GetInput(forward_arg, id));
1800 }
1801 }
1802 mindspore::HashSet<std::string> inner_graph_used_weights_set;
1803 // Weight in inner graph
1804 const auto &fir_graph_parameters = first_grad_fg->parameters();
1805 for (const auto ¶m : fir_graph_parameters) {
1806 auto weight_tensor = PyNativeAlgo::Common::GetTensorFromParam(param);
1807 if (weight_tensor != nullptr) {
1808 (void)inner_graph_used_weights_set.emplace(weight_tensor->id());
1809 }
1810 }
1811 for (const auto &weight : inner_graph_info->weight_params) {
1812 // If weight used in graph, but not need get grad by gradnet, it will be a valuenode, no need replace
1813 if (inner_graph_used_weights_set.find(weight.first) == inner_graph_used_weights_set.end()) {
1814 continue;
1815 }
1816 const auto it = outer_graph_info->weight_params.find(weight.first);
1817 if (it != outer_graph_info->weight_params.end()) {
1818 // Can find in outer graph
1819 MS_LOG(DEBUG) << "Replace weight param name " << weight.second->name() << ", id " << weight.first;
1820 (void)inputs->emplace_back(it->second);
1821 } else {
1822 MS_LOG(DEBUG) << "Can't find weight param name " << weight.second->name() << ", id " << weight.first;
1823 top_cell()->SetParamNodeMapInGraphInfoMap(weight.first, weight.second, true);
1824 (void)inputs->emplace_back(weight.second);
1825 }
1826 }
1827 }
1828
SwitchTopCell()1829 void GradExecutor::SwitchTopCell() {
1830 ClearPipelineTopCellRes();
1831 // Get outer top cell
1832 auto outer_top_cell = PopTopCellStack();
1833 MS_EXCEPTION_IF_NULL(outer_top_cell);
1834 MS_LOG(DEBUG) << "Get outer top cell ptr " << outer_top_cell.get();
1835 // If inner graph compile graph, outer must be compile
1836 if (top_cell()->vm_compile()) {
1837 outer_top_cell->set_force_top_cell_compile(true);
1838 outer_top_cell->set_use_dynamic_shape_process(outer_top_cell->use_dynamic_shape_process() ||
1839 top_cell()->use_dynamic_shape_process());
1840 }
1841 outer_top_cell->ResumeMetaGradInfo();
1842 set_top_cell(outer_top_cell);
1843 }
1844
ClearGlobalRes() const1845 void GradExecutor::ClearGlobalRes() const {
1846 abstract::AnalysisContext::ClearContext();
1847 parse::data_converter::ClearObjectCache();
1848 parse::Parser::CleanParserResource();
1849 trace::ClearTraceStack();
1850 ad::CleanRes();
1851 pipeline::ReclaimOptimizer();
1852 }
1853
ClearGradRes()1854 void GradExecutor::ClearGradRes() {
1855 MS_LOG(DEBUG) << "Top cell run finish " << top_cell_ << " and its shadow top cell " << top_cell_->shadow_top_cell();
1856 // Pop current top cell on stack
1857 if (!top_cell_->is_pipeline_top_cell()) {
1858 (void)PopTopCellStack();
1859 }
1860
1861 if (!top_cell_stack_.empty() && top_cell_->is_pipeline_top_cell()) {
1862 MS_LOG(DEBUG) << "Top cell stack real running top cell " << top_cell_stack_.top();
1863 if (top_cell_stack_.top() == top_cell_) {
1864 MS_LOG(DEBUG) << "Pop pipeline top cell " << top_cell_stack_.top() << " from stack with input args id "
1865 << top_cell_stack_.top()->input_args_id();
1866 (void)PopTopCellStack();
1867 }
1868 }
1869 auto has_higher_order = std::any_of(already_run_top_cell_.begin(), already_run_top_cell_.end(),
1870 [](const auto &elem) { return elem.second->is_high_order_top_cell(); });
1871 // High order must not clean
1872 if (!has_higher_order) {
1873 top_cell_->ClearDeviceMemory();
1874 }
1875
1876 top_cell_->input_args_info()->Reset();
1877 top_cell_->set_is_finish_backward(true);
1878 ClearPipelineTopCellRes();
1879 top_input_args_info_ = nullptr;
1880 ClearGlobalRes();
1881 MS_LOG(DEBUG) << "Current top cell stack size " << top_cell_stack_.size() << ", pipeline top cell map size "
1882 << pipeline_top_cell_map_.size() << ", pipeline top cell map with already run cell id "
1883 << top_cell_->already_run_cell_id() << " size "
1884 << (pipeline_top_cell_map_.find(top_cell_->already_run_cell_id()) == pipeline_top_cell_map_.end()
1885 ? 0
1886 : pipeline_top_cell_map_[top_cell_->already_run_cell_id()].size());
1887 top_cell_ = nullptr;
1888 // Nested grad, get outer top cell if exist
1889 // Run top cell with bprop, and bprop has grad, after running inner grad, top cell should be restore
1890 if (!top_cell_stack_.empty()) {
1891 top_cell_ = top_cell_stack_.top();
1892 MS_LOG(DEBUG) << "Get outer top cell " << top_cell_ << " as the currently running top cell";
1893 }
1894 }
1895
ClearPipelineTopCellRes()1896 void GradExecutor::ClearPipelineTopCellRes() {
1897 // Remove pipipe top cell from pipeline top cell map exclude the first one
1898 if (top_cell_->is_pipeline_top_cell()) {
1899 // Run second step and following step
1900 if (top_cell_->shadow_top_cell() != nullptr) {
1901 ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->shadow_top_cell()->input_args_id(), false);
1902 top_cell_->set_shadow_top_cell(nullptr);
1903 } else if (!top_cell_->is_ir_grad()) {
1904 // Pipeline top cell exclude the first top cell
1905 ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->input_args_id(), false);
1906 }
1907 } else {
1908 // If top cell is not pipeline, because it is stored in pipeline top cell map in first step, here need to do delete
1909 // from the map.
1910 ErasePipelineTopCell(top_cell_->already_run_cell_id(), top_cell_->input_args_id(), true);
1911 }
1912 if (top_cell_->grad_first()) {
1913 DecreaseGradOrder();
1914 }
1915 grad_operation_.clear();
1916 }
1917
ClearRes()1918 void GradExecutor::ClearRes() {
1919 MS_LOG(DEBUG) << "Clear grad res";
1920 WaitBpropTask();
1921 init_ = false;
1922 grad_flag_ = false;
1923 enable_grad_ = true;
1924 is_run_recompute_ = false;
1925 save_graphs_ = false;
1926 forward_use_dynamic_shape_process_ = false;
1927
1928 kernel_graph_id_for_control_flow_ = UINT32_MAX;
1929 custom_bprop_cell_count_ = 0;
1930 grad_order_ = 0;
1931 op_num_in_bprop_graph_ = kDefaultContainerSize;
1932 grad_operation_.clear();
1933
1934 top_cell_ = nullptr;
1935 top_input_args_info_ = nullptr;
1936 std::stack<InputArgsInfoPtr>().swap(input_args_info_stack_);
1937 std::stack<TopCellInfoPtr>().swap(top_cell_stack_);
1938 already_run_top_cell_.clear();
1939 pipeline_top_cell_map_.clear();
1940 dynamic_inputs_cells_.clear();
1941 need_gc_top_cell_list_.clear();
1942 dynamic_shape()->Clear();
1943 jit()->Clear();
1944 }
1945
AsyncClearTopCell()1946 void GradExecutor::AsyncClearTopCell() {
1947 for (const auto &need_gc_top_cell : need_gc_top_cell_list_) {
1948 if (forward()->enable_async()) {
1949 auto task = [need_gc_top_cell]() { need_gc_top_cell->Clear(); };
1950 DispatchGradQueueTask(std::move(task));
1951 } else {
1952 need_gc_top_cell->Clear();
1953 }
1954 }
1955 need_gc_top_cell_list_.clear();
1956 }
1957
AsyncClearAutoGradCell(const TopCellInfoPtr & top_cell)1958 void GradExecutor::AsyncClearAutoGradCell(const TopCellInfoPtr &top_cell) {
1959 if (forward()->enable_async()) {
1960 auto task = [top_cell] { top_cell->set_auto_grad_cell_ptr(nullptr); };
1961 DispatchGradQueueTask(std::move(task));
1962 } else {
1963 top_cell->set_auto_grad_cell_ptr(nullptr);
1964 }
1965 }
1966
WorkerJoin()1967 void GradExecutor::WorkerJoin() {
1968 bprop_queue_->WorkerJoin();
1969 assist_queue_->WorkerJoin();
1970 }
1971
GetInput(const ValuePtr & v,const string & obj_id) const1972 AnfNodePtr GradExecutor::GetInput(const ValuePtr &v, const string &obj_id) const {
1973 // Is not a tensor
1974 AnfNodePtr node = GetNonTensorInput(v, obj_id);
1975 if (node != nullptr) {
1976 return node;
1977 }
1978 // Get param input
1979 node = GetParamInput(v, obj_id);
1980 if (node != nullptr) {
1981 return node;
1982 }
1983 // Get op output
1984 node = GetOutputNodeAsInput(obj_id);
1985 if (node != nullptr) {
1986 return node;
1987 }
1988 // A tuple returns in this case: x = op1, y = op2, return (x, y)
1989 // or a scalar or (scalar, tensor)
1990 node = GetValueSequenceInput(v);
1991 if (node != nullptr) {
1992 return node;
1993 }
1994 auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v);
1995 MS_LOG(DEBUG) << "Get input value node " << v_node->ToString() << ", id " << obj_id;
1996 return v_node;
1997 }
1998
GetParamInput(const ValuePtr & v,const std::string & id) const1999 AnfNodePtr GradExecutor::GetParamInput(const ValuePtr &v, const std::string &id) const {
2000 const auto &graph_info = top_cell()->graph_info_map().at(curr_g());
2001 MS_EXCEPTION_IF_NULL(graph_info);
2002 // Get input param input
2003 const auto it = graph_info->input_params.find(id);
2004 if (it != graph_info->input_params.end()) {
2005 MS_LOG(DEBUG) << "Get input param " << id;
2006 return it->second;
2007 }
2008
2009 // Get weight param input
2010 MS_EXCEPTION_IF_NULL(v);
2011 if (v->isa<tensor::BaseTensor>() && v->cast<tensor::BaseTensorPtr>()->is_parameter()) {
2012 const auto item_by_id = graph_info->weight_params.find(id);
2013 if (item_by_id != graph_info->weight_params.end()) {
2014 MS_LOG(DEBUG) << "Get weight param " << id;
2015 return item_by_id->second;
2016 }
2017 MS_LOG(DEBUG) << "Add new weight param " << id;
2018 auto tensor = v->cast<tensor::BaseTensorPtr>();
2019 const auto ¶m_info = tensor->param_info();
2020 MS_EXCEPTION_IF_NULL(param_info);
2021 const auto ¶m_name = param_info->name();
2022 // Add new weight param to graph info
2023 auto weight_param = curr_g()->add_parameter();
2024 weight_param->set_name(param_name);
2025 weight_param->debug_info()->set_name(param_name);
2026 weight_param->set_default_param(tensor);
2027 weight_param->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
2028 top_cell()->SetParamNodeMapInGraphInfoMap(id, weight_param, true);
2029 return weight_param;
2030 }
2031 return nullptr;
2032 }
2033
GetOutputNodeAsInput(const std::string & obj_id) const2034 AnfNodePtr GradExecutor::GetOutputNodeAsInput(const std::string &obj_id) const {
2035 const auto &graph_info = top_cell()->graph_info_map().at(curr_g());
2036 MS_EXCEPTION_IF_NULL(graph_info);
2037 const auto it = graph_info->node_map.find(obj_id);
2038 if (it == graph_info->node_map.end()) {
2039 return nullptr;
2040 }
2041 // Single output CNode
2042 if (it->second.second.size() == 1 && it->second.second[0] == -1) {
2043 MS_LOG(DEBUG) << "Get input node " << it->second.first->ToString() << ", id " << obj_id;
2044 return it->second.first;
2045 }
2046 // Create tuple get item node for multiple output CNode
2047 return CreateTupleGetItemNode(obj_id, it->second);
2048 }
2049
GetValueSequenceInput(const ValuePtr & v) const2050 AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v) const {
2051 MS_EXCEPTION_IF_NULL(v);
2052 if (!v->isa<ValueSequence>()) {
2053 return nullptr;
2054 }
2055 ValuePtrList input_args;
2056 abstract::AbstractBasePtrList abs_list;
2057 AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
2058 const auto &obj_tuple = v->cast<ValueSequencePtr>();
2059 const auto &v_list = obj_tuple->value();
2060 for (size_t i = 0; i < obj_tuple->size(); ++i) {
2061 const auto &v_arg = v_list[i];
2062 // Graph have no define for grad
2063 if (v_arg->isa<FuncGraph>()) {
2064 continue;
2065 }
2066 (void)input_args.emplace_back(v_arg);
2067 const std::string &id = PyNativeAlgo::Common::GetIdByValue(v_arg);
2068 (void)inputs.emplace_back(GetInput(v_arg, id));
2069 (void)abs_list.emplace_back(PyNativeAlgo::Common::SetAbstractValueToAnyValue(v_arg->ToAbstract()));
2070 (void)GetValueSequenceInput(v_arg);
2071 }
2072 // Create make tuple node and record to graph info map.
2073 auto cnode = curr_g()->NewCNode(inputs);
2074 cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
2075 MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
2076 return cnode;
2077 }
2078
CreateTupleGetItemNode(const std::string & obj_id,const std::pair<AnfNodePtr,std::vector<int64_t>> & out) const2079 AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
2080 const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const {
2081 AnfNodePtr c_node = out.first->cast<CNodePtr>();
2082 bool param_is_sequence = false;
2083 if (c_node == nullptr) {
2084 // Input param is tuple or list
2085 if (GetParamInput(MakeValue(true), obj_id) != nullptr) {
2086 MS_LOG(EXCEPTION) << "Get wrong input node " << out.first->DebugString();
2087 }
2088 param_is_sequence = true;
2089 c_node = out.first;
2090 }
2091 MS_LOG(DEBUG) << "Sequence input node " << c_node->DebugString() << ", id " << obj_id << ", out second "
2092 << out.second;
2093 // Create tuple get item node
2094 auto abs = c_node->abstract();
2095 MS_EXCEPTION_IF_NULL(abs);
2096 for (auto idx : out.second) {
2097 AnfNodePtrList tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), c_node, NewValueNode(idx)};
2098 c_node = curr_g()->NewCNode(tuple_get_item_inputs);
2099 if (!abs->isa<abstract::AbstractSequence>()) {
2100 MS_LOG(EXCEPTION) << "Input node abs is not sequence " << abs->ToString();
2101 }
2102 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2103 if (static_cast<size_t>(idx) >= abs_seq->size()) {
2104 MS_LOG(EXCEPTION) << "Index exceeds the size of elements. Index " << idx << ", element size " << abs_seq->size();
2105 }
2106 abs = abs_seq->elements()[static_cast<size_t>(idx)];
2107 MS_EXCEPTION_IF_NULL(abs);
2108 c_node->set_abstract(abs);
2109 if (param_is_sequence) {
2110 c_node->set_user_data(kParamterIsSequence, MakeValue(param_is_sequence));
2111 }
2112 }
2113 MS_LOG(DEBUG) << "Create tuple getitem node " << c_node->DebugString() << ", abs " << c_node->abstract()->ToString();
2114 return c_node;
2115 }
2116
GetTopCell(const std::string & already_run_cell_id,const std::string & input_args_id)2117 TopCellInfoPtr GradExecutor::GetTopCell(const std::string &already_run_cell_id, const std::string &input_args_id) {
2118 TopCellInfoPtr find_top_cell = nullptr;
2119 for (const auto &[cell_id, top_cell] : already_run_top_cell_) {
2120 MS_EXCEPTION_IF_NULL(top_cell);
2121 MS_LOG(DEBUG) << "Top cell " << top_cell << " with already run cell id " << cell_id << ", input args id "
2122 << top_cell->input_args_id();
2123 // Complete match, means run grad operation first
2124 if (top_cell->already_run_cell_id() == already_run_cell_id) {
2125 find_top_cell = top_cell;
2126 break;
2127 }
2128 // Partial match, means run forward first without grad_operation in already run cell id
2129 if (already_run_cell_id.find(top_cell->already_run_cell_id()) != std::string::npos &&
2130 top_cell->already_run_cell_id().back() == '_') {
2131 find_top_cell = top_cell;
2132 break;
2133 }
2134 // Partial match, means run grad first, but follow a other net grad
2135 if (top_cell->already_run_cell_id().find(already_run_cell_id) != std::string::npos &&
2136 already_run_cell_id.back() == '_') {
2137 find_top_cell = top_cell;
2138 break;
2139 }
2140 }
2141
2142 // Get pipeline top cell
2143 if (find_top_cell == nullptr) {
2144 MS_LOG(DEBUG) << "Not find in already run top cell map, try find in pipeline top cell map";
2145 find_top_cell = GetPipelineTopCell(already_run_cell_id, input_args_id, already_run_cell_id.back() == '_');
2146 } else if (find_top_cell->is_pipeline_top_cell()) {
2147 // Delete first pipeline top from already run top cell map
2148 (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
2149 if (find_top_cell->input_args_id() != input_args_id) {
2150 MS_LOG(DEBUG) << "Find top cell input args id " << find_top_cell->input_args_id()
2151 << " not match current input args id " << input_args_id << ", try find in pipeline top cell map";
2152 find_top_cell = GetPipelineTopCell(already_run_cell_id, input_args_id, already_run_cell_id.back() == '_');
2153 }
2154 }
2155
2156 // Same topcell info, but grad operation is not the same, construct backward graph again
2157 if (find_top_cell != nullptr) {
2158 if (!find_top_cell->grad_operation().empty() && find_top_cell->grad_operation() != grad_operation_) {
2159 MS_LOG(DEBUG) << "Already exist grad operation " << find_top_cell->grad_operation() << " is different with new "
2160 << grad_operation_;
2161 (void)already_run_top_cell_.erase(find_top_cell->already_run_cell_id());
2162 return nullptr;
2163 }
2164 return find_top_cell;
2165 }
2166 return nullptr;
2167 }
2168
SetHookChanged(const py::object & cell) const2169 void GradExecutor::SetHookChanged(const py::object &cell) const {
2170 if (top_cell_ == nullptr) {
2171 return;
2172 }
2173 const auto &cell_id = PyNativeAlgo::PyParser::GetIdByPyObj(cell);
2174 if (top_cell_->cell_id().find(cell_id) != std::string::npos) {
2175 top_cell_->set_hook_changed(true);
2176 }
2177 if (RequiresGrad()) {
2178 top_cell_->set_sub_cell_hook_changed(cell_id);
2179 }
2180 }
2181
ProcessOpGradInfo(const FrontendOpRunInfoPtr & op_run_info) const2182 void GradExecutor::ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const {
2183 MS_EXCEPTION_IF_NULL(op_run_info);
2184 RecordForwardGraph(op_run_info);
2185 if (top_cell_->is_bprop_need_get_forward_graph()) {
2186 MS_LOG(DEBUG) << "Just need forward graph";
2187 return;
2188 }
2189 DoOpGrad(op_run_info);
2190 top_cell()->GetOpInfo(op_run_info, false);
2191 UpdateTopCellForwardTensorInfoInBpropGraph(op_run_info->op_info, op_run_info->real_out,
2192 op_run_info->base_op_run_info.stream_id);
2193 DynamicDetectNodeInfoPtr node_info;
2194 if (op_run_info->op_grad_info->output_value_simple_info != nullptr) {
2195 node_info = std::make_shared<DynamicDetectNodeInfo>(op_run_info->op_grad_info->op_prim);
2196 } else {
2197 node_info = std::make_shared<DynamicDetectNodeInfo>(
2198 op_run_info->op_grad_info->op_prim, op_run_info->op_grad_info->input_abs, op_run_info->op_grad_info->out_abs);
2199 }
2200 CheckBpropCutNode(top_cell(), op_run_info->op_grad_info->op_prim);
2201 (void)dynamic_shape()->CheckNodeDynamic(top_cell(), op_run_info->op_grad_info->input_value, node_info);
2202 }
2203
SaveOutputNodeMap(const std::string & obj_id,const FrontendOpRunInfoPtr & op_run_info,const CNodePtr & cnode) const2204 void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
2205 const CNodePtr &cnode) const {
2206 MS_EXCEPTION_IF_NULL(cnode);
2207 MS_LOG(DEBUG) << "Cnode is " << cnode->DebugString() << ", out value id " << obj_id;
2208 // In hook compute, output is a copy of input; If hook input is a input param, follow op use hook output as input,
2209 // which GetInput will always find input param, so need delete from input param map
2210 MS_EXCEPTION_IF_NULL(op_run_info);
2211 if (op_run_info->run_in_vm && kHookOp.find(op_run_info->base_op_run_info.op_name) != kHookOp.end()) {
2212 for (size_t i = 0; i < op_run_info->input_size; ++i) {
2213 top_cell()->DeleteParamNodeInfo(curr_g(), op_run_info->input_value_id[i]);
2214 }
2215 }
2216 top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode);
2217 }
2218
2219 // Run ad grad for curr op and connect grad graph with previous op
DoOpGrad(const FrontendOpRunInfoPtr & op_run_info) const2220 void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info) const {
2221 MS_EXCEPTION_IF_NULL(op_run_info);
2222 auto &&grad_param = CreateOpGradParam(op_run_info, top_cell());
2223 if (forward()->enable_async()) {
2224 auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
2225 auto task = [auto_grad_cell_ptr, grad_param]() { (void)auto_grad_cell_ptr->KPynativeOp(grad_param); };
2226 DispatchGradQueueTask(std::move(task));
2227 } else {
2228 (void)top_cell()->auto_grad_cell_ptr()->KPynativeOp(grad_param);
2229 }
2230 }
2231
UpdateTopCellForwardTensorInfoInBpropGraph(const std::string & op_info,const ValuePtr & v,const size_t & stream_id) const2232 void GradExecutor::UpdateTopCellForwardTensorInfoInBpropGraph(const std::string &op_info, const ValuePtr &v,
2233 const size_t &stream_id) const {
2234 auto pre_top_cell = GetAlreadyRunTopCell(top_cell()->already_run_cell_id());
2235 // The shape of the last two steps is the same, and the pre_top_cell is not empty.
2236 // But if a dynamic shape is enabled at this point, you still need to execute SaveTensorIdWithOpInfo.
2237 if (pre_top_cell == nullptr || use_dynamic_shape_process()) {
2238 // First run top cell, save op output info for replacement
2239 pre_top_cell = GetPipelineRunTopCell(top_cell_->already_run_cell_id());
2240 if (pre_top_cell == nullptr) {
2241 top_cell_->SaveTensorIdWithOpInfo(op_info, v);
2242 use_dynamic_shape_process() ? MS_LOG(DEBUG) << "Current top cell is in dynamic process"
2243 : MS_LOG(DEBUG)
2244 << "Top cell " << top_cell_->already_run_cell_id() << " run firstly, op info "
2245 << op_info << ", output id " << PyNativeAlgo::Common::GetIdByValue(v);
2246 return;
2247 }
2248 }
2249
2250 // In dynamic process, no need replaces
2251 if (top_cell_->use_dynamic_shape_process()) {
2252 return;
2253 }
2254
2255 // top cell is pipeline top cell
2256 if (top_cell_->is_pipeline_top_cell()) {
2257 if (pre_top_cell->is_finish_backward()) {
2258 // Not first run top cell, do update; Save op_info -> tensor
2259 if (!pre_top_cell->is_ir_grad()) {
2260 MS_LOG(EXCEPTION) << "Forward repalce top cell must be ir grad";
2261 }
2262 MS_LOG(DEBUG) << "Store pipeline top cell " << top_cell_->already_run_cell_id() << " with input args id "
2263 << top_cell_->input_args_id() << ", op info " << op_info << ", output id "
2264 << PyNativeAlgo::Common::GetIdByValue(v);
2265 StoreForwardOutputWithOpInfo(pre_top_cell->replace_info().op_info_with_tensor_object, op_info, v,
2266 &top_cell_->replace_info());
2267 } else {
2268 // The first ir grad is not run before, indicate this is a first step, top cell will run func grad independently
2269 MS_LOG(DEBUG) << "Current top cell is pipeline top cell and run firstly, op info " << op_info << ", output id "
2270 << PyNativeAlgo::Common::GetIdByValue(v);
2271 }
2272 } else {
2273 // Not first run top cell, do update
2274 MS_LOG(DEBUG) << "Update top cell forward output tensor info " << op_info << ", output id "
2275 << PyNativeAlgo::Common::GetIdByValue(v);
2276 UpdateForwardOutputTensorInfo(op_info, v, pre_top_cell->replace_info());
2277 }
2278 }
2279
GetRealInputNodeBySkipHook(const AnfNodePtr & input_node) const2280 AnfNodePtr GradExecutor::GetRealInputNodeBySkipHook(const AnfNodePtr &input_node) const {
2281 if (input_node == nullptr) {
2282 MS_LOG(DEBUG) << "The input node is nullptr.";
2283 return input_node;
2284 }
2285 const auto &cell_backward_hook_op = top_cell()->cell_backward_hook_op();
2286 for (const auto &elem : cell_backward_hook_op) {
2287 constexpr size_t cell_backward_hook_num = 2;
2288 if (elem.second.size() < cell_backward_hook_num) { // In cell own scope, no need to skip backward hook op.
2289 continue;
2290 }
2291 // The input node is the first backward hook op of another cell, skip the backward hook op.
2292 if (IsPrimitiveCNode(input_node, prim::kPrimCellBackwardHook) && input_node == elem.second[0]) {
2293 // Single input.
2294 auto backward_hook_op = input_node->cast<CNodePtr>();
2295 MS_EXCEPTION_IF_NULL(backward_hook_op);
2296 return backward_hook_op->input(1);
2297 } else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
2298 // Multi inputs.
2299 auto tuple_get_item = input_node->cast<CNodePtr>();
2300 MS_EXCEPTION_IF_NULL(tuple_get_item);
2301 auto inp_in_tuple = tuple_get_item->input(1);
2302 MS_EXCEPTION_IF_NULL(inp_in_tuple);
2303 if (IsPrimitiveCNode(inp_in_tuple, prim::kPrimCellBackwardHook) && inp_in_tuple == elem.second[0]) {
2304 constexpr size_t idx = 2;
2305 auto idx_node = tuple_get_item->input(idx);
2306 MS_EXCEPTION_IF_NULL(idx_node);
2307 auto value_node = idx_node->cast<ValueNodePtr>();
2308 MS_EXCEPTION_IF_NULL(value_node);
2309 auto out_idx = GetValue<int64_t>(value_node->value());
2310 auto backward_hook_op = inp_in_tuple->cast<CNodePtr>();
2311 MS_EXCEPTION_IF_NULL(backward_hook_op);
2312 return backward_hook_op->input(1 + LongToSize(out_idx));
2313 }
2314 }
2315 }
2316 return input_node;
2317 }
2318
ConstructForwardGraph(const FrontendOpRunInfoPtr & op_run_info) const2319 CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const {
2320 MS_EXCEPTION_IF_NULL(op_run_info);
2321 AnfNodePtrList inputs;
2322 (void)inputs.emplace_back(NewValueNode(op_run_info->op_grad_info->op_prim));
2323 for (size_t i = 0; i < op_run_info->input_size; i++) {
2324 AnfNodePtr input_node = nullptr;
2325 const auto node = GetInput(op_run_info->op_grad_info->input_value[i], op_run_info->input_value_id[i]);
2326 input_node = GetRealInputNodeBySkipHook(node);
2327 // update abstract
2328 if (input_node != nullptr) {
2329 (void)inputs.emplace_back(input_node);
2330 }
2331 }
2332 const auto &cnode = curr_g()->NewCNodeInOrder(inputs);
2333 if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
2334 top_cell()->RecordCellBackwardHookOp(op_run_info->cell_obj_id, cnode);
2335 }
2336 MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
2337 << cnode->DebugString();
2338 return cnode;
2339 }
2340
RecordForwardGraph(const FrontendOpRunInfoPtr & op_run_info) const2341 void GradExecutor::RecordForwardGraph(const FrontendOpRunInfoPtr &op_run_info) const {
2342 if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
2343 MS_EXCEPTION_IF_NULL(op_run_info);
2344 if (op_run_info->input_value_id.empty()) {
2345 (void)std::transform(op_run_info->op_grad_info->input_value.begin(), op_run_info->op_grad_info->input_value.end(),
2346 std::back_inserter(op_run_info->input_value_id),
2347 [](const ValuePtr &value) { return PyNativeAlgo::Common::GetIdByValue(value); });
2348 }
2349 if (op_run_info->out_value_id.empty()) {
2350 op_run_info->out_value_id = PyNativeAlgo::Common::GetIdByValue(op_run_info->real_out);
2351 }
2352 const auto &cnode = ConstructForwardGraph(op_run_info);
2353 MS_EXCEPTION_IF_NULL(cnode);
2354 // By simple infer, abstract is nullptr
2355 if (op_run_info->base_op_run_info.abstract == nullptr) {
2356 cnode->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(op_run_info->real_out->ToAbstract()));
2357 } else {
2358 cnode->set_abstract(op_run_info->base_op_run_info.abstract);
2359 }
2360 SaveOutputNodeMap(op_run_info->out_value_id, op_run_info, cnode);
2361 }
2362 }
2363
RecordForwardGraphForInput(const ValuePtr & value,const string & input_id,const abstract::AbstractBasePtr & param_abs)2364 void GradExecutor::RecordForwardGraphForInput(const ValuePtr &value, const string &input_id,
2365 const abstract::AbstractBasePtr ¶m_abs) {
2366 save_graphs_ = MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
2367 if (save_graphs_ || top_cell_->is_bprop_need_get_forward_graph()) {
2368 auto new_param = curr_g()->add_parameter();
2369 new_param->set_abstract(param_abs);
2370 if (value->isa<ValueSequence>()) {
2371 top_cell()->SetNodeMapInGraphInfoMap(input_id, new_param, true);
2372 }
2373 top_cell()->SetParamNodeMapInGraphInfoMap(input_id, new_param);
2374 }
2375 }
2376
RecordNestedGraph(const FuncGraphPtr & first_grad_fg,const GraphInfoPtr & inner_graph_info,const std::vector<ValuePtr> & forward_args,const ValuePtr & out)2377 void GradExecutor::RecordNestedGraph(const FuncGraphPtr &first_grad_fg, const GraphInfoPtr &inner_graph_info,
2378 const std::vector<ValuePtr> &forward_args, const ValuePtr &out) {
2379 if (save_graphs_) {
2380 AnfNodePtrList inputs{NewValueNode(first_grad_fg)};
2381 DoParameterReplace(first_grad_fg, inner_graph_info, forward_args, &inputs);
2382 auto cnode = curr_g()->NewCNode(inputs);
2383 auto out_id = PyNativeAlgo::Common::GetIdByValue(out);
2384 top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
2385 cnode->set_abstract(first_grad_fg->output()->abstract());
2386 MS_LOG(DEBUG) << "Nested make cnode is: " << cnode->DebugString() << ", out id " << out_id;
2387 }
2388 }
2389
SetBpropGraphJitLevel(const py::object & obj) const2390 void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
2391 if (!py::hasattr(obj, kAttrCellJitConfigDict)) {
2392 return;
2393 }
2394
2395 auto jit_config = py::getattr(obj, kAttrCellJitConfigDict);
2396 if (!py::isinstance<py::dict>(jit_config)) {
2397 MS_LOG(EXCEPTION) << "JitConfig only support dict!";
2398 }
2399 auto jit_config_dict = jit_config.cast<py::dict>();
2400 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
2401 MS_EXCEPTION_IF_NULL(graph_executor);
2402 graph_executor->SetJitConfig(jit_config_dict);
2403 }
2404
SaveDynamicInputsCells(const py::object & obj,const py::args & args)2405 void GradExecutor::SaveDynamicInputsCells(const py::object &obj, const py::args &args) {
2406 const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
2407 MS_LOG(INFO) << "SaveDynamicInputsCells: "
2408 << (py::isinstance<Cell>(obj) ? obj_id + " " + obj.cast<CellPtr>()->ToString()
2409 : py::getattr(obj, "__name__").cast<std::string>());
2410 (void)dynamic_inputs_cells_.insert(obj_id);
2411 }
2412
SetTopCellDynamicAttr(const py::object & cell)2413 void GradExecutor::SetTopCellDynamicAttr(const py::object &cell) {
2414 if (top_cell_ == nullptr) {
2415 return;
2416 }
2417
2418 if (top_cell()->use_dynamic_shape_process()) {
2419 // Top cell is already dynamic, no need to set again.
2420 return;
2421 }
2422 top_cell()->set_use_dynamic_shape_process(dynamic_inputs_cells_.count(PyNativeAlgo::PyParser::GetIdByPyObj(cell)));
2423 }
2424
DispatchGradQueueTask(std::function<void (void)> && task) const2425 void GradExecutor::DispatchGradQueueTask(std::function<void(void)> &&task) const {
2426 if (!bprop_queue_->Push(new (std::nothrow) BpropTask(std::move(task)))) {
2427 bprop_queue_->CheckException();
2428 }
2429 }
2430
ClearBpropTask() const2431 void GradExecutor::ClearBpropTask() const {
2432 if (bprop_queue_ != nullptr) {
2433 GilReleaseWithCheck gil_release;
2434 bprop_queue_->Clear();
2435 assist_queue_->Clear();
2436 bprop_queue_->CheckException();
2437 }
2438 }
2439
WaitBpropTask() const2440 void GradExecutor::WaitBpropTask() const {
2441 if (bprop_queue_ != nullptr) {
2442 GilReleaseWithCheck gil_release;
2443 bprop_queue_->Wait();
2444 assist_queue_->Wait();
2445 bprop_queue_->CheckException();
2446 }
2447 }
2448
DispatchAssistQueueTask(std::function<void (void)> task) const2449 void GradExecutor::DispatchAssistQueueTask(std::function<void(void)> task) const {
2450 bool success = assist_queue_->Push(new (std::nothrow) BpropTask(std::move(task)));
2451 if (!success) {
2452 assist_queue_->CheckException();
2453 }
2454 }
2455
ChildAfterFork()2456 void GradExecutor::ChildAfterFork() {
2457 MS_LOG(DEBUG) << "GradExecutor reinitialize after fork.";
2458 if (bprop_queue_ != nullptr) {
2459 MS_LOG(DEBUG) << "Reinitialize bprop_queue_.";
2460 bprop_queue_->ChildAfterFork();
2461 }
2462 if (assist_queue_ != nullptr) {
2463 MS_LOG(DEBUG) << "Reinitialize assist_queue_.";
2464 assist_queue_->ChildAfterFork();
2465 }
2466 runtime::PyBoostOpExecute::GetInstance().ClearBackend();
2467 MS_LOG(DEBUG) << "GradExecutor reinitialize after fork done.";
2468 }
2469 } // namespace pynative
2470 } // namespace mindspore
2471