• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/function/func_grad.h"
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "pybind_api/gil_scoped_long_running.h"
23 #include "include/common/utils/primitive_utils.h"
24 #include "include/common/utils/hook.h"
25 #include "pipeline/pynative/pynative_utils.h"
26 #include "ops/framework_ops.h"
27 #include "ops/other_ops.h"
28 
29 namespace mindspore::pynative::autograd {
30 namespace {
Add(const ValuePtr & input,const ValuePtr & other,const FuncBuilderPtr & func_impl)31 ValuePtr Add(const ValuePtr &input, const ValuePtr &other, const FuncBuilderPtr &func_impl) {
32   if (input->isa<None>()) {
33     MS_EXCEPTION_IF_NULL(other);
34     return other;
35   }
36   if (other->isa<None>()) {
37     MS_EXCEPTION_IF_NULL(input);
38     return input;
39   }
40   auto result = func_impl->Add(input, other);
41   MS_EXCEPTION_IF_NULL(result);
42   return result;
43 }
44 
Add(const ValuePtr & other,size_t input_index,const FuncBuilderPtr & func_impl,std::vector<ValuePtr> * inputs)45 void Add(const ValuePtr &other, size_t input_index, const FuncBuilderPtr &func_impl, std::vector<ValuePtr> *inputs) {
46   if (input_index >= inputs->size()) {
47     MS_LOG(EXCEPTION) << "The input index should less than inputs size";
48   }
49 
50   (*inputs)[input_index] = Add(inputs->at(input_index), other, func_impl);
51 }
52 
PaddingGradientInput(const ValuePtr & grad,size_t output_size,size_t input_index)53 ValuePtrList PaddingGradientInput(const ValuePtr &grad, size_t output_size, size_t input_index) {
54   ValuePtrList gradients;
55   gradients.reserve(output_size);
56   for (size_t i = 0; i < output_size; ++i) {
57     if (input_index == i) {
58       (void)gradients.emplace_back(grad);
59     } else {
60       // If gradient is not, we just set kNone, then we lazy update zero gradient by
61       // LazeUpdateZeroGradient method
62       (void)gradients.emplace_back(kNone);
63     }
64   }
65   return gradients;
66 }
67 
GeneratePythonArgs(const ValuePtrList & inputs,const ValuePtr & output,bool is_need_recompute)68 VectorRef GeneratePythonArgs(const ValuePtrList &inputs, const ValuePtr &output, bool is_need_recompute) {
69   VectorRef args;
70   for (const auto &value : inputs) {
71     (void)args.emplace_back(value);
72   }
73   // If we not need recompute, we save output.
74   if (!is_need_recompute) {
75     (void)args.emplace_back(output);
76   }
77   return args;
78 }
79 
ValueListToValue(const ValuePtrList & values,const abstract::AbstractBasePtr & abs)80 ValuePtr ValueListToValue(const ValuePtrList &values, const abstract::AbstractBasePtr &abs) {
81   if (values.size() == kSizeZero) {
82     MS_LOG(EXCEPTION) << "tensors size should not be empty!";
83   }
84   if (values.size() == kSizeOne && !abs->isa<abstract::AbstractSequence>()) {
85     return values[kIndex0];
86   }
87   return std::make_shared<ValueTuple>(values);
88 }
89 
IsOutputBothEmpty(const ValuePtr & input_grads,const ValuePtr & weight_grads)90 bool IsOutputBothEmpty(const ValuePtr &input_grads, const ValuePtr &weight_grads) {
91   if (!input_grads->isa<ValueTuple>() || !weight_grads->isa<ValueTuple>()) {
92     return false;
93   }
94   auto input_grads_tuple = input_grads->cast<ValueTuplePtr>();
95   auto weight_grads_tuple = weight_grads->cast<ValueTuplePtr>();
96   return input_grads_tuple->size() == 0 && weight_grads_tuple->size() == 0;
97 }
98 
GenerateEmptyTupleValue()99 ValuePtr GenerateEmptyTupleValue() {
100   std::vector<ValuePtr> value_list;
101   auto inputs_value = std::make_shared<ValueTuple>(value_list);
102   auto weights_value = std::make_shared<ValueTuple>(value_list);
103   std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
104   return std::make_shared<ValueTuple>(tuple_list);
105 }
106 
SetFlattenTensorGradMetaData(const ValuePtrList & flatten_outs,const VariablePtr & variable)107 void SetFlattenTensorGradMetaData(const ValuePtrList &flatten_outs, const VariablePtr &variable) {
108   for (size_t i = 0; i < flatten_outs.size(); ++i) {
109     if (flatten_outs[i]->isa<tensor::BaseTensor>()) {
110       auto tensor = flatten_outs[i]->cast<tensor::BaseTensorPtr>();
111       auto auto_grad_meta_data = tensor->auto_grad_meta_data();
112       if (auto_grad_meta_data == nullptr) {
113         MS_LOG(DEBUG) << "tensor has no auto_grad_meta_data";
114         auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
115         tensor->set_auto_grad_meta_data(auto_grad_meta_data);
116       }
117       auto_grad_meta_data->set_variable(variable);
118       auto_grad_meta_data->set_output_index(i);
119     }
120   }
121 }
122 
IsValidTensorInput(const ValuePtr & v)123 bool IsValidTensorInput(const ValuePtr &v) {
124   MS_EXCEPTION_IF_NULL(v);
125   return v->isa<tensor::BaseTensor>() || v->isa<tensor::MetaSparseTensor>();
126 }
127 
IsNeedComputeGrad(const ValuePtr & input)128 bool IsNeedComputeGrad(const ValuePtr &input) {
129   MS_EXCEPTION_IF_NULL(input);
130   if (input->isa<tensor::BaseTensor>()) {
131     auto input_tensor = input->cast<tensor::BaseTensorPtr>();
132     auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
133     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
134     auto variable = auto_grad_meta_data->variable();
135     if (variable != nullptr && variable->is_need_grad()) {
136       return true;
137     }
138   } else if (input->isa<ValueSequence>()) {
139     auto seq = input->cast<ValueSequencePtr>();
140     if (!seq->value().empty() && !seq->value().front()->isa<tensor::BaseTensor>()) {
141       return false;
142     }
143     return std::any_of(seq->value().begin(), seq->value().end(),
144                        [](const ValuePtr &val) { return IsNeedComputeGrad(val); });
145   }
146   return false;
147 }
148 
CallBackwardHooks(const ValuePtr & value,ValuePtrList * grad_in)149 ValuePtrList CallBackwardHooks(const ValuePtr &value, ValuePtrList *grad_in) {
150   if (value == nullptr) {
151     MS_LOG(DEBUG) << "Get null value";
152     return *grad_in;
153   }
154   MS_EXCEPTION_IF_NULL(grad_in);
155   auto tensor = value->cast<tensor::BaseTensorPtr>();
156   if (tensor == nullptr) {
157     MS_LOG(DEBUG) << "Hook just work on tensor, not support value " << value->ToString();
158     return *grad_in;
159   }
160   auto auto_grad_meta = tensor->auto_grad_meta_data();
161   MS_EXCEPTION_IF_NULL(auto_grad_meta);
162   if (auto_grad_meta->backward_hooks().empty()) {
163     MS_LOG(DEBUG) << "Get empty backward hooks for tensor id " << tensor->id();
164     return *grad_in;
165   }
166   if (grad_in->size() != kSizeOne) {
167     MS_LOG(EXCEPTION) << "Tensor hook just work on one tensor value, not support value sequence";
168   }
169   runtime::OpExecutor::GetInstance().WaitAll();
170   for (const auto &hook : auto_grad_meta->backward_hooks()) {
171     MS_LOG(DEBUG) << "Run hook id " << hook.first;
172     MS_EXCEPTION_IF_NULL(hook.second);
173     (*grad_in)[kIndex0] = (*(hook.second))(grad_in->front());
174   }
175   runtime::OpExecutor::GetInstance().WaitAll();
176   MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(*grad_in, "After hook print gradient in: ");
177   auto_grad_meta->ClearBackwardHooks();
178   return *grad_in;
179 }
180 }  // namespace
181 
CallBackward(const ValuePtrList & gradients_in)182 ValuePtrList FuncBackwardNode::CallBackward(const ValuePtrList &gradients_in) {
183   MS_LOG(DEBUG) << "Begin CallBackward: " << name();
184   const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
185   auto ir_builder = FuncBuilder(name_, device_target, nullptr);
186   auto inputs = PreProcess(gradients_in, &ir_builder);
187   ir_builder.SetInputs(name(), &inputs, &attrs_);
188   const std::vector<NodePtr> cal_grads_node = grad_func()(&ir_builder);
189   ValuePtrList cal_grads_values;
190   cal_grads_values.reserve(cal_grads_node.size());
191   // Binary op grad result may be nulllptr, we need convert to kNone.
192   (void)std::transform(cal_grads_node.begin(), cal_grads_node.end(), std::back_inserter(cal_grads_values),
193                        [](const NodePtr &node) -> ValuePtr {
194                          if (node == nullptr) {
195                            return kNone;
196                          }
197                          return node->Value();
198                        });
199   auto gradients = PostProcess(cal_grads_values);
200   MS_LOG(DEBUG) << "End CallBackward: " << name();
201   return gradients;
202 }
203 
PreProcess(const ValuePtrList & dout,FuncBuilder * emitter)204 NodePtrList FuncBackwardNode::PreProcess(const ValuePtrList &dout, FuncBuilder *emitter) {
205   NodePtrList node_inputs;
206   node_inputs.reserve(op_inputs_.size() + kSizeFive);
207   for (size_t i = 0; i < op_inputs_.size(); ++i) {
208     auto func_node = emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i]);
209     func_node->set_need_compute_grad_out(IsNeedComputeGrad(op_inputs_[i]));
210     (void)node_inputs.emplace_back(func_node);
211   }
212   (void)node_inputs.emplace_back(emitter->NewFuncNode(op_output_, out_abstract_, InputType::kOpOutput));
213   if (dout.size() == kSizeOne && !op_output_->isa<ValueSequence>()) {
214     (void)node_inputs.emplace_back(emitter->NewFuncNode(dout[kIndex0], out_abstract_, InputType::kOpOutput));
215   } else {
216     (void)node_inputs.emplace_back(
217       emitter->NewFuncNode(std::make_shared<ValueTuple>(dout), out_abstract_, InputType::kOpOutput));
218   }
219   return node_inputs;
220 }
221 
Release()222 void FuncBackwardNode::Release() {
223   op_inputs_.clear();
224   op_output_ = nullptr;
225 }
226 
CallBackward(const ValuePtrList & grads)227 ValuePtrList HookBackwardNode::CallBackward(const ValuePtrList &grads) {
228   runtime::OpExecutor::GetInstance().WaitAll();
229   MS_LOG(DEBUG) << "Begin HookBackwardNode CallBackward ";
230   auto gradient = ValueListToValue(grads, out_abstract_);
231   const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
232   // Python grad func can not process None, we need to convert None to zero tensor.
233   auto func_builder = FuncBuilder(name_, device_target, nullptr);
234   auto filled_zeros_grad = func_builder.FillZeros(gradient, out_abstract_);
235   (void)args_.emplace_back(filled_zeros_grad);
236   py::gil_scoped_acquire gil_acquire;
237   auto out = prim_->RunHookFunction(args_);
238   ValuePtrList gradient_values;
239   if (utils::isa<PyObjectRef>(out)) {
240     PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
241     auto out_py_tuple = py_ref.object_;
242     ConvertPyObjectToTensor(out_py_tuple, &gradient_values);
243   }
244   if (gradient_values.empty()) {
245     MS_LOG(EXCEPTION) << "Hook fn output is not <PyObjectRef> type!";
246   }
247   auto gradient_tensors = PostProcess(gradient_values);
248   MS_LOG(DEBUG) << "End HookBackwardNode CallBackward";
249   runtime::OpExecutor::GetInstance().WaitAll();
250   return gradient_tensors;
251 }
252 
Release()253 void HookBackwardNode::Release() { args_.clear(); }
254 
CallBackward(const ValuePtrList & grads)255 ValuePtrList GraphBackwardNode::CallBackward(const ValuePtrList &grads) {
256   MS_LOG(DEBUG) << "Begin GraphBackwardNode CallBackward ";
257   MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(grads, "bprop cut input grads: ");
258   auto graph_call_back = PyNativeAlgo::AutoGrad::CreateGraphCallBack(func_graph_, cache_key_, graph_call_condition_);
259   // Add graph din
260   const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
261   auto ir_builder = FuncBuilder(name_, device_target, nullptr);
262   auto real_dout = LazeUpdateZeroGradient(grads, &ir_builder, op_output_);
263 
264   // If output is jit and has dict output. Key and value will converte into tuples for inputs
265   if (!graph_call_condition_.jit_out_has_dict_) {
266     for (const auto &arg : real_dout) {
267       (void)args_.emplace_back(arg);
268     }
269   } else {
270     if (!op_output_->isa<ValueDictionary>()) {
271       MS_LOG(EXCEPTION) << "Get wrong data type " << op_output_->ToString();
272     }
273     const auto &v_dict = op_output_->cast<ValueDictionaryPtr>();
274     ValuePtrList key_inputs;
275     for (const auto &elem : v_dict->value()) {
276       (void)key_inputs.emplace_back(elem.first);
277     }
278     (void)args_.emplace_back(std::make_shared<ValueTuple>(key_inputs));
279     (void)args_.emplace_back(std::make_shared<ValueTuple>(real_dout));
280   }
281   auto gradient_vec_ref = graph_call_back(args_);
282   auto gradient_values = common::AnfAlgo::TransformVectorRefToMultiValue(gradient_vec_ref);
283   auto gradient_tensors = PostProcess(gradient_values);
284   MS_LOG(DEBUG) << "End GraphBackwardNode CallBackward";
285   return gradient_tensors;
286 }
287 
BuildFlattenSensGradient(const ValuePtrList & sens_gradient) const288 ValuePtrList GraphRoot::BuildFlattenSensGradient(const ValuePtrList &sens_gradient) const {
289   ValuePtrList real_gradients;
290   for (const auto &index : gradient_index_) {
291     if (index >= sens_gradient.size()) {
292       MS_LOG(EXCEPTION) << "Inputs gradient index should smaller than flatten_values size!";
293     }
294     (void)real_gradients.emplace_back(sens_gradient[index]);
295   }
296   return real_gradients;
297 }
298 
FuncGrad(const ValuePtrList & input_param_values,size_t op_num_in_bprop_graph,bool grad_by_value,bool is_run_recompute)299 FuncGrad::FuncGrad(const ValuePtrList &input_param_values, size_t op_num_in_bprop_graph, bool grad_by_value,
300                    bool is_run_recompute) {
301   MS_LOG(DEBUG) << "Start FuncGrad, input size: " << input_param_values.size();
302   for (size_t i = 0; i < input_param_values.size(); ++i) {
303     const auto &input_param_value = input_param_values[i];
304     auto func_node = std::make_shared<BackwardNode>("input" + std::to_string(i));
305     auto variable = std::make_shared<FuncVariable>(func_node, true);
306 
307     if (!input_param_value->isa<ValueSequence>()) {
308       // For hook input
309       func_node->set_op_output(input_param_value);
310       PyNativeAlgo::AutoGrad::SetGradInfoForInputs(input_param_value, variable);
311     } else {
312       variable->set_is_need_grad(false);
313     }
314     (void)variable_set_.insert(variable);
315     (void)cell_inputs_.emplace_back(input_param_value, variable);
316   }
317   is_run_recompute_ = is_run_recompute;
318   device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
319   func_impl_ = std::make_shared<FuncBuilder>("func_emitter", device_target_);
320 }
321 
KPynativeOp(const GradParamPtr & grad_param)322 bool FuncGrad::KPynativeOp(const GradParamPtr &grad_param) {
323   MS_EXCEPTION_IF_NULL(grad_param);
324 
325   auto &prim = grad_param->op_grad_info->op_prim;
326   if (!PyNativeAlgo::AutoGrad::IsPrimNeedGrad(prim) ||
327       (grad_by_value_ && !PyNativeAlgo::AutoGrad::NeedGrad(grad_param->op_grad_info->input_value))) {
328     MS_LOG(DEBUG) << "Prim " << prim->name() << " does not need to do op grad.";
329     return true;
330   }
331   auto flatten_inputs = PyNativeAlgo::DataConvert::FlattenTensorSeqInValueSeq(grad_param->op_grad_info->input_value);
332   ConstructParameterNodes(flatten_inputs);
333   BackwardNodePtr fn = nullptr;
334   bool is_custom_prim =
335     IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
336   if (!is_custom_prim) {
337     auto handle = expander::bprop::BpropIRBuilderFactory::Instance().GetBuilder(prim->name());
338     if (handle != nullptr) {
339       fn = BuildFuncBackwardNode(prim, handle->func, flatten_inputs, grad_param->op_grad_info);
340     } else {
341       fn = BuildCustomBackwardNode(prim, flatten_inputs, grad_param->op_grad_info);
342     }
343   } else {
344     PyNativeAlgo::AutoGrad::CheckRecomputeInputs(grad_param);
345     fn = BuildHookBackwardNode(prim, flatten_inputs, grad_param->op_grad_info);
346   }
347   auto variable = std::make_shared<FuncVariable>(fn, false);
348   if (isa<FakeBackwardNode>(fn)) {
349     variable->set_is_fake_bprop(true);
350     variable->set_fake_prim_name(prim->name());
351   }
352 
353   (void)variable_set_.insert(variable);
354   SetFlattenTensorGradMetaData(PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(grad_param->op_grad_info->out_value),
355                                variable);
356   MS_LOG(DEBUG) << "End update next edge for " << variable->ToString();
357   return true;
358 }
359 
UpdateOutputNodeOfTopCell(const ValuePtr & sens_out)360 void FuncGrad::UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) {
361   MS_LOG(DEBUG) << "Real output of top cell is " << PyNativeAlgo::Common::GetIdByValue(sens_out);
362   sens_value_ = sens_out;
363   auto flatten_sens = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_out);
364   ConstructParameterNodes(flatten_sens);
365 }
366 
BuildForwardLastNode(const ValuePtr & sens_gradient)367 void FuncGrad::BuildForwardLastNode(const ValuePtr &sens_gradient) {
368   ValuePtrList root_gradient_value;
369   if (sens_gradient == nullptr) {
370     root_gradient_value = OnsLike(sens_value_);
371   } else {
372     root_gradient_value = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_gradient);
373   }
374   auto root = std::make_shared<GraphRoot>("GraphRoot");
375   auto flatten_args = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_value_);
376   root->UpdateNextEdges(flatten_args);
377   root_gradients_ = root->BuildFlattenSensGradient(root_gradient_value);
378   auto sens_variable = std::make_shared<FuncVariable>(root, false);
379   if (root_gradients_.empty()) {
380     sens_variable->set_is_need_grad(false);
381   }
382   (void)variable_set_.insert(sens_variable);
383   last_variable_ = sens_variable;
384 }
385 
KPynativeWithFProp(const GradParamPtr & grad_param)386 bool FuncGrad::KPynativeWithFProp(const GradParamPtr &grad_param) {
387   MS_EXCEPTION_IF_NULL(grad_param);
388   MS_LOG(DEBUG) << "Do KPynativeWithFProp";
389   if (!grad_by_value_) {
390     MS_LOG(EXCEPTION) << "High grad not support pyboost call";
391   }
392   auto fn = BuildGraphBackwardNode(grad_param);
393   auto variable = std::make_shared<FuncVariable>(fn, false);
394   (void)variable_set_.insert(variable);
395   SetFlattenTensorGradMetaData(PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(grad_param->op_grad_info->out_value),
396                                variable);
397   return true;
398 }
399 
BuildGraphBackwardNode(const GradParamPtr & grad_param)400 BackwardNodePtr FuncGrad::BuildGraphBackwardNode(const GradParamPtr &grad_param) {
401   MS_EXCEPTION_IF_NULL(grad_param);
402   if (ir_bprop_ == nullptr) {
403     ir_bprop_ = std::make_unique<IrBprop>(std::make_shared<AdParam>(), device_target_, grad_by_value_);
404   }
405   grad_param->is_func_grad = true;
406   auto [cache_hit, bprop_graph] = ir_bprop_->GetBpropGraph(grad_param);
407   bool is_jit_dynamic_shape = grad_param->is_jit_graph && grad_param->use_dynamic_shape_process;
408   // Save replace info in first time
409   if (!cache_hit && is_jit_dynamic_shape && grad_param->has_added_v) {
410     const auto &jit = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->jit();
411     jit->SaveForwardOutputTensorInfoInBpropGraph(bprop_graph);
412   }
413   VectorRef input_args;
414   (void)std::transform(grad_param->op_grad_info->input_value.begin(), grad_param->op_grad_info->input_value.end(),
415                        std::back_inserter(input_args), [](const ValuePtr &v) { return v; });
416   PyNativeAlgo::Common::DumpGraphIR("call_graph.ir", bprop_graph);
417   auto fn = std::make_shared<GraphBackwardNode>(
418     bprop_graph->ToString(), bprop_graph, input_args, grad_param->op_grad_info->out_value,
419     grad_param->op_grad_info->output_size, grad_param->graph_cache_key, grad_param->is_control_flow,
420     grad_param->is_jit_graph, grad_param->use_dynamic_shape_process, grad_param->jit_out_has_dict);
421   auto flatten_inputs = PyNativeAlgo::DataConvert::FlattenTensorSeqInValueSeq(grad_param->op_grad_info->input_value);
422   ConstructParameterNodes(flatten_inputs);
423   fn->UpdateNextEdges(flatten_inputs);
424   return fn;
425 }
426 
BackPropagate()427 void FuncGrad::BackPropagate() {
428   MS_LOG(DEBUG) << "Begin BackPropagate";
429   const auto &last_node_reverse_iter = GetLastNodeReverseIter();
430   const auto &root_fn = (*last_node_reverse_iter)->func_node();
431   mindspore::HashMap<BackwardNode *, ValuePtrList> input_buffer;
432   (void)input_buffer.insert({root_fn.get(), root_gradients_});
433   MS_LOG(DEBUG) << "Is running recompute grad " << is_run_recompute_;
434   for (auto iter = last_node_reverse_iter; iter != variable_set_.rend(); ++iter) {
435     const auto &variable = *iter;
436     const auto &fn = variable->func_node();
437     MS_LOG(DEBUG) << "Begin calculate op: " << fn->name() << " gradients!";
438     if (!variable->is_need_propagate() || !variable->is_need_grad()) {
439       MS_LOG(DEBUG) << "No need grad, variable is: " << variable->ToString();
440       continue;
441     }
442     if (static_cast<bool>(MS_UNLIKELY(variable->is_fake_bprop()))) {
443       MS_LOG(EXCEPTION) << "Illegal primitive " << variable->fake_prim_name() << "'s bprop not defined";
444     }
445     if (input_buffer.find(fn.get()) == input_buffer.end()) {
446       MS_LOG(EXCEPTION) << "Fn not has gradient";
447     }
448     auto &gradient_in = input_buffer[fn.get()];
449     MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(gradient_in, "Begin print gradient in: ");
450     // If register hook by weight, and weight in recompute cell.So, hook will execute, which is not expect.
451     if (!is_run_recompute_) {
452       gradient_in = CallBackwardHooks(fn->op_output(), &gradient_in);
453     }
454     auto gradient_out = fn->CallBackward(gradient_in);
455     MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(gradient_out, "Begin print gradient out: ");
456     if (gradient_out.size() != fn->next_edges().size()) {
457       MS_LOG(EXCEPTION) << "Fn gradient size should be same as next edges size";
458     }
459     for (size_t i = 0; i < fn->next_edges().size(); ++i) {
460       const auto &next_edge = fn->next_edges()[i];
461       const auto &last_variable = next_edge.variable;
462       // If network not calculate inputs grad, some op will be pruning, we need skip this op.
463       if (!last_variable->is_need_grad()) {
464         MS_LOG(DEBUG) << "variable is not need grad, " << last_variable->ToString();
465         continue;
466       }
467       const auto &last_fn = last_variable->func_node();
468       const auto &last_gradient = gradient_out[i];
469       // If last_gradient is None, It represents that this tensor grad is zeros.
470       if (last_gradient->isa<None>()) {
471         MS_LOG(DEBUG) << last_variable->ToString() << ", its gradient is kNone!";
472         continue;
473       }
474       if (input_buffer.find(last_fn.get()) != input_buffer.end()) {
475         Add(last_gradient, next_edge.input_index, func_impl_, &input_buffer[last_fn.get()]);
476       } else {
477         input_buffer[last_fn.get()] =
478           PaddingGradientInput(last_gradient, last_fn->output_size(), next_edge.input_index);
479       }
480       last_variable->set_is_need_propagate(true);
481     }
482     if (variable->is_leaf()) {
483       MS_LOG(DEBUG) << "Get leaf node " << variable->ToString();
484       auto grads = input_buffer[fn.get()];
485       if (grads.empty() || grads[0]->isa<None>()) {
486         MS_LOG(EXCEPTION) << variable->ToString() << ", " << (grads.empty() ? "grad is empty" : "grad is kNone");
487       }
488       auto grad_tensor = grads[0]->cast<tensor::BaseTensorPtr>();
489       MS_EXCEPTION_IF_NULL(grad_tensor);
490       variable->set_grad(grad_tensor);
491     }
492     (void)input_buffer.erase(fn.get());
493     variable->Release();
494   }
495   MS_LOG(DEBUG) << "End BackPropagate";
496 }
497 
GetLastNodeReverseIter()498 OrderedSet<FuncVariablePtr>::reverse_iterator FuncGrad::GetLastNodeReverseIter() {
499   for (auto iter = variable_set_.rbegin(); iter != variable_set_.rend(); ++iter) {
500     if (*iter == last_variable_) {
501       last_variable_->set_is_need_propagate(true);
502       return iter;
503     }
504   }
505   return variable_set_.rend();
506 }
507 
ConstructParameterNodes(const ValuePtrList & inputs)508 void FuncGrad::ConstructParameterNodes(const ValuePtrList &inputs) {
509   for (const auto &value : inputs) {
510     if (value->isa<tensor::BaseTensor>()) {
511       auto tensor = value->cast<tensor::BaseTensorPtr>();
512       auto auto_grad_meta_data = tensor->auto_grad_meta_data();
513       MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
514       if (auto_grad_meta_data->variable() != nullptr) {
515         continue;
516       }
517       if (auto_grad_meta_data->input_type() == InputType::kParameter &&
518           PyNativeAlgo::Common::IsParamRequiresGrad(tensor)) {
519         auto fn = std::make_shared<BackwardNode>("parameter");
520         fn->set_op_output(value);
521         auto variable = std::make_shared<FuncVariable>(fn, true);
522         auto_grad_meta_data->set_variable(variable);
523         (void)variable_set_.insert(variable);
524         weights_used_in_graph_.emplace_back(tensor);
525       }
526     }
527   }
528 }
529 
BuildFuncBackwardNode(const PrimitivePtr & prim,const expander::bprop::BpropBuilderFunc & func,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)530 BackwardNodePtr FuncGrad::BuildFuncBackwardNode(const PrimitivePtr &prim, const expander::bprop::BpropBuilderFunc &func,
531                                                 const ValuePtrList &flatten_inputs, const OpGradInfoPtr &op_grad_info) {
532   PyNativeAlgo::AutoGrad::CheckAndSetAbstract(op_grad_info);
533   auto fn = std::make_shared<FuncBackwardNode>(
534     prim->name(), func, prim->attrs(), op_grad_info->input_value, op_grad_info->input_abs, op_grad_info->out_value,
535     op_grad_info->output_size, op_grad_info->out_abs, op_grad_info->input_value_grad_type);
536   fn->UpdateNextEdges(flatten_inputs);
537   return fn;
538 }
539 
BuildCustomBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)540 BackwardNodePtr FuncGrad::BuildCustomBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
541                                                   const OpGradInfoPtr &op_grad_info) {
542   MS_EXCEPTION_IF_NULL(prim);
543   MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
544   {
545     py::gil_scoped_acquire gil;
546     auto prim_py = prim->cast<PrimitivePyPtr>();
547     if (prim_py == nullptr) {
548       MS_LOG(DEBUG) << "Prim is not PrimitivePy, can not find python bprop";
549       return BuildFakeBackwardNode(prim, flatten_inputs, op_grad_info);
550     }
551     py::function fn = prim_py->GetBpropFunction();
552     if (py::isinstance<py::none>(fn)) {
553       fn = GetBpropFunction(prim->name());
554     }
555     if (!fn || py::isinstance<py::none>(fn)) {
556       MS_LOG(INFO) << "Can not find bprop function for " << prim->name() << ". fn: " << ConvertPyObjToString(fn);
557       return BuildFakeBackwardNode(prim, flatten_inputs, op_grad_info);
558     }
559     (void)prim_py->AddBackwardHookFn(0, fn);
560     (void)prim_py->AddAttr("custom_op_bprop", MakeValue(true));
561   }
562   return BuildHookBackwardNode(prim, flatten_inputs, op_grad_info);
563 }
564 
BuildHookBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)565 BackwardNodePtr FuncGrad::BuildHookBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
566                                                 const OpGradInfoPtr &op_grad_info) {
567   MS_EXCEPTION_IF_NULL(prim);
568   auto bprop_cut = PyNativeAlgo::AutoGrad::BuildBpropCutPrim(prim, op_grad_info->is_need_recompute);
569   VectorRef args =
570     GeneratePythonArgs(op_grad_info->input_value, op_grad_info->out_value, op_grad_info->is_need_recompute);
571   auto fn = std::make_shared<HookBackwardNode>(prim->name(), bprop_cut, std::move(args), op_grad_info->output_size,
572                                                op_grad_info->out_abs);
573   fn->UpdateNextEdges(flatten_inputs);
574   return fn;
575 }
576 
BuildFakeBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)577 BackwardNodePtr FuncGrad::BuildFakeBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
578                                                 const OpGradInfoPtr &op_grad_info) {
579   MS_EXCEPTION_IF_NULL(prim);
580   auto fn = std::make_shared<FakeBackwardNode>(prim->name(), op_grad_info->output_size);
581   fn->UpdateNextEdges(flatten_inputs);
582   return fn;
583 }
584 
GetGrads(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)585 ValuePtr FuncGrad::GetGrads(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
586                             const GradAttr &grad_attr) {
587   auto inputs_grad = GetInputGrads(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
588   auto weights_grad = GetWeightGrads(grad_attr.grad_weights, weights, grad_attr.weight_param_is_tuple);
589   // Gradients wrt inputs and weights.
590   if (inputs_grad != nullptr && weights_grad != nullptr) {
591     if (IsOutputBothEmpty(inputs_grad, weights_grad)) {
592       return GenerateEmptyTupleValue();
593     }
594     ValuePtrList gradients{inputs_grad, weights_grad};
595     return std::make_shared<ValueTuple>(gradients);
596   }
597   // Gradients wrt inputs.
598   if (inputs_grad != nullptr) {
599     return inputs_grad;
600   }
601   // Gradients wrt weights.
602   if (weights_grad != nullptr) {
603     return weights_grad;
604   }
605   // grad_all_inputs, grad_weights and get_by_position are all false.
606   if (cell_inputs_.empty()) {
607     // If no input nodes, return empty tuple.
608     return std::make_shared<ValueTuple>(ValuePtrList{});
609   }
610 
611   // If there are input nodes, return gradient of first input node.
612   // Tuple, List, scalar will be ignore
613   if (IsValidTensorInput(cell_inputs_[kIndex0].first)) {
614     return PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(
615       cell_inputs_[kIndex0].first, cell_inputs_[kIndex0].second->grad(), func_impl_.get(), SpecialType::kZerosLikeType);
616   }
617   MS_LOG(DEBUG) << "Get first input node is not tensor " << cell_inputs_[0].first->ToString();
618   return std::make_shared<ValueTuple>(ValuePtrList{});
619 }
620 
GetInputGrads(bool grad_all_inputs,bool get_by_position,const std::vector<size_t> & grad_position)621 ValuePtr FuncGrad::GetInputGrads(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position) {
622   std::vector<size_t> grad_pos_list;
623   if (get_by_position) {
624     grad_pos_list = grad_position;
625   } else if (grad_all_inputs) {
626     grad_pos_list.resize(cell_inputs_.size());
627     iota(grad_pos_list.begin(), grad_pos_list.end(), 0);
628   } else {
629     return nullptr;
630   }
631   ValuePtrList input_grads;
632   input_grads.reserve(cell_inputs_.size());
633   if (!cell_inputs_.empty()) {
634     for (size_t index : grad_pos_list) {
635       if (index >= cell_inputs_.size()) {
636         MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size.";
637       }
638       // Tuple, List, scalar will be ignore
639       if (!IsValidTensorInput(cell_inputs_[index].first)) {
640         MS_LOG(DEBUG) << cell_inputs_[index].first->ToString() << "is no tensor";
641         continue;
642       }
643       ValuePtr real_dout = PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(
644         cell_inputs_[index].first, cell_inputs_[index].second->grad(), func_impl_.get(), SpecialType::kZerosLikeType);
645       (void)input_grads.emplace_back(real_dout);
646     }
647     if (get_by_position && input_grads.size() == kSizeOne) {
648       return input_grads[kIndex0];
649     }
650   }
651   return std::make_shared<ValueTuple>(input_grads);
652 }
653 
GetWeightGrads(bool grad_weights,const tensor::BaseTensorPtrList & weights,bool weight_param_is_tuple)654 ValuePtr FuncGrad::GetWeightGrads(bool grad_weights, const tensor::BaseTensorPtrList &weights,
655                                   bool weight_param_is_tuple) {
656   // No need to return gradient of weights.
657   if (!grad_weights) {
658     return nullptr;
659   }
660   if (weight_param_is_tuple) {
661     ValuePtrList weight_grads;
662     weight_grads.reserve(weights.size());
663     for (const auto &weight : weights) {
664       (void)weight_grads.emplace_back(GetWeightGrad(weight));
665     }
666     return std::make_shared<ValueTuple>(weight_grads);
667   }
668   return GetWeightGrad(weights[0]);
669 }
670 
GetWeightGrad(const tensor::BaseTensorPtr & weight)671 ValuePtr FuncGrad::GetWeightGrad(const tensor::BaseTensorPtr &weight) {
672   MS_EXCEPTION_IF_NULL(weight);
673   auto auto_grad_meta_data = weight->auto_grad_meta_data();
674   if (auto_grad_meta_data == nullptr) {
675     return func_impl_->Zeros(weight);
676   }
677   auto variable = auto_grad_meta_data->variable();
678   const auto &func_variable = std::dynamic_pointer_cast<FuncVariable>(variable);
679   if (variable != nullptr && variable->is_need_grad()) {
680     // If weight used in the forward network, but requires_grad is false, return zero like.
681     if (func_variable->grad() == nullptr ||
682         (weight->param_info() != nullptr && !weight->param_info()->requires_grad())) {
683       MS_LOG(INFO) << "weight participate in forward calculation, but requires_grad is false";
684       return func_impl_->Zeros(weight);
685     }
686     auto weight_grad = func_variable->grad();
687     return weight_grad;
688   }
689   MS_LOG(INFO) << "parameter does not need grad, tensor: " << PyNativeAlgo::Common::GetIdByValue(weight);
690   return func_impl_->Zeros(weight);
691 }
692 
ClearGrads(const tensor::BaseTensorPtrList & weights)693 void FuncGrad::ClearGrads(const tensor::BaseTensorPtrList &weights) {
694   // Clear input grads.
695   for (const auto &input : cell_inputs_) {
696     input.second->set_grad(nullptr);
697   }
698   cell_inputs_.clear();
699   // Clear weights grad info
700   for (const auto &weight : weights) {
701     weight->set_auto_grad_meta_data(nullptr);
702   }
703 }
704 
OnsLike(const ValuePtr & sens)705 ValuePtrList FuncGrad::OnsLike(const ValuePtr &sens) {
706   MS_EXCEPTION_IF_NULL(sens);
707   auto flatten_values = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens);
708   const auto &v = PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(std::make_shared<ValueTuple>(flatten_values), nullptr,
709                                                                 func_impl_.get(), SpecialType::kOnesLikeType);
710   auto v_seq = v->cast<ValueTuplePtr>();
711   return v_seq->value();
712 }
713 
CheckSensShapeAndType(const ValuePtr & sens_gradient)714 void FuncGrad::CheckSensShapeAndType(const ValuePtr &sens_gradient) {
715   if (sens_gradient == nullptr) {
716     return;
717   }
718   const auto sens_gradient_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_gradient->ToAbstract());
719   const auto out_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_value_->ToAbstract());
720   const auto &sens_gradient_shape = sens_gradient_abs->BuildShape()->ToString();
721   const auto &out_shape = out_abs->BuildShape()->ToString();
722   if (sens_gradient_shape != "()" && out_shape != "()") {
723     if (sens_gradient_shape != out_shape) {
724       // Sens shape in ir graph is determined by graph output, so it can be dynamic shape; But input shape is
725       // determined by user input, which could not be dynamic shape.
726       if (!sens_gradient_abs->BuildShape()->IsDynamic()) {
727         MS_EXCEPTION(ValueError) << "The shape should be " << out_shape << ", but got " << sens_gradient_shape << ", "
728                                  << ", sens gradient abs " << sens_gradient_abs->ToString() << ", out abs"
729                                  << out_abs->ToString();
730       }
731     }
732     const auto &sens_gradient_dtype = sens_gradient_abs->BuildType()->ToString();
733     const auto &out_dtype = out_abs->BuildType()->ToString();
734     if (sens_gradient_dtype != out_dtype) {
735       MS_EXCEPTION(TypeError) << "The dtype should be " << out_dtype << ", but got " << sens_gradient_dtype << ", "
736                               << ", sens gradient abs " << sens_gradient_abs->ToString() << ", out abs"
737                               << out_abs->ToString();
738     }
739   }
740 }
741 
PruningGradGraph(const tensor::BaseTensorPtrList & weights,const GradAttr & grad_attr,const std::vector<size_t> & grad_position)742 void FuncGrad::PruningGradGraph(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr,
743                                 const std::vector<size_t> &grad_position) {
744   PruningInput(grad_attr, grad_position);
745   PruningWeights(weights, grad_attr);
746 
747   // Pruning all node in grad graph
748   for (const auto &variable : variable_set_) {
749     if (variable->is_leaf()) {
750       continue;
751     }
752     bool is_need_grad =
753       std::any_of(variable->func_node()->next_edges().begin(), variable->func_node()->next_edges().end(),
754                   [](const auto &edge) { return edge.variable->is_need_grad(); });
755     if (!is_need_grad) {
756       variable->set_is_need_grad(false);
757     }
758   }
759 }
760 
PruningInput(const GradAttr & grad_attr,const std::vector<size_t> & grad_position)761 void FuncGrad::PruningInput(const GradAttr &grad_attr, const std::vector<size_t> &grad_position) {
762   mindspore::HashSet<size_t> grad_pos_list{grad_position.begin(), grad_position.end()};
763   // Pruning inputs by position in grad graph
764   if (grad_attr.get_by_position) {
765     for (size_t i = 0; i < cell_inputs_.size(); ++i) {
766       if (grad_pos_list.find(i) == grad_pos_list.end()) {
767         cell_inputs_[i].second->set_is_need_grad(false);
768       }
769     }
770     return;
771   }
772 
773   // Pruning first input in grad graph
774   if (!grad_attr.grad_all_inputs && !grad_attr.get_by_position && !grad_attr.grad_weights) {
775     for (size_t i = 1; i < cell_inputs_.size(); ++i) {
776       cell_inputs_[i].second->set_is_need_grad(false);
777     }
778   }
779 
780   // Pruning all inputs not grad
781   if (!grad_attr.grad_all_inputs && grad_attr.grad_weights) {
782     for (auto &cell_input : cell_inputs_) {
783       cell_input.second->set_is_need_grad(false);
784     }
785   }
786 }
787 
PruningWeights(const tensor::BaseTensorPtrList & weights,const GradAttr & grad_attr)788 void FuncGrad::PruningWeights(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr) {
789   // Pruning weights in grad graph
790   if (grad_attr.grad_weights) {
791     mindspore::HashSet<std::string> grad_weights_id;
792     for (const auto &weight : weights) {
793       (void)grad_weights_id.emplace(weight->id());
794     }
795     for (const auto &weight : weights_used_in_graph_) {
796       if (grad_weights_id.find(weight->id()) == grad_weights_id.end()) {
797         auto variable = weight->auto_grad_meta_data()->variable();
798         MS_EXCEPTION_IF_NULL(variable);
799         variable->set_is_need_grad(false);
800       }
801     }
802   } else {
803     for (const auto &weight : weights_used_in_graph_) {
804       auto variable = weight->auto_grad_meta_data()->variable();
805       MS_EXCEPTION_IF_NULL(variable);
806       variable->set_is_need_grad(false);
807     }
808   }
809 }
810 
Finish(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr,const ValuePtr & sens)811 ValuePtr FuncGrad::Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
812                           const GradAttr &grad_attr, const ValuePtr &sens) {
813   CheckSensShapeAndType(sens);
814   BuildForwardLastNode(sens);
815   PruningGradGraph(weights, grad_attr, grad_position);
816   if (last_variable_->is_need_grad()) {
817     GilReleaseWithCheck gil_release;
818     BackPropagate();
819   }
820   python_adapter::PyAdapterCallback::ProcessUnPairedCellHook(true);
821   ValuePtr gradients = GetGrads(weights, grad_position, grad_attr);
822   ClearGrads(weights);
823   return gradients;
824 }
825 }  // namespace mindspore::pynative::autograd
826