1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/ir/ir_bprop.h"
18 #include <string>
19 #include <vector>
20 #include <memory>
21 #include "pipeline/pynative/pynative_utils.h"
22 #include "include/common/utils/primitive_utils.h"
23 #include "pipeline/jit/ps/pass.h"
24 #include "ir/func_graph_cloner.h"
25 #include "ops/sequence_ops.h"
26 #include "ops/framework_ops.h"
27 #include "ops/structure_ops.h"
28 #include "ops/other_ops.h"
29
30 namespace mindspore::pynative::autograd {
31 namespace {
32 constexpr size_t kOutAndDoutNum = 2;
33 const mindspore::HashSet<std::string> kMonadOp = {kLoadOpName, kDependOpName, kUpdateStateOpName};
34 const mindspore::HashSet<std::string> kMetaFuncGraphOp{
35 kPyExecuteOpName,
36 kAttrMutableOpName,
37 kMakeDictOpName,
38 };
39 mindspore::HashMap<std::string, FuncGraphPtr> pass_grad_graph_;
40
OptimizeBpropBuilder(const FuncGraphPtr & bprop_func_graph,const GradParamPtr & grad_param)41 FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph, const GradParamPtr &grad_param) {
42 PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
43 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
44 resource->set_func_graph(bprop_func_graph);
45 auto manager = resource->manager();
46 MS_EXCEPTION_IF_NULL(manager);
47 manager->AddFuncGraph(bprop_func_graph);
48 auto after_opt_bg = pipeline::JitBpropGraphPass(resource, true);
49 auto is_dynamic_shape_control_flow =
50 grad_param->is_jit_graph && grad_param->use_dynamic_shape_process && grad_param->is_control_flow;
51 if (is_dynamic_shape_control_flow) {
52 for (const auto &g : manager->func_graphs()) {
53 g->set_flag(kFlagJitCallGraph, true);
54 }
55 }
56 auto abs_seq = after_opt_bg->parameters().empty()
57 ? nullptr
58 : after_opt_bg->parameters().back()->abstract()->cast<abstract::AbstractSequencePtr>();
59 if (abs_seq != nullptr && !abs_seq->dynamic_len() && grad_param->is_jit_graph &&
60 grad_param->use_dynamic_shape_process) {
61 PyNativeAlgo::Common::ProcessTupleParam(after_opt_bg, after_opt_bg->parameters().size() - kIndex1);
62 }
63 PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
64 return after_opt_bg;
65 }
66
ProcessMonadNode(const PrimitivePtr & prim,const CNodePtr & cnode,const GradParamPtr & grad_param)67 bool ProcessMonadNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param) {
68 MS_EXCEPTION_IF_NULL(prim);
69 if (kMonadOp.find(prim->name()) != kMonadOp.end()) {
70 MS_LOG(DEBUG) << "Get monad cnode " << cnode->DebugString();
71 return true;
72 }
73 if ((prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) &&
74 (cnode->inputs().back()->abstract()->isa<abstract::AbstractMonad>())) {
75 AnfNodePtrList inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
76 cnode->set_inputs(inputs);
77 }
78 MS_EXCEPTION_IF_NULL(grad_param);
79 // Jit graph contain monad op
80 if (grad_param->is_jit_graph) {
81 for (size_t i = 1; i < cnode->size(); ++i) {
82 cnode->set_input(i, common::AnfAlgo::VisitKernelWithReturnType(cnode->input(i), 0, false,
83 {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
84 .first);
85 }
86 }
87 return false;
88 }
89
ClearGradMetaData(const ValuePtr & value)90 void ClearGradMetaData(const ValuePtr &value) {
91 if (value->isa<tensor::BaseTensor>()) {
92 auto tensor = value->cast<tensor::BaseTensorPtr>();
93 tensor->set_auto_grad_meta_data(nullptr);
94 } else if (value->isa<ValueSequence>()) {
95 auto value_sequence = value->cast<ValueSequencePtr>();
96 for (const auto &val : value_sequence->value()) {
97 ClearGradMetaData(val);
98 }
99 }
100 }
101
102 // Handle bprob of op which input dtype is real number and output dtype is complex number.
103 // If the dtype of a gradient(din) is complex number and the input of that is real number,
104 // only the real part of the gradient make sense in back propagate. So we handle it by
105 // insert a Real() ops after the gradient.
106 // input: AnfNode with input of op which input dtype is real number and output dtype is complex number.
107 // din: CNodePtr with gradient of input.
108 // tape: Funcgraph witch input and din belong to.
109 // return: New din with inserted real op if necessarily.
HandleRealToComplex(const tensor::BaseTensorPtr & input,const AbstractBasePtr & abs,const AnfNodePtr & din,const KernelGraphPtr & tape)110 AnfNodePtr HandleRealToComplex(const tensor::BaseTensorPtr &input, const AbstractBasePtr &abs, const AnfNodePtr &din,
111 const KernelGraphPtr &tape) {
112 MS_EXCEPTION_IF_NULL(din);
113 TypePtr din_type = din->Type();
114 if (din_type == nullptr || !din_type->isa<TensorType>()) {
115 return din;
116 }
117 din_type = din_type->cast_ptr<TensorType>()->element();
118 MS_EXCEPTION_IF_NULL(din_type);
119 // cppcheck-suppress unreadVariable
120 if (MS_LIKELY(din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128)) {
121 return din;
122 }
123
124 MS_EXCEPTION_IF_NULL(input);
125 TypePtr input_type = input->Dtype();
126 if (input_type == nullptr) {
127 return din;
128 }
129 if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
130 return din;
131 }
132
133 AnfNodePtr new_din = tape->FuncGraph::NewCNode({NewValueNode(prim::kPrimReal), din});
134 AbstractBasePtr real_abs =
135 std::make_shared<abstract::AbstractTensor>(abstract::AbstractTensor(input_type, abs->GetShapeTrack()));
136 new_din->set_abstract(real_abs);
137 return new_din;
138 }
139
PlantFuncGradBpropGraphDout(const GradParamPtr & grad_param,const FuncGraphPtr & graph)140 void PlantFuncGradBpropGraphDout(const GradParamPtr &grad_param, const FuncGraphPtr &graph) {
141 MS_EXCEPTION_IF_NULL(graph);
142 MS_EXCEPTION_IF_NULL(grad_param);
143 if (!grad_param->is_func_grad) {
144 return;
145 }
146 // Plant dout tuple or dict
147 if (graph->parameters().back()->abstract()->isa<abstract::AbstractSequence>()) {
148 PyNativeAlgo::Common::ProcessTupleParam(graph, grad_param->input_size);
149 } else if (graph->parameters().back()->abstract()->isa<abstract::AbstractDictionary>()) {
150 PyNativeAlgo::Common::ProcessDictParam(graph, grad_param->input_size);
151 }
152 }
153 } // namespace
154
ClearAutoGradCache()155 void ClearAutoGradCache() {
156 pass_grad_graph_.clear();
157 bprop_pass::ClearCache();
158 PyNativeAlgo::AutoGrad::ClearAutoGradStaticCache();
159 }
160
GetBpropGraph(const GradParamPtr & grad_param)161 std::pair<bool, FuncGraphPtr> IrBprop::GetBpropGraph(const GradParamPtr &grad_param) {
162 MS_EXCEPTION_IF_NULL(grad_param);
163 const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
164 bool cache_hit = (it != pass_grad_graph_.end());
165 if (grad_param->is_control_flow || grad_param->is_jit_self_dynamic_shape) {
166 MS_LOG(DEBUG) << "Get control flow graph or dynamic shape";
167 return std::make_pair(cache_hit, GetBpropGraphFromFprop(grad_param));
168 }
169 return std::make_pair(cache_hit, GetBpropGraphFromExpander(grad_param));
170 }
171
BuildCustomBpropCNode(const CNodePtr & cnode,const PrimitivePtr & prim,std::vector<CNodePtr> * outputs)172 void IrBprop::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs) {
173 MS_EXCEPTION_IF_NULL(prim);
174 MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
175 {
176 py::gil_scoped_acquire gil;
177 auto prim_py = prim->cast<PrimitivePyPtr>();
178 if (prim_py == nullptr) {
179 MS_LOG(DEBUG) << "Prim is not PrimitivePy, can not find python bprop";
180 return;
181 }
182 py::function fn = prim_py->GetBpropFunction();
183 if (py::isinstance<py::none>(fn)) {
184 fn = GetBpropFunction(prim->name());
185 }
186 if (!fn || py::isinstance<py::none>(fn)) {
187 MS_LOG(INFO) << "Can not find bprop function for " << prim->name() << ". fn: " << ConvertPyObjToString(fn);
188 return;
189 }
190 (void)prim_py->AddBackwardHookFn(0, fn);
191 (void)prim_py->AddAttr("custom_op_bprop", MakeValue(true));
192 }
193 BuildBPropCutCNode(cnode, prim, outputs);
194 }
195
BuildBPropCutCNode(const CNodePtr & cnode,const PrimitivePtr & prim,std::vector<CNodePtr> * outputs,bool is_need_recompute)196 void IrBprop::BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs,
197 bool is_need_recompute) {
198 MS_EXCEPTION_IF_NULL(prim);
199 auto bprop_cut = PyNativeAlgo::AutoGrad::BuildBpropCutPrim(prim, is_need_recompute);
200
201 // Create gradient outputs cnode
202 AnfNodePtrList inputs{NewValueNode(bprop_cut)};
203 for (size_t i = 1; i < cnode->size() - kOutAndDoutNum; ++i) {
204 (void)inputs.emplace_back(cnode->input(i));
205 }
206 if (!is_need_recompute) {
207 // If not recompute, we should add out as bprop input.
208 (void)inputs.emplace_back(cnode->input(cnode->size() - kOutAndDoutNum));
209 }
210 (void)inputs.emplace_back(cnode->input(cnode->size() - 1));
211
212 auto bprop_cut_cnode = ad_param_->tape_->FuncGraph::NewCNode(inputs);
213 AbstractBasePtrList abs_list;
214 // Only add last input dout to user.
215 AddUser(cnode->input(cnode->size() - 1), bprop_cut_cnode, bprop_cut_cnode->size() - 1);
216 for (size_t i = 1; i < cnode->size() - kOutAndDoutNum; ++i) {
217 // Input may be parameter, we need add to user map.
218 AddUser(cnode->input(i), bprop_cut_cnode, i);
219 auto din = ad_param_->tape_->FuncGraph::NewCNode(
220 {NewValueNode(prim::kPrimTupleGetItem), bprop_cut_cnode, NewValueNode(static_cast<int64_t>(i - 1))});
221 MS_EXCEPTION_IF_NULL(cnode->input(i)->abstract());
222 din->set_abstract(cnode->input(i)->abstract());
223 (void)abs_list.emplace_back(cnode->input(i)->abstract());
224 (void)outputs->emplace_back(din);
225 }
226 bprop_cut_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
227 ad_param_->tape_->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
228 bprop_graph_run_by_single_op_ = true;
229 }
230
MapParameter(const ValuePtr & value,const abstract::AbstractBasePtr & abs)231 AnfNodePtr IrBprop::MapParameter(const ValuePtr &value, const abstract::AbstractBasePtr &abs) {
232 if (value->isa<tensor::BaseTensor>()) {
233 const auto &tensor = value->cast<tensor::BaseTensorPtr>();
234 const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
235 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
236 const auto ¶m = auto_grad_meta_data->parameter();
237 if (param != nullptr) {
238 // In dynamic shape scenario, abs my be need change
239 param->set_abstract(abs);
240 return param;
241 }
242 set_bprop_graph_run_by_single_op(auto_grad_meta_data->is_register_hook());
243 if (auto_grad_meta_data->input_type() == InputType::kParameter &&
244 PyNativeAlgo::Common::IsParamRequiresGrad(tensor)) {
245 return AddParameterNode(tensor, abs);
246 }
247 return PyNativeAlgo::Common::CreateValueNodeByValue(value, abs);
248 } else if (value->isa<ValueSequence>()) {
249 const auto &val_seq = value->cast<ValueSequencePtr>()->value();
250 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
251 MS_EXCEPTION_IF_NULL(abs_seq);
252 if (val_seq.size() != abs_seq->size()) {
253 MS_LOG(EXCEPTION) << "Get value sequence size " << val_seq.size() << " not equal to abstract size "
254 << abs_seq->size();
255 }
256 AnfNodePtrList inputs;
257 (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
258 for (size_t i = 0; i < val_seq.size(); ++i) {
259 (void)inputs.emplace_back(MapParameter(val_seq[i], abs_seq->elements()[i]));
260 }
261 auto cnode = ad_param_->tape_->FuncGraph::NewCNode(inputs);
262 // For replacing fg parameter by user
263 for (size_t i = 1; i < inputs.size(); ++i) {
264 AddUser(inputs[i], cnode, i);
265 }
266 cnode->set_abstract(abs);
267 return cnode;
268 } else if (value->isa<tensor::COOTensor>()) {
269 const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
270 return MapParameter(coo_tensor->GetIndices(), abs);
271 } else if (value->isa<tensor::CSRTensor>()) {
272 const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
273 return MapParameter(csr_tensor->GetIndices(), abs);
274 } else {
275 return PyNativeAlgo::Common::CreateValueNodeByValue(value, abs);
276 }
277 }
278
AddParameterNode(const tensor::BaseTensorPtr & tensor,const abstract::AbstractBasePtr & abs)279 ParameterPtr IrBprop::AddParameterNode(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs) {
280 MS_EXCEPTION_IF_NULL(tensor);
281 auto param = CreateTapeParameter(tensor, abs);
282 auto zeros_like_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
283 ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), param->abstract(), SpecialType::kZerosLikeType);
284 auto func_node = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_dout);
285 auto input_adjoint = std::make_shared<IrVariable>(func_node, tensor, true);
286 (void)ad_param_->variable_adjoint_set_.insert(input_adjoint);
287 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
288 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
289 auto_grad_meta_data->set_variable(input_adjoint);
290 (void)ad_param_->weights_used_in_graph_.emplace_back(param);
291 return param;
292 }
293
CreateTapeParameter(const tensor::BaseTensorPtr & tensor,const abstract::AbstractBasePtr & abs)294 ParameterPtr IrBprop::CreateTapeParameter(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs) {
295 MS_EXCEPTION_IF_NULL(tensor);
296 MS_EXCEPTION_IF_NULL(abs);
297 auto param = ad_param_->fg_->add_parameter();
298 param->set_abstract(abs);
299 if (tensor->is_parameter()) {
300 param->set_default_param(tensor);
301 }
302 auto auto_grad_meta_data = tensor->auto_grad_meta_data();
303 if (auto_grad_meta_data == nullptr) {
304 auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
305 tensor->set_auto_grad_meta_data(auto_grad_meta_data);
306 }
307 auto_grad_meta_data->set_input_type(InputType::kParameter);
308 auto_grad_meta_data->set_parameter(param);
309 return param;
310 }
311
UpdateNextEdges(const VariablePtr & variable,const std::vector<CNodePtr> & dins,const ValuePtrList & inputs_value,const abstract::AbstractBasePtrList & abs,const string & op_name)312 void IrBprop::UpdateNextEdges(const VariablePtr &variable, const std::vector<CNodePtr> &dins,
313 const ValuePtrList &inputs_value, const abstract::AbstractBasePtrList &abs,
314 const string &op_name) {
315 size_t input_size = inputs_value.size();
316 if (dins.size() != input_size) {
317 MS_LOG(EXCEPTION) << "The size of dins " << dins.size() << " is not same as input_value " << input_size;
318 }
319 const auto &fn = variable->ir_function_node();
320 for (size_t i = 0; i < input_size; ++i) {
321 auto din = dins[i];
322 MS_EXCEPTION_IF_NULL(din);
323 MS_LOG(DEBUG) << "Input arg id: " << PyNativeAlgo::Common::GetIdByValue(inputs_value[i]) << ", din "
324 << din->DebugString();
325 #ifndef ENABLE_TEST
326 // VM no need run pass
327 din = pass_forward_->PassForDin(din, op_name, false);
328 #endif
329 UpdateNextEdge(fn, din, inputs_value[i], abs[i]);
330 }
331 if (fn->next_edges().empty()) {
332 variable->set_is_need_grad(false);
333 }
334 MS_LOG(DEBUG) << "Finish update next edges for variable: " << variable->ToString();
335 }
336
AddUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)337 void IrBprop::AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
338 MS_EXCEPTION_IF_NULL(ad_param_);
339 (void)ad_param_->users_.dout_user_[node].emplace_back(user, index);
340 }
341
AddReverseUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)342 void IrBprop::AddReverseUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
343 (void)ad_param_->reverse_users_[node].emplace_back(user, index);
344 }
345
BackPropagate()346 void IrBprop::BackPropagate() {
347 UpdateLazyUser();
348 const auto &last_node_reverse_iter = GetLastNodeReverseIter();
349 #ifndef ENABLE_TEST
350 SeenNum seen = NewSeenGeneration();
351 #endif
352 MS_LOG(DEBUG) << "Is running recompute grad " << is_run_recompute_;
353 for (auto iter = last_node_reverse_iter; iter != ad_param_->variable_adjoint_set_.rend(); ++iter) {
354 const auto &variable = *iter;
355 if (!variable->is_need_propagate() || !variable->is_need_grad()) {
356 MS_LOG(DEBUG) << "No need grad, variable is: " << variable->ToString();
357 continue;
358 }
359 if (static_cast<bool>(MS_UNLIKELY(variable->is_fake_bprop()))) {
360 MS_LOG(EXCEPTION) << "Illegal primitive " << variable->fake_prim_name() << "'s bprop not defined";
361 }
362 MS_LOG(DEBUG) << "Begin backpropagate: " << variable->ToString();
363 const auto &fn = variable->ir_function_node();
364 // If zeroslike not used in funcgraph, we need replace the zeroslike placeholder with real zeroslike value.
365 if (static_cast<bool>(MS_UNLIKELY(PyNativeAlgo::AutoGrad::IsZerosLikeNode(fn->accumulate_dout())))) {
366 fn->set_accumulate_dout(PyNativeAlgo::AutoGrad::BuildSpecialNode(
367 fn->tape(), variable->out_value(), fn->accumulate_dout()->abstract(), SpecialType::kZerosLikeType));
368 }
369 // If register hook by weight, and weight in recompute cell.So, hook will execute, which is not expect.
370 if (!is_run_recompute_) {
371 fn->set_accumulate_dout(pass_forward_->PassBackwardHook(variable->out_value(), fn->accumulate_dout()));
372 }
373 // Replace real dout to fake dout, update replace result to eliminate tuplegetitem
374 // when accumulate_dout is tuplegetitem
375 Replace(fn->fake_dout(), fn->accumulate_dout(), &ad_param_->users_.dout_user_, true);
376 // replace edges which exist fake dout
377 fn->ReplaceEdges();
378 const auto &next_edges = fn->next_edges();
379 for (const auto &next_edge : next_edges) {
380 const auto &last_variable = next_edge.first;
381 const auto &din = next_edge.second;
382 #ifndef ENABLE_TEST
383 // VM no need run pass
384 pass_forward_->ConvertMakeTupleInputToDynamicInput(din, seen, bprop_graph_run_by_single_op_);
385 #endif
386 last_variable->ir_function_node()->UpdateAccumulativeDout(din);
387 last_variable->set_is_need_propagate(true);
388 }
389 }
390 MS_LOG(DEBUG) << "End BackPropagate";
391 }
392
GetLastNodeReverseIter()393 OrderedSet<IrVariablePtr>::reverse_iterator IrBprop::GetLastNodeReverseIter() {
394 for (auto iter = ad_param_->variable_adjoint_set_.rbegin(); iter != ad_param_->variable_adjoint_set_.rend(); ++iter) {
395 if (*iter == ad_param_->last_variable_) {
396 ad_param_->last_variable_->set_is_need_propagate(true);
397 return iter;
398 }
399 }
400 return ad_param_->variable_adjoint_set_.rend();
401 }
402
BuildForwardLastNode()403 AbstractBasePtr IrBprop::BuildForwardLastNode() {
404 MS_LOG(DEBUG) << "Process last node info " << PyNativeAlgo::Common::GetIdByValue(ad_param_->sens_value_);
405 auto zeros_like_node = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, ad_param_->sens_value_, nullptr,
406 SpecialType::kZerosLikeType);
407 auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_node);
408 auto sens_variable = std::make_shared<IrVariable>(fn, ad_param_->sens_value_);
409 if (ad_param_->sens_value_->isa<tensor::BaseTensor>()) {
410 const auto &sens_tensor = ad_param_->sens_value_->cast<tensor::BaseTensorPtr>();
411 const auto &auto_grad_meta_data = sens_tensor->auto_grad_meta_data();
412 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
413 if (PyNativeAlgo::Common::IsConstant(auto_grad_meta_data->input_type())) {
414 sens_variable->set_is_need_grad(false);
415 }
416 }
417 UpdateNextEdge(fn, zeros_like_node, ad_param_->sens_value_, fn->accumulate_dout()->abstract());
418 (void)ad_param_->variable_adjoint_set_.insert(sens_variable);
419 ad_param_->last_variable_ = sens_variable;
420 return fn->accumulate_dout()->abstract();
421 }
422
GetBpropGraphFromFprop(const GradParamPtr & grad_param)423 FuncGraphPtr IrBprop::GetBpropGraphFromFprop(const GradParamPtr &grad_param) {
424 MS_EXCEPTION_IF_NULL(grad_param);
425 FuncGraphPtr after_opt_fg = nullptr;
426 // Find ad graph in cache
427 const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
428 bool cache_hit = (it != pass_grad_graph_.end());
429 if (cache_hit) {
430 MS_LOG(DEBUG) << "Get ad grad graph by cache";
431 after_opt_fg = BasicClone(it->second);
432 } else {
433 auto bprop_builder = std::make_shared<FuncGraph>();
434 bprop_builder->debug_info()->set_name("bprop_builder");
435
436 AnfNodePtrList fprop_app_inputs{NewValueNode(grad_param->fg)};
437 for (const auto &abs : grad_param->op_grad_info->input_abs) {
438 auto param = bprop_builder->add_parameter();
439 param->set_abstract(abs);
440 (void)fprop_app_inputs.emplace_back(param);
441 }
442 auto fprop_app = bprop_builder->NewCNode(fprop_app_inputs);
443 // Get bprop from fprop_fg, it is 2th output of fprop_fg
444 auto get_bprop = bprop_builder->NewCNode(
445 {NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(kIndex1))});
446
447 AnfNodePtrList node_list{get_bprop};
448 auto dout = bprop_builder->add_parameter();
449 dout->set_abstract(grad_param->op_grad_info->out_abs);
450 (void)node_list.emplace_back(dout);
451 auto call_bprop = bprop_builder->NewCNode(node_list);
452
453 AnfNodePtrList actual_out{NewValueNode(prim::kPrimMakeTuple)};
454 for (size_t i = 0; i < grad_param->input_size; ++i) {
455 // Index 0 env, skip
456 auto out =
457 bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), call_bprop, NewValueNode(SizeToLong(i + 1))});
458 (void)actual_out.emplace_back(out);
459 }
460 bprop_builder->set_output(bprop_builder->NewCNode(actual_out));
461 // Call pass for optimize graph, such as inline
462 after_opt_fg = OptimizeBpropBuilder(bprop_builder, grad_param);
463 PlantFuncGradBpropGraphDout(grad_param, after_opt_fg);
464 if (grad_param->is_func_grad && grad_param->is_control_flow) {
465 after_opt_fg = LiftingClone(after_opt_fg);
466 }
467 if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
468 pass_grad_graph_[grad_param->graph_cache_key] = BasicClone(after_opt_fg);
469 }
470 }
471 return after_opt_fg;
472 }
473
GetBpropGraphFromExpander(const GradParamPtr & grad_param)474 FuncGraphPtr IrBprop::GetBpropGraphFromExpander(const GradParamPtr &grad_param) {
475 // Find ad graph in cache
476 if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
477 const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
478 if (it != pass_grad_graph_.end()) {
479 MS_LOG(DEBUG) << "Get ad grad graph by cache";
480 return BasicClone(it->second);
481 }
482 } else {
483 pass_grad_graph_.clear();
484 }
485
486 // Create new ad param for graph ad
487 PyNativeAlgo::Common::DumpGraphIR("ad_input_graph.ir", grad_param->fg);
488 auto current_ad_param = ad_param_;
489 ad_param_ = std::make_shared<AdParam>();
490 ad_param_->tape_->debug_info()->set_name("ad_graph");
491 bprop_graph_run_by_single_op_ = bprop_graph_run_by_single_op_ || grad_param->use_dynamic_shape_process;
492
493 GradGraphByExpander(grad_param);
494
495 if (ad_param_->last_node_ != nullptr) {
496 // Set dout parameter
497 const auto last_prim = GetCNodePrimitive(ad_param_->last_node_);
498 if (kMonadOp.find(last_prim->name()) != kMonadOp.end()) {
499 ad_param_->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(
500 ad_param_->last_node_, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
501 .first;
502 }
503 if (ad_param_->anfnode_to_variable_adjoint_.count(ad_param_->last_node_) == 0) {
504 MS_LOG(EXCEPTION) << "Can not find last node" << ad_param_->last_node_->DebugString();
505 }
506 ad_param_->last_variable_ = ad_param_->anfnode_to_variable_adjoint_[ad_param_->last_node_];
507 auto ad_graph_dout = ad_param_->tape_->add_parameter();
508 ad_graph_dout->set_abstract(ad_param_->last_node_->abstract());
509 ad_param_->last_variable_->ir_function_node()->UpdateAccumulativeDout(ad_graph_dout);
510 (void)BackPropagate();
511 } else {
512 // Just have a return node
513 auto ad_graph_dout = ad_param_->tape_->add_parameter();
514 ad_graph_dout->set_abstract(grad_param->fg->output()->abstract());
515 ad_graph_dout->debug_info()->set_name("sens");
516 ad_param_->sens_value_ = grad_param->op_grad_info->out_value;
517 (void)BuildForwardLastNode();
518 // Update dout
519 MS_EXCEPTION_IF_NULL(ad_param_->last_variable_);
520 if (ad_param_->last_variable_->is_need_grad()) {
521 ad_param_->last_variable_->ir_function_node()->UpdateAccumulativeDout(ad_graph_dout);
522 }
523 (void)BackPropagate();
524 }
525
526 AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
527 abstract::AbstractBasePtrList out_abs_list;
528 for (const auto &node : grad_param->fg->parameters()) {
529 (void)outputs.emplace_back(ad_param_->anfnode_to_variable_adjoint_.at(node)->RealDout());
530 (void)out_abs_list.emplace_back(outputs.back()->abstract());
531 }
532 auto ad_graph_out = ad_param_->tape_->FuncGraph::NewCNode(outputs);
533 ad_graph_out->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
534 ad_param_->tape_->set_output(ad_graph_out);
535 auto ad_graph = ad_param_->tape_;
536 auto abs_seq = ad_graph->parameters().empty()
537 ? nullptr
538 : ad_graph->parameters().back()->abstract()->cast<abstract::AbstractSequencePtr>();
539 if (abs_seq != nullptr && !abs_seq->dynamic_len() && grad_param->is_jit_graph &&
540 grad_param->use_dynamic_shape_process) {
541 auto manager = MakeManager();
542 MS_EXCEPTION_IF_NULL(manager);
543 manager->AddFuncGraph(ad_graph);
544 PyNativeAlgo::Common::ProcessTupleParam(ad_graph, ad_graph->parameters().size() - kIndex1);
545 }
546 PyNativeAlgo::Common::DumpGraphIR("ad_output_graph.ir", ad_graph);
547
548 // Plant dout tuple
549 PlantFuncGradBpropGraphDout(grad_param, ad_graph);
550
551 // Save ad graph in cache
552 if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
553 pass_grad_graph_[grad_param->graph_cache_key] = BasicClone(ad_graph);
554 }
555 // Replace cnode with valuenode for reduce compute
556 bool jit_by_value = grad_param->is_jit_graph && grad_by_value_;
557 if (jit_by_value) {
558 PyNativeAlgo::Common::ReplaceCNodeWithValueNode(ad_graph);
559 }
560 // Restore ad param
561 ad_param_ = current_ad_param;
562 return ad_graph;
563 }
564
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,expander::bprop::UserType * user,bool need_update)565 void IrBprop::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, expander::bprop::UserType *user,
566 bool need_update) {
567 MS_EXCEPTION_IF_NULL(user);
568 if (user->find(old_node) == user->end()) {
569 return;
570 }
571 const auto &old_node_users = (*user)[old_node];
572 for (const auto &pair_node : old_node_users) {
573 auto cnode = pair_node.first.lock();
574 if (cnode == nullptr) {
575 continue;
576 }
577 size_t index = pair_node.second;
578 if (index >= cnode->size()) {
579 // After convert attr cnode input will less
580 if (auto v = cnode->GetAttr(kAttrConvertAttrNode); v != nullptr) {
581 index -= GetValue<size_t>(v);
582 } else {
583 MS_LOG(EXCEPTION) << "exception for index: " << index << "greater than cnode size: " << cnode->size();
584 }
585 }
586 cnode->set_input(index, new_node);
587 if (need_update && IsPrimitiveCNode(new_node, prim::kPrimTupleGetItem)) {
588 AddTupleGetItemUser(new_node, cnode, index);
589 }
590 }
591 }
592
GradGraphByExpander(const GradParamPtr & grad_param)593 void IrBprop::GradGraphByExpander(const GradParamPtr &grad_param) {
594 MS_EXCEPTION_IF_NULL(grad_param);
595 if (pass_forward_->need_reverse_graph()) {
596 pass_forward_->ReversePassFuncGraph(grad_param->fg);
597 }
598
599 // First handle parameters
600 CreateParameterAdjoint(grad_param);
601
602 // Second handle cnodes
603 const auto &order = TopoSort(grad_param->fg->output());
604 for (const auto &node : order) {
605 if (node == nullptr || !node->isa<CNode>()) {
606 continue;
607 }
608 auto cnode = node->cast<CNodePtr>();
609 MS_EXCEPTION_IF_NULL(cnode);
610 auto prim = GetCNodePrimitive(cnode);
611 if (prim == nullptr) {
612 MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
613 }
614 ad_param_->last_node_ = cnode;
615 if (ProcessMonadNode(prim, cnode, grad_param) || IsPrimitiveEquals(prim, prim::kPrimStopGradient)) {
616 continue;
617 }
618 MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString() << ", " << cnode->fullname_with_scope();
619 ValuePtrList inputs_value;
620 AnfNodePtrList cnode_inputs;
621 PrepareGradCNodeInputs(prim, cnode, &inputs_value, &cnode_inputs);
622 // Do grad for every cnode
623 GradCNode(prim, cnode, grad_param, inputs_value, &cnode_inputs);
624 }
625 }
626
CreateParameterAdjoint(const GradParamPtr & grad_param) const627 void IrBprop::CreateParameterAdjoint(const GradParamPtr &grad_param) const {
628 auto &graph_parameters = grad_param->fg->parameters();
629 if (graph_parameters.size() != grad_param->input_size) {
630 MS_LOG(EXCEPTION) << "Parameters size " << graph_parameters.size() << " is not equal to graph input size "
631 << grad_param->input_size;
632 }
633 for (size_t i = 0; i < graph_parameters.size(); ++i) {
634 MS_LOG(DEBUG) << "Get param " << graph_parameters[i]->DebugString();
635 ParameterPtr param = ad_param_->tape_->add_parameter();
636 param->set_abstract(graph_parameters[i]->abstract());
637 auto zeros_like_dout =
638 PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
639 graph_parameters[i]->abstract(), SpecialType::kZerosLikeType);
640 auto func_node = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_dout);
641 // Copy to avoid corrupt real input grad info.
642 auto op_arg = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(grad_param->op_grad_info->input_value[i]);
643 ClearGradMetaData(op_arg);
644 auto adjoint = std::make_shared<IrVariable>(func_node, op_arg, true);
645 adjoint->set_k_node(param);
646 PyNativeAlgo::AutoGrad::SetGradMetaData(op_arg, adjoint, graph_parameters[i]->cast<ParameterPtr>());
647 (void)ad_param_->variable_adjoint_set_.insert(adjoint);
648 (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(graph_parameters[i], adjoint));
649 }
650 }
651
PrepareGradCNodeInputs(const PrimitivePtr & prim,const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)652 void IrBprop::PrepareGradCNodeInputs(const PrimitivePtr &prim, const CNodePtr &cnode, ValuePtrList *inputs_value,
653 AnfNodePtrList *cnode_inputs) {
654 MS_EXCEPTION_IF_NULL(cnode);
655 MS_EXCEPTION_IF_NULL(inputs_value);
656 MS_EXCEPTION_IF_NULL(cnode_inputs);
657 (void)cnode_inputs->emplace_back(std::make_shared<ValueNode>(prim));
658 *inputs_value = GetInputArgs(cnode, cnode_inputs);
659 pass_forward_->ReversePassCNode(cnode, inputs_value, cnode_inputs);
660 }
661
GetInputArgs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs) const662 ValuePtrList IrBprop::GetInputArgs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs) const {
663 MS_EXCEPTION_IF_NULL(cnode);
664 MS_EXCEPTION_IF_NULL(cnode_inputs);
665 ValuePtrList input_value;
666 for (size_t i = 1; i < cnode->size(); ++i) {
667 const auto &input_node = cnode->input(i);
668 // Find knode and out value
669 const auto it = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
670 if (it != ad_param_->anfnode_to_variable_adjoint_.end()) {
671 (void)cnode_inputs->emplace_back(it->second->k_node());
672 (void)input_value.emplace_back(it->second->out_value());
673 continue;
674 }
675 if (input_node->isa<ValueNode>()) {
676 auto v_node = input_node->cast<ValueNodePtr>();
677 auto v = v_node->value();
678 if (v != nullptr && v->isa<tensor::BaseTensor>()) {
679 const auto &t = v->cast<tensor::BaseTensorPtr>();
680 const auto &grad_meta = t->auto_grad_meta_data();
681 // Jit forward graph has no parameters(input is tuple or constant), so input used in graph as valuenode, but it
682 // is used by tape_ as parameter also
683 if (grad_meta != nullptr && PyNativeAlgo::Common::IsParam(grad_meta->input_type())) {
684 auto new_tensor = std::make_shared<tensor::Tensor>(t->data_type(), t->shape(), t->data_ptr());
685 new_tensor->set_device_address(t->device_address());
686 v = new_tensor;
687 }
688 }
689 (void)PyNativeAlgo::Common::SetValueGradInfo(v, nullptr, InputType::kConstant);
690 // In case of jit forward graph and pynative bprop graph used same valuenode
691 auto new_v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v, v_node->abstract());
692 (void)cnode_inputs->emplace_back(new_v_node);
693 (void)input_value.emplace_back(v);
694 } else {
695 // Make Fake value
696 auto v = MakeValue<int64_t>(0);
697 (void)cnode_inputs->emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(v, input_node->abstract()));
698 (void)input_value.emplace_back(v);
699 MS_LOG(DEBUG) << "Get input node " << input_node->DebugString();
700 }
701 }
702 return input_value;
703 }
704
GradCNode(const PrimitivePtr & prim,const CNodePtr & cnode,const GradParamPtr & grad_param,const ValuePtrList & inputs_value,AnfNodePtrList * cnode_inputs)705 void IrBprop::GradCNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param,
706 const ValuePtrList &inputs_value, AnfNodePtrList *cnode_inputs) {
707 MS_EXCEPTION_IF_NULL(prim);
708 MS_EXCEPTION_IF_NULL(cnode);
709 bool jit_by_value = grad_param->is_jit_graph && grad_by_value_;
710 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
711 (void)BuildKNodeForMakeTuple(cnode);
712 return;
713 } else if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
714 (void)BuildKNodeForTupleGetItem(cnode);
715 return;
716 }
717 MS_EXCEPTION_IF_NULL(cnode_inputs);
718 auto k_node = GetKnode(prim, cnode, *cnode_inputs, jit_by_value);
719 if (bprop_graph_run_by_single_op_ && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
720 std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
721 MS_EXCEPTION_IF_NULL(node->abstract());
722 return node->abstract()->isa<abstract::AbstractSequence>();
723 })) {
724 k_node->cast<CNodePtr>()->AddAttr(kAttrIsPyboostTupleInput, MakeValue(true));
725 }
726 MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
727 // Set out
728 auto out = PyNativeAlgo::Common::CreatOutputTensorValueByAbstract(cnode->abstract());
729 (void)cnode_inputs->emplace_back(k_node);
730 // Set dout
731 AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
732 ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), cnode->abstract(), SpecialType::kZerosLikeType);
733 (void)cnode_inputs->emplace_back(dout);
734 auto input_node = ad_param_->tape_->FuncGraph::NewCNode(*cnode_inputs);
735 input_node->set_abstract(cnode->abstract());
736
737 std::vector<CNodePtr> outputs;
738 // Get bprop by expander
739 auto ret = BpropExpander(&outputs, &ad_param_->users_).Run(input_node);
740 if (!ret || outputs.empty()) {
741 // Get bprop by python custom
742 MS_LOG(DEBUG) << "Expander has no bprop of this node: " << input_node->DebugString();
743 BuildCustomBpropCNode(input_node, prim, &outputs);
744 }
745
746 auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
747 auto variable_adjoint = std::make_shared<IrVariable>(fn, out);
748 variable_adjoint->set_k_node(k_node);
749 // Get bprop by fake bprop
750 if (outputs.empty()) {
751 MS_LOG(DEBUG) << "Build fake bprop for this node: " << input_node->DebugString();
752 PyNativeAlgo::AutoGrad::BuildFakeBpropCNode(input_node, &outputs);
753 variable_adjoint->set_is_fake_bprop(true);
754 variable_adjoint->set_fake_prim_name(prim->name());
755 }
756 // Create current op node din edge
757 AbstractBasePtrList input_abs;
758 for (size_t i = 1; i < cnode->size(); ++i) {
759 (void)input_abs.emplace_back(cnode->input(i)->abstract());
760 }
761 UpdateNextEdges(variable_adjoint, outputs, inputs_value, input_abs);
762 PyNativeAlgo::AutoGrad::SetGradMetaData(out, variable_adjoint);
763 (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(cnode, variable_adjoint));
764 (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
765 }
766
BuildKNodeForMakeTuple(const AnfNodePtr & input_node)767 AnfNodePtr IrBprop::BuildKNodeForMakeTuple(const AnfNodePtr &input_node) {
768 MS_EXCEPTION_IF_NULL(input_node);
769 MS_LOG(DEBUG) << "Build knode for MakeTuple " << input_node->DebugString();
770 const auto &cnode = input_node->cast<CNodePtr>();
771 MS_EXCEPTION_IF_NULL(cnode);
772 AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
773 ValuePtrList input_value;
774 AbstractBasePtrList input_abs;
775 for (size_t i = 1; i < cnode->size(); ++i) {
776 (void)inputs.emplace_back(BuildKNodeForCNodeInput(cnode->input(i)));
777 if (cnode->input(i)->isa<CNode>() || cnode->input(i)->isa<Parameter>()) {
778 const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(cnode->input(i));
779 if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
780 MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
781 }
782 (void)input_value.emplace_back(input_adjoint_iter->second->out_value());
783 (void)input_abs.emplace_back(cnode->input(i)->abstract());
784 } else {
785 auto value_node = cnode->input(i)->cast<ValueNodePtr>();
786 MS_EXCEPTION_IF_NULL(value_node);
787 (void)input_value.emplace_back(value_node->value());
788 (void)input_abs.emplace_back(value_node->abstract());
789 }
790 }
791 auto out_value = MakeValue(input_value);
792 AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, input_node->abstract(),
793 SpecialType::kZerosLikeType);
794 auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
795 auto variable_adjoint = std::make_shared<IrVariable>(fn, out_value);
796 auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
797 k_node->set_abstract(input_node->abstract());
798 variable_adjoint->set_k_node(k_node);
799 // Create dout for maketuple
800 std::vector<CNodePtr> make_tuple_dout;
801 for (size_t i = 1; i < cnode->size(); ++i) {
802 auto d = ad_param_->tape_->FuncGraph::NewCNode(
803 {NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(SizeToLong(i - 1))});
804 d->set_abstract(cnode->input(i)->abstract());
805 (void)make_tuple_dout.emplace_back(d);
806 AddUser(dout, d, 1);
807 }
808 UpdateNextEdges(variable_adjoint, make_tuple_dout, input_value, input_abs);
809 (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
810 (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
811 return k_node;
812 }
813
BuildKNodeForCNodeInput(const AnfNodePtr & input_node)814 AnfNodePtr IrBprop::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
815 MS_EXCEPTION_IF_NULL(input_node);
816 if (input_node->isa<CNode>()) {
817 const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
818 if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
819 if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
820 return BuildKNodeForMakeTuple(input_node);
821 } else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
822 return BuildKNodeForTupleGetItem(input_node);
823 }
824 MS_LOG(EXCEPTION) << "Can not find input in adjoint map, inp: " << input_node->DebugString();
825 }
826 return input_adjoint_iter->second->k_node();
827 } else {
828 // Tuple sens will come in
829 if (input_node->isa<Parameter>()) {
830 const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
831 if (input_adjoint_iter != ad_param_->anfnode_to_variable_adjoint_.end() &&
832 input_adjoint_iter->second->k_node() != nullptr) {
833 return input_adjoint_iter->second->k_node();
834 }
835 }
836 return input_node;
837 }
838 }
839
BuildKNodeForTupleGetItem(const AnfNodePtr & input_node)840 AnfNodePtr IrBprop::BuildKNodeForTupleGetItem(const AnfNodePtr &input_node) {
841 MS_EXCEPTION_IF_NULL(input_node);
842 MS_LOG(DEBUG) << "Build knode for TupleGetItem " << input_node->DebugString();
843 const auto &tuple_item_cnode = input_node->cast<CNodePtr>();
844 MS_EXCEPTION_IF_NULL(tuple_item_cnode);
845 // Find make tuple or sens(tuple) node for get out value
846 const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(tuple_item_cnode->input(kIndex1));
847 if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
848 MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << tuple_item_cnode->input(kIndex1)->DebugString();
849 }
850 const auto &v_tuple = input_adjoint_iter->second->out_value()->cast<ValueSequencePtr>();
851 MS_EXCEPTION_IF_NULL(v_tuple);
852 auto index_value = GetValueNode<Int64ImmPtr>(tuple_item_cnode->input(kIndex2));
853 auto index_value_int = LongToSize(index_value->value());
854 auto out_value = (*v_tuple)[index_value_int];
855 MS_EXCEPTION_IF_NULL(out_value);
856 AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, input_node->abstract(),
857 SpecialType::kZerosLikeType);
858 auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
859 auto variable_adjoint = std::make_shared<IrVariable>(fn, out_value);
860
861 AnfNodePtrList inputs{NewValueNode(prim::kPrimTupleGetItem)};
862 // Get make tuple knode
863 (void)inputs.emplace_back(BuildKNodeForCNodeInput(tuple_item_cnode->input(kIndex1)));
864 // Get index knode
865 (void)inputs.emplace_back(BuildKNodeForCNodeInput(tuple_item_cnode->input(kIndex2)));
866 auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
867 k_node->set_abstract(input_node->abstract());
868 variable_adjoint->set_k_node(k_node);
869 // Create dout for tuplegetitem
870 AnfNodePtrList tuple_getitem_dout{NewValueNode(prim::kPrimMakeTuple)};
871 const auto &abs_tuple = tuple_item_cnode->input(kIndex1)->abstract()->cast<abstract::AbstractSequencePtr>();
872 for (size_t i = 0; i < v_tuple->size(); ++i) {
873 const auto &v = v_tuple->value()[i];
874 if (i == index_value_int) {
875 (void)tuple_getitem_dout.emplace_back(dout);
876 } else {
877 (void)tuple_getitem_dout.emplace_back(PyNativeAlgo::AutoGrad::BuildSpecialNode(
878 ad_param_->tape_, v, abs_tuple->elements()[i], SpecialType::kZerosLikeType));
879 }
880 }
881 CNodePtr tuple_getitem_dout_value = ad_param_->tape_->FuncGraph::NewCNode(tuple_getitem_dout);
882 tuple_getitem_dout_value->set_abstract(tuple_item_cnode->input(kIndex1)->abstract());
883 auto index_dout_value =
884 PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, index_value,
885 tuple_item_cnode->input(kIndex1)->abstract(), SpecialType::kZerosLikeType)
886 ->cast<CNodePtr>();
887 UpdateNextEdges(variable_adjoint, {tuple_getitem_dout_value, index_dout_value}, {v_tuple, index_value},
888 {tuple_item_cnode->input(kIndex1)->abstract(), tuple_item_cnode->input(kIndex2)->abstract()});
889 AddUser(dout, tuple_getitem_dout_value, index_value_int + 1);
890 (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
891 (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
892 return k_node;
893 }
894
GetKnode(const PrimitivePtr & prim,const CNodePtr & cnode,const AnfNodePtrList & cnode_inputs,bool jit_by_value)895 AnfNodePtr IrBprop::GetKnode(const PrimitivePtr &prim, const CNodePtr &cnode, const AnfNodePtrList &cnode_inputs,
896 bool jit_by_value) {
897 if (IsPrimitiveEquals(prim, prim::kPrimMirror)) {
898 return ad_param_->anfnode_to_variable_adjoint_.at(cnode->input(kIndex1))->k_node();
899 } else {
900 auto c_k_node = ad_param_->tape_->FuncGraph::NewCNode(cnode_inputs);
901 c_k_node->set_abstract(cnode->abstract());
902 // In jit, copy forward graph cnode info to bprop graph
903 if (jit_by_value && cnode->forward().first != nullptr) {
904 auto new_v_node = PyNativeAlgo::Common::CreateValueNodeByValue(cnode->forward().first->value(),
905 cnode->forward().first->abstract());
906 c_k_node->set_forward(new_v_node, cnode->forward().second);
907 ad_param_->tape_->set_used_forward_nodes({c_k_node});
908 }
909 c_k_node->AddAttr(bprop_pass::kIsKNode, MakeValue(true));
910 return c_k_node;
911 }
912 }
913
UpdateNextEdgeForDict(const IrFunctionNodePtr & fn,const AnfNodePtr & din,const ValuePtr & input_arg,const AbstractBasePtr & abs)914 void IrBprop::UpdateNextEdgeForDict(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
915 const AbstractBasePtr &abs) {
916 auto value_dict = input_arg->cast<ValueDictionaryPtr>()->value();
917 const auto &abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
918 MS_EXCEPTION_IF_NULL(abs_dict);
919 if (value_dict.size() != abs_dict->size()) {
920 MS_LOG(EXCEPTION) << "Get value dict size " << value_dict.size() << " not equal to abstract size "
921 << abs_dict->size();
922 }
923 for (size_t i = 0; i < value_dict.size(); ++i) {
924 auto sub_value = value_dict[i];
925 auto key_item = PyNativeAlgo::Common::CreateValueNodeByValue(sub_value.first, abs_dict->elements()[i].first);
926 CNodePtr new_din = ad_param_->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimDictGetItem), din, key_item});
927 new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs_dict->elements()[i].second));
928 if (din == fn->fake_dout()) {
929 // The new_din's index input is fn->fake_dout()
930 LazyAddUser(fn->fake_dout(), new_din, 1);
931 }
932 // Add next edge to fn
933 UpdateNextEdge(fn, new_din, sub_value.second, abs_dict->elements()[i].second);
934 }
935 }
936
UpdateNextEdge(const IrFunctionNodePtr & fn,const AnfNodePtr & din,const ValuePtr & input_arg,const AbstractBasePtr & abs)937 void IrBprop::UpdateNextEdge(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
938 const AbstractBasePtr &abs) {
939 MS_EXCEPTION_IF_NULL(din);
940 MS_EXCEPTION_IF_NULL(input_arg);
941 if (input_arg->isa<tensor::BaseTensor>()) {
942 tensor::BaseTensorPtr input_tensor = nullptr;
943 input_tensor = input_arg->cast<tensor::BaseTensorPtr>();
944 auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
945 MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
946 auto variable = auto_grad_meta_data->variable();
947 if (variable == nullptr || !variable->is_need_grad()) {
948 return;
949 }
950 auto real_din = HandleRealToComplex(input_tensor, abs, din, fn->tape());
951 auto new_din = TraceInput(fn, variable->out_value(), variable->ir_function_node()->accumulate_dout()->abstract(),
952 input_tensor, real_din);
953 fn->AddNextEdge(variable, new_din);
954 } else if (input_arg->isa<ValueSequence>()) {
955 auto value_seq = input_arg->cast<ValueSequencePtr>()->value();
956 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
957 MS_EXCEPTION_IF_NULL(abs_seq);
958 if (value_seq.size() != abs_seq->size()) {
959 MS_LOG(EXCEPTION) << "Get value sequence size " << value_seq.size() << " not equal to abstract size "
960 << abs_seq->size();
961 }
962 for (size_t i = 0; i < value_seq.size(); ++i) {
963 auto sub_value = value_seq[i];
964 CNodePtr new_din = ad_param_->tape_->FuncGraph::NewCNode(
965 {NewValueNode(prim::kPrimTupleGetItem), din, NewValueNode(SizeToLong(i))});
966 new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs_seq->elements()[i]));
967 if (din == fn->fake_dout()) {
968 // The new_din's index input is fn->fake_dout()
969 LazyAddUser(fn->fake_dout(), new_din, 1);
970 }
971 // Add next edge to fn
972 UpdateNextEdge(fn, new_din, sub_value, abs_seq->elements()[i]);
973 }
974 } else if (input_arg->isa<tensor::COOTensor>()) {
975 auto input_tensor = input_arg->cast<tensor::COOTensorPtr>()->GetIndices();
976 UpdateNextEdge(fn, din, input_tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_tensor->ToAbstract()));
977 } else if (input_arg->isa<tensor::CSRTensor>()) {
978 auto input_tensor = input_arg->cast<tensor::CSRTensorPtr>()->GetIndices();
979 UpdateNextEdge(fn, din, input_tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_tensor->ToAbstract()));
980 } else if (input_arg->isa<ValueDictionary>()) {
981 UpdateNextEdgeForDict(fn, din, input_arg, abs);
982 } else {
983 MS_LOG(DEBUG) << "It is not tensor, not need derivation " << input_arg->ToString();
984 return;
985 }
986 }
987
TraceInput(const IrFunctionNodePtr & fn,const ValuePtr & out_value,const abstract::AbstractBasePtr & out_abs,const tensor::BaseTensorPtr & input_tensor,const AnfNodePtr & din)988 AnfNodePtr IrBprop::TraceInput(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
989 const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor,
990 const AnfNodePtr &din) {
991 MS_EXCEPTION_IF_NULL(out_value);
992 MS_EXCEPTION_IF_NULL(out_abs);
993 MS_EXCEPTION_IF_NULL(input_tensor);
994 MS_EXCEPTION_IF_NULL(din);
995
996 // The node corresponding output tensor is the same as the currently used tensor
997 if (out_value->isa<tensor::BaseTensor>()) {
998 // out_value is be used, may be it is one of multiple output
999 auto out_tensor = out_value->cast<tensor::BaseTensorPtr>();
1000 if (input_tensor->id() == out_tensor->id()) {
1001 return din;
1002 }
1003 return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, out_abs, SpecialType::kZerosLikeType);
1004 } else if (out_value->isa<ValueSequence>()) {
1005 // The corresponding output of node is ValueSequence, but used one of it
1006 AnfNodePtrList inputs;
1007 (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1008 auto value_seq = out_value->cast<ValueSequencePtr>();
1009 auto abs_seq = out_abs->cast<abstract::AbstractSequencePtr>();
1010 if (abs_seq == nullptr) {
1011 MS_LOG(EXCEPTION) << "Get output abstract " << out_abs->ToString() << ", not abstract sequence";
1012 }
1013 int index = -1;
1014 for (size_t i = 0; i < value_seq->size(); ++i) {
1015 // Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
1016 // din is zero , which set by second branch condition above
1017 auto new_din = TraceInput(fn, value_seq->value()[i], abs_seq->elements()[i], input_tensor, din);
1018 (void)inputs.emplace_back(new_din);
1019
1020 // if exist din == fake_dout, we record it in user vector
1021 if (din == fn->fake_dout() && new_din == din) {
1022 index = static_cast<int>(inputs.size()) - 1;
1023 }
1024 }
1025 auto new_din = ad_param_->tape_->FuncGraph::NewCNode(inputs);
1026 new_din->set_abstract(out_abs);
1027 if (index != -1) {
1028 LazyAddUser(fn->fake_dout(), new_din, index);
1029 }
1030 return new_din;
1031 } else if (out_value->isa<ValueDictionary>()) {
1032 return TraceInputForDict(fn, out_value, out_abs, input_tensor, din);
1033 }
1034 MS_LOG(DEBUG) << "Get non tensor input " << out_value->ToString();
1035 return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, out_abs, SpecialType::kZerosLikeType);
1036 }
1037
TraceInputForDict(const IrFunctionNodePtr & fn,const ValuePtr & out_value,const abstract::AbstractBasePtr & out_abs,const tensor::BaseTensorPtr & input_tensor,const AnfNodePtr & din)1038 AnfNodePtr IrBprop::TraceInputForDict(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
1039 const abstract::AbstractBasePtr &out_abs,
1040 const tensor::BaseTensorPtr &input_tensor, const AnfNodePtr &din) {
1041 // The corresponding output of node is ValueDictionary, but used one of it
1042 AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1043 AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1044 abstract::AbstractBasePtrList local_key_abs_inputs;
1045 abstract::AbstractBasePtrList local_value_abs_inputs;
1046 auto value_dict = out_value->cast<ValueDictionaryPtr>();
1047 auto abs_dict = out_abs->cast<abstract::AbstractDictionaryPtr>();
1048 MS_EXCEPTION_IF_NULL(abs_dict);
1049 int index = -1;
1050 for (size_t i = 0; i < value_dict->size(); ++i) {
1051 // Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
1052 // din is zero, which set by second branch condition above
1053 (void)key_inputs.emplace_back(
1054 PyNativeAlgo::Common::CreateValueNodeByValue(value_dict->value()[i].first, abs_dict->elements()[i].first));
1055 (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
1056 auto new_din = TraceInput(fn, value_dict->value()[i].second, abs_dict->elements()[i].second, input_tensor, din);
1057 (void)value_inputs.emplace_back(new_din);
1058 (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
1059
1060 // if exist din == fake_dout, we record it in user vector
1061 if (din == fn->fake_dout() && new_din == din) {
1062 index = static_cast<int>(value_inputs.size()) - 1;
1063 }
1064 }
1065 auto local_key_node = ad_param_->tape_->NewCNode(key_inputs);
1066 local_key_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
1067 auto local_value_node = ad_param_->tape_->NewCNode(value_inputs);
1068 local_value_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
1069 auto new_din = ad_param_->tape_->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
1070 new_din->set_abstract(abs_dict);
1071 if (index != -1) {
1072 LazyAddUser(fn->fake_dout(), new_din, index);
1073 }
1074 return new_din;
1075 }
1076
AddTupleGetItemUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)1077 void IrBprop::AddTupleGetItemUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
1078 (void)ad_param_->users_.tuple_getitem_user_[node].emplace_back(user, index);
1079 }
1080
UpdateLazyUser()1081 void IrBprop::UpdateLazyUser() {
1082 // For lazy add user data, we need emplace to user.
1083 for (const auto &user_data : ad_param_->lazy_user_data_) {
1084 AddUser(std::get<kIndex0>(user_data), std::get<kIndex1>(user_data), std::get<kIndex2>(user_data));
1085 }
1086 }
1087
LazyAddUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)1088 void IrBprop::LazyAddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
1089 MS_EXCEPTION_IF_NULL(node);
1090 MS_EXCEPTION_IF_NULL(user);
1091 (void)ad_param_->lazy_user_data_.emplace_back(node, user, index);
1092 }
1093 } // namespace mindspore::pynative::autograd
1094