1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2022-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/pynative/grad/ir/ir_grad.h"
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/expander/bprop/bprop.h"
26 #include "frontend/optimizer/ad/dfunctor.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "include/common/profiler.h"
30 #include "ir/anf.h"
31 #include "ir/func_graph_cloner.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
34 #include "pipeline/pynative/pynative_utils.h"
35 #include "utils/info.h"
36 #include "utils/profile.h"
37
38 namespace mindspore {
39 namespace pynative {
40 namespace autograd {
41 namespace {
SetJitCallGraph(const CNodePtr & cnode,const FuncGraphPtr & call_graph,const std::string & cache_key,const GraphCallCondition & graph_call_condition)42 void SetJitCallGraph(const CNodePtr &cnode, const FuncGraphPtr &call_graph, const std::string &cache_key,
43 const GraphCallCondition &graph_call_condition) {
44 MS_EXCEPTION_IF_NULL(cnode);
45 common::AnfAlgo::SetNodeAttr(kAttrJitCallNode, MakeValue(true), cnode);
46 auto graph_call_back = PyNativeAlgo::AutoGrad::CreateGraphCallBack(call_graph, cache_key, graph_call_condition);
47 cnode->set_user_data<JitCallGraph>(std::make_shared<JitCallGraph>(graph_call_back));
48 }
49
IsOutputBothEmpty(const AnfNodePtr & inputs_grad,const AnfNodePtr & weights_grad)50 bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
51 if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
52 return false;
53 }
54 auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
55 auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
56 if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
57 !IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
58 return false;
59 }
60 constexpr int kEmptyTupeSize = 1;
61 if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
62 return false;
63 }
64 return true;
65 }
66
GenerateEmptyTupleValue()67 AnfNodePtr GenerateEmptyTupleValue() {
68 std::vector<ValuePtr> value_list;
69 auto inputs_value = std::make_shared<ValueTuple>(value_list);
70 auto weights_value = std::make_shared<ValueTuple>(value_list);
71 std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
72 auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
73 return PyNativeAlgo::Common::CreateValueNodeByValue(tuple_value);
74 }
75
IsValidTensorInput(const abstract::AbstractBasePtr & abs)76 bool IsValidTensorInput(const abstract::AbstractBasePtr &abs) {
77 MS_EXCEPTION_IF_NULL(abs);
78 return abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSparseTensor>();
79 }
80
GetTupleItemNodeInput(const KernelGraphPtr & tape,const AnfNodePtr & node)81 AnfNodePtr GetTupleItemNodeInput(const KernelGraphPtr &tape, const AnfNodePtr &node) {
82 MS_EXCEPTION_IF_NULL(tape);
83 MS_EXCEPTION_IF_NULL(node);
84 auto cnode = node->cast<CNodePtr>();
85 MS_EXCEPTION_IF_NULL(cnode);
86 AnfNodePtr new_cnode = nullptr;
87 if (IsPrimitive(cnode->input(kIndex1), prim::kPrimTupleGetItem)) {
88 auto inner_cnode = cnode->input(kIndex1)->cast<CNodePtr>();
89 new_cnode = tape->FuncGraph::NewCNode(
90 {inner_cnode->input(kIndex0), GetTupleItemNodeInput(tape, inner_cnode), inner_cnode->input(kIndex2)});
91 } else {
92 AnfNodePtrList new_inputs{cnode->inputs().begin(), cnode->inputs().end()};
93 new_cnode = tape->FuncGraph::NewCNode(new_inputs);
94 }
95 MS_EXCEPTION_IF_NULL(new_cnode);
96 new_cnode->set_abstract(cnode->abstract());
97 return new_cnode;
98 }
99
IsConstant(const ValuePtr & value)100 bool IsConstant(const ValuePtr &value) {
101 MS_EXCEPTION_IF_NULL(value);
102 if (value->isa<tensor::BaseTensor>()) {
103 const auto &tensor = value->cast<tensor::BaseTensorPtr>();
104 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
105 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
106 if (auto_grad_meta_data->input_type() == InputType::kParameter ||
107 auto_grad_meta_data->input_type() == InputType::kInput) {
108 return false;
109 }
110 auto k_node = auto_grad_meta_data->k_node();
111 if (k_node != nullptr) {
112 return false;
113 }
114 return true;
115 } else if (value->isa<ValueSequence>()) {
116 auto val_seq = value->cast<ValueSequencePtr>();
117 return std::all_of(val_seq->value().begin(), val_seq->value().end(),
118 [](const ValuePtr &value) { return IsConstant(value); });
119 } else if (value->isa<tensor::COOTensor>()) {
120 auto coo_tensor = value->cast<tensor::COOTensorPtr>();
121 return IsConstant(coo_tensor->GetIndices());
122 } else if (value->isa<tensor::CSRTensor>()) {
123 auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
124 return IsConstant(csr_tensor->GetIndices());
125 }
126 return true;
127 }
128 } // namespace
129
HyperAdd(const AnfNodePtr & left_node,const AnfNodePtr & right_node)130 AnfNodePtr IrFunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
131 MS_EXCEPTION_IF_NULL(left_node);
132 MS_EXCEPTION_IF_NULL(right_node);
133
134 if (PyNativeAlgo::AutoGrad::IsZerosLikeNode(left_node)) {
135 return right_node;
136 }
137 if (PyNativeAlgo::AutoGrad::IsZerosLikeNode(right_node)) {
138 return left_node;
139 }
140 if (!IsPrimitiveCNode(left_node, prim::kPrimMakeTuple)) {
141 auto add_result = tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimAdd), left_node, right_node});
142 add_result->set_abstract(right_node->abstract());
143 return add_result;
144 }
145 if (IsPrimitiveCNode(left_node, prim::kPrimMakeTuple) && IsPrimitiveCNode(right_node, prim::kPrimMakeTuple)) {
146 auto left_cnode = left_node->cast<CNodePtr>();
147 auto right_cnode = right_node->cast<CNodePtr>();
148 MS_EXCEPTION_IF_NULL(right_cnode);
149 AnfNodePtrList inputs = {NewValueNode(prim::kPrimMakeTuple)};
150 AbstractBasePtrList abs;
151 for (size_t i = 1; i < left_cnode->size(); ++i) {
152 auto add_result = HyperAdd(left_cnode->input(i), right_cnode->input(i));
153 (void)abs.emplace_back(add_result->abstract());
154 (void)inputs.emplace_back(add_result);
155 }
156 auto add_tuple = tape_->FuncGraph::NewCNode(inputs);
157 add_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
158 return add_tuple;
159 }
160 MS_LOG(EXCEPTION) << "Unknown cnode type" << left_node->DebugString();
161 }
162
AddNextEdge(const VariablePtr & next_variable,const AnfNodePtr & din)163 void IrFunctionNode::AddNextEdge(const VariablePtr &next_variable, const AnfNodePtr &din) {
164 MS_EXCEPTION_IF_NULL(next_variable);
165 MS_EXCEPTION_IF_NULL(din);
166 // next_node and its corresponding din
167 (void)next_edges_.emplace_back(next_variable, din);
168 if (din == fake_dout_) {
169 (void)need_replace_edges_.emplace_back(next_edges_.size() - 1);
170 }
171 }
172
UpdateAccumulativeDout(const AnfNodePtr & new_dout)173 void IrFunctionNode::UpdateAccumulativeDout(const AnfNodePtr &new_dout) {
174 MS_EXCEPTION_IF_NULL(new_dout);
175 accumulate_dout_ = HyperAdd(accumulate_dout_, new_dout);
176 }
177
ReplaceEdges()178 void IrFunctionNode::ReplaceEdges() {
179 MS_EXCEPTION_IF_NULL(accumulate_dout_);
180 for (const auto index : need_replace_edges_) {
181 next_edges_[index].second = accumulate_dout_;
182 }
183 }
184
IrGrad(const std::vector<ValuePtr> & input_param_values,const AbstractBasePtrList & abs_list,size_t op_num_in_bprop_graph,const runtime::AsyncHqueuePtr & assist_queue,bool grad_by_value,bool is_run_recompute)185 IrGrad::IrGrad(const std::vector<ValuePtr> &input_param_values, const AbstractBasePtrList &abs_list,
186 size_t op_num_in_bprop_graph, const runtime::AsyncHqueuePtr &assist_queue, bool grad_by_value,
187 bool is_run_recompute)
188 : ad_param_(std::make_shared<AdParam>()) {
189 ad_param()->tape_->debug_info()->set_name("grad_top");
190 MS_LOG(DEBUG) << "Start IrGrad, input size: " << input_param_values.size();
191 ad_param()->variable_adjoint_set_.reserve(op_num_in_bprop_graph);
192 ad_param()->anfnode_to_variable_adjoint_.reserve(op_num_in_bprop_graph);
193 ad_param()->users_.dout_user_.reserve(op_num_in_bprop_graph);
194 ad_param()->weights_used_in_graph_.reserve(op_num_in_bprop_graph);
195
196 for (size_t i = 0; i < input_param_values.size(); ++i) {
197 auto input_parameter = ad_param()->fg_->add_parameter();
198 input_parameter->set_abstract(abs_list[i]);
199 input_parameter->set_name(input_parameter->UniqueName());
200 TraceGuard trace_guard(std::make_shared<TraceCopy>(input_parameter->debug_info()));
201 auto tape_parameter = ad_param()->tape_->add_parameter();
202 tape_parameter->set_abstract(abs_list[i]);
203
204 auto zeros_like_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
205 ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), abs_list[i], SpecialType::kZerosLikeType);
206 auto func_node = std::make_shared<IrFunctionNode>(ad_param()->tape_, zeros_like_dout);
207 auto input_adjoint = std::make_shared<IrVariable>(func_node, input_param_values[i], true);
208
209 if (!input_param_values[i]->isa<ValueSequence>()) {
210 PyNativeAlgo::AutoGrad::SetGradInfoForInputs(input_param_values[i], input_adjoint, input_parameter);
211 } else {
212 input_adjoint->set_is_need_grad(false);
213 }
214 (void)cell_inputs_.emplace_back(input_parameter, input_adjoint);
215 (void)ad_param()->variable_adjoint_set_.insert(input_adjoint);
216 }
217
218 assist_queue_ = assist_queue;
219 grad_by_value_ = grad_by_value;
220 device_target_ = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
221 ir_bprop_ = std::make_unique<IrBprop>(ad_param_, device_target_, grad_by_value_, is_run_recompute);
222 }
223
KPynativeOp(const GradParamPtr & grad_param)224 bool IrGrad::KPynativeOp(const GradParamPtr &grad_param) {
225 MS_EXCEPTION_IF_NULL(grad_param);
226
227 auto &prim = grad_param->op_grad_info->op_prim;
228 if (!PyNativeAlgo::AutoGrad::IsPrimNeedGrad(prim) ||
229 (grad_by_value_ && !PyNativeAlgo::AutoGrad::NeedGrad(grad_param->op_grad_info->input_value))) {
230 MS_LOG(DEBUG) << "Prim " << prim->name() << " does not need to do op grad.";
231 return true;
232 }
233
234 auto cloned_value = grad_param->op_grad_info->out_value;
235 if (grad_param->op_grad_info->out_value->isa<ValueSequence>()) {
236 cloned_value = ShallowCopyTensorValue(grad_param->op_grad_info->out_value);
237 PyNativeAlgo::Common::ClearDeviceAddress(cloned_value);
238 }
239
240 PyNativeAlgo::AutoGrad::CheckAndSetAbstract(grad_param->op_grad_info);
241 // construct zeroslike placeholder, if need use in bprop, we replace it in backprogate.
242 AnfNodePtr dout =
243 PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
244 grad_param->op_grad_info->out_abs, SpecialType::kZerosLikeType);
245 auto fn = std::make_shared<IrFunctionNode>(ad_param()->tape_, dout);
246 auto variable_adjoint = std::make_shared<IrVariable>(fn, cloned_value);
247 // Custom forward cnode no need record in bprop graph, because it is a flag cnode for run python. So just create
248 // bprop_cut grad op is ok
249 bool is_custom_prim =
250 IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
251 AnfNodePtr k_node = nullptr;
252 if (!grad_by_value_ && !is_custom_prim) {
253 k_node = BuildKNode(NewValueNode(prim), grad_param, true);
254 SetKNodeInfo(grad_param->op_grad_info->out_value, k_node, grad_param->op_grad_info->out_abs);
255 need_do_manager_replace_ = true;
256 }
257 CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint, k_node, is_custom_prim);
258 MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
259 // Gradient outputs
260 std::vector<CNodePtr> outputs;
261 if (!is_custom_prim) {
262 auto ret = BpropExpander(&outputs, &ad_param()->users_).Run(input_node, grad_param->op_grad_info->input_value);
263 // cppcheck-suppress unreadVariable
264 if (MS_UNLIKELY(!ret || outputs.empty())) {
265 MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << prim->name();
266 ir_bprop_->BuildCustomBpropCNode(input_node, prim, &outputs);
267 }
268 } else {
269 PyNativeAlgo::AutoGrad::CheckRecomputeInputs(grad_param);
270 ir_bprop_->BuildBPropCutCNode(input_node, prim, &outputs, grad_param->op_grad_info->is_need_recompute);
271 }
272 // cppcheck-suppress unreadVariable
273 if (MS_UNLIKELY(outputs.empty())) {
274 MS_LOG(DEBUG) << "This op has not custom bprop: " << prim->name();
275 PyNativeAlgo::AutoGrad::BuildFakeBpropCNode(input_node, &outputs);
276 variable_adjoint->set_is_fake_bprop(true);
277 variable_adjoint->set_fake_prim_name(prim->name());
278 }
279 (void)ad_param()->variable_adjoint_set_.insert(variable_adjoint);
280 PyNativeAlgo::AutoGrad::SetGradMetaData(grad_param->op_grad_info->out_value, variable_adjoint);
281 ir_bprop_->UpdateNextEdges(variable_adjoint, outputs, grad_param->op_grad_info->input_value,
282 grad_param->op_grad_info->input_abs, prim->name());
283 return true;
284 }
285
KPynativeWithFProp(const GradParamPtr & grad_param)286 bool IrGrad::KPynativeWithFProp(const GradParamPtr &grad_param) {
287 MS_EXCEPTION_IF_NULL(grad_param);
288 MS_LOG(DEBUG) << "Do KPynativeWithFProp";
289 AnfNodePtrList args_node_list;
290 CNodePtr bprop_cnode = nullptr;
291 AnfNodePtr k_node = nullptr;
292 AnfNodePtr dout = nullptr;
293 if (grad_by_value_) {
294 for (size_t i = 0; i < grad_param->input_size; ++i) {
295 if (PyNativeAlgo::Common::IsParam(grad_param->op_grad_info->input_value_grad_type[i])) {
296 auto parameter =
297 ir_bprop_->MapParameter(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]);
298 MS_EXCEPTION_IF_NULL(parameter);
299 (void)args_node_list.emplace_back(parameter);
300 continue;
301 }
302 // Valuenode, node
303 const auto value_node = PyNativeAlgo::Common::CreateValueNodeByValue(
304 grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]->Clone());
305 auto cnode = PyNativeAlgo::Common::ConvertValueSequenceToMakeTuple(value_node, ad_param()->tape_);
306 (void)args_node_list.emplace_back(cnode);
307 }
308 bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
309 } else {
310 k_node = BuildKNode(NewValueNode(grad_param->source_fg), grad_param, false);
311 BuildKNodeListForHighOrderGraph(grad_param->op_grad_info->input_value, grad_param->op_grad_info->input_abs,
312 &args_node_list);
313 bprop_cnode = GetBpropGraphCNode(grad_param, args_node_list, &dout);
314 }
315 auto fn = std::make_shared<IrFunctionNode>(ad_param()->tape_, dout);
316 auto variable_adjoint = std::make_shared<IrVariable>(fn, grad_param->op_grad_info->out_value);
317 variable_adjoint->set_k_node(k_node);
318 std::vector<CNodePtr> outputs;
319 for (size_t i = 0; i < grad_param->input_size; ++i) {
320 CNodePtr din = ad_param()->tape_->FuncGraph::NewCNode(
321 {NewValueNode(prim::kPrimTupleGetItem), bprop_cnode, NewValueNode(SizeToLong(i))});
322 din->set_abstract(grad_param->op_grad_info->input_abs[i]);
323 (void)outputs.emplace_back(din);
324 }
325 ir_bprop_->UpdateNextEdges(variable_adjoint, outputs, grad_param->op_grad_info->input_value,
326 grad_param->op_grad_info->input_abs);
327 (void)ad_param()->variable_adjoint_set_.insert(variable_adjoint);
328 (void)ad_param()->anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
329 PyNativeAlgo::AutoGrad::SetGradMetaData(grad_param->op_grad_info->out_value, variable_adjoint);
330 SetKNodeInfo(grad_param->op_grad_info->out_value, k_node, grad_param->op_grad_info->out_abs);
331 return true;
332 }
333
GetBPropCNode(const GradParamPtr & grad_param,const AnfNodePtrList & args,const FuncGraphPtr & bprop_graph,bool cache_hit,AnfNodePtr * const tape_dout)334 CNodePtr IrGrad::GetBPropCNode(const GradParamPtr &grad_param, const AnfNodePtrList &args,
335 const FuncGraphPtr &bprop_graph, bool cache_hit, AnfNodePtr *const tape_dout) {
336 AnfNodePtrList bprop_inputs(args.begin(), args.end());
337 bool is_jit_dynamic_shape = grad_param->is_jit_graph && grad_param->use_dynamic_shape_process;
338 // Save replace info in first time
339 if (!cache_hit && is_jit_dynamic_shape && grad_param->has_added_v) {
340 const auto &jit = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor()->jit();
341 jit->SaveForwardOutputTensorInfoInBpropGraph(bprop_graph);
342 }
343
344 // Call by tape_
345 MS_EXCEPTION_IF_NULL(tape_dout);
346 *tape_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
347 grad_param->op_grad_info->out_abs, SpecialType::kZerosLikeType);
348 if (is_jit_dynamic_shape && grad_param->op_grad_info->out_abs->isa<abstract::AbstractSequence>()) {
349 auto abs_seq = grad_param->op_grad_info->out_abs->cast<abstract::AbstractSequencePtr>();
350 // Dynamic len has no size current
351 if (!abs_seq->dynamic_len()) {
352 for (size_t i = 0; i < abs_seq->size(); ++i) {
353 CNodePtr din = ad_param()->tape_->FuncGraph::NewCNode(
354 {NewValueNode(prim::kPrimTupleGetItem), *tape_dout, NewValueNode(SizeToLong(i))});
355 din->set_abstract(abs_seq->elements()[i]);
356 (void)bprop_inputs.emplace_back(din);
357 ir_bprop_->AddUser(*tape_dout, din, kIndex1);
358 }
359 }
360 } else {
361 (void)bprop_inputs.emplace_back(*tape_dout);
362 }
363 (void)bprop_inputs.insert(bprop_inputs.cbegin(), NewValueNode(bprop_graph));
364 // get_bprop is a call node
365 auto bprop_cnode = ad_param()->tape_->FuncGraph::NewCNode(bprop_inputs);
366 bprop_cnode->set_abstract(bprop_graph->output()->abstract());
367 if (is_jit_dynamic_shape) {
368 GraphCallCondition graph_call_condition{grad_param->is_control_flow, grad_param->is_jit_graph,
369 grad_param->use_dynamic_shape_process, false, false};
370 SetJitCallGraph(bprop_cnode, bprop_graph, grad_param->graph_cache_key, graph_call_condition);
371 ad_param()->tape_->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
372 }
373 // For replacing parameter and dout.
374 for (size_t i = 1; i < bprop_inputs.size(); ++i) {
375 ir_bprop_->AddUser(bprop_inputs[i], bprop_cnode, i);
376 }
377 return bprop_cnode;
378 }
379
GetBpropGraphCNode(const GradParamPtr & grad_param,const AnfNodePtrList & args,AnfNodePtr * const tape_dout)380 CNodePtr IrGrad::GetBpropGraphCNode(const GradParamPtr &grad_param, const AnfNodePtrList &args,
381 AnfNodePtr *const tape_dout) {
382 MS_EXCEPTION_IF_NULL(grad_param);
383 auto [cache_hit, bprop_graph] = ir_bprop_->GetBpropGraph(grad_param);
384 if (grad_param->is_control_flow || grad_param->is_jit_self_dynamic_shape) {
385 need_do_manager_replace_ = true;
386 }
387 return GetBPropCNode(grad_param, args, bprop_graph, cache_hit, tape_dout);
388 }
389
UpdateOutputNodeOfTopCell(const ValuePtr & sens_out)390 void IrGrad::UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) {
391 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative,
392 runtime::ProfilerEvent::kPyNativeGradUpdateSens,
393 runtime::ProfilerRecorder::kNoName, true);
394 MS_EXCEPTION_IF_NULL(sens_out);
395 MS_LOG(DEBUG) << "Real output of top cell is " << PyNativeAlgo::Common::GetIdByValue(sens_out);
396 ad_param()->sens_value_ = sens_out;
397 UpdateSensParameter(ad_param()->sens_value_);
398 }
399
Finish(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)400 FuncGraphPtr IrGrad::Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
401 const GradAttr &grad_attr) {
402 // Set sens node and weights node
403 SetSensAndWeights(weights, grad_attr.has_sens);
404
405 // BackPropagate sensitivity, except when the last node is a valuenode which may be obtained by constant folding;
406 if (ad_param()->last_variable_->is_need_grad() && !ad_param()->last_variable_->is_leaf()) {
407 ir_bprop_->BackPropagate();
408 }
409 SetOutput(weights, grad_position, grad_attr);
410 // Replace Parameter of primal func graph with parameter of ad_param()->tape_;
411 ReplacePrimalParameter(grad_attr.has_sens);
412 PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", ad_param()->tape_);
413 // Clear weights grad info
414 for (const auto &weight : weights) {
415 weight->set_auto_grad_meta_data(nullptr);
416 }
417 return ad_param()->tape_;
418 }
419
ConstructBpropGraphInput(const GradParamPtr & grad_param,const AnfNodePtr & dout,const VariablePtr & variable_adjoint,const AnfNodePtr & k_node,bool is_custom_prim)420 CNodePtr IrGrad::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
421 const VariablePtr &variable_adjoint, const AnfNodePtr &k_node,
422 bool is_custom_prim) {
423 MS_EXCEPTION_IF_NULL(grad_param);
424 AnfNodePtrList node_list;
425 (void)node_list.emplace_back(NewValueNode(grad_param->op_grad_info->op_prim));
426 if (grad_by_value_ || is_custom_prim) {
427 for (size_t i = 0; i < grad_param->input_size; ++i) {
428 if (PyNativeAlgo::Common::IsParam(grad_param->op_grad_info->input_value_grad_type[i])) {
429 // To solve the input is a tuple like (parameter, ...)
430 auto parameter =
431 ir_bprop_->MapParameter(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]);
432 MS_EXCEPTION_IF_NULL(parameter);
433 (void)node_list.emplace_back(parameter);
434 continue;
435 }
436 // Node abstract obj may free, so v node abstract will be not correct
437 (void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(
438 grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]->Clone()));
439 }
440 // Hook run by single op
441 if (!ir_bprop_->bprop_graph_run_by_single_op()) {
442 ir_bprop()->set_bprop_graph_run_by_single_op([&grad_param]() {
443 auto tensor = grad_param->op_grad_info->out_value->template cast<tensor::BaseTensorPtr>();
444 if (tensor == nullptr) {
445 return false;
446 }
447 auto auto_grad_meta = tensor->auto_grad_meta_data();
448 MS_EXCEPTION_IF_NULL(auto_grad_meta);
449 return auto_grad_meta->is_register_hook();
450 }());
451 }
452 // Set out
453 (void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(grad_param->op_grad_info->out_value,
454 grad_param->op_grad_info->out_abs));
455 } else {
456 // Input is a Parameter or cnode, not a value node
457 BuildKNodeListFromPrimalCNode(grad_param->op_grad_info->input_value, grad_param->op_grad_info->input_abs,
458 &node_list);
459 // Set out
460 MS_EXCEPTION_IF_NULL(variable_adjoint);
461 (void)node_list.emplace_back(k_node);
462 }
463 // Set dout
464 (void)node_list.emplace_back(dout);
465 auto input_node = ad_param()->tape_->FuncGraph::NewCNode(node_list);
466 return input_node;
467 }
468
BuildKNodeListFromPrimalCNode(const ValuePtrList & input_value,const abstract::AbstractBasePtrList & input_abs,AnfNodePtrList * const node_list)469 void IrGrad::BuildKNodeListFromPrimalCNode(const ValuePtrList &input_value,
470 const abstract::AbstractBasePtrList &input_abs,
471 AnfNodePtrList *const node_list) {
472 for (size_t i = 0; i < input_value.size(); ++i) {
473 (void)node_list->emplace_back(BuildKNodeForCNodeInput(input_value[i], input_abs[i]));
474 MS_LOG(DEBUG) << "Get knode for input: " << PyNativeAlgo::Common::GetIdByValue(input_value[i]);
475 }
476 }
477
BuildKNodeForCNodeInput(const ValuePtr & input,const abstract::AbstractBasePtr & abs)478 AnfNodePtr IrGrad::BuildKNodeForCNodeInput(const ValuePtr &input, const abstract::AbstractBasePtr &abs) {
479 MS_EXCEPTION_IF_NULL(input);
480 if (input->isa<tensor::BaseTensor>()) {
481 const auto &tensor = input->cast<tensor::BaseTensorPtr>();
482 const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
483 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
484 auto k_node = auto_grad_meta_data->k_node();
485 if (k_node != nullptr) {
486 return k_node;
487 }
488 if (PyNativeAlgo::Common::IsParam(auto_grad_meta_data->input_type())) {
489 return ir_bprop_->MapParameter(input, abs);
490 }
491 } else if (input->isa<ValueSequence>() && !IsConstant(input)) {
492 AnfNodePtrList inputs;
493 (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
494 const auto &val_sequence = input->cast<ValueSequencePtr>()->value();
495 const auto &abs_sequence = abs->cast<abstract::AbstractSequencePtr>();
496 MS_EXCEPTION_IF_NULL(abs_sequence);
497 if (val_sequence.size() != abs_sequence->size()) {
498 MS_LOG(EXCEPTION) << "Get value sequence size " << val_sequence.size() << " not equal to abstract size "
499 << abs_sequence->size();
500 }
501 for (size_t i = 0; i < val_sequence.size(); ++i) {
502 (void)inputs.emplace_back(BuildKNodeForCNodeInput(val_sequence[i], abs_sequence->elements()[i]));
503 }
504 auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
505 k_node->set_abstract(abs);
506 return k_node;
507 }
508 auto value_node = NewValueNode(input);
509 value_node->set_abstract(abs);
510 return value_node;
511 }
512
BuildKNodeListForHighOrderGraph(const ValuePtrList & input_value,const abstract::AbstractBasePtrList & input_abs,AnfNodePtrList * const node_list)513 void IrGrad::BuildKNodeListForHighOrderGraph(const ValuePtrList &input_value,
514 const abstract::AbstractBasePtrList &input_abs,
515 AnfNodePtrList *const node_list) {
516 for (size_t i = 0; i < input_value.size(); ++i) {
517 const auto knode = BuildKNodeForCNodeInput(input_value[i], input_abs[i]);
518 // Convert value sequence to make tuple, so that finalpass can eliminate tuplegetitem.
519 // BuildKnodeForTuplgeGetItem now do not support input is valuesequence.
520 if (knode->isa<ValueNode>()) {
521 auto value_node = knode->cast<ValueNodePtr>();
522 (void)node_list->emplace_back(
523 PyNativeAlgo::Common::ConvertValueSequenceToMakeTuple(value_node, ad_param()->tape_));
524 } else {
525 (void)node_list->emplace_back(knode);
526 }
527
528 MS_LOG(DEBUG) << "Get knode for input: " << PyNativeAlgo::Common::GetIdByValue(input_value[i]);
529 }
530 }
531
SetKNodeInfo(const ValuePtr & value,const AnfNodePtr & k_node,const AbstractBasePtr & out_abs)532 void IrGrad::SetKNodeInfo(const ValuePtr &value, const AnfNodePtr &k_node, const AbstractBasePtr &out_abs) {
533 if (value->isa<tensor::BaseTensor>()) {
534 auto tensor = value->cast<tensor::BaseTensorPtr>();
535 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
536 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
537 auto_grad_meta_data->set_k_node(k_node);
538 (void)k_nodes_used_in_graph_.emplace_back(k_node);
539 } else if (value->isa<ValueSequence>()) {
540 const auto &value_sequence = value->cast<ValueSequencePtr>()->value();
541 const auto &abs_seq = out_abs->cast<abstract::AbstractSequencePtr>();
542 MS_EXCEPTION_IF_NULL(abs_seq);
543 if (abs_seq->dynamic_len()) {
544 return;
545 }
546 if (value_sequence.size() != abs_seq->size()) {
547 MS_LOG(EXCEPTION) << "Get value sequence size " << value_sequence.size() << " not equal to abstract size "
548 << abs_seq->size();
549 }
550 for (size_t i = 0; i < value_sequence.size(); ++i) {
551 auto sub_k_node = ad_param()->tape_->FuncGraph::NewCNode(
552 {NewValueNode(prim::kPrimTupleGetItem), k_node, NewValueNode(static_cast<int64_t>(i))});
553 sub_k_node->set_abstract(abs_seq->elements()[i]);
554 SetKNodeInfo(value_sequence[i], sub_k_node, abs_seq->elements()[i]);
555 }
556 }
557 }
558
BuildKNode(const AnfNodePtr & prim,const GradParamPtr & grad_param,bool from_single_op)559 AnfNodePtr IrGrad::BuildKNode(const AnfNodePtr &prim, const GradParamPtr &grad_param, bool from_single_op) {
560 MS_EXCEPTION_IF_NULL(grad_param);
561 AnfNodePtrList node_list;
562 (void)node_list.emplace_back(prim);
563 for (size_t i = 0; i < grad_param->input_size; ++i) {
564 (void)node_list.emplace_back(
565 BuildKNodeForCNodeInput(grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]));
566 }
567 auto k_node = ad_param()->tape_->FuncGraph::NewCNode(node_list);
568 k_node->set_abstract(grad_param->op_grad_info->out_abs);
569 k_node->AddAttr(bprop_pass::kIsKNode, MakeValue(true));
570 if (from_single_op && grad_param->out_used_in_bporp_graph) {
571 auto v_node = PyNativeAlgo::Common::CreateValueNodeByValue(grad_param->op_grad_info->out_value,
572 grad_param->op_grad_info->out_abs);
573 k_node->set_forward(v_node, "");
574 ad_param()->tape_->set_used_forward_nodes({k_node});
575 }
576 MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
577 return k_node;
578 }
579
UpdateSensParameter(const ValuePtr & value)580 void IrGrad::UpdateSensParameter(const ValuePtr &value) {
581 MS_EXCEPTION_IF_NULL(value);
582 if (value->isa<tensor::BaseTensor>()) {
583 const auto &sens_tensor = value->cast<tensor::BaseTensorPtr>();
584 const auto &auto_grad_meta_data = sens_tensor->auto_grad_meta_data();
585 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
586 const auto variable = auto_grad_meta_data->variable();
587 // Return input parameter or weight parameter for net, if v is parameter just entry once
588 if (auto_grad_meta_data->input_type() == InputType::kParameter && variable == nullptr) {
589 (void)ir_bprop_->AddParameterNode(sens_tensor,
590 PyNativeAlgo::Common::SetAbstractValueToAnyValue(sens_tensor->ToAbstract()));
591 }
592 } else if (value->isa<ValueSequence>()) {
593 const auto &value_seq = value->cast<ValueSequencePtr>()->value();
594 for (const auto &v : value_seq) {
595 UpdateSensParameter(v);
596 }
597 } else if (value->isa<ValueDictionary>()) {
598 auto dic_v = value->cast<ValueDictionaryPtr>();
599 for (const auto &v : dic_v->value()) {
600 UpdateSensParameter(v.second);
601 }
602 }
603 }
604
ExtractParameter(const tensor::BaseTensorPtr & tensor) const605 ParameterPtr IrGrad::ExtractParameter(const tensor::BaseTensorPtr &tensor) const {
606 MS_EXCEPTION_IF_NULL(tensor);
607 const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
608 if (auto_grad_meta_data != nullptr && PyNativeAlgo::Common::IsParam(auto_grad_meta_data->input_type())) {
609 return auto_grad_meta_data->parameter();
610 }
611 return nullptr;
612 }
613
SetSensAndWeights(const tensor::BaseTensorPtrList & weights,bool has_sens_arg)614 void IrGrad::SetSensAndWeights(const tensor::BaseTensorPtrList &weights, bool has_sens_arg) {
615 const auto &sens_abstract = ir_bprop_->BuildForwardLastNode();
616 ParameterPtr sens_param = nullptr;
617 if (has_sens_arg) {
618 sens_param = ad_param()->tape_->add_parameter();
619 sens_param->set_name(sens_param->UniqueName());
620 sens_param->debug_info()->set_name("sens");
621 sens_param->set_abstract(sens_abstract);
622 }
623 // Update dout for dout
624 MS_EXCEPTION_IF_NULL(ad_param()->last_variable_);
625 if (ad_param()->last_variable_->is_need_grad()) {
626 if (has_sens_arg) {
627 ad_param()->last_variable_->ir_function_node()->UpdateAccumulativeDout(sens_param);
628 } else {
629 ad_param()->last_variable_->ir_function_node()->UpdateAccumulativeDout(PyNativeAlgo::AutoGrad::BuildSpecialNode(
630 ad_param()->tape_, ad_param()->sens_value_, sens_abstract, SpecialType::kOnesLikeType));
631 }
632 }
633 // Add weights parameter
634 need_grad_weights_.reserve(weights.size());
635 for (const auto &weight_tensor : weights) {
636 (void)need_grad_weights_.emplace(weight_tensor->id());
637 UpdateTapeParameter(weight_tensor);
638 }
639 for (auto &weight : ad_param_->weights_used_in_graph_) {
640 auto tensor = PyNativeAlgo::Common::GetTensorFromParam(weight);
641 MS_EXCEPTION_IF_NULL(tensor);
642 if (need_grad_weights_.find(tensor->id()) == need_grad_weights_.end()) {
643 UpdateTapeParameter(tensor);
644 }
645 }
646 }
647
GetGradNodeByIndex(const tensor::BaseTensorPtr & tensor)648 AnfNodePtr IrGrad::GetGradNodeByIndex(const tensor::BaseTensorPtr &tensor) {
649 MS_EXCEPTION_IF_NULL(tensor);
650 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
651 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
652 auto variable = auto_grad_meta_data->variable();
653 if (variable != nullptr && variable->is_need_grad()) {
654 // If weight used in the forward network, but requires_grad is false, return zero like.
655 if (tensor->param_info() != nullptr && !tensor->param_info()->requires_grad()) {
656 MS_LOG(INFO) << "weight participate in forward calculation, but requires_grad is false";
657 return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, tensor, nullptr, SpecialType::kZerosLikeType);
658 }
659 const auto &ir_variable = std::dynamic_pointer_cast<IrVariable>(variable);
660 MS_EXCEPTION_IF_NULL(ir_variable);
661 return ir_variable->RealDout();
662 }
663 MS_LOG(INFO) << "parameter does not need grad, tensor: " << PyNativeAlgo::Common::GetIdByValue(tensor);
664 return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param()->tape_, tensor, nullptr, SpecialType::kZerosLikeType);
665 }
666
GetInputGrad(bool grad_all_inputs,bool get_by_position,const std::vector<size_t> & grad_position)667 AnfNodePtr IrGrad::GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position) {
668 std::vector<size_t> grad_pos_list;
669 if (get_by_position) {
670 grad_pos_list = grad_position;
671 } else if (grad_all_inputs) {
672 grad_pos_list.resize(cell_inputs_.size());
673 iota(grad_pos_list.begin(), grad_pos_list.end(), 0);
674 } else {
675 return nullptr;
676 }
677
678 AnfNodePtrList inputs_grad_list{NewValueNode(prim::kPrimMakeTuple)};
679 AbstractBasePtrList inputs_grad_spec;
680 if (!cell_inputs_.empty()) {
681 for (size_t index : grad_pos_list) {
682 if (index >= cell_inputs_.size()) {
683 MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size.";
684 }
685 // Tuple, List, scalar will be ignored
686 if (!IsValidTensorInput(cell_inputs_[index].first->abstract())) {
687 MS_LOG(DEBUG) << "Get input node is not tensor "
688 << ", abs " << cell_inputs_[index].first->abstract()->ToString();
689 continue;
690 }
691 auto ir_variable = std::dynamic_pointer_cast<IrVariable>(cell_inputs_[index].second);
692 MS_EXCEPTION_IF_NULL(ir_variable);
693 auto real_dout = ir_variable->RealDout();
694 MS_EXCEPTION_IF_NULL(real_dout);
695 (void)inputs_grad_list.emplace_back(real_dout);
696 (void)inputs_grad_spec.emplace_back(real_dout->abstract());
697 }
698 constexpr size_t single_pos_size = 1;
699 if (get_by_position && inputs_grad_spec.size() == single_pos_size) {
700 // First elem is prim
701 return inputs_grad_list[single_pos_size];
702 }
703 }
704 auto input_grad_ret = ad_param()->tape_->FuncGraph::NewCNode(inputs_grad_list);
705 input_grad_ret->set_abstract(std::make_shared<abstract::AbstractTuple>(inputs_grad_spec));
706 return input_grad_ret;
707 }
708
GetWeightGrad(bool grad_weights,const tensor::BaseTensorPtrList & weights,bool weight_param_is_tuple)709 AnfNodePtr IrGrad::GetWeightGrad(bool grad_weights, const tensor::BaseTensorPtrList &weights,
710 bool weight_param_is_tuple) {
711 // No need to return gradient of weights.
712 if (!grad_weights) {
713 return nullptr;
714 }
715 if (weight_param_is_tuple) {
716 AnfNodePtrList weights_grad_list{NewValueNode(prim::kPrimMakeTuple)};
717 AbstractBasePtrList weights_grad_spec;
718 for (const auto &weight : weights) {
719 auto grad_node = GetGradNodeByIndex(weight);
720 MS_EXCEPTION_IF_NULL(grad_node);
721 (void)weights_grad_list.emplace_back(grad_node);
722 (void)weights_grad_spec.emplace_back(grad_node->abstract());
723 }
724 auto weight_grad_ret = ad_param()->tape_->FuncGraph::NewCNode(weights_grad_list);
725 weight_grad_ret->set_abstract(std::make_shared<abstract::AbstractTuple>(weights_grad_spec));
726 return weight_grad_ret;
727 } else {
728 return GetGradNodeByIndex(weights[0]);
729 }
730 }
731
SetOutput(const tensor::BaseTensorPtrList & weights,const std::vector<size_t> & grad_position,const GradAttr & grad_attr)732 void IrGrad::SetOutput(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
733 const GradAttr &grad_attr) {
734 auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
735 auto weights_grad_ret = GetWeightGrad(grad_attr.grad_weights, weights, grad_attr.weight_param_is_tuple);
736 // Gradients wrt inputs and weights.
737 if (inputs_grad_ret != nullptr && weights_grad_ret != nullptr) {
738 if (IsOutputBothEmpty(inputs_grad_ret, weights_grad_ret)) {
739 auto tape_output = GenerateEmptyTupleValue();
740 ad_param()->tape_->set_output(tape_output);
741 } else {
742 auto tape_output =
743 ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple), inputs_grad_ret, weights_grad_ret});
744 tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(
745 abstract::AbstractBasePtrList{inputs_grad_ret->abstract(), weights_grad_ret->abstract()}));
746 ad_param()->tape_->set_output(tape_output);
747 }
748 return;
749 }
750 // Gradients wrt inputs.
751 if (inputs_grad_ret != nullptr) {
752 ad_param()->tape_->set_output(inputs_grad_ret);
753 return;
754 }
755 // Gradients wrt weights.
756 if (weights_grad_ret != nullptr) {
757 ad_param()->tape_->set_output(weights_grad_ret);
758 return;
759 }
760 // grad_all_inputs, grad_weights and get_by_position are all false.
761 AnfNodePtr tape_output = nullptr;
762 if (cell_inputs_.empty()) {
763 // If no input nodes, return empty tuple.
764 tape_output = ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimMakeTuple)});
765 abstract::AbstractBasePtrList abs{};
766 tape_output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
767 } else {
768 // If there are input nodes, return gradient of first input node.
769 // Tuple, List, scalar will be ignore
770 if (IsValidTensorInput(cell_inputs_[0].first->abstract())) {
771 auto ir_variable = std::dynamic_pointer_cast<IrVariable>(cell_inputs_[kIndex0].second);
772 MS_EXCEPTION_IF_NULL(ir_variable);
773 tape_output = ir_variable->RealDout();
774 } else {
775 MS_LOG(DEBUG) << "Get first input node is not tensor " << cell_inputs_[0].second->out_value()->ToString();
776 tape_output = NewValueNode(kNull);
777 tape_output->set_abstract(nullptr);
778 }
779 }
780 ad_param()->tape_->set_output(tape_output);
781 }
782
ElimateTupleGetItem()783 void IrGrad::ElimateTupleGetItem() {
784 for (auto &user : ad_param()->users_.tuple_getitem_user_) {
785 auto old_node = user.first;
786 auto old_cnode = old_node->cast<CNodePtr>();
787 MS_EXCEPTION_IF_NULL(old_cnode);
788 auto tuple_node = old_cnode->input(kIndex1);
789 if (!IsPrimitiveCNode(tuple_node, prim::kPrimMakeTuple)) {
790 continue;
791 }
792 auto index_value = GetValueNode<Int64ImmPtr>(old_cnode->input(kIndex2));
793 size_t index = LongToSize(index_value->value());
794 auto tuple_cnode = tuple_node->cast<CNodePtr>();
795 ir_bprop_->Replace(old_node, tuple_cnode->input(index + 1), &ad_param()->users_.tuple_getitem_user_);
796 }
797 }
798
DoParameterReplaceByManager(bool has_sens_arg)799 void IrGrad::DoParameterReplaceByManager(bool has_sens_arg) {
800 const auto ¶meters = ad_param()->tape_->parameters();
801 auto cell_inputs_size = cell_inputs_.size();
802 auto mng = MakeManager({ad_param()->tape_}, false);
803 auto tr = mng->Transact();
804 for (size_t i = 0; i < cell_inputs_size; ++i) {
805 (void)tr.Replace(cell_inputs_[i].first, parameters[i]);
806 }
807 // (Inputs, sens, weights) or (Inputs, weights)
808 size_t weight_offset = cell_inputs_size;
809 if (has_sens_arg) {
810 weight_offset = weight_offset + 1;
811 }
812 for (size_t i = weight_offset; i < parameters.size(); ++i) {
813 auto tensor = PyNativeAlgo::Common::GetTensorFromParam(parameters[i]);
814 MS_EXCEPTION_IF_NULL(tensor);
815 auto parameter = ExtractParameter(tensor);
816 MS_EXCEPTION_IF_NULL(parameter);
817 (void)tr.Replace(parameter, parameters[i]);
818 }
819 tr.Commit();
820 }
821
DoParameterReplaceByUser(bool has_sens_arg,expander::bprop::UserType * user)822 void IrGrad::DoParameterReplaceByUser(bool has_sens_arg, expander::bprop::UserType *user) {
823 MS_EXCEPTION_IF_NULL(user);
824 const auto ¶meters = ad_param()->tape_->parameters();
825 auto cell_inputs_size = cell_inputs_.size();
826 for (size_t i = 0; i < cell_inputs_size; ++i) {
827 ir_bprop_->Replace(cell_inputs_[i].first, parameters[i], user);
828 }
829 size_t weight_offset = cell_inputs_size;
830 if (has_sens_arg) {
831 weight_offset = weight_offset + 1;
832 }
833 for (size_t i = weight_offset; i < parameters.size(); ++i) {
834 auto tensor = PyNativeAlgo::Common::GetTensorFromParam(parameters[i]);
835 MS_EXCEPTION_IF_NULL(tensor);
836 auto parameter = ExtractParameter(tensor);
837 MS_EXCEPTION_IF_NULL(parameter);
838 ir_bprop_->Replace(parameter, parameters[i], user);
839 }
840 }
841
ReplacePrimalParameter(bool has_sens_arg)842 void IrGrad::ReplacePrimalParameter(bool has_sens_arg) {
843 PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", ad_param()->tape_);
844 if (need_do_manager_replace_ || ad_param()->tape_->has_flag(kFlagIsControlFlow)) {
845 MS_LOG(DEBUG) << "Do parameter replace by manager.";
846 DoParameterReplaceByManager(has_sens_arg);
847 need_do_manager_replace_ = false;
848 } else {
849 MS_LOG(DEBUG) << "Do parameter replace by user.";
850 DoParameterReplaceByUser(has_sens_arg, &ad_param()->users_.dout_user_);
851 }
852 if (!ad_param()->reverse_users_.empty()) {
853 DoParameterReplaceByUser(has_sens_arg, &ad_param()->reverse_users_);
854 }
855 ElimateTupleGetItem();
856 }
857
UpdateTapeParameter(const tensor::BaseTensorPtr & tensor)858 void IrGrad::UpdateTapeParameter(const tensor::BaseTensorPtr &tensor) {
859 auto p = ad_param()->tape_->add_parameter();
860 auto param = ExtractParameter(tensor);
861 if (param == nullptr) {
862 param =
863 ir_bprop_->CreateTapeParameter(tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
864 }
865 MS_EXCEPTION_IF_NULL(param);
866 const auto ¶m_info = tensor->param_info();
867 if (param_info != nullptr) {
868 const auto ¶m_name = param_info->name();
869 p->set_name(param_name);
870 p->debug_info()->set_name(param_name);
871 }
872 TraceGuard trace_guard(std::make_shared<TraceCopy>(p->debug_info()));
873 p->set_default_param(tensor);
874 p->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(tensor->ToAbstract()));
875 }
876 } // namespace autograd
877 } // namespace pynative
878 } // namespace mindspore
879