1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/function/func_grad.h"
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "pybind_api/gil_scoped_long_running.h"
23 #include "include/common/utils/primitive_utils.h"
24 #include "include/common/utils/hook.h"
25 #include "pipeline/pynative/pynative_utils.h"
26 #include "ops/framework_ops.h"
27 #include "ops/other_ops.h"
28
29 namespace mindspore::pynative::autograd {
30 namespace {
Add(const ValuePtr & input,const ValuePtr & other,const FuncBuilderPtr & func_impl)31 ValuePtr Add(const ValuePtr &input, const ValuePtr &other, const FuncBuilderPtr &func_impl) {
32 if (input->isa<None>()) {
33 MS_EXCEPTION_IF_NULL(other);
34 return other;
35 }
36 if (other->isa<None>()) {
37 MS_EXCEPTION_IF_NULL(input);
38 return input;
39 }
40 auto result = func_impl->Add(input, other);
41 MS_EXCEPTION_IF_NULL(result);
42 return result;
43 }
44
Add(const ValuePtr & other,size_t input_index,const FuncBuilderPtr & func_impl,std::vector<ValuePtr> * inputs)45 void Add(const ValuePtr &other, size_t input_index, const FuncBuilderPtr &func_impl, std::vector<ValuePtr> *inputs) {
46 if (input_index >= inputs->size()) {
47 MS_LOG(EXCEPTION) << "The input index should less than inputs size";
48 }
49
50 (*inputs)[input_index] = Add(inputs->at(input_index), other, func_impl);
51 }
52
PaddingGradientInput(const ValuePtr & grad,size_t output_size,size_t input_index)53 ValuePtrList PaddingGradientInput(const ValuePtr &grad, size_t output_size, size_t input_index) {
54 ValuePtrList gradients;
55 gradients.reserve(output_size);
56 for (size_t i = 0; i < output_size; ++i) {
57 if (input_index == i) {
58 (void)gradients.emplace_back(grad);
59 } else {
60 // If gradient is not, we just set kNone, then we lazy update zero gradient by
61 // LazeUpdateZeroGradient method
62 (void)gradients.emplace_back(kNone);
63 }
64 }
65 return gradients;
66 }
67
GeneratePythonArgs(const ValuePtrList & inputs,const ValuePtr & output,bool is_need_recompute)68 VectorRef GeneratePythonArgs(const ValuePtrList &inputs, const ValuePtr &output, bool is_need_recompute) {
69 VectorRef args;
70 for (const auto &value : inputs) {
71 (void)args.emplace_back(value);
72 }
73 // If we not need recompute, we save output.
74 if (!is_need_recompute) {
75 (void)args.emplace_back(output);
76 }
77 return args;
78 }
79
ValueListToValue(const ValuePtrList & values,const abstract::AbstractBasePtr & abs)80 ValuePtr ValueListToValue(const ValuePtrList &values, const abstract::AbstractBasePtr &abs) {
81 if (values.size() == kSizeZero) {
82 MS_LOG(EXCEPTION) << "tensors size should not be empty!";
83 }
84 if (values.size() == kSizeOne && !abs->isa<abstract::AbstractSequence>()) {
85 return values[kIndex0];
86 }
87 return std::make_shared<ValueTuple>(values);
88 }
89
IsOutputBothEmpty(const ValuePtr & input_grads,const ValuePtr & weight_grads)90 bool IsOutputBothEmpty(const ValuePtr &input_grads, const ValuePtr &weight_grads) {
91 if (!input_grads->isa<ValueTuple>() || !weight_grads->isa<ValueTuple>()) {
92 return false;
93 }
94 auto input_grads_tuple = input_grads->cast<ValueTuplePtr>();
95 auto weight_grads_tuple = weight_grads->cast<ValueTuplePtr>();
96 return input_grads_tuple->size() == 0 && weight_grads_tuple->size() == 0;
97 }
98
GenerateEmptyTupleValue()99 ValuePtr GenerateEmptyTupleValue() {
100 std::vector<ValuePtr> value_list;
101 auto inputs_value = std::make_shared<ValueTuple>(value_list);
102 auto weights_value = std::make_shared<ValueTuple>(value_list);
103 std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
104 return std::make_shared<ValueTuple>(tuple_list);
105 }
106
SetFlattenTensorGradMetaData(const ValuePtrList & flatten_outs,const VariablePtr & variable)107 void SetFlattenTensorGradMetaData(const ValuePtrList &flatten_outs, const VariablePtr &variable) {
108 for (size_t i = 0; i < flatten_outs.size(); ++i) {
109 if (flatten_outs[i]->isa<tensor::BaseTensor>()) {
110 auto tensor = flatten_outs[i]->cast<tensor::BaseTensorPtr>();
111 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
112 if (auto_grad_meta_data == nullptr) {
113 MS_LOG(DEBUG) << "tensor has no auto_grad_meta_data";
114 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
115 tensor->set_auto_grad_meta_data(auto_grad_meta_data);
116 }
117 auto_grad_meta_data->set_variable(variable);
118 auto_grad_meta_data->set_output_index(i);
119 }
120 }
121 }
122
IsValidTensorInput(const ValuePtr & v)123 bool IsValidTensorInput(const ValuePtr &v) {
124 MS_EXCEPTION_IF_NULL(v);
125 return v->isa<tensor::BaseTensor>() || v->isa<tensor::MetaSparseTensor>();
126 }
127
IsNeedComputeGrad(const ValuePtr & input)128 bool IsNeedComputeGrad(const ValuePtr &input) {
129 MS_EXCEPTION_IF_NULL(input);
130 if (input->isa<tensor::BaseTensor>()) {
131 auto input_tensor = input->cast<tensor::BaseTensorPtr>();
132 auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
133 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
134 auto variable = auto_grad_meta_data->variable();
135 if (variable != nullptr && variable->is_need_grad()) {
136 return true;
137 }
138 } else if (input->isa<ValueSequence>()) {
139 auto seq = input->cast<ValueSequencePtr>();
140 if (!seq->value().empty() && !seq->value().front()->isa<tensor::BaseTensor>()) {
141 return false;
142 }
143 return std::any_of(seq->value().begin(), seq->value().end(),
144 [](const ValuePtr &val) { return IsNeedComputeGrad(val); });
145 }
146 return false;
147 }
148
CallBackwardHooks(const ValuePtr & value,ValuePtrList * grad_in)149 ValuePtrList CallBackwardHooks(const ValuePtr &value, ValuePtrList *grad_in) {
150 if (value == nullptr) {
151 MS_LOG(DEBUG) << "Get null value";
152 return *grad_in;
153 }
154 MS_EXCEPTION_IF_NULL(grad_in);
155 auto tensor = value->cast<tensor::BaseTensorPtr>();
156 if (tensor == nullptr) {
157 MS_LOG(DEBUG) << "Hook just work on tensor, not support value " << value->ToString();
158 return *grad_in;
159 }
160 auto auto_grad_meta = tensor->auto_grad_meta_data();
161 MS_EXCEPTION_IF_NULL(auto_grad_meta);
162 if (auto_grad_meta->backward_hooks().empty()) {
163 MS_LOG(DEBUG) << "Get empty backward hooks for tensor id " << tensor->id();
164 return *grad_in;
165 }
166 if (grad_in->size() != kSizeOne) {
167 MS_LOG(EXCEPTION) << "Tensor hook just work on one tensor value, not support value sequence";
168 }
169 runtime::OpExecutor::GetInstance().WaitAll();
170 for (const auto &hook : auto_grad_meta->backward_hooks()) {
171 MS_LOG(DEBUG) << "Run hook id " << hook.first;
172 MS_EXCEPTION_IF_NULL(hook.second);
173 (*grad_in)[kIndex0] = (*(hook.second))(grad_in->front());
174 }
175 runtime::OpExecutor::GetInstance().WaitAll();
176 MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(*grad_in, "After hook print gradient in: ");
177 auto_grad_meta->ClearBackwardHooks();
178 return *grad_in;
179 }
180 } // namespace
181
CallBackward(const ValuePtrList & gradients_in)182 ValuePtrList FuncBackwardNode::CallBackward(const ValuePtrList &gradients_in) {
183 MS_LOG(DEBUG) << "Begin CallBackward: " << name();
184 const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
185 auto ir_builder = FuncBuilder(name_, device_target, nullptr);
186 auto inputs = PreProcess(gradients_in, &ir_builder);
187 ir_builder.SetInputs(name(), &inputs, &attrs_);
188 const std::vector<NodePtr> cal_grads_node = grad_func()(&ir_builder);
189 ValuePtrList cal_grads_values;
190 cal_grads_values.reserve(cal_grads_node.size());
191 // Binary op grad result may be nulllptr, we need convert to kNone.
192 (void)std::transform(cal_grads_node.begin(), cal_grads_node.end(), std::back_inserter(cal_grads_values),
193 [](const NodePtr &node) -> ValuePtr {
194 if (node == nullptr) {
195 return kNone;
196 }
197 return node->Value();
198 });
199 auto gradients = PostProcess(cal_grads_values);
200 MS_LOG(DEBUG) << "End CallBackward: " << name();
201 return gradients;
202 }
203
PreProcess(const ValuePtrList & dout,FuncBuilder * emitter)204 NodePtrList FuncBackwardNode::PreProcess(const ValuePtrList &dout, FuncBuilder *emitter) {
205 NodePtrList node_inputs;
206 node_inputs.reserve(op_inputs_.size() + kSizeFive);
207 for (size_t i = 0; i < op_inputs_.size(); ++i) {
208 auto func_node = emitter->NewFuncNode(op_inputs_[i], input_abstract_[i], grad_type_[i]);
209 func_node->set_need_compute_grad_out(IsNeedComputeGrad(op_inputs_[i]));
210 (void)node_inputs.emplace_back(func_node);
211 }
212 (void)node_inputs.emplace_back(emitter->NewFuncNode(op_output_, out_abstract_, InputType::kOpOutput));
213 if (dout.size() == kSizeOne && !op_output_->isa<ValueSequence>()) {
214 (void)node_inputs.emplace_back(emitter->NewFuncNode(dout[kIndex0], out_abstract_, InputType::kOpOutput));
215 } else {
216 (void)node_inputs.emplace_back(
217 emitter->NewFuncNode(std::make_shared<ValueTuple>(dout), out_abstract_, InputType::kOpOutput));
218 }
219 return node_inputs;
220 }
221
Release()222 void FuncBackwardNode::Release() {
223 op_inputs_.clear();
224 op_output_ = nullptr;
225 }
226
CallBackward(const ValuePtrList & grads)227 ValuePtrList HookBackwardNode::CallBackward(const ValuePtrList &grads) {
228 runtime::OpExecutor::GetInstance().WaitAll();
229 MS_LOG(DEBUG) << "Begin HookBackwardNode CallBackward ";
230 auto gradient = ValueListToValue(grads, out_abstract_);
231 const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
232 // Python grad func can not process None, we need to convert None to zero tensor.
233 auto func_builder = FuncBuilder(name_, device_target, nullptr);
234 auto filled_zeros_grad = func_builder.FillZeros(gradient, out_abstract_);
235 (void)args_.emplace_back(filled_zeros_grad);
236 py::gil_scoped_acquire gil_acquire;
237 auto out = prim_->RunHookFunction(args_);
238 ValuePtrList gradient_values;
239 if (utils::isa<PyObjectRef>(out)) {
240 PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
241 auto out_py_tuple = py_ref.object_;
242 ConvertPyObjectToTensor(out_py_tuple, &gradient_values);
243 }
244 if (gradient_values.empty()) {
245 MS_LOG(EXCEPTION) << "Hook fn output is not <PyObjectRef> type!";
246 }
247 auto gradient_tensors = PostProcess(gradient_values);
248 MS_LOG(DEBUG) << "End HookBackwardNode CallBackward";
249 runtime::OpExecutor::GetInstance().WaitAll();
250 return gradient_tensors;
251 }
252
Release()253 void HookBackwardNode::Release() { args_.clear(); }
254
CallBackward(const ValuePtrList & grads)255 ValuePtrList GraphBackwardNode::CallBackward(const ValuePtrList &grads) {
256 MS_LOG(DEBUG) << "Begin GraphBackwardNode CallBackward ";
257 MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(grads, "bprop cut input grads: ");
258 auto graph_call_back = PyNativeAlgo::AutoGrad::CreateGraphCallBack(func_graph_, cache_key_, graph_call_condition_);
259 // Add graph din
260 const auto &device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
261 auto ir_builder = FuncBuilder(name_, device_target, nullptr);
262 auto real_dout = LazeUpdateZeroGradient(grads, &ir_builder, op_output_);
263
264 // If output is jit and has dict output. Key and value will converte into tuples for inputs
265 if (!graph_call_condition_.jit_out_has_dict_) {
266 for (const auto &arg : real_dout) {
267 (void)args_.emplace_back(arg);
268 }
269 } else {
270 if (!op_output_->isa<ValueDictionary>()) {
271 MS_LOG(EXCEPTION) << "Get wrong data type " << op_output_->ToString();
272 }
273 const auto &v_dict = op_output_->cast<ValueDictionaryPtr>();
274 ValuePtrList key_inputs;
275 for (const auto &elem : v_dict->value()) {
276 (void)key_inputs.emplace_back(elem.first);
277 }
278 (void)args_.emplace_back(std::make_shared<ValueTuple>(key_inputs));
279 (void)args_.emplace_back(std::make_shared<ValueTuple>(real_dout));
280 }
281 auto gradient_vec_ref = graph_call_back(args_);
282 auto gradient_values = common::AnfAlgo::TransformVectorRefToMultiValue(gradient_vec_ref);
283 auto gradient_tensors = PostProcess(gradient_values);
284 MS_LOG(DEBUG) << "End GraphBackwardNode CallBackward";
285 return gradient_tensors;
286 }
287
BuildFlattenSensGradient(const ValuePtrList & sens_gradient) const288 ValuePtrList GraphRoot::BuildFlattenSensGradient(const ValuePtrList &sens_gradient) const {
289 ValuePtrList real_gradients;
290 for (const auto &index : gradient_index_) {
291 if (index >= sens_gradient.size()) {
292 MS_LOG(EXCEPTION) << "Inputs gradient index should smaller than flatten_values size!";
293 }
294 (void)real_gradients.emplace_back(sens_gradient[index]);
295 }
296 return real_gradients;
297 }
298
FuncGrad(const ValuePtrList & input_param_values,size_t op_num_in_bprop_graph,bool grad_by_value,bool is_run_recompute)299 FuncGrad::FuncGrad(const ValuePtrList &input_param_values, size_t op_num_in_bprop_graph, bool grad_by_value,
300 bool is_run_recompute) {
301 MS_LOG(DEBUG) << "Start FuncGrad, input size: " << input_param_values.size();
302 for (size_t i = 0; i < input_param_values.size(); ++i) {
303 const auto &input_param_value = input_param_values[i];
304 auto func_node = std::make_shared<BackwardNode>("input" + std::to_string(i));
305 auto variable = std::make_shared<FuncVariable>(func_node, true);
306
307 if (!input_param_value->isa<ValueSequence>()) {
308 // For hook input
309 func_node->set_op_output(input_param_value);
310 PyNativeAlgo::AutoGrad::SetGradInfoForInputs(input_param_value, variable);
311 } else {
312 variable->set_is_need_grad(false);
313 }
314 (void)variable_set_.insert(variable);
315 (void)cell_inputs_.emplace_back(input_param_value, variable);
316 }
317 is_run_recompute_ = is_run_recompute;
318 device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
319 func_impl_ = std::make_shared<FuncBuilder>("func_emitter", device_target_);
320 }
321
KPynativeOp(const GradParamPtr & grad_param)322 bool FuncGrad::KPynativeOp(const GradParamPtr &grad_param) {
323 MS_EXCEPTION_IF_NULL(grad_param);
324
325 auto &prim = grad_param->op_grad_info->op_prim;
326 if (!PyNativeAlgo::AutoGrad::IsPrimNeedGrad(prim) ||
327 (grad_by_value_ && !PyNativeAlgo::AutoGrad::NeedGrad(grad_param->op_grad_info->input_value))) {
328 MS_LOG(DEBUG) << "Prim " << prim->name() << " does not need to do op grad.";
329 return true;
330 }
331 auto flatten_inputs = PyNativeAlgo::DataConvert::FlattenTensorSeqInValueSeq(grad_param->op_grad_info->input_value);
332 ConstructParameterNodes(flatten_inputs);
333 BackwardNodePtr fn = nullptr;
334 bool is_custom_prim =
335 IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
336 if (!is_custom_prim) {
337 auto handle = expander::bprop::BpropIRBuilderFactory::Instance().GetBuilder(prim->name());
338 if (handle != nullptr) {
339 fn = BuildFuncBackwardNode(prim, handle->func, flatten_inputs, grad_param->op_grad_info);
340 } else {
341 fn = BuildCustomBackwardNode(prim, flatten_inputs, grad_param->op_grad_info);
342 }
343 } else {
344 PyNativeAlgo::AutoGrad::CheckRecomputeInputs(grad_param);
345 fn = BuildHookBackwardNode(prim, flatten_inputs, grad_param->op_grad_info);
346 }
347 auto variable = std::make_shared<FuncVariable>(fn, false);
348 if (isa<FakeBackwardNode>(fn)) {
349 variable->set_is_fake_bprop(true);
350 variable->set_fake_prim_name(prim->name());
351 }
352
353 (void)variable_set_.insert(variable);
354 SetFlattenTensorGradMetaData(PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(grad_param->op_grad_info->out_value),
355 variable);
356 MS_LOG(DEBUG) << "End update next edge for " << variable->ToString();
357 return true;
358 }
359
UpdateOutputNodeOfTopCell(const ValuePtr & sens_out)360 void FuncGrad::UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) {
361 MS_LOG(DEBUG) << "Real output of top cell is " << PyNativeAlgo::Common::GetIdByValue(sens_out);
362 sens_value_ = sens_out;
363 auto flatten_sens = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_out);
364 ConstructParameterNodes(flatten_sens);
365 }
366
BuildForwardLastNode(const ValuePtr & sens_gradient)367 void FuncGrad::BuildForwardLastNode(const ValuePtr &sens_gradient) {
368 ValuePtrList root_gradient_value;
369 if (sens_gradient == nullptr) {
370 root_gradient_value = OnsLike(sens_value_);
371 } else {
372 root_gradient_value = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_gradient);
373 }
374 auto root = std::make_shared<GraphRoot>("GraphRoot");
375 auto flatten_args = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens_value_);
376 root->UpdateNextEdges(flatten_args);
377 root_gradients_ = root->BuildFlattenSensGradient(root_gradient_value);
378 auto sens_variable = std::make_shared<FuncVariable>(root, false);
379 if (root_gradients_.empty()) {
380 sens_variable->set_is_need_grad(false);
381 }
382 (void)variable_set_.insert(sens_variable);
383 last_variable_ = sens_variable;
384 }
385
KPynativeWithFProp(const GradParamPtr & grad_param)386 bool FuncGrad::KPynativeWithFProp(const GradParamPtr &grad_param) {
387 MS_EXCEPTION_IF_NULL(grad_param);
388 MS_LOG(DEBUG) << "Do KPynativeWithFProp";
389 if (!grad_by_value_) {
390 MS_LOG(EXCEPTION) << "High grad not support pyboost call";
391 }
392 auto fn = BuildGraphBackwardNode(grad_param);
393 auto variable = std::make_shared<FuncVariable>(fn, false);
394 (void)variable_set_.insert(variable);
395 SetFlattenTensorGradMetaData(PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(grad_param->op_grad_info->out_value),
396 variable);
397 return true;
398 }
399
BuildGraphBackwardNode(const GradParamPtr & grad_param)400 BackwardNodePtr FuncGrad::BuildGraphBackwardNode(const GradParamPtr &grad_param) {
401 MS_EXCEPTION_IF_NULL(grad_param);
402 if (ir_bprop_ == nullptr) {
403 ir_bprop_ = std::make_unique<IrBprop>(std::make_shared<AdParam>(), device_target_, grad_by_value_);
404 }
405 grad_param->is_func_grad = true;
406 auto [cache_hit, bprop_graph] = ir_bprop_->GetBpropGraph(grad_param);
407 bool is_jit_dynamic_shape = grad_param->is_jit_graph && grad_param->use_dynamic_shape_process;
408 // Save replace info in first time
409 if (!cache_hit && is_jit_dynamic_shape && grad_param->has_added_v) {
410 const auto &jit = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->jit();
411 jit->SaveForwardOutputTensorInfoInBpropGraph(bprop_graph);
412 }
413 VectorRef input_args;
414 (void)std::transform(grad_param->op_grad_info->input_value.begin(), grad_param->op_grad_info->input_value.end(),
415 std::back_inserter(input_args), [](const ValuePtr &v) { return v; });
416 PyNativeAlgo::Common::DumpGraphIR("call_graph.ir", bprop_graph);
417 auto fn = std::make_shared<GraphBackwardNode>(
418 bprop_graph->ToString(), bprop_graph, input_args, grad_param->op_grad_info->out_value,
419 grad_param->op_grad_info->output_size, grad_param->graph_cache_key, grad_param->is_control_flow,
420 grad_param->is_jit_graph, grad_param->use_dynamic_shape_process, grad_param->jit_out_has_dict);
421 auto flatten_inputs = PyNativeAlgo::DataConvert::FlattenTensorSeqInValueSeq(grad_param->op_grad_info->input_value);
422 ConstructParameterNodes(flatten_inputs);
423 fn->UpdateNextEdges(flatten_inputs);
424 return fn;
425 }
426
BackPropagate()427 void FuncGrad::BackPropagate() {
428 MS_LOG(DEBUG) << "Begin BackPropagate";
429 const auto &last_node_reverse_iter = GetLastNodeReverseIter();
430 const auto &root_fn = (*last_node_reverse_iter)->func_node();
431 mindspore::HashMap<BackwardNode *, ValuePtrList> input_buffer;
432 (void)input_buffer.insert({root_fn.get(), root_gradients_});
433 MS_LOG(DEBUG) << "Is running recompute grad " << is_run_recompute_;
434 for (auto iter = last_node_reverse_iter; iter != variable_set_.rend(); ++iter) {
435 const auto &variable = *iter;
436 const auto &fn = variable->func_node();
437 MS_LOG(DEBUG) << "Begin calculate op: " << fn->name() << " gradients!";
438 if (!variable->is_need_propagate() || !variable->is_need_grad()) {
439 MS_LOG(DEBUG) << "No need grad, variable is: " << variable->ToString();
440 continue;
441 }
442 if (static_cast<bool>(MS_UNLIKELY(variable->is_fake_bprop()))) {
443 MS_LOG(EXCEPTION) << "Illegal primitive " << variable->fake_prim_name() << "'s bprop not defined";
444 }
445 if (input_buffer.find(fn.get()) == input_buffer.end()) {
446 MS_LOG(EXCEPTION) << "Fn not has gradient";
447 }
448 auto &gradient_in = input_buffer[fn.get()];
449 MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(gradient_in, "Begin print gradient in: ");
450 // If register hook by weight, and weight in recompute cell.So, hook will execute, which is not expect.
451 if (!is_run_recompute_) {
452 gradient_in = CallBackwardHooks(fn->op_output(), &gradient_in);
453 }
454 auto gradient_out = fn->CallBackward(gradient_in);
455 MS_LOG(DEBUG) << PyNativeAlgo::Common::PrintDebugInfo(gradient_out, "Begin print gradient out: ");
456 if (gradient_out.size() != fn->next_edges().size()) {
457 MS_LOG(EXCEPTION) << "Fn gradient size should be same as next edges size";
458 }
459 for (size_t i = 0; i < fn->next_edges().size(); ++i) {
460 const auto &next_edge = fn->next_edges()[i];
461 const auto &last_variable = next_edge.variable;
462 // If network not calculate inputs grad, some op will be pruning, we need skip this op.
463 if (!last_variable->is_need_grad()) {
464 MS_LOG(DEBUG) << "variable is not need grad, " << last_variable->ToString();
465 continue;
466 }
467 const auto &last_fn = last_variable->func_node();
468 const auto &last_gradient = gradient_out[i];
469 // If last_gradient is None, It represents that this tensor grad is zeros.
470 if (last_gradient->isa<None>()) {
471 MS_LOG(DEBUG) << last_variable->ToString() << ", its gradient is kNone!";
472 continue;
473 }
474 if (input_buffer.find(last_fn.get()) != input_buffer.end()) {
475 Add(last_gradient, next_edge.input_index, func_impl_, &input_buffer[last_fn.get()]);
476 } else {
477 input_buffer[last_fn.get()] =
478 PaddingGradientInput(last_gradient, last_fn->output_size(), next_edge.input_index);
479 }
480 last_variable->set_is_need_propagate(true);
481 }
482 if (variable->is_leaf()) {
483 MS_LOG(DEBUG) << "Get leaf node " << variable->ToString();
484 auto grads = input_buffer[fn.get()];
485 if (grads.empty() || grads[0]->isa<None>()) {
486 MS_LOG(EXCEPTION) << variable->ToString() << ", " << (grads.empty() ? "grad is empty" : "grad is kNone");
487 }
488 auto grad_tensor = grads[0]->cast<tensor::BaseTensorPtr>();
489 MS_EXCEPTION_IF_NULL(grad_tensor);
490 variable->set_grad(grad_tensor);
491 }
492 (void)input_buffer.erase(fn.get());
493 variable->Release();
494 }
495 MS_LOG(DEBUG) << "End BackPropagate";
496 }
497
GetLastNodeReverseIter()498 OrderedSet<FuncVariablePtr>::reverse_iterator FuncGrad::GetLastNodeReverseIter() {
499 for (auto iter = variable_set_.rbegin(); iter != variable_set_.rend(); ++iter) {
500 if (*iter == last_variable_) {
501 last_variable_->set_is_need_propagate(true);
502 return iter;
503 }
504 }
505 return variable_set_.rend();
506 }
507
ConstructParameterNodes(const ValuePtrList & inputs)508 void FuncGrad::ConstructParameterNodes(const ValuePtrList &inputs) {
509 for (const auto &value : inputs) {
510 if (value->isa<tensor::BaseTensor>()) {
511 auto tensor = value->cast<tensor::BaseTensorPtr>();
512 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
513 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
514 if (auto_grad_meta_data->variable() != nullptr) {
515 continue;
516 }
517 if (auto_grad_meta_data->input_type() == InputType::kParameter &&
518 PyNativeAlgo::Common::IsParamRequiresGrad(tensor)) {
519 auto fn = std::make_shared<BackwardNode>("parameter");
520 fn->set_op_output(value);
521 auto variable = std::make_shared<FuncVariable>(fn, true);
522 auto_grad_meta_data->set_variable(variable);
523 (void)variable_set_.insert(variable);
524 weights_used_in_graph_.emplace_back(tensor);
525 }
526 }
527 }
528 }
529
BuildFuncBackwardNode(const PrimitivePtr & prim,const expander::bprop::BpropBuilderFunc & func,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)530 BackwardNodePtr FuncGrad::BuildFuncBackwardNode(const PrimitivePtr &prim, const expander::bprop::BpropBuilderFunc &func,
531 const ValuePtrList &flatten_inputs, const OpGradInfoPtr &op_grad_info) {
532 PyNativeAlgo::AutoGrad::CheckAndSetAbstract(op_grad_info);
533 auto fn = std::make_shared<FuncBackwardNode>(
534 prim->name(), func, prim->attrs(), op_grad_info->input_value, op_grad_info->input_abs, op_grad_info->out_value,
535 op_grad_info->output_size, op_grad_info->out_abs, op_grad_info->input_value_grad_type);
536 fn->UpdateNextEdges(flatten_inputs);
537 return fn;
538 }
539
BuildCustomBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)540 BackwardNodePtr FuncGrad::BuildCustomBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
541 const OpGradInfoPtr &op_grad_info) {
542 MS_EXCEPTION_IF_NULL(prim);
543 MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
544 {
545 py::gil_scoped_acquire gil;
546 auto prim_py = prim->cast<PrimitivePyPtr>();
547 if (prim_py == nullptr) {
548 MS_LOG(DEBUG) << "Prim is not PrimitivePy, can not find python bprop";
549 return BuildFakeBackwardNode(prim, flatten_inputs, op_grad_info);
550 }
551 py::function fn = prim_py->GetBpropFunction();
552 if (py::isinstance<py::none>(fn)) {
553 fn = GetBpropFunction(prim->name());
554 }
555 if (!fn || py::isinstance<py::none>(fn)) {
556 MS_LOG(INFO) << "Can not find bprop function for " << prim->name() << ". fn: " << ConvertPyObjToString(fn);
557 return BuildFakeBackwardNode(prim, flatten_inputs, op_grad_info);
558 }
559 (void)prim_py->AddBackwardHookFn(0, fn);
560 (void)prim_py->AddAttr("custom_op_bprop", MakeValue(true));
561 }
562 return BuildHookBackwardNode(prim, flatten_inputs, op_grad_info);
563 }
564
BuildHookBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)565 BackwardNodePtr FuncGrad::BuildHookBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
566 const OpGradInfoPtr &op_grad_info) {
567 MS_EXCEPTION_IF_NULL(prim);
568 auto bprop_cut = PyNativeAlgo::AutoGrad::BuildBpropCutPrim(prim, op_grad_info->is_need_recompute);
569 VectorRef args =
570 GeneratePythonArgs(op_grad_info->input_value, op_grad_info->out_value, op_grad_info->is_need_recompute);
571 auto fn = std::make_shared<HookBackwardNode>(prim->name(), bprop_cut, std::move(args), op_grad_info->output_size,
572 op_grad_info->out_abs);
573 fn->UpdateNextEdges(flatten_inputs);
574 return fn;
575 }
576
BuildFakeBackwardNode(const PrimitivePtr & prim,const ValuePtrList & flatten_inputs,const OpGradInfoPtr & op_grad_info)577 BackwardNodePtr FuncGrad::BuildFakeBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
578 const OpGradInfoPtr &op_grad_info) {
579 MS_EXCEPTION_IF_NULL(prim);
580 auto fn = std::make_shared<FakeBackwardNode>(prim->name(), op_grad_info->output_size);
581 fn->UpdateNextEdges(flatten_inputs);
582 return fn;
583 }
584
GetGrads(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)585 ValuePtr FuncGrad::GetGrads(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
586 const GradAttr &grad_attr) {
587 auto inputs_grad = GetInputGrads(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
588 auto weights_grad = GetWeightGrads(grad_attr.grad_weights, weights, grad_attr.weight_param_is_tuple);
589 // Gradients wrt inputs and weights.
590 if (inputs_grad != nullptr && weights_grad != nullptr) {
591 if (IsOutputBothEmpty(inputs_grad, weights_grad)) {
592 return GenerateEmptyTupleValue();
593 }
594 ValuePtrList gradients{inputs_grad, weights_grad};
595 return std::make_shared<ValueTuple>(gradients);
596 }
597 // Gradients wrt inputs.
598 if (inputs_grad != nullptr) {
599 return inputs_grad;
600 }
601 // Gradients wrt weights.
602 if (weights_grad != nullptr) {
603 return weights_grad;
604 }
605 // grad_all_inputs, grad_weights and get_by_position are all false.
606 if (cell_inputs_.empty()) {
607 // If no input nodes, return empty tuple.
608 return std::make_shared<ValueTuple>(ValuePtrList{});
609 }
610
611 // If there are input nodes, return gradient of first input node.
612 // Tuple, List, scalar will be ignore
613 if (IsValidTensorInput(cell_inputs_[kIndex0].first)) {
614 return PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(
615 cell_inputs_[kIndex0].first, cell_inputs_[kIndex0].second->grad(), func_impl_.get(), SpecialType::kZerosLikeType);
616 }
617 MS_LOG(DEBUG) << "Get first input node is not tensor " << cell_inputs_[0].first->ToString();
618 return std::make_shared<ValueTuple>(ValuePtrList{});
619 }
620
GetInputGrads(bool grad_all_inputs,bool get_by_position,const std::vector<size_t> & grad_position)621 ValuePtr FuncGrad::GetInputGrads(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position) {
622 std::vector<size_t> grad_pos_list;
623 if (get_by_position) {
624 grad_pos_list = grad_position;
625 } else if (grad_all_inputs) {
626 grad_pos_list.resize(cell_inputs_.size());
627 iota(grad_pos_list.begin(), grad_pos_list.end(), 0);
628 } else {
629 return nullptr;
630 }
631 ValuePtrList input_grads;
632 input_grads.reserve(cell_inputs_.size());
633 if (!cell_inputs_.empty()) {
634 for (size_t index : grad_pos_list) {
635 if (index >= cell_inputs_.size()) {
636 MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size.";
637 }
638 // Tuple, List, scalar will be ignore
639 if (!IsValidTensorInput(cell_inputs_[index].first)) {
640 MS_LOG(DEBUG) << cell_inputs_[index].first->ToString() << "is no tensor";
641 continue;
642 }
643 ValuePtr real_dout = PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(
644 cell_inputs_[index].first, cell_inputs_[index].second->grad(), func_impl_.get(), SpecialType::kZerosLikeType);
645 (void)input_grads.emplace_back(real_dout);
646 }
647 if (get_by_position && input_grads.size() == kSizeOne) {
648 return input_grads[kIndex0];
649 }
650 }
651 return std::make_shared<ValueTuple>(input_grads);
652 }
653
GetWeightGrads(bool grad_weights,const tensor::BaseTensorPtrList & weights,bool weight_param_is_tuple)654 ValuePtr FuncGrad::GetWeightGrads(bool grad_weights, const tensor::BaseTensorPtrList &weights,
655 bool weight_param_is_tuple) {
656 // No need to return gradient of weights.
657 if (!grad_weights) {
658 return nullptr;
659 }
660 if (weight_param_is_tuple) {
661 ValuePtrList weight_grads;
662 weight_grads.reserve(weights.size());
663 for (const auto &weight : weights) {
664 (void)weight_grads.emplace_back(GetWeightGrad(weight));
665 }
666 return std::make_shared<ValueTuple>(weight_grads);
667 }
668 return GetWeightGrad(weights[0]);
669 }
670
GetWeightGrad(const tensor::BaseTensorPtr & weight)671 ValuePtr FuncGrad::GetWeightGrad(const tensor::BaseTensorPtr &weight) {
672 MS_EXCEPTION_IF_NULL(weight);
673 auto auto_grad_meta_data = weight->auto_grad_meta_data();
674 if (auto_grad_meta_data == nullptr) {
675 return func_impl_->Zeros(weight);
676 }
677 auto variable = auto_grad_meta_data->variable();
678 const auto &func_variable = std::dynamic_pointer_cast<FuncVariable>(variable);
679 if (variable != nullptr && variable->is_need_grad()) {
680 // If weight used in the forward network, but requires_grad is false, return zero like.
681 if (func_variable->grad() == nullptr ||
682 (weight->param_info() != nullptr && !weight->param_info()->requires_grad())) {
683 MS_LOG(INFO) << "weight participate in forward calculation, but requires_grad is false";
684 return func_impl_->Zeros(weight);
685 }
686 auto weight_grad = func_variable->grad();
687 return weight_grad;
688 }
689 MS_LOG(INFO) << "parameter does not need grad, tensor: " << PyNativeAlgo::Common::GetIdByValue(weight);
690 return func_impl_->Zeros(weight);
691 }
692
ClearGrads(const tensor::BaseTensorPtrList & weights)693 void FuncGrad::ClearGrads(const tensor::BaseTensorPtrList &weights) {
694 // Clear input grads.
695 for (const auto &input : cell_inputs_) {
696 input.second->set_grad(nullptr);
697 }
698 cell_inputs_.clear();
699 // Clear weights grad info
700 for (const auto &weight : weights) {
701 weight->set_auto_grad_meta_data(nullptr);
702 }
703 }
704
OnsLike(const ValuePtr & sens)705 ValuePtrList FuncGrad::OnsLike(const ValuePtr &sens) {
706 MS_EXCEPTION_IF_NULL(sens);
707 auto flatten_values = PyNativeAlgo::DataConvert::FlattenTensorSeqInValue(sens);
708 const auto &v = PyNativeAlgo::AutoGrad::BuildSpecialValueGrad(std::make_shared<ValueTuple>(flatten_values), nullptr,
709 func_impl_.get(), SpecialType::kOnesLikeType);
710 auto v_seq = v->cast<ValueTuplePtr>();
711 return v_seq->value();
712 }
713
CheckSensShapeAndType(const ValuePtr & sens_gradient)714 void FuncGrad::CheckSensShapeAndType(const ValuePtr &sens_gradient) {
715 if (sens_gradient == nullptr) {
716 return;
717 }
718 const auto sens_gradient_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_gradient->ToAbstract());
719 const auto out_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_value_->ToAbstract());
720 const auto &sens_gradient_shape = sens_gradient_abs->BuildShape()->ToString();
721 const auto &out_shape = out_abs->BuildShape()->ToString();
722 if (sens_gradient_shape != "()" && out_shape != "()") {
723 if (sens_gradient_shape != out_shape) {
724 // Sens shape in ir graph is determined by graph output, so it can be dynamic shape; But input shape is
725 // determined by user input, which could not be dynamic shape.
726 if (!sens_gradient_abs->BuildShape()->IsDynamic()) {
727 MS_EXCEPTION(ValueError) << "The shape should be " << out_shape << ", but got " << sens_gradient_shape << ", "
728 << ", sens gradient abs " << sens_gradient_abs->ToString() << ", out abs"
729 << out_abs->ToString();
730 }
731 }
732 const auto &sens_gradient_dtype = sens_gradient_abs->BuildType()->ToString();
733 const auto &out_dtype = out_abs->BuildType()->ToString();
734 if (sens_gradient_dtype != out_dtype) {
735 MS_EXCEPTION(TypeError) << "The dtype should be " << out_dtype << ", but got " << sens_gradient_dtype << ", "
736 << ", sens gradient abs " << sens_gradient_abs->ToString() << ", out abs"
737 << out_abs->ToString();
738 }
739 }
740 }
741
PruningGradGraph(const tensor::BaseTensorPtrList & weights,const GradAttr & grad_attr,const std::vector<size_t> & grad_position)742 void FuncGrad::PruningGradGraph(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr,
743 const std::vector<size_t> &grad_position) {
744 PruningInput(grad_attr, grad_position);
745 PruningWeights(weights, grad_attr);
746
747 // Pruning all node in grad graph
748 for (const auto &variable : variable_set_) {
749 if (variable->is_leaf()) {
750 continue;
751 }
752 bool is_need_grad =
753 std::any_of(variable->func_node()->next_edges().begin(), variable->func_node()->next_edges().end(),
754 [](const auto &edge) { return edge.variable->is_need_grad(); });
755 if (!is_need_grad) {
756 variable->set_is_need_grad(false);
757 }
758 }
759 }
760
PruningInput(const GradAttr & grad_attr,const std::vector<size_t> & grad_position)761 void FuncGrad::PruningInput(const GradAttr &grad_attr, const std::vector<size_t> &grad_position) {
762 mindspore::HashSet<size_t> grad_pos_list{grad_position.begin(), grad_position.end()};
763 // Pruning inputs by position in grad graph
764 if (grad_attr.get_by_position) {
765 for (size_t i = 0; i < cell_inputs_.size(); ++i) {
766 if (grad_pos_list.find(i) == grad_pos_list.end()) {
767 cell_inputs_[i].second->set_is_need_grad(false);
768 }
769 }
770 return;
771 }
772
773 // Pruning first input in grad graph
774 if (!grad_attr.grad_all_inputs && !grad_attr.get_by_position && !grad_attr.grad_weights) {
775 for (size_t i = 1; i < cell_inputs_.size(); ++i) {
776 cell_inputs_[i].second->set_is_need_grad(false);
777 }
778 }
779
780 // Pruning all inputs not grad
781 if (!grad_attr.grad_all_inputs && grad_attr.grad_weights) {
782 for (auto &cell_input : cell_inputs_) {
783 cell_input.second->set_is_need_grad(false);
784 }
785 }
786 }
787
PruningWeights(const tensor::BaseTensorPtrList & weights,const GradAttr & grad_attr)788 void FuncGrad::PruningWeights(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr) {
789 // Pruning weights in grad graph
790 if (grad_attr.grad_weights) {
791 mindspore::HashSet<std::string> grad_weights_id;
792 for (const auto &weight : weights) {
793 (void)grad_weights_id.emplace(weight->id());
794 }
795 for (const auto &weight : weights_used_in_graph_) {
796 if (grad_weights_id.find(weight->id()) == grad_weights_id.end()) {
797 auto variable = weight->auto_grad_meta_data()->variable();
798 MS_EXCEPTION_IF_NULL(variable);
799 variable->set_is_need_grad(false);
800 }
801 }
802 } else {
803 for (const auto &weight : weights_used_in_graph_) {
804 auto variable = weight->auto_grad_meta_data()->variable();
805 MS_EXCEPTION_IF_NULL(variable);
806 variable->set_is_need_grad(false);
807 }
808 }
809 }
810
Finish(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr,const ValuePtr & sens)811 ValuePtr FuncGrad::Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
812 const GradAttr &grad_attr, const ValuePtr &sens) {
813 CheckSensShapeAndType(sens);
814 BuildForwardLastNode(sens);
815 PruningGradGraph(weights, grad_attr, grad_position);
816 if (last_variable_->is_need_grad()) {
817 GilReleaseWithCheck gil_release;
818 BackPropagate();
819 }
820 python_adapter::PyAdapterCallback::ProcessUnPairedCellHook(true);
821 ValuePtr gradients = GetGrads(weights, grad_position, grad_attr);
822 ClearGrads(weights);
823 return gradients;
824 }
825 } // namespace mindspore::pynative::autograd
826