• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2022-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/pynative/grad/ir/ir_grad.h"
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/expander/bprop/bprop.h"
26 #include "frontend/optimizer/ad/dfunctor.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "include/common/profiler.h"
30 #include "ir/anf.h"
31 #include "ir/func_graph_cloner.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
34 #include "pipeline/pynative/pynative_utils.h"
35 #include "utils/info.h"
36 #include "utils/profile.h"
37 
38 namespace mindspore {
39 namespace pynative {
40 namespace autograd {
41 namespace {
SetJitCallGraph(const CNodePtr & cnode,const FuncGraphPtr & call_graph,const std::string & cache_key,const GraphCallCondition & graph_call_condition)42 void SetJitCallGraph(const CNodePtr &cnode, const FuncGraphPtr &call_graph, const std::string &cache_key,
43                      const GraphCallCondition &graph_call_condition) {
44   MS_EXCEPTION_IF_NULL(cnode);
45   common::AnfAlgo::SetNodeAttr(kAttrJitCallNode, MakeValue(true), cnode);
46   auto graph_call_back = PyNativeAlgo::AutoGrad::CreateGraphCallBack(call_graph, cache_key, graph_call_condition);
47   cnode->set_user_data<JitCallGraph>(std::make_shared<JitCallGraph>(graph_call_back));
48 }
49 
IsOutputBothEmpty(const AnfNodePtr & inputs_grad,const AnfNodePtr & weights_grad)50 bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
51   if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
52     return false;
53   }
54   auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
55   auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
56   if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
57       !IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
58     return false;
59   }
60   constexpr int kEmptyTupeSize = 1;
61   if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
62     return false;
63   }
64   return true;
65 }
66 
GenerateEmptyTupleValue()67 AnfNodePtr GenerateEmptyTupleValue() {
68   std::vector<ValuePtr> value_list;
69   auto inputs_value = std::make_shared<ValueTuple>(value_list);
70   auto weights_value = std::make_shared<ValueTuple>(value_list);
71   std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
72   auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
73   return PyNativeAlgo::Common::CreateValueNodeByValue(tuple_value);
74 }
75 
IsValidTensorInput(const abstract::AbstractBasePtr & abs)76 bool IsValidTensorInput(const abstract::AbstractBasePtr &abs) {
77   MS_EXCEPTION_IF_NULL(abs);
78   return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSparseTensor>();
79 }
80 
GetTupleItemNodeInput(const KernelGraphPtr & tape,const AnfNodePtr & node)81 AnfNodePtr GetTupleItemNodeInput(const KernelGraphPtr &tape, const AnfNodePtr &node) {
82   MS_EXCEPTION_IF_NULL(tape);
83   MS_EXCEPTION_IF_NULL(node);
84   auto cnode = node->cast<CNodePtr>();
85   MS_EXCEPTION_IF_NULL(cnode);
86   AnfNodePtr new_cnode = nullptr;
87   if (IsPrimitive(cnode->input(kIndex1), prim::kPrimTupleGetItem)) {
88     auto inner_cnode = cnode->input(kIndex1)->cast<CNodePtr>();
89     new_cnode = tape->FuncGraph::NewCNode(
90       {inner_cnode->input(kIndex0), GetTupleItemNodeInput(tape, inner_cnode), inner_cnode->input(kIndex2)});
91   } else {
92     AnfNodePtrList new_inputs{cnode->inputs().begin(), cnode->inputs().end()};
93     new_cnode = tape->FuncGraph::NewCNode(new_inputs);
94   }
95   MS_EXCEPTION_IF_NULL(new_cnode);
96   new_cnode->set_abstract(cnode->abstract());
97   return new_cnode;
98 }
99 
IsConstant(const ValuePtr & value)100 bool IsConstant(const ValuePtr &value) {
101   MS_EXCEPTION_IF_NULL(value);
102   if (value->isa<tensor::BaseTensor>()) {
103     const auto &tensor = value->cast<tensor::BaseTensorPtr>();
104     auto auto_grad_meta_data = tensor->auto_grad_meta_data();
105     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
106     if (auto_grad_meta_data->input_type() == InputType::kParameter ||
107         auto_grad_meta_data->input_type() == InputType::kInput) {
108       return false;
109     }
110     auto k_node = auto_grad_meta_data->k_node();
111     if (k_node != nullptr) {
112       return false;
113     }
114     return true;
115   } else if (value->isa<ValueSequence>()) {
116     auto val_seq = value->cast<ValueSequencePtr>();
117     return std::all_of(val_seq->value().begin(), val_seq->value().end(),
118                        [](const ValuePtr &value) { return IsConstant(value); });
119   } else if (value->isa<tensor::COOTensor>()) {
120     auto coo_tensor = value->cast<tensor::COOTensorPtr>();
121     return IsConstant(coo_tensor->GetIndices());
122   } else if (value->isa<tensor::CSRTensor>()) {
123     auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
124     return IsConstant(csr_tensor->GetIndices());
125   }
126   return true;
127 }
128 }  // namespace
129 
HyperAdd(const AnfNodePtr & left_node,const AnfNodePtr & right_node)130 AnfNodePtr IrFunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
131   MS_EXCEPTION_IF_NULL(left_node);
132   MS_EXCEPTION_IF_NULL(right_node);
133 
134   if (PyNativeAlgo::AutoGrad::IsZerosLikeNode(left_node)) {
135     return right_node;
136   }
137   if (PyNativeAlgo::AutoGrad::IsZerosLikeNode(right_node)) {
138     return left_node;
139   }
140   if (!IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
141     auto add_result = tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimAdd), left_node, right_node});
142     add_result->set_abstract(right_node->abstract());
143     return add_result;
144   }
145   if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple) && IsPrimitiveCNode(right_node, prim::kPrimMakeTuple)) {
146     auto left_cnode = left_node->cast<CNodePtr>();
147     auto right_cnode = right_node->cast<CNodePtr>();
148     MS_EXCEPTION_IF_NULL(right_cnode);
149     AnfNodePtrList inputs = {NewValueNode(prim::kPrimMakeTuple)};
150     AbstractBasePtrList abs;
151     for (size_t i = 1; i < left_cnode->size(); ++i) {
152       auto add_result = HyperAdd(left_cnode->input(i), right_cnode->input(i));
153       (void)abs.emplace_back(add_result->abstract());
154       (void)inputs.emplace_back(add_result);
155     }
156     auto add_tuple = tape_->FuncGraph::NewCNode(inputs);
157     add_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
158     return add_tuple;
159   }
160   MS_LOG(EXCEPTION) << "Unknown cnode type" << left_node->DebugString();
161 }
162 
AddNextEdge(const VariablePtr & next_variable,const AnfNodePtr & din)163 void IrFunctionNode::AddNextEdge(const VariablePtr &next_variable, const AnfNodePtr &din) {
164   MS_EXCEPTION_IF_NULL(next_variable);
165   MS_EXCEPTION_IF_NULL(din);
166   // next_node and its corresponding din
167   (void)next_edges_.emplace_back(next_variable, din);
168   if (din == fake_dout_) {
169     (void)need_replace_edges_.emplace_back(next_edges_.size() - 1);
170   }
171 }
172 
UpdateAccumulativeDout(const AnfNodePtr & new_dout)173 void IrFunctionNode::UpdateAccumulativeDout(const AnfNodePtr &new_dout) {
174   MS_EXCEPTION_IF_NULL(new_dout);
175   accumulate_dout_ = HyperAdd(accumulate_dout_, new_dout);
176 }
177 
ReplaceEdges()178 void IrFunctionNode::ReplaceEdges() {
179   MS_EXCEPTION_IF_NULL(accumulate_dout_);
180   for (const auto index : need_replace_edges_) {
181     next_edges_[index].second = accumulate_dout_;
182   }
183 }
184 
IrGrad(const std::vector<ValuePtr> & input_param_values,const AbstractBasePtrList & abs_list,size_t op_num_in_bprop_graph,const runtime::AsyncHqueuePtr & assist_queue,bool grad_by_value,bool is_run_recompute)185 IrGrad::IrGrad(const std::vector<ValuePtr> &input_param_values, const AbstractBasePtrList &abs_list,
186                size_t op_num_in_bprop_graph, const runtime::AsyncHqueuePtr &assist_queue, bool grad_by_value,
187                bool is_run_recompute)
188     : ad_param_(std::make_shared<AdParam>()) {
189   ad_param()->tape_->debug_info()->set_name("grad_top");
190   MS_LOG(DEBUG) << "Start IrGrad, input size: " << input_param_values.size();
191   ad_param()->variable_adjoint_set_.reserve(op_num_in_bprop_graph);
192   ad_param()->anfnode_to_variable_adjoint_.reserve(op_num_in_bprop_graph);
193   ad_param()->users_.dout_user_.reserve(op_num_in_bprop_graph);
194   ad_param()->weights_used_in_graph_.reserve(op_num_in_bprop_graph);
195 
196   for (size_t i = 0; i < input_param_values.size(); ++i) {
197     auto input_parameter = ad_param()->fg_->add_parameter();
198     input_parameter->set_abstract(abs_list[i]);
199     input_parameter->set_name(input_parameter->UniqueName());
200     TraceGuard trace_guard(std::make_shared<TraceCopy>(input_parameter->debug_info()));
201     auto tape_parameter = ad_param()->tape_->add_parameter();
202     tape_parameter->set_abstract(abs_list[i]);
203 
204     auto zeros_like_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
205       ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), abs_list[i], SpecialType::kZerosLikeType);
206     auto func_node = std::make_shared<IrFunctionNode>(ad_param()->tape_, zeros_like_dout);
207     auto input_adjoint = std::make_shared<IrVariable>(func_node, input_param_values[i], true);
208 
209     if (!input_param_values[i]->isa<ValueSequence>()) {
210       PyNativeAlgo::AutoGrad::SetGradInfoForInputs(input_param_values[i], input_adjoint, input_parameter);
211     } else {
212       input_adjoint->set_is_need_grad(false);
213     }
214     (void)cell_inputs_.emplace_back(input_parameter, input_adjoint);
215     (void)ad_param()->variable_adjoint_set_.insert(input_adjoint);
216   }
217 
218   assist_queue_ = assist_queue;
219   grad_by_value_ = grad_by_value;
220   device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
221   ir_bprop_ = std::make_unique<IrBprop>(ad_param_, device_target_, grad_by_value_, is_run_recompute);
222 }
223 
KPynativeOp(const GradParamPtr & grad_param)224 bool IrGrad::KPynativeOp(const GradParamPtr &grad_param) {
225   MS_EXCEPTION_IF_NULL(grad_param);
226 
227   auto &prim = grad_param->op_grad_info->op_prim;
228   if (!PyNativeAlgo::AutoGrad::IsPrimNeedGrad(prim) ||
229       (grad_by_value_ && !PyNativeAlgo::AutoGrad::NeedGrad(grad_param->op_grad_info->input_value))) {
230     MS_LOG(DEBUG) << "Prim " << prim->name() << " does not need to do op grad.";
231     return true;
232   }
233 
234   auto cloned_value = grad_param->op_grad_info->out_value;
235   if (grad_param->op_grad_info->out_value->isa<ValueSequence>()) {
236     cloned_value = ShallowCopyTensorValue(grad_param->op_grad_info->out_value);
237     PyNativeAlgo::Common::ClearDeviceAddress(cloned_value);
238   }
239 
240   PyNativeAlgo::AutoGrad::CheckAndSetAbstract(grad_param->op_grad_info);
241   // construct zeroslike placeholder, if need use in bprop, we replace it in backprogate.
242   AnfNodePtr dout =
243     PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
244                                              grad_param->op_grad_info->out_abs, SpecialType::kZerosLikeType);
245   auto fn = std::make_shared<IrFunctionNode>(ad_param()->tape_, dout);
246   auto variable_adjoint = std::make_shared<IrVariable>(fn, cloned_value);
247   // Custom forward cnode no need record in bprop graph, because it is a flag cnode for run python. So just create
248   // bprop_cut grad op is ok
249   bool is_custom_prim =
250     IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
251   AnfNodePtr k_node = nullptr;
252   if (!grad_by_value_ && !is_custom_prim) {
253     k_node = BuildKNode(NewValueNode(prim), grad_param, true);
254     SetKNodeInfo(grad_param->op_grad_info->out_value, k_node, grad_param->op_grad_info->out_abs);
255     need_do_manager_replace_ = true;
256   }
257   CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint, k_node, is_custom_prim);
258   MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
259   // Gradient outputs
260   std::vector<CNodePtr> outputs;
261   if (!is_custom_prim) {
262     auto ret = BpropExpander(&outputs, &ad_param()->users_).Run(input_node, grad_param->op_grad_info->input_value);
263     // cppcheck-suppress unreadVariable
264     if (MS_UNLIKELY(!ret || outputs.empty())) {
265       MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << prim->name();
266       ir_bprop_->BuildCustomBpropCNode(input_node, prim, &outputs);
267     }
268   } else {
269     PyNativeAlgo::AutoGrad::CheckRecomputeInputs(grad_param);
270     ir_bprop_->BuildBPropCutCNode(input_node, prim, &outputs, grad_param->op_grad_info->is_need_recompute);
271   }
272   // cppcheck-suppress unreadVariable
273   if (MS_UNLIKELY(outputs.empty())) {
274     MS_LOG(DEBUG) << "This op has not custom bprop: " << prim->name();
275     PyNativeAlgo::AutoGrad::BuildFakeBpropCNode(input_node, &outputs);
276     variable_adjoint->set_is_fake_bprop(true);
277     variable_adjoint->set_fake_prim_name(prim->name());
278   }
279   (void)ad_param()->variable_adjoint_set_.insert(variable_adjoint);
280   PyNativeAlgo::AutoGrad::SetGradMetaData(grad_param->op_grad_info->out_value, variable_adjoint);
281   ir_bprop_->UpdateNextEdges(variable_adjoint, outputs, grad_param->op_grad_info->input_value,
282                              grad_param->op_grad_info->input_abs, prim->name());
283   return true;
284 }
285 
KPynativeWithFProp(const GradParamPtr & grad_param)286 bool IrGrad::KPynativeWithFProp(const GradParamPtr &grad_param) {
287   MS_EXCEPTION_IF_NULL(grad_param);
288   MS_LOG(DEBUG) << "Do KPynativeWithFProp";
289   AnfNodePtrList args_node_list;
290   CNodePtr bprop_cnode = nullptr;
291   AnfNodePtr k_node = nullptr;
292   AnfNodePtr dout = nullptr;
293   if (grad_by_value_) {
294     for (size_t i = 0; i < grad_param->input_size; ++i) {
295       if (PyNativeAlgo::Common::IsParam(grad_param->op_grad_info->input_value_grad_type[i])) {
296         auto parameter =
297           ir_bprop_->MapParameter(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]);
298         MS_EXCEPTION_IF_NULL(parameter);
299         (void)args_node_list.emplace_back(parameter);
300         continue;
301       }
302       // Valuenode, node
303       const auto value_node = PyNativeAlgo::Common::CreateValueNodeByValue(
304         grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]->Clone());
305       auto cnode = PyNativeAlgo::Common::ConvertValueSequenceToMakeTuple(value_node, ad_param()->tape_);
306       (void)args_node_list.emplace_back(cnode);
307     }
308     bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
309   } else {
310     k_node = BuildKNode(NewValueNode(grad_param->source_fg), grad_param, false);
311     BuildKNodeListForHighOrderGraph(grad_param->op_grad_info->input_value, grad_param->op_grad_info->input_abs,
312                                     &args_node_list);
313     bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
314   }
315   auto fn = std::make_shared<IrFunctionNode>(ad_param()->tape_, dout);
316   auto variable_adjoint = std::make_shared<IrVariable>(fn, grad_param->op_grad_info->out_value);
317   variable_adjoint->set_k_node(k_node);
318   std::vector<CNodePtr> outputs;
319   for (size_t i = 0; i < grad_param->input_size; ++i) {
320     CNodePtr din = ad_param()->tape_->FuncGraph::NewCNode(
321       {NewValueNode(prim::kPrimTupleGetItem), bprop_cnode, NewValueNode(SizeToLong(i))});
322     din->set_abstract(grad_param->op_grad_info->input_abs[i]);
323     (void)outputs.emplace_back(din);
324   }
325   ir_bprop_->UpdateNextEdges(variable_adjoint, outputs, grad_param->op_grad_info->input_value,
326                              grad_param->op_grad_info->input_abs);
327   (void)ad_param()->variable_adjoint_set_.insert(variable_adjoint);
328   (void)ad_param()->anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
329   PyNativeAlgo::AutoGrad::SetGradMetaData(grad_param->op_grad_info->out_value, variable_adjoint);
330   SetKNodeInfo(grad_param->op_grad_info->out_value, k_node, grad_param->op_grad_info->out_abs);
331   return true;
332 }
333 
GetBPropCNode(const GradParamPtr & grad_param,const AnfNodePtrList & args,const FuncGraphPtr & bprop_graph,bool cache_hit,AnfNodePtr * const tape_dout)334 CNodePtr IrGrad::GetBPropCNode(const GradParamPtr &grad_param, const AnfNodePtrList &args,
335                                const FuncGraphPtr &bprop_graph, bool cache_hit, AnfNodePtr *const tape_dout) {
336   AnfNodePtrList bprop_inputs(args.begin(), args.end());
337   bool is_jit_dynamic_shape = grad_param->is_jit_graph && grad_param->use_dynamic_shape_process;
338   // Save replace info in first time
339   if (!cache_hit && is_jit_dynamic_shape && grad_param->has_added_v) {
340     const auto &jit = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->jit();
341     jit->SaveForwardOutputTensorInfoInBpropGraph(bprop_graph);
342   }
343 
344   // Call by tape_
345   MS_EXCEPTION_IF_NULL(tape_dout);
346   *tape_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
347                                                         grad_param->op_grad_info->out_abs, SpecialType::kZerosLikeType);
348   if (is_jit_dynamic_shape && grad_param->op_grad_info->out_abs->isa<abstract::AbstractSequence>()) {
349     auto abs_seq = grad_param->op_grad_info->out_abs->cast<abstract::AbstractSequencePtr>();
350     // Dynamic len has no size current
351     if (!abs_seq->dynamic_len()) {
352       for (size_t i = 0; i < abs_seq->size(); ++i) {
353         CNodePtr din = ad_param()->tape_->FuncGraph::NewCNode(
354           {NewValueNode(prim::kPrimTupleGetItem), *tape_dout, NewValueNode(SizeToLong(i))});
355         din->set_abstract(abs_seq->elements()[i]);
356         (void)bprop_inputs.emplace_back(din);
357         ir_bprop_->AddUser(*tape_dout, din, kIndex1);
358       }
359     }
360   } else {
361     (void)bprop_inputs.emplace_back(*tape_dout);
362   }
363   (void)bprop_inputs.insert(bprop_inputs.cbegin(), NewValueNode(bprop_graph));
364   // get_bprop is a call node
365   auto bprop_cnode = ad_param()->tape_->FuncGraph::NewCNode(bprop_inputs);
366   bprop_cnode->set_abstract(bprop_graph->output()->abstract());
367   if (is_jit_dynamic_shape) {
368     GraphCallCondition graph_call_condition{grad_param->is_control_flow, grad_param->is_jit_graph,
369                                             grad_param->use_dynamic_shape_process, false, false};
370     SetJitCallGraph(bprop_cnode, bprop_graph, grad_param->graph_cache_key, graph_call_condition);
371     ad_param()->tape_->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
372   }
373   // For replacing parameter and dout.
374   for (size_t i = 1; i < bprop_inputs.size(); ++i) {
375     ir_bprop_->AddUser(bprop_inputs[i], bprop_cnode, i);
376   }
377   return bprop_cnode;
378 }
379 
GetBpropGraphCNode(const GradParamPtr & grad_param,const AnfNodePtrList & args,AnfNodePtr * const tape_dout)380 CNodePtr IrGrad::GetBpropGraphCNode(const GradParamPtr &grad_param, const AnfNodePtrList &args,
381                                     AnfNodePtr *const tape_dout) {
382   MS_EXCEPTION_IF_NULL(grad_param);
383   auto [cache_hit, bprop_graph] = ir_bprop_->GetBpropGraph(grad_param);
384   if (grad_param->is_control_flow || grad_param->is_jit_self_dynamic_shape) {
385     need_do_manager_replace_ = true;
386   }
387   return GetBPropCNode(grad_param, args, bprop_graph, cache_hit, tape_dout);
388 }
389 
UpdateOutputNodeOfTopCell(const ValuePtr & sens_out)390 void IrGrad::UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) {
391   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
392                                      runtime::ProfilerEvent::kPyNativeGradUpdateSens,
393                                      runtime::ProfilerRecorder::kNoName, true);
394   MS_EXCEPTION_IF_NULL(sens_out);
395   MS_LOG(DEBUG) << "Real output of top cell is " << PyNativeAlgo::Common::GetIdByValue(sens_out);
396   ad_param()->sens_value_ = sens_out;
397   UpdateSensParameter(ad_param()->sens_value_);
398 }
399 
Finish(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)400 FuncGraphPtr IrGrad::Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
401                             const GradAttr &grad_attr) {
402   // Set sens node and weights node
403   SetSensAndWeights(weights, grad_attr.has_sens);
404 
405   // BackPropagate sensitivity, except when the last node is a valuenode which may be obtained by constant folding;
406   if (ad_param()->last_variable_->is_need_grad() && !ad_param()->last_variable_->is_leaf()) {
407     ir_bprop_->BackPropagate();
408   }
409   SetOutput(weights, grad_position, grad_attr);
410   // Replace Parameter of primal func graph with parameter of ad_param()->tape_;
411   ReplacePrimalParameter(grad_attr.has_sens);
412   PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", ad_param()->tape_);
413   // Clear weights grad info
414   for (const auto &weight : weights) {
415     weight->set_auto_grad_meta_data(nullptr);
416   }
417   return ad_param()->tape_;
418 }
419 
ConstructBpropGraphInput(const GradParamPtr & grad_param,const AnfNodePtr & dout,const VariablePtr & variable_adjoint,const AnfNodePtr & k_node,bool is_custom_prim)420 CNodePtr IrGrad::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
421                                           const VariablePtr &variable_adjoint, const AnfNodePtr &k_node,
422                                           bool is_custom_prim) {
423   MS_EXCEPTION_IF_NULL(grad_param);
424   AnfNodePtrList node_list;
425   (void)node_list.emplace_back(NewValueNode(grad_param->op_grad_info->op_prim));
426   if (grad_by_value_ || is_custom_prim) {
427     for (size_t i = 0; i < grad_param->input_size; ++i) {
428       if (PyNativeAlgo::Common::IsParam(grad_param->op_grad_info->input_value_grad_type[i])) {
429         // To solve the input is a tuple like (parameter, ...)
430         auto parameter =
431           ir_bprop_->MapParameter(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]);
432         MS_EXCEPTION_IF_NULL(parameter);
433         (void)node_list.emplace_back(parameter);
434         continue;
435       }
436       // Node abstract obj may free, so v node abstract will be not correct
437       (void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(
438         grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]->Clone()));
439     }
440     // Hook run by single op
441     if (!ir_bprop_->bprop_graph_run_by_single_op()) {
442       ir_bprop()->set_bprop_graph_run_by_single_op([&grad_param]() {
443         auto tensor = grad_param->op_grad_info->out_value->template cast<tensor::BaseTensorPtr>();
444         if (tensor == nullptr) {
445           return false;
446         }
447         auto auto_grad_meta = tensor->auto_grad_meta_data();
448         MS_EXCEPTION_IF_NULL(auto_grad_meta);
449         return auto_grad_meta->is_register_hook();
450       }());
451     }
452     // Set out
453     (void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(grad_param->op_grad_info->out_value,
454                                                                               grad_param->op_grad_info->out_abs));
455   } else {
456     // Input is a Parameter or cnode, not a value node
457     BuildKNodeListFromPrimalCNode(grad_param->op_grad_info->input_value, grad_param->op_grad_info->input_abs,
458                                   &node_list);
459     // Set out
460     MS_EXCEPTION_IF_NULL(variable_adjoint);
461     (void)node_list.emplace_back(k_node);
462   }
463   // Set dout
464   (void)node_list.emplace_back(dout);
465   auto input_node = ad_param()->tape_->FuncGraph::NewCNode(node_list);
466   return input_node;
467 }
468 
BuildKNodeListFromPrimalCNode(const ValuePtrList & input_value,const abstract::AbstractBasePtrList & input_abs,AnfNodePtrList * const node_list)469 void IrGrad::BuildKNodeListFromPrimalCNode(const ValuePtrList &input_value,
470                                            const abstract::AbstractBasePtrList &input_abs,
471                                            AnfNodePtrList *const node_list) {
472   for (size_t i = 0; i < input_value.size(); ++i) {
473     (void)node_list->emplace_back(BuildKNodeForCNodeInput(input_value[i], input_abs[i]));
474     MS_LOG(DEBUG) << "Get knode for input:  " << PyNativeAlgo::Common::GetIdByValue(input_value[i]);
475   }
476 }
477 
BuildKNodeForCNodeInput(const ValuePtr & input,const abstract::AbstractBasePtr & abs)478 AnfNodePtr IrGrad::BuildKNodeForCNodeInput(const ValuePtr &input, const abstract::AbstractBasePtr &abs) {
479   MS_EXCEPTION_IF_NULL(input);
480   if (input->isa<tensor::BaseTensor>()) {
481     const auto &tensor = input->cast<tensor::BaseTensorPtr>();
482     const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
483     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
484     auto k_node = auto_grad_meta_data->k_node();
485     if (k_node != nullptr) {
486       return k_node;
487     }
488     if (PyNativeAlgo::Common::IsParam(auto_grad_meta_data->input_type())) {
489       return ir_bprop_->MapParameter(input, abs);
490     }
491   } else if (input->isa<ValueSequence>() && !IsConstant(input)) {
492     AnfNodePtrList inputs;
493     (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
494     const auto &val_sequence = input->cast<ValueSequencePtr>()->value();
495     const auto &abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
496     MS_EXCEPTION_IF_NULL(abs_sequence);
497     if (val_sequence.size() != abs_sequence->size()) {
498       MS_LOG(EXCEPTION) << "Get value sequence size " << val_sequence.size() << " not equal to abstract size "
499                         << abs_sequence->size();
500     }
501     for (size_t i = 0; i < val_sequence.size(); ++i) {
502       (void)inputs.emplace_back(BuildKNodeForCNodeInput(val_sequence[i], abs_sequence->elements()[i]));
503     }
504     auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
505     k_node->set_abstract(abs);
506     return k_node;
507   }
508   auto value_node = NewValueNode(input);
509   value_node->set_abstract(abs);
510   return value_node;
511 }
512 
BuildKNodeListForHighOrderGraph(const ValuePtrList & input_value,const abstract::AbstractBasePtrList & input_abs,AnfNodePtrList * const node_list)513 void IrGrad::BuildKNodeListForHighOrderGraph(const ValuePtrList &input_value,
514                                              const abstract::AbstractBasePtrList &input_abs,
515                                              AnfNodePtrList *const node_list) {
516   for (size_t i = 0; i < input_value.size(); ++i) {
517     const auto knode = BuildKNodeForCNodeInput(input_value[i], input_abs[i]);
518     // Convert value sequence to make tuple, so that finalpass can eliminate tuplegetitem.
519     // BuildKnodeForTuplgeGetItem now do not support input is valuesequence.
520     if (knode->isa<ValueNode>()) {
521       auto value_node = knode->cast<ValueNodePtr>();
522       (void)node_list->emplace_back(
523         PyNativeAlgo::Common::ConvertValueSequenceToMakeTuple(value_node, ad_param()->tape_));
524     } else {
525       (void)node_list->emplace_back(knode);
526     }
527 
528     MS_LOG(DEBUG) << "Get knode for input:  " << PyNativeAlgo::Common::GetIdByValue(input_value[i]);
529   }
530 }
531 
SetKNodeInfo(const ValuePtr & value,const AnfNodePtr & k_node,const AbstractBasePtr & out_abs)532 void IrGrad::SetKNodeInfo(const ValuePtr &value, const AnfNodePtr &k_node, const AbstractBasePtr &out_abs) {
533   if (value->isa<tensor::BaseTensor>()) {
534     auto tensor = value->cast<tensor::BaseTensorPtr>();
535     auto auto_grad_meta_data = tensor->auto_grad_meta_data();
536     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
537     auto_grad_meta_data->set_k_node(k_node);
538     (void)k_nodes_used_in_graph_.emplace_back(k_node);
539   } else if (value->isa<ValueSequence>()) {
540     const auto &value_sequence = value->cast<ValueSequencePtr>()->value();
541     const auto &abs_seq = out_abs->cast<abstract::AbstractSequencePtr>();
542     MS_EXCEPTION_IF_NULL(abs_seq);
543     if (abs_seq->dynamic_len()) {
544       return;
545     }
546     if (value_sequence.size() != abs_seq->size()) {
547       MS_LOG(EXCEPTION) << "Get value sequence size " << value_sequence.size() << " not equal to abstract size "
548                         << abs_seq->size();
549     }
550     for (size_t i = 0; i < value_sequence.size(); ++i) {
551       auto sub_k_node = ad_param()->tape_->FuncGraph::NewCNode(
552         {NewValueNode(prim::kPrimTupleGetItem), k_node, NewValueNode(static_cast<int64_t>(i))});
553       sub_k_node->set_abstract(abs_seq->elements()[i]);
554       SetKNodeInfo(value_sequence[i], sub_k_node, abs_seq->elements()[i]);
555     }
556   }
557 }
558 
BuildKNode(const AnfNodePtr & prim,const GradParamPtr & grad_param,bool from_single_op)559 AnfNodePtr IrGrad::BuildKNode(const AnfNodePtr &prim, const GradParamPtr &grad_param, bool from_single_op) {
560   MS_EXCEPTION_IF_NULL(grad_param);
561   AnfNodePtrList node_list;
562   (void)node_list.emplace_back(prim);
563   for (size_t i = 0; i < grad_param->input_size; ++i) {
564     (void)node_list.emplace_back(
565       BuildKNodeForCNodeInput(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]));
566   }
567   auto k_node = ad_param()->tape_->FuncGraph::NewCNode(node_list);
568   k_node->set_abstract(grad_param->op_grad_info->out_abs);
569   k_node->AddAttr(bprop_pass::kIsKNode, MakeValue(true));
570   if (from_single_op && grad_param->out_used_in_bporp_graph) {
571     auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(grad_param->op_grad_info->out_value,
572                                                                grad_param->op_grad_info->out_abs);
573     k_node->set_forward(v_node, "");
574     ad_param()->tape_->set_used_forward_nodes({k_node});
575   }
576   MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
577   return k_node;
578 }
579 
UpdateSensParameter(const ValuePtr & value)580 void IrGrad::UpdateSensParameter(const ValuePtr &value) {
581   MS_EXCEPTION_IF_NULL(value);
582   if (value->isa<tensor::BaseTensor>()) {
583     const auto &sens_tensor = value->cast<tensor::BaseTensorPtr>();
584     const auto &auto_grad_meta_data = sens_tensor->auto_grad_meta_data();
585     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
586     const auto variable = auto_grad_meta_data->variable();
587     // Return input parameter or weight parameter for net, if v is parameter just entry once
588     if (auto_grad_meta_data->input_type() == InputType::kParameter && variable == nullptr) {
589       (void)ir_bprop_->AddParameterNode(sens_tensor,
590                                         PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_tensor->ToAbstract()));
591     }
592   } else if (value->isa<ValueSequence>()) {
593     const auto &value_seq = value->cast<ValueSequencePtr>()->value();
594     for (const auto &v : value_seq) {
595       UpdateSensParameter(v);
596     }
597   } else if (value->isa<ValueDictionary>()) {
598     auto dic_v = value->cast<ValueDictionaryPtr>();
599     for (const auto &v : dic_v->value()) {
600       UpdateSensParameter(v.second);
601     }
602   }
603 }
604 
ExtractParameter(const tensor::BaseTensorPtr & tensor) const605 ParameterPtr IrGrad::ExtractParameter(const tensor::BaseTensorPtr &tensor) const {
606   MS_EXCEPTION_IF_NULL(tensor);
607   const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
608   if (auto_grad_meta_data != nullptr && PyNativeAlgo::Common::IsParam(auto_grad_meta_data->input_type())) {
609     return auto_grad_meta_data->parameter();
610   }
611   return nullptr;
612 }
613 
SetSensAndWeights(const tensor::BaseTensorPtrList & weights,bool has_sens_arg)614 void IrGrad::SetSensAndWeights(const tensor::BaseTensorPtrList &weights, bool has_sens_arg) {
615   const auto &sens_abstract = ir_bprop_->BuildForwardLastNode();
616   ParameterPtr sens_param = nullptr;
617   if (has_sens_arg) {
618     sens_param = ad_param()->tape_->add_parameter();
619     sens_param->set_name(sens_param->UniqueName());
620     sens_param->debug_info()->set_name("sens");
621     sens_param->set_abstract(sens_abstract);
622   }
623   // Update dout for dout
624   MS_EXCEPTION_IF_NULL(ad_param()->last_variable_);
625   if (ad_param()->last_variable_->is_need_grad()) {
626     if (has_sens_arg) {
627       ad_param()->last_variable_->ir_function_node()->UpdateAccumulativeDout(sens_param);
628     } else {
629       ad_param()->last_variable_->ir_function_node()->UpdateAccumulativeDout(PyNativeAlgo::AutoGrad::BuildSpecialNode(
630         ad_param()->tape_, ad_param()->sens_value_, sens_abstract, SpecialType::kOnesLikeType));
631     }
632   }
633   // Add weights parameter
634   need_grad_weights_.reserve(weights.size());
635   for (const auto &weight_tensor : weights) {
636     (void)need_grad_weights_.emplace(weight_tensor->id());
637     UpdateTapeParameter(weight_tensor);
638   }
639   for (auto &weight : ad_param_->weights_used_in_graph_) {
640     auto tensor = PyNativeAlgo::Common::GetTensorFromParam(weight);
641     MS_EXCEPTION_IF_NULL(tensor);
642     if (need_grad_weights_.find(tensor->id()) == need_grad_weights_.end()) {
643       UpdateTapeParameter(tensor);
644     }
645   }
646 }
647 
GetGradNodeByIndex(const tensor::BaseTensorPtr & tensor)648 AnfNodePtr IrGrad::GetGradNodeByIndex(const tensor::BaseTensorPtr &tensor) {
649   MS_EXCEPTION_IF_NULL(tensor);
650   auto auto_grad_meta_data = tensor->auto_grad_meta_data();
651   MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
652   auto variable = auto_grad_meta_data->variable();
653   if (variable != nullptr && variable->is_need_grad()) {
654     // If weight used in the forward network, but requires_grad is false, return zero like.
655     if (tensor->param_info() != nullptr && !tensor->param_info()->requires_grad()) {
656       MS_LOG(INFO) << "weight participate in forward calculation, but requires_grad is false";
657       return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, tensor, nullptr, SpecialType::kZerosLikeType);
658     }
659     const auto &ir_variable = std::dynamic_pointer_cast<IrVariable>(variable);
660     MS_EXCEPTION_IF_NULL(ir_variable);
661     return ir_variable->RealDout();
662   }
663   MS_LOG(INFO) << "parameter does not need grad, tensor: " << PyNativeAlgo::Common::GetIdByValue(tensor);
664   return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, tensor, nullptr, SpecialType::kZerosLikeType);
665 }
666 
GetInputGrad(bool grad_all_inputs,bool get_by_position,const std::vector<size_t> & grad_position)667 AnfNodePtr IrGrad::GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position) {
668   std::vector<size_t> grad_pos_list;
669   if (get_by_position) {
670     grad_pos_list = grad_position;
671   } else if (grad_all_inputs) {
672     grad_pos_list.resize(cell_inputs_.size());
673     iota(grad_pos_list.begin(), grad_pos_list.end(), 0);
674   } else {
675     return nullptr;
676   }
677 
678   AnfNodePtrList inputs_grad_list{NewValueNode(prim::kPrimMakeTuple)};
679   AbstractBasePtrList inputs_grad_spec;
680   if (!cell_inputs_.empty()) {
681     for (size_t index : grad_pos_list) {
682       if (index >= cell_inputs_.size()) {
683         MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size.";
684       }
685       // Tuple, List, scalar will be ignored
686       if (!IsValidTensorInput(cell_inputs_[index].first->abstract())) {
687         MS_LOG(DEBUG) << "Get input node is not tensor "
688                       << ", abs " << cell_inputs_[index].first->abstract()->ToString();
689         continue;
690       }
691       auto ir_variable = std::dynamic_pointer_cast<IrVariable>(cell_inputs_[index].second);
692       MS_EXCEPTION_IF_NULL(ir_variable);
693       auto real_dout = ir_variable->RealDout();
694       MS_EXCEPTION_IF_NULL(real_dout);
695       (void)inputs_grad_list.emplace_back(real_dout);
696       (void)inputs_grad_spec.emplace_back(real_dout->abstract());
697     }
698     constexpr size_t single_pos_size = 1;
699     if (get_by_position && inputs_grad_spec.size() == single_pos_size) {
700       // First elem is prim
701       return inputs_grad_list[single_pos_size];
702     }
703   }
704   auto input_grad_ret = ad_param()->tape_->FuncGraph::NewCNode(inputs_grad_list);
705   input_grad_ret->set_abstract(std::make_shared<abstract::AbstractTuple>(inputs_grad_spec));
706   return input_grad_ret;
707 }
708 
GetWeightGrad(bool grad_weights,const tensor::BaseTensorPtrList & weights,bool weight_param_is_tuple)709 AnfNodePtr IrGrad::GetWeightGrad(bool grad_weights, const tensor::BaseTensorPtrList &weights,
710                                  bool weight_param_is_tuple) {
711   // No need to return gradient of weights.
712   if (!grad_weights) {
713     return nullptr;
714   }
715   if (weight_param_is_tuple) {
716     AnfNodePtrList weights_grad_list{NewValueNode(prim::kPrimMakeTuple)};
717     AbstractBasePtrList weights_grad_spec;
718     for (const auto &weight : weights) {
719       auto grad_node = GetGradNodeByIndex(weight);
720       MS_EXCEPTION_IF_NULL(grad_node);
721       (void)weights_grad_list.emplace_back(grad_node);
722       (void)weights_grad_spec.emplace_back(grad_node->abstract());
723     }
724     auto weight_grad_ret = ad_param()->tape_->FuncGraph::NewCNode(weights_grad_list);
725     weight_grad_ret->set_abstract(std::make_shared<abstract::AbstractTuple>(weights_grad_spec));
726     return weight_grad_ret;
727   } else {
728     return GetGradNodeByIndex(weights[0]);
729   }
730 }
731 
SetOutput(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)732 void IrGrad::SetOutput(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
733                        const GradAttr &grad_attr) {
734   auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
735   auto weights_grad_ret = GetWeightGrad(grad_attr.grad_weights, weights, grad_attr.weight_param_is_tuple);
736   // Gradients wrt inputs and weights.
737   if (inputs_grad_ret != nullptr && weights_grad_ret != nullptr) {
738     if (IsOutputBothEmpty(inputs_grad_ret, weights_grad_ret)) {
739       auto tape_output = GenerateEmptyTupleValue();
740       ad_param()->tape_->set_output(tape_output);
741     } else {
742       auto tape_output =
743         ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple), inputs_grad_ret, weights_grad_ret});
744       tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(
745         abstract::AbstractBasePtrList{inputs_grad_ret->abstract(), weights_grad_ret->abstract()}));
746       ad_param()->tape_->set_output(tape_output);
747     }
748     return;
749   }
750   // Gradients wrt inputs.
751   if (inputs_grad_ret != nullptr) {
752     ad_param()->tape_->set_output(inputs_grad_ret);
753     return;
754   }
755   // Gradients wrt weights.
756   if (weights_grad_ret != nullptr) {
757     ad_param()->tape_->set_output(weights_grad_ret);
758     return;
759   }
760   // grad_all_inputs, grad_weights and get_by_position are all false.
761   AnfNodePtr tape_output = nullptr;
762   if (cell_inputs_.empty()) {
763     // If no input nodes, return empty tuple.
764     tape_output = ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple)});
765     abstract::AbstractBasePtrList abs{};
766     tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
767   } else {
768     // If there are input nodes, return gradient of first input node.
769     // Tuple, List, scalar will be ignore
770     if (IsValidTensorInput(cell_inputs_[0].first->abstract())) {
771       auto ir_variable = std::dynamic_pointer_cast<IrVariable>(cell_inputs_[kIndex0].second);
772       MS_EXCEPTION_IF_NULL(ir_variable);
773       tape_output = ir_variable->RealDout();
774     } else {
775       MS_LOG(DEBUG) << "Get first input node is not tensor " << cell_inputs_[0].second->out_value()->ToString();
776       tape_output = NewValueNode(kNull);
777       tape_output->set_abstract(nullptr);
778     }
779   }
780   ad_param()->tape_->set_output(tape_output);
781 }
782 
ElimateTupleGetItem()783 void IrGrad::ElimateTupleGetItem() {
784   for (auto &user : ad_param()->users_.tuple_getitem_user_) {
785     auto old_node = user.first;
786     auto old_cnode = old_node->cast<CNodePtr>();
787     MS_EXCEPTION_IF_NULL(old_cnode);
788     auto tuple_node = old_cnode->input(kIndex1);
789     if (!IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
790       continue;
791     }
792     auto index_value = GetValueNode<Int64ImmPtr>(old_cnode->input(kIndex2));
793     size_t index = LongToSize(index_value->value());
794     auto tuple_cnode = tuple_node->cast<CNodePtr>();
795     ir_bprop_->Replace(old_node, tuple_cnode->input(index + 1), &ad_param()->users_.tuple_getitem_user_);
796   }
797 }
798 
DoParameterReplaceByManager(bool has_sens_arg)799 void IrGrad::DoParameterReplaceByManager(bool has_sens_arg) {
800   const auto &parameters = ad_param()->tape_->parameters();
801   auto cell_inputs_size = cell_inputs_.size();
802   auto mng = MakeManager({ad_param()->tape_}, false);
803   auto tr = mng->Transact();
804   for (size_t i = 0; i < cell_inputs_size; ++i) {
805     (void)tr.Replace(cell_inputs_[i].first, parameters[i]);
806   }
807   // (Inputs, sens, weights) or (Inputs, weights)
808   size_t weight_offset = cell_inputs_size;
809   if (has_sens_arg) {
810     weight_offset = weight_offset + 1;
811   }
812   for (size_t i = weight_offset; i < parameters.size(); ++i) {
813     auto tensor = PyNativeAlgo::Common::GetTensorFromParam(parameters[i]);
814     MS_EXCEPTION_IF_NULL(tensor);
815     auto parameter = ExtractParameter(tensor);
816     MS_EXCEPTION_IF_NULL(parameter);
817     (void)tr.Replace(parameter, parameters[i]);
818   }
819   tr.Commit();
820 }
821 
DoParameterReplaceByUser(bool has_sens_arg,expander::bprop::UserType * user)822 void IrGrad::DoParameterReplaceByUser(bool has_sens_arg, expander::bprop::UserType *user) {
823   MS_EXCEPTION_IF_NULL(user);
824   const auto &parameters = ad_param()->tape_->parameters();
825   auto cell_inputs_size = cell_inputs_.size();
826   for (size_t i = 0; i < cell_inputs_size; ++i) {
827     ir_bprop_->Replace(cell_inputs_[i].first, parameters[i], user);
828   }
829   size_t weight_offset = cell_inputs_size;
830   if (has_sens_arg) {
831     weight_offset = weight_offset + 1;
832   }
833   for (size_t i = weight_offset; i < parameters.size(); ++i) {
834     auto tensor = PyNativeAlgo::Common::GetTensorFromParam(parameters[i]);
835     MS_EXCEPTION_IF_NULL(tensor);
836     auto parameter = ExtractParameter(tensor);
837     MS_EXCEPTION_IF_NULL(parameter);
838     ir_bprop_->Replace(parameter, parameters[i], user);
839   }
840 }
841 
ReplacePrimalParameter(bool has_sens_arg)842 void IrGrad::ReplacePrimalParameter(bool has_sens_arg) {
843   PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", ad_param()->tape_);
844   if (need_do_manager_replace_ || ad_param()->tape_->has_flag(kFlagIsControlFlow)) {
845     MS_LOG(DEBUG) << "Do parameter replace by manager.";
846     DoParameterReplaceByManager(has_sens_arg);
847     need_do_manager_replace_ = false;
848   } else {
849     MS_LOG(DEBUG) << "Do parameter replace by user.";
850     DoParameterReplaceByUser(has_sens_arg, &ad_param()->users_.dout_user_);
851   }
852   if (!ad_param()->reverse_users_.empty()) {
853     DoParameterReplaceByUser(has_sens_arg, &ad_param()->reverse_users_);
854   }
855   ElimateTupleGetItem();
856 }
857 
UpdateTapeParameter(const tensor::BaseTensorPtr & tensor)858 void IrGrad::UpdateTapeParameter(const tensor::BaseTensorPtr &tensor) {
859   auto p = ad_param()->tape_->add_parameter();
860   auto param = ExtractParameter(tensor);
861   if (param == nullptr) {
862     param =
863       ir_bprop_->CreateTapeParameter(tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
864   }
865   MS_EXCEPTION_IF_NULL(param);
866   const auto &param_info = tensor->param_info();
867   if (param_info != nullptr) {
868     const auto &param_name = param_info->name();
869     p->set_name(param_name);
870     p->debug_info()->set_name(param_name);
871   }
872   TraceGuard trace_guard(std::make_shared<TraceCopy>(p->debug_info()));
873   p->set_default_param(tensor);
874   p->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
875 }
876 }  // namespace autograd
877 }  // namespace pynative
878 }  // namespace mindspore
879