• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/ir/ir_bprop.h"
18 #include <string>
19 #include <vector>
20 #include <memory>
21 #include "pipeline/pynative/pynative_utils.h"
22 #include "include/common/utils/primitive_utils.h"
23 #include "pipeline/jit/ps/pass.h"
24 #include "ir/func_graph_cloner.h"
25 #include "ops/sequence_ops.h"
26 #include "ops/framework_ops.h"
27 #include "ops/structure_ops.h"
28 #include "ops/other_ops.h"
29 
30 namespace mindspore::pynative::autograd {
31 namespace {
32 constexpr size_t kOutAndDoutNum = 2;
33 const mindspore::HashSet<std::string> kMonadOp = {kLoadOpName, kDependOpName, kUpdateStateOpName};
34 const mindspore::HashSet<std::string> kMetaFuncGraphOp{
35   kPyExecuteOpName,
36   kAttrMutableOpName,
37   kMakeDictOpName,
38 };
39 mindspore::HashMap<std::string, FuncGraphPtr> pass_grad_graph_;
40 
OptimizeBpropBuilder(const FuncGraphPtr & bprop_func_graph,const GradParamPtr & grad_param)41 FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph, const GradParamPtr &grad_param) {
42   PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
43   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
44   resource->set_func_graph(bprop_func_graph);
45   auto manager = resource->manager();
46   MS_EXCEPTION_IF_NULL(manager);
47   manager->AddFuncGraph(bprop_func_graph);
48   auto after_opt_bg = pipeline::JitBpropGraphPass(resource, true);
49   auto is_dynamic_shape_control_flow =
50     grad_param->is_jit_graph && grad_param->use_dynamic_shape_process && grad_param->is_control_flow;
51   if (is_dynamic_shape_control_flow) {
52     for (const auto &g : manager->func_graphs()) {
53       g->set_flag(kFlagJitCallGraph, true);
54     }
55   }
56   auto abs_seq = after_opt_bg->parameters().empty()
57                    ? nullptr
58                    : after_opt_bg->parameters().back()->abstract()->cast<abstract::AbstractSequencePtr>();
59   if (abs_seq != nullptr && !abs_seq->dynamic_len() && grad_param->is_jit_graph &&
60       grad_param->use_dynamic_shape_process) {
61     PyNativeAlgo::Common::ProcessTupleParam(after_opt_bg, after_opt_bg->parameters().size() - kIndex1);
62   }
63   PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
64   return after_opt_bg;
65 }
66 
ProcessMonadNode(const PrimitivePtr & prim,const CNodePtr & cnode,const GradParamPtr & grad_param)67 bool ProcessMonadNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param) {
68   MS_EXCEPTION_IF_NULL(prim);
69   if (kMonadOp.find(prim->name()) != kMonadOp.end()) {
70     MS_LOG(DEBUG) << "Get monad cnode " << cnode->DebugString();
71     return true;
72   }
73   if ((prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) &&
74       (cnode->inputs().back()->abstract()->isa<abstract::AbstractMonad>())) {
75     AnfNodePtrList inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
76     cnode->set_inputs(inputs);
77   }
78   MS_EXCEPTION_IF_NULL(grad_param);
79   // Jit graph contain monad op
80   if (grad_param->is_jit_graph) {
81     for (size_t i = 1; i < cnode->size(); ++i) {
82       cnode->set_input(i, common::AnfAlgo::VisitKernelWithReturnType(cnode->input(i), 0, false,
83                                                                      {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
84                             .first);
85     }
86   }
87   return false;
88 }
89 
ClearGradMetaData(const ValuePtr & value)90 void ClearGradMetaData(const ValuePtr &value) {
91   if (value->isa<tensor::BaseTensor>()) {
92     auto tensor = value->cast<tensor::BaseTensorPtr>();
93     tensor->set_auto_grad_meta_data(nullptr);
94   } else if (value->isa<ValueSequence>()) {
95     auto value_sequence = value->cast<ValueSequencePtr>();
96     for (const auto &val : value_sequence->value()) {
97       ClearGradMetaData(val);
98     }
99   }
100 }
101 
102 // Handle bprob of op which input dtype is real number and output dtype is complex number.
103 // If the dtype of a gradient(din) is complex number and the input of that is real number,
104 // only the real part of the gradient make sense in back propagate. So we handle it by
105 // insert a Real() ops after the gradient.
106 // input: AnfNode with input of op which input dtype is real number and output dtype is complex number.
107 // din: CNodePtr with gradient of input.
108 // tape: Funcgraph witch input and din belong to.
109 // return: New din with inserted real op if necessarily.
HandleRealToComplex(const tensor::BaseTensorPtr & input,const AbstractBasePtr & abs,const AnfNodePtr & din,const KernelGraphPtr & tape)110 AnfNodePtr HandleRealToComplex(const tensor::BaseTensorPtr &input, const AbstractBasePtr &abs, const AnfNodePtr &din,
111                                const KernelGraphPtr &tape) {
112   MS_EXCEPTION_IF_NULL(din);
113   TypePtr din_type = din->Type();
114   if (din_type == nullptr || !din_type->isa<TensorType>()) {
115     return din;
116   }
117   din_type = din_type->cast_ptr<TensorType>()->element();
118   MS_EXCEPTION_IF_NULL(din_type);
119   // cppcheck-suppress unreadVariable
120   if (MS_LIKELY(din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128)) {
121     return din;
122   }
123 
124   MS_EXCEPTION_IF_NULL(input);
125   TypePtr input_type = input->Dtype();
126   if (input_type == nullptr) {
127     return din;
128   }
129   if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
130     return din;
131   }
132 
133   AnfNodePtr new_din = tape->FuncGraph::NewCNode({NewValueNode(prim::kPrimReal), din});
134   AbstractBasePtr real_abs =
135     std::make_shared<abstract::AbstractTensor>(abstract::AbstractTensor(input_type, abs->GetShapeTrack()));
136   new_din->set_abstract(real_abs);
137   return new_din;
138 }
139 
PlantFuncGradBpropGraphDout(const GradParamPtr & grad_param,const FuncGraphPtr & graph)140 void PlantFuncGradBpropGraphDout(const GradParamPtr &grad_param, const FuncGraphPtr &graph) {
141   MS_EXCEPTION_IF_NULL(graph);
142   MS_EXCEPTION_IF_NULL(grad_param);
143   if (!grad_param->is_func_grad) {
144     return;
145   }
146   // Plant dout tuple or dict
147   if (graph->parameters().back()->abstract()->isa<abstract::AbstractSequence>()) {
148     PyNativeAlgo::Common::ProcessTupleParam(graph, grad_param->input_size);
149   } else if (graph->parameters().back()->abstract()->isa<abstract::AbstractDictionary>()) {
150     PyNativeAlgo::Common::ProcessDictParam(graph, grad_param->input_size);
151   }
152 }
153 }  // namespace
154 
ClearAutoGradCache()155 void ClearAutoGradCache() {
156   pass_grad_graph_.clear();
157   bprop_pass::ClearCache();
158   PyNativeAlgo::AutoGrad::ClearAutoGradStaticCache();
159 }
160 
GetBpropGraph(const GradParamPtr & grad_param)161 std::pair<bool, FuncGraphPtr> IrBprop::GetBpropGraph(const GradParamPtr &grad_param) {
162   MS_EXCEPTION_IF_NULL(grad_param);
163   const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
164   bool cache_hit = (it != pass_grad_graph_.end());
165   if (grad_param->is_control_flow || grad_param->is_jit_self_dynamic_shape) {
166     MS_LOG(DEBUG) << "Get control flow graph or dynamic shape";
167     return std::make_pair(cache_hit, GetBpropGraphFromFprop(grad_param));
168   }
169   return std::make_pair(cache_hit, GetBpropGraphFromExpander(grad_param));
170 }
171 
BuildCustomBpropCNode(const CNodePtr & cnode,const PrimitivePtr & prim,std::vector<CNodePtr> * outputs)172 void IrBprop::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs) {
173   MS_EXCEPTION_IF_NULL(prim);
174   MS_LOG(DEBUG) << "Try build custom bprop: " << prim->name();
175   {
176     py::gil_scoped_acquire gil;
177     auto prim_py = prim->cast<PrimitivePyPtr>();
178     if (prim_py == nullptr) {
179       MS_LOG(DEBUG) << "Prim is not PrimitivePy, can not find python bprop";
180       return;
181     }
182     py::function fn = prim_py->GetBpropFunction();
183     if (py::isinstance<py::none>(fn)) {
184       fn = GetBpropFunction(prim->name());
185     }
186     if (!fn || py::isinstance<py::none>(fn)) {
187       MS_LOG(INFO) << "Can not find bprop function for " << prim->name() << ". fn: " << ConvertPyObjToString(fn);
188       return;
189     }
190     (void)prim_py->AddBackwardHookFn(0, fn);
191     (void)prim_py->AddAttr("custom_op_bprop", MakeValue(true));
192   }
193   BuildBPropCutCNode(cnode, prim, outputs);
194 }
195 
BuildBPropCutCNode(const CNodePtr & cnode,const PrimitivePtr & prim,std::vector<CNodePtr> * outputs,bool is_need_recompute)196 void IrBprop::BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs,
197                                  bool is_need_recompute) {
198   MS_EXCEPTION_IF_NULL(prim);
199   auto bprop_cut = PyNativeAlgo::AutoGrad::BuildBpropCutPrim(prim, is_need_recompute);
200 
201   // Create gradient outputs cnode
202   AnfNodePtrList inputs{NewValueNode(bprop_cut)};
203   for (size_t i = 1; i < cnode->size() - kOutAndDoutNum; ++i) {
204     (void)inputs.emplace_back(cnode->input(i));
205   }
206   if (!is_need_recompute) {
207     // If not recompute, we should add out as bprop input.
208     (void)inputs.emplace_back(cnode->input(cnode->size() - kOutAndDoutNum));
209   }
210   (void)inputs.emplace_back(cnode->input(cnode->size() - 1));
211 
212   auto bprop_cut_cnode = ad_param_->tape_->FuncGraph::NewCNode(inputs);
213   AbstractBasePtrList abs_list;
214   // Only add last input dout to user.
215   AddUser(cnode->input(cnode->size() - 1), bprop_cut_cnode, bprop_cut_cnode->size() - 1);
216   for (size_t i = 1; i < cnode->size() - kOutAndDoutNum; ++i) {
217     // Input may be parameter, we need add to user map.
218     AddUser(cnode->input(i), bprop_cut_cnode, i);
219     auto din = ad_param_->tape_->FuncGraph::NewCNode(
220       {NewValueNode(prim::kPrimTupleGetItem), bprop_cut_cnode, NewValueNode(static_cast<int64_t>(i - 1))});
221     MS_EXCEPTION_IF_NULL(cnode->input(i)->abstract());
222     din->set_abstract(cnode->input(i)->abstract());
223     (void)abs_list.emplace_back(cnode->input(i)->abstract());
224     (void)outputs->emplace_back(din);
225   }
226   bprop_cut_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
227   ad_param_->tape_->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
228   bprop_graph_run_by_single_op_ = true;
229 }
230 
MapParameter(const ValuePtr & value,const abstract::AbstractBasePtr & abs)231 AnfNodePtr IrBprop::MapParameter(const ValuePtr &value, const abstract::AbstractBasePtr &abs) {
232   if (value->isa<tensor::BaseTensor>()) {
233     const auto &tensor = value->cast<tensor::BaseTensorPtr>();
234     const auto &auto_grad_meta_data = tensor->auto_grad_meta_data();
235     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
236     const auto &param = auto_grad_meta_data->parameter();
237     if (param != nullptr) {
238       // In dynamic shape scenario, abs my be need change
239       param->set_abstract(abs);
240       return param;
241     }
242     set_bprop_graph_run_by_single_op(auto_grad_meta_data->is_register_hook());
243     if (auto_grad_meta_data->input_type() == InputType::kParameter &&
244         PyNativeAlgo::Common::IsParamRequiresGrad(tensor)) {
245       return AddParameterNode(tensor, abs);
246     }
247     return PyNativeAlgo::Common::CreateValueNodeByValue(value, abs);
248   } else if (value->isa<ValueSequence>()) {
249     const auto &val_seq = value->cast<ValueSequencePtr>()->value();
250     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
251     MS_EXCEPTION_IF_NULL(abs_seq);
252     if (val_seq.size() != abs_seq->size()) {
253       MS_LOG(EXCEPTION) << "Get value sequence size " << val_seq.size() << " not equal to abstract size "
254                         << abs_seq->size();
255     }
256     AnfNodePtrList inputs;
257     (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
258     for (size_t i = 0; i < val_seq.size(); ++i) {
259       (void)inputs.emplace_back(MapParameter(val_seq[i], abs_seq->elements()[i]));
260     }
261     auto cnode = ad_param_->tape_->FuncGraph::NewCNode(inputs);
262     // For replacing fg parameter by user
263     for (size_t i = 1; i < inputs.size(); ++i) {
264       AddUser(inputs[i], cnode, i);
265     }
266     cnode->set_abstract(abs);
267     return cnode;
268   } else if (value->isa<tensor::COOTensor>()) {
269     const auto &coo_tensor = value->cast<tensor::COOTensorPtr>();
270     return MapParameter(coo_tensor->GetIndices(), abs);
271   } else if (value->isa<tensor::CSRTensor>()) {
272     const auto &csr_tensor = value->cast<tensor::CSRTensorPtr>();
273     return MapParameter(csr_tensor->GetIndices(), abs);
274   } else {
275     return PyNativeAlgo::Common::CreateValueNodeByValue(value, abs);
276   }
277 }
278 
AddParameterNode(const tensor::BaseTensorPtr & tensor,const abstract::AbstractBasePtr & abs)279 ParameterPtr IrBprop::AddParameterNode(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs) {
280   MS_EXCEPTION_IF_NULL(tensor);
281   auto param = CreateTapeParameter(tensor, abs);
282   auto zeros_like_dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
283     ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), param->abstract(), SpecialType::kZerosLikeType);
284   auto func_node = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_dout);
285   auto input_adjoint = std::make_shared<IrVariable>(func_node, tensor, true);
286   (void)ad_param_->variable_adjoint_set_.insert(input_adjoint);
287   auto auto_grad_meta_data = tensor->auto_grad_meta_data();
288   MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
289   auto_grad_meta_data->set_variable(input_adjoint);
290   (void)ad_param_->weights_used_in_graph_.emplace_back(param);
291   return param;
292 }
293 
CreateTapeParameter(const tensor::BaseTensorPtr & tensor,const abstract::AbstractBasePtr & abs)294 ParameterPtr IrBprop::CreateTapeParameter(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs) {
295   MS_EXCEPTION_IF_NULL(tensor);
296   MS_EXCEPTION_IF_NULL(abs);
297   auto param = ad_param_->fg_->add_parameter();
298   param->set_abstract(abs);
299   if (tensor->is_parameter()) {
300     param->set_default_param(tensor);
301   }
302   auto auto_grad_meta_data = tensor->auto_grad_meta_data();
303   if (auto_grad_meta_data == nullptr) {
304     auto_grad_meta_data = std::make_shared<AutoGradMetaData>();
305     tensor->set_auto_grad_meta_data(auto_grad_meta_data);
306   }
307   auto_grad_meta_data->set_input_type(InputType::kParameter);
308   auto_grad_meta_data->set_parameter(param);
309   return param;
310 }
311 
UpdateNextEdges(const VariablePtr & variable,const std::vector<CNodePtr> & dins,const ValuePtrList & inputs_value,const abstract::AbstractBasePtrList & abs,const string & op_name)312 void IrBprop::UpdateNextEdges(const VariablePtr &variable, const std::vector<CNodePtr> &dins,
313                               const ValuePtrList &inputs_value, const abstract::AbstractBasePtrList &abs,
314                               const string &op_name) {
315   size_t input_size = inputs_value.size();
316   if (dins.size() != input_size) {
317     MS_LOG(EXCEPTION) << "The size of dins " << dins.size() << " is not same as input_value " << input_size;
318   }
319   const auto &fn = variable->ir_function_node();
320   for (size_t i = 0; i < input_size; ++i) {
321     auto din = dins[i];
322     MS_EXCEPTION_IF_NULL(din);
323     MS_LOG(DEBUG) << "Input arg id: " << PyNativeAlgo::Common::GetIdByValue(inputs_value[i]) << ", din "
324                   << din->DebugString();
325 #ifndef ENABLE_TEST
326     // VM no need run pass
327     din = pass_forward_->PassForDin(din, op_name, false);
328 #endif
329     UpdateNextEdge(fn, din, inputs_value[i], abs[i]);
330   }
331   if (fn->next_edges().empty()) {
332     variable->set_is_need_grad(false);
333   }
334   MS_LOG(DEBUG) << "Finish update next edges for variable: " << variable->ToString();
335 }
336 
AddUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)337 void IrBprop::AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
338   MS_EXCEPTION_IF_NULL(ad_param_);
339   (void)ad_param_->users_.dout_user_[node].emplace_back(user, index);
340 }
341 
AddReverseUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)342 void IrBprop::AddReverseUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
343   (void)ad_param_->reverse_users_[node].emplace_back(user, index);
344 }
345 
BackPropagate()346 void IrBprop::BackPropagate() {
347   UpdateLazyUser();
348   const auto &last_node_reverse_iter = GetLastNodeReverseIter();
349 #ifndef ENABLE_TEST
350   SeenNum seen = NewSeenGeneration();
351 #endif
352   MS_LOG(DEBUG) << "Is running recompute grad " << is_run_recompute_;
353   for (auto iter = last_node_reverse_iter; iter != ad_param_->variable_adjoint_set_.rend(); ++iter) {
354     const auto &variable = *iter;
355     if (!variable->is_need_propagate() || !variable->is_need_grad()) {
356       MS_LOG(DEBUG) << "No need grad, variable is: " << variable->ToString();
357       continue;
358     }
359     if (static_cast<bool>(MS_UNLIKELY(variable->is_fake_bprop()))) {
360       MS_LOG(EXCEPTION) << "Illegal primitive " << variable->fake_prim_name() << "'s bprop not defined";
361     }
362     MS_LOG(DEBUG) << "Begin backpropagate: " << variable->ToString();
363     const auto &fn = variable->ir_function_node();
364     // If zeroslike not used in funcgraph, we need replace the zeroslike placeholder with real zeroslike value.
365     if (static_cast<bool>(MS_UNLIKELY(PyNativeAlgo::AutoGrad::IsZerosLikeNode(fn->accumulate_dout())))) {
366       fn->set_accumulate_dout(PyNativeAlgo::AutoGrad::BuildSpecialNode(
367         fn->tape(), variable->out_value(), fn->accumulate_dout()->abstract(), SpecialType::kZerosLikeType));
368     }
369     // If register hook by weight, and weight in recompute cell.So, hook will execute, which is not expect.
370     if (!is_run_recompute_) {
371       fn->set_accumulate_dout(pass_forward_->PassBackwardHook(variable->out_value(), fn->accumulate_dout()));
372     }
373     // Replace real dout to fake dout, update replace result to eliminate tuplegetitem
374     // when accumulate_dout is tuplegetitem
375     Replace(fn->fake_dout(), fn->accumulate_dout(), &ad_param_->users_.dout_user_, true);
376     // replace edges which exist fake dout
377     fn->ReplaceEdges();
378     const auto &next_edges = fn->next_edges();
379     for (const auto &next_edge : next_edges) {
380       const auto &last_variable = next_edge.first;
381       const auto &din = next_edge.second;
382 #ifndef ENABLE_TEST
383       // VM no need run pass
384       pass_forward_->ConvertMakeTupleInputToDynamicInput(din, seen, bprop_graph_run_by_single_op_);
385 #endif
386       last_variable->ir_function_node()->UpdateAccumulativeDout(din);
387       last_variable->set_is_need_propagate(true);
388     }
389   }
390   MS_LOG(DEBUG) << "End BackPropagate";
391 }
392 
GetLastNodeReverseIter()393 OrderedSet<IrVariablePtr>::reverse_iterator IrBprop::GetLastNodeReverseIter() {
394   for (auto iter = ad_param_->variable_adjoint_set_.rbegin(); iter != ad_param_->variable_adjoint_set_.rend(); ++iter) {
395     if (*iter == ad_param_->last_variable_) {
396       ad_param_->last_variable_->set_is_need_propagate(true);
397       return iter;
398     }
399   }
400   return ad_param_->variable_adjoint_set_.rend();
401 }
402 
BuildForwardLastNode()403 AbstractBasePtr IrBprop::BuildForwardLastNode() {
404   MS_LOG(DEBUG) << "Process last node info " << PyNativeAlgo::Common::GetIdByValue(ad_param_->sens_value_);
405   auto zeros_like_node = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, ad_param_->sens_value_, nullptr,
406                                                                   SpecialType::kZerosLikeType);
407   auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_node);
408   auto sens_variable = std::make_shared<IrVariable>(fn, ad_param_->sens_value_);
409   if (ad_param_->sens_value_->isa<tensor::BaseTensor>()) {
410     const auto &sens_tensor = ad_param_->sens_value_->cast<tensor::BaseTensorPtr>();
411     const auto &auto_grad_meta_data = sens_tensor->auto_grad_meta_data();
412     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
413     if (PyNativeAlgo::Common::IsConstant(auto_grad_meta_data->input_type())) {
414       sens_variable->set_is_need_grad(false);
415     }
416   }
417   UpdateNextEdge(fn, zeros_like_node, ad_param_->sens_value_, fn->accumulate_dout()->abstract());
418   (void)ad_param_->variable_adjoint_set_.insert(sens_variable);
419   ad_param_->last_variable_ = sens_variable;
420   return fn->accumulate_dout()->abstract();
421 }
422 
GetBpropGraphFromFprop(const GradParamPtr & grad_param)423 FuncGraphPtr IrBprop::GetBpropGraphFromFprop(const GradParamPtr &grad_param) {
424   MS_EXCEPTION_IF_NULL(grad_param);
425   FuncGraphPtr after_opt_fg = nullptr;
426   // Find ad graph in cache
427   const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
428   bool cache_hit = (it != pass_grad_graph_.end());
429   if (cache_hit) {
430     MS_LOG(DEBUG) << "Get ad grad graph by cache";
431     after_opt_fg = BasicClone(it->second);
432   } else {
433     auto bprop_builder = std::make_shared<FuncGraph>();
434     bprop_builder->debug_info()->set_name("bprop_builder");
435 
436     AnfNodePtrList fprop_app_inputs{NewValueNode(grad_param->fg)};
437     for (const auto &abs : grad_param->op_grad_info->input_abs) {
438       auto param = bprop_builder->add_parameter();
439       param->set_abstract(abs);
440       (void)fprop_app_inputs.emplace_back(param);
441     }
442     auto fprop_app = bprop_builder->NewCNode(fprop_app_inputs);
443     // Get bprop from fprop_fg, it is 2th output of fprop_fg
444     auto get_bprop = bprop_builder->NewCNode(
445       {NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(kIndex1))});
446 
447     AnfNodePtrList node_list{get_bprop};
448     auto dout = bprop_builder->add_parameter();
449     dout->set_abstract(grad_param->op_grad_info->out_abs);
450     (void)node_list.emplace_back(dout);
451     auto call_bprop = bprop_builder->NewCNode(node_list);
452 
453     AnfNodePtrList actual_out{NewValueNode(prim::kPrimMakeTuple)};
454     for (size_t i = 0; i < grad_param->input_size; ++i) {
455       // Index 0 env, skip
456       auto out =
457         bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), call_bprop, NewValueNode(SizeToLong(i + 1))});
458       (void)actual_out.emplace_back(out);
459     }
460     bprop_builder->set_output(bprop_builder->NewCNode(actual_out));
461     // Call pass for optimize graph, such as inline
462     after_opt_fg = OptimizeBpropBuilder(bprop_builder, grad_param);
463     PlantFuncGradBpropGraphDout(grad_param, after_opt_fg);
464     if (grad_param->is_func_grad && grad_param->is_control_flow) {
465       after_opt_fg = LiftingClone(after_opt_fg);
466     }
467     if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
468       pass_grad_graph_[grad_param->graph_cache_key] = BasicClone(after_opt_fg);
469     }
470   }
471   return after_opt_fg;
472 }
473 
GetBpropGraphFromExpander(const GradParamPtr & grad_param)474 FuncGraphPtr IrBprop::GetBpropGraphFromExpander(const GradParamPtr &grad_param) {
475   // Find ad graph in cache
476   if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
477     const auto it = pass_grad_graph_.find(grad_param->graph_cache_key);
478     if (it != pass_grad_graph_.end()) {
479       MS_LOG(DEBUG) << "Get ad grad graph by cache";
480       return BasicClone(it->second);
481     }
482   } else {
483     pass_grad_graph_.clear();
484   }
485 
486   // Create new ad param for graph ad
487   PyNativeAlgo::Common::DumpGraphIR("ad_input_graph.ir", grad_param->fg);
488   auto current_ad_param = ad_param_;
489   ad_param_ = std::make_shared<AdParam>();
490   ad_param_->tape_->debug_info()->set_name("ad_graph");
491   bprop_graph_run_by_single_op_ = bprop_graph_run_by_single_op_ || grad_param->use_dynamic_shape_process;
492 
493   GradGraphByExpander(grad_param);
494 
495   if (ad_param_->last_node_ != nullptr) {
496     // Set dout parameter
497     const auto last_prim = GetCNodePrimitive(ad_param_->last_node_);
498     if (kMonadOp.find(last_prim->name()) != kMonadOp.end()) {
499       ad_param_->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(
500                                 ad_param_->last_node_, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
501                                 .first;
502     }
503     if (ad_param_->anfnode_to_variable_adjoint_.count(ad_param_->last_node_) == 0) {
504       MS_LOG(EXCEPTION) << "Can not find last node" << ad_param_->last_node_->DebugString();
505     }
506     ad_param_->last_variable_ = ad_param_->anfnode_to_variable_adjoint_[ad_param_->last_node_];
507     auto ad_graph_dout = ad_param_->tape_->add_parameter();
508     ad_graph_dout->set_abstract(ad_param_->last_node_->abstract());
509     ad_param_->last_variable_->ir_function_node()->UpdateAccumulativeDout(ad_graph_dout);
510     (void)BackPropagate();
511   } else {
512     // Just have a return node
513     auto ad_graph_dout = ad_param_->tape_->add_parameter();
514     ad_graph_dout->set_abstract(grad_param->fg->output()->abstract());
515     ad_graph_dout->debug_info()->set_name("sens");
516     ad_param_->sens_value_ = grad_param->op_grad_info->out_value;
517     (void)BuildForwardLastNode();
518     // Update dout
519     MS_EXCEPTION_IF_NULL(ad_param_->last_variable_);
520     if (ad_param_->last_variable_->is_need_grad()) {
521       ad_param_->last_variable_->ir_function_node()->UpdateAccumulativeDout(ad_graph_dout);
522     }
523     (void)BackPropagate();
524   }
525 
526   AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
527   abstract::AbstractBasePtrList out_abs_list;
528   for (const auto &node : grad_param->fg->parameters()) {
529     (void)outputs.emplace_back(ad_param_->anfnode_to_variable_adjoint_.at(node)->RealDout());
530     (void)out_abs_list.emplace_back(outputs.back()->abstract());
531   }
532   auto ad_graph_out = ad_param_->tape_->FuncGraph::NewCNode(outputs);
533   ad_graph_out->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
534   ad_param_->tape_->set_output(ad_graph_out);
535   auto ad_graph = ad_param_->tape_;
536   auto abs_seq = ad_graph->parameters().empty()
537                    ? nullptr
538                    : ad_graph->parameters().back()->abstract()->cast<abstract::AbstractSequencePtr>();
539   if (abs_seq != nullptr && !abs_seq->dynamic_len() && grad_param->is_jit_graph &&
540       grad_param->use_dynamic_shape_process) {
541     auto manager = MakeManager();
542     MS_EXCEPTION_IF_NULL(manager);
543     manager->AddFuncGraph(ad_graph);
544     PyNativeAlgo::Common::ProcessTupleParam(ad_graph, ad_graph->parameters().size() - kIndex1);
545   }
546   PyNativeAlgo::Common::DumpGraphIR("ad_output_graph.ir", ad_graph);
547 
548   // Plant dout tuple
549   PlantFuncGradBpropGraphDout(grad_param, ad_graph);
550 
551   // Save ad graph in cache
552   if (grad_param->is_jit_graph || !grad_param->use_dynamic_shape_process) {
553     pass_grad_graph_[grad_param->graph_cache_key] = BasicClone(ad_graph);
554   }
555   // Replace cnode with valuenode for reduce compute
556   bool jit_by_value = grad_param->is_jit_graph && grad_by_value_;
557   if (jit_by_value) {
558     PyNativeAlgo::Common::ReplaceCNodeWithValueNode(ad_graph);
559   }
560   // Restore ad param
561   ad_param_ = current_ad_param;
562   return ad_graph;
563 }
564 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,expander::bprop::UserType * user,bool need_update)565 void IrBprop::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, expander::bprop::UserType *user,
566                       bool need_update) {
567   MS_EXCEPTION_IF_NULL(user);
568   if (user->find(old_node) == user->end()) {
569     return;
570   }
571   const auto &old_node_users = (*user)[old_node];
572   for (const auto &pair_node : old_node_users) {
573     auto cnode = pair_node.first.lock();
574     if (cnode == nullptr) {
575       continue;
576     }
577     size_t index = pair_node.second;
578     if (index >= cnode->size()) {
579       // After convert attr cnode input will less
580       if (auto v = cnode->GetAttr(kAttrConvertAttrNode); v != nullptr) {
581         index -= GetValue<size_t>(v);
582       } else {
583         MS_LOG(EXCEPTION) << "exception for index: " << index << "greater than cnode size: " << cnode->size();
584       }
585     }
586     cnode->set_input(index, new_node);
587     if (need_update && IsPrimitiveCNode(new_node, prim::kPrimTupleGetItem)) {
588       AddTupleGetItemUser(new_node, cnode, index);
589     }
590   }
591 }
592 
GradGraphByExpander(const GradParamPtr & grad_param)593 void IrBprop::GradGraphByExpander(const GradParamPtr &grad_param) {
594   MS_EXCEPTION_IF_NULL(grad_param);
595   if (pass_forward_->need_reverse_graph()) {
596     pass_forward_->ReversePassFuncGraph(grad_param->fg);
597   }
598 
599   // First handle parameters
600   CreateParameterAdjoint(grad_param);
601 
602   // Second handle cnodes
603   const auto &order = TopoSort(grad_param->fg->output());
604   for (const auto &node : order) {
605     if (node == nullptr || !node->isa<CNode>()) {
606       continue;
607     }
608     auto cnode = node->cast<CNodePtr>();
609     MS_EXCEPTION_IF_NULL(cnode);
610     auto prim = GetCNodePrimitive(cnode);
611     if (prim == nullptr) {
612       MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
613     }
614     ad_param_->last_node_ = cnode;
615     if (ProcessMonadNode(prim, cnode, grad_param) || IsPrimitiveEquals(prim, prim::kPrimStopGradient)) {
616       continue;
617     }
618     MS_LOG(DEBUG) << "Get cnode " << cnode->DebugString() << ", " << cnode->fullname_with_scope();
619     ValuePtrList inputs_value;
620     AnfNodePtrList cnode_inputs;
621     PrepareGradCNodeInputs(prim, cnode, &inputs_value, &cnode_inputs);
622     // Do grad for every cnode
623     GradCNode(prim, cnode, grad_param, inputs_value, &cnode_inputs);
624   }
625 }
626 
CreateParameterAdjoint(const GradParamPtr & grad_param) const627 void IrBprop::CreateParameterAdjoint(const GradParamPtr &grad_param) const {
628   auto &graph_parameters = grad_param->fg->parameters();
629   if (graph_parameters.size() != grad_param->input_size) {
630     MS_LOG(EXCEPTION) << "Parameters size " << graph_parameters.size() << " is not equal to graph input size "
631                       << grad_param->input_size;
632   }
633   for (size_t i = 0; i < graph_parameters.size(); ++i) {
634     MS_LOG(DEBUG) << "Get param " << graph_parameters[i]->DebugString();
635     ParameterPtr param = ad_param_->tape_->add_parameter();
636     param->set_abstract(graph_parameters[i]->abstract());
637     auto zeros_like_dout =
638       PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(),
639                                                graph_parameters[i]->abstract(), SpecialType::kZerosLikeType);
640     auto func_node = std::make_shared<IrFunctionNode>(ad_param_->tape_, zeros_like_dout);
641     // Copy to avoid corrupt real input grad info.
642     auto op_arg = PyNativeAlgo::Common::CreateFakeValueWithoutDeviceAddress(grad_param->op_grad_info->input_value[i]);
643     ClearGradMetaData(op_arg);
644     auto adjoint = std::make_shared<IrVariable>(func_node, op_arg, true);
645     adjoint->set_k_node(param);
646     PyNativeAlgo::AutoGrad::SetGradMetaData(op_arg, adjoint, graph_parameters[i]->cast<ParameterPtr>());
647     (void)ad_param_->variable_adjoint_set_.insert(adjoint);
648     (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(graph_parameters[i], adjoint));
649   }
650 }
651 
PrepareGradCNodeInputs(const PrimitivePtr & prim,const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)652 void IrBprop::PrepareGradCNodeInputs(const PrimitivePtr &prim, const CNodePtr &cnode, ValuePtrList *inputs_value,
653                                      AnfNodePtrList *cnode_inputs) {
654   MS_EXCEPTION_IF_NULL(cnode);
655   MS_EXCEPTION_IF_NULL(inputs_value);
656   MS_EXCEPTION_IF_NULL(cnode_inputs);
657   (void)cnode_inputs->emplace_back(std::make_shared<ValueNode>(prim));
658   *inputs_value = GetInputArgs(cnode, cnode_inputs);
659   pass_forward_->ReversePassCNode(cnode, inputs_value, cnode_inputs);
660 }
661 
GetInputArgs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs) const662 ValuePtrList IrBprop::GetInputArgs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs) const {
663   MS_EXCEPTION_IF_NULL(cnode);
664   MS_EXCEPTION_IF_NULL(cnode_inputs);
665   ValuePtrList input_value;
666   for (size_t i = 1; i < cnode->size(); ++i) {
667     const auto &input_node = cnode->input(i);
668     // Find knode and out value
669     const auto it = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
670     if (it != ad_param_->anfnode_to_variable_adjoint_.end()) {
671       (void)cnode_inputs->emplace_back(it->second->k_node());
672       (void)input_value.emplace_back(it->second->out_value());
673       continue;
674     }
675     if (input_node->isa<ValueNode>()) {
676       auto v_node = input_node->cast<ValueNodePtr>();
677       auto v = v_node->value();
678       if (v != nullptr && v->isa<tensor::BaseTensor>()) {
679         const auto &t = v->cast<tensor::BaseTensorPtr>();
680         const auto &grad_meta = t->auto_grad_meta_data();
681         // Jit forward graph has no parameters(input is tuple or constant), so input used in graph as valuenode, but it
682         // is used by tape_ as parameter also
683         if (grad_meta != nullptr && PyNativeAlgo::Common::IsParam(grad_meta->input_type())) {
684           auto new_tensor = std::make_shared<tensor::Tensor>(t->data_type(), t->shape(), t->data_ptr());
685           new_tensor->set_device_address(t->device_address());
686           v = new_tensor;
687         }
688       }
689       (void)PyNativeAlgo::Common::SetValueGradInfo(v, nullptr, InputType::kConstant);
690       // In case of jit forward graph and pynative bprop graph used same valuenode
691       auto new_v_node = PyNativeAlgo::Common::CreateValueNodeByValue(v, v_node->abstract());
692       (void)cnode_inputs->emplace_back(new_v_node);
693       (void)input_value.emplace_back(v);
694     } else {
695       // Make Fake value
696       auto v = MakeValue<int64_t>(0);
697       (void)cnode_inputs->emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(v, input_node->abstract()));
698       (void)input_value.emplace_back(v);
699       MS_LOG(DEBUG) << "Get input node " << input_node->DebugString();
700     }
701   }
702   return input_value;
703 }
704 
GradCNode(const PrimitivePtr & prim,const CNodePtr & cnode,const GradParamPtr & grad_param,const ValuePtrList & inputs_value,AnfNodePtrList * cnode_inputs)705 void IrBprop::GradCNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param,
706                         const ValuePtrList &inputs_value, AnfNodePtrList *cnode_inputs) {
707   MS_EXCEPTION_IF_NULL(prim);
708   MS_EXCEPTION_IF_NULL(cnode);
709   bool jit_by_value = grad_param->is_jit_graph && grad_by_value_;
710   if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
711     (void)BuildKNodeForMakeTuple(cnode);
712     return;
713   } else if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
714     (void)BuildKNodeForTupleGetItem(cnode);
715     return;
716   }
717   MS_EXCEPTION_IF_NULL(cnode_inputs);
718   auto k_node = GetKnode(prim, cnode, *cnode_inputs, jit_by_value);
719   if (bprop_graph_run_by_single_op_ && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
720       std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
721         MS_EXCEPTION_IF_NULL(node->abstract());
722         return node->abstract()->isa<abstract::AbstractSequence>();
723       })) {
724     k_node->cast<CNodePtr>()->AddAttr(kAttrIsPyboostTupleInput, MakeValue(true));
725   }
726   MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
727   // Set out
728   auto out = PyNativeAlgo::Common::CreatOutputTensorValueByAbstract(cnode->abstract());
729   (void)cnode_inputs->emplace_back(k_node);
730   // Set dout
731   AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(
732     ad_param_->tape_, PyNativeAlgo::AutoGrad::GetFakeZeroTensor(), cnode->abstract(), SpecialType::kZerosLikeType);
733   (void)cnode_inputs->emplace_back(dout);
734   auto input_node = ad_param_->tape_->FuncGraph::NewCNode(*cnode_inputs);
735   input_node->set_abstract(cnode->abstract());
736 
737   std::vector<CNodePtr> outputs;
738   // Get bprop by expander
739   auto ret = BpropExpander(&outputs, &ad_param_->users_).Run(input_node);
740   if (!ret || outputs.empty()) {
741     // Get bprop by python custom
742     MS_LOG(DEBUG) << "Expander has no bprop of this node: " << input_node->DebugString();
743     BuildCustomBpropCNode(input_node, prim, &outputs);
744   }
745 
746   auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
747   auto variable_adjoint = std::make_shared<IrVariable>(fn, out);
748   variable_adjoint->set_k_node(k_node);
749   // Get bprop by fake bprop
750   if (outputs.empty()) {
751     MS_LOG(DEBUG) << "Build fake bprop for this node: " << input_node->DebugString();
752     PyNativeAlgo::AutoGrad::BuildFakeBpropCNode(input_node, &outputs);
753     variable_adjoint->set_is_fake_bprop(true);
754     variable_adjoint->set_fake_prim_name(prim->name());
755   }
756   // Create current op node din edge
757   AbstractBasePtrList input_abs;
758   for (size_t i = 1; i < cnode->size(); ++i) {
759     (void)input_abs.emplace_back(cnode->input(i)->abstract());
760   }
761   UpdateNextEdges(variable_adjoint, outputs, inputs_value, input_abs);
762   PyNativeAlgo::AutoGrad::SetGradMetaData(out, variable_adjoint);
763   (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(cnode, variable_adjoint));
764   (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
765 }
766 
BuildKNodeForMakeTuple(const AnfNodePtr & input_node)767 AnfNodePtr IrBprop::BuildKNodeForMakeTuple(const AnfNodePtr &input_node) {
768   MS_EXCEPTION_IF_NULL(input_node);
769   MS_LOG(DEBUG) << "Build knode for MakeTuple " << input_node->DebugString();
770   const auto &cnode = input_node->cast<CNodePtr>();
771   MS_EXCEPTION_IF_NULL(cnode);
772   AnfNodePtrList inputs{NewValueNode(prim::kPrimMakeTuple)};
773   ValuePtrList input_value;
774   AbstractBasePtrList input_abs;
775   for (size_t i = 1; i < cnode->size(); ++i) {
776     (void)inputs.emplace_back(BuildKNodeForCNodeInput(cnode->input(i)));
777     if (cnode->input(i)->isa<CNode>() || cnode->input(i)->isa<Parameter>()) {
778       const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(cnode->input(i));
779       if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
780         MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
781       }
782       (void)input_value.emplace_back(input_adjoint_iter->second->out_value());
783       (void)input_abs.emplace_back(cnode->input(i)->abstract());
784     } else {
785       auto value_node = cnode->input(i)->cast<ValueNodePtr>();
786       MS_EXCEPTION_IF_NULL(value_node);
787       (void)input_value.emplace_back(value_node->value());
788       (void)input_abs.emplace_back(value_node->abstract());
789     }
790   }
791   auto out_value = MakeValue(input_value);
792   AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, input_node->abstract(),
793                                                              SpecialType::kZerosLikeType);
794   auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
795   auto variable_adjoint = std::make_shared<IrVariable>(fn, out_value);
796   auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
797   k_node->set_abstract(input_node->abstract());
798   variable_adjoint->set_k_node(k_node);
799   // Create dout for maketuple
800   std::vector<CNodePtr> make_tuple_dout;
801   for (size_t i = 1; i < cnode->size(); ++i) {
802     auto d = ad_param_->tape_->FuncGraph::NewCNode(
803       {NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(SizeToLong(i - 1))});
804     d->set_abstract(cnode->input(i)->abstract());
805     (void)make_tuple_dout.emplace_back(d);
806     AddUser(dout, d, 1);
807   }
808   UpdateNextEdges(variable_adjoint, make_tuple_dout, input_value, input_abs);
809   (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
810   (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
811   return k_node;
812 }
813 
BuildKNodeForCNodeInput(const AnfNodePtr & input_node)814 AnfNodePtr IrBprop::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
815   MS_EXCEPTION_IF_NULL(input_node);
816   if (input_node->isa<CNode>()) {
817     const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
818     if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
819       if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
820         return BuildKNodeForMakeTuple(input_node);
821       } else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
822         return BuildKNodeForTupleGetItem(input_node);
823       }
824       MS_LOG(EXCEPTION) << "Can not find input in adjoint map, inp: " << input_node->DebugString();
825     }
826     return input_adjoint_iter->second->k_node();
827   } else {
828     // Tuple sens will come in
829     if (input_node->isa<Parameter>()) {
830       const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(input_node);
831       if (input_adjoint_iter != ad_param_->anfnode_to_variable_adjoint_.end() &&
832           input_adjoint_iter->second->k_node() != nullptr) {
833         return input_adjoint_iter->second->k_node();
834       }
835     }
836     return input_node;
837   }
838 }
839 
BuildKNodeForTupleGetItem(const AnfNodePtr & input_node)840 AnfNodePtr IrBprop::BuildKNodeForTupleGetItem(const AnfNodePtr &input_node) {
841   MS_EXCEPTION_IF_NULL(input_node);
842   MS_LOG(DEBUG) << "Build knode for TupleGetItem " << input_node->DebugString();
843   const auto &tuple_item_cnode = input_node->cast<CNodePtr>();
844   MS_EXCEPTION_IF_NULL(tuple_item_cnode);
845   // Find make tuple or sens(tuple) node for get out value
846   const auto input_adjoint_iter = ad_param_->anfnode_to_variable_adjoint_.find(tuple_item_cnode->input(kIndex1));
847   if (input_adjoint_iter == ad_param_->anfnode_to_variable_adjoint_.end()) {
848     MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << tuple_item_cnode->input(kIndex1)->DebugString();
849   }
850   const auto &v_tuple = input_adjoint_iter->second->out_value()->cast<ValueSequencePtr>();
851   MS_EXCEPTION_IF_NULL(v_tuple);
852   auto index_value = GetValueNode<Int64ImmPtr>(tuple_item_cnode->input(kIndex2));
853   auto index_value_int = LongToSize(index_value->value());
854   auto out_value = (*v_tuple)[index_value_int];
855   MS_EXCEPTION_IF_NULL(out_value);
856   AnfNodePtr dout = PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, input_node->abstract(),
857                                                              SpecialType::kZerosLikeType);
858   auto fn = std::make_shared<IrFunctionNode>(ad_param_->tape_, dout);
859   auto variable_adjoint = std::make_shared<IrVariable>(fn, out_value);
860 
861   AnfNodePtrList inputs{NewValueNode(prim::kPrimTupleGetItem)};
862   // Get make tuple knode
863   (void)inputs.emplace_back(BuildKNodeForCNodeInput(tuple_item_cnode->input(kIndex1)));
864   // Get index knode
865   (void)inputs.emplace_back(BuildKNodeForCNodeInput(tuple_item_cnode->input(kIndex2)));
866   auto k_node = ad_param_->tape_->FuncGraph::NewCNode(inputs);
867   k_node->set_abstract(input_node->abstract());
868   variable_adjoint->set_k_node(k_node);
869   // Create dout for tuplegetitem
870   AnfNodePtrList tuple_getitem_dout{NewValueNode(prim::kPrimMakeTuple)};
871   const auto &abs_tuple = tuple_item_cnode->input(kIndex1)->abstract()->cast<abstract::AbstractSequencePtr>();
872   for (size_t i = 0; i < v_tuple->size(); ++i) {
873     const auto &v = v_tuple->value()[i];
874     if (i == index_value_int) {
875       (void)tuple_getitem_dout.emplace_back(dout);
876     } else {
877       (void)tuple_getitem_dout.emplace_back(PyNativeAlgo::AutoGrad::BuildSpecialNode(
878         ad_param_->tape_, v, abs_tuple->elements()[i], SpecialType::kZerosLikeType));
879     }
880   }
881   CNodePtr tuple_getitem_dout_value = ad_param_->tape_->FuncGraph::NewCNode(tuple_getitem_dout);
882   tuple_getitem_dout_value->set_abstract(tuple_item_cnode->input(kIndex1)->abstract());
883   auto index_dout_value =
884     PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, index_value,
885                                              tuple_item_cnode->input(kIndex1)->abstract(), SpecialType::kZerosLikeType)
886       ->cast<CNodePtr>();
887   UpdateNextEdges(variable_adjoint, {tuple_getitem_dout_value, index_dout_value}, {v_tuple, index_value},
888                   {tuple_item_cnode->input(kIndex1)->abstract(), tuple_item_cnode->input(kIndex2)->abstract()});
889   AddUser(dout, tuple_getitem_dout_value, index_value_int + 1);
890   (void)ad_param_->anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
891   (void)ad_param_->variable_adjoint_set_.insert(variable_adjoint);
892   return k_node;
893 }
894 
GetKnode(const PrimitivePtr & prim,const CNodePtr & cnode,const AnfNodePtrList & cnode_inputs,bool jit_by_value)895 AnfNodePtr IrBprop::GetKnode(const PrimitivePtr &prim, const CNodePtr &cnode, const AnfNodePtrList &cnode_inputs,
896                              bool jit_by_value) {
897   if (IsPrimitiveEquals(prim, prim::kPrimMirror)) {
898     return ad_param_->anfnode_to_variable_adjoint_.at(cnode->input(kIndex1))->k_node();
899   } else {
900     auto c_k_node = ad_param_->tape_->FuncGraph::NewCNode(cnode_inputs);
901     c_k_node->set_abstract(cnode->abstract());
902     // In jit, copy forward graph cnode info to bprop graph
903     if (jit_by_value && cnode->forward().first != nullptr) {
904       auto new_v_node = PyNativeAlgo::Common::CreateValueNodeByValue(cnode->forward().first->value(),
905                                                                      cnode->forward().first->abstract());
906       c_k_node->set_forward(new_v_node, cnode->forward().second);
907       ad_param_->tape_->set_used_forward_nodes({c_k_node});
908     }
909     c_k_node->AddAttr(bprop_pass::kIsKNode, MakeValue(true));
910     return c_k_node;
911   }
912 }
913 
UpdateNextEdgeForDict(const IrFunctionNodePtr & fn,const AnfNodePtr & din,const ValuePtr & input_arg,const AbstractBasePtr & abs)914 void IrBprop::UpdateNextEdgeForDict(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
915                                     const AbstractBasePtr &abs) {
916   auto value_dict = input_arg->cast<ValueDictionaryPtr>()->value();
917   const auto &abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
918   MS_EXCEPTION_IF_NULL(abs_dict);
919   if (value_dict.size() != abs_dict->size()) {
920     MS_LOG(EXCEPTION) << "Get value dict size " << value_dict.size() << " not equal to abstract size "
921                       << abs_dict->size();
922   }
923   for (size_t i = 0; i < value_dict.size(); ++i) {
924     auto sub_value = value_dict[i];
925     auto key_item = PyNativeAlgo::Common::CreateValueNodeByValue(sub_value.first, abs_dict->elements()[i].first);
926     CNodePtr new_din = ad_param_->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimDictGetItem), din, key_item});
927     new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs_dict->elements()[i].second));
928     if (din == fn->fake_dout()) {
929       // The new_din's index input is fn->fake_dout()
930       LazyAddUser(fn->fake_dout(), new_din, 1);
931     }
932     // Add next edge to fn
933     UpdateNextEdge(fn, new_din, sub_value.second, abs_dict->elements()[i].second);
934   }
935 }
936 
UpdateNextEdge(const IrFunctionNodePtr & fn,const AnfNodePtr & din,const ValuePtr & input_arg,const AbstractBasePtr & abs)937 void IrBprop::UpdateNextEdge(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
938                              const AbstractBasePtr &abs) {
939   MS_EXCEPTION_IF_NULL(din);
940   MS_EXCEPTION_IF_NULL(input_arg);
941   if (input_arg->isa<tensor::BaseTensor>()) {
942     tensor::BaseTensorPtr input_tensor = nullptr;
943     input_tensor = input_arg->cast<tensor::BaseTensorPtr>();
944     auto auto_grad_meta_data = input_tensor->auto_grad_meta_data();
945     MS_EXCEPTION_IF_NULL(auto_grad_meta_data);
946     auto variable = auto_grad_meta_data->variable();
947     if (variable == nullptr || !variable->is_need_grad()) {
948       return;
949     }
950     auto real_din = HandleRealToComplex(input_tensor, abs, din, fn->tape());
951     auto new_din = TraceInput(fn, variable->out_value(), variable->ir_function_node()->accumulate_dout()->abstract(),
952                               input_tensor, real_din);
953     fn->AddNextEdge(variable, new_din);
954   } else if (input_arg->isa<ValueSequence>()) {
955     auto value_seq = input_arg->cast<ValueSequencePtr>()->value();
956     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
957     MS_EXCEPTION_IF_NULL(abs_seq);
958     if (value_seq.size() != abs_seq->size()) {
959       MS_LOG(EXCEPTION) << "Get value sequence size " << value_seq.size() << " not equal to abstract size "
960                         << abs_seq->size();
961     }
962     for (size_t i = 0; i < value_seq.size(); ++i) {
963       auto sub_value = value_seq[i];
964       CNodePtr new_din = ad_param_->tape_->FuncGraph::NewCNode(
965         {NewValueNode(prim::kPrimTupleGetItem), din, NewValueNode(SizeToLong(i))});
966       new_din->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(abs_seq->elements()[i]));
967       if (din == fn->fake_dout()) {
968         // The new_din's index input is fn->fake_dout()
969         LazyAddUser(fn->fake_dout(), new_din, 1);
970       }
971       // Add next edge to fn
972       UpdateNextEdge(fn, new_din, sub_value, abs_seq->elements()[i]);
973     }
974   } else if (input_arg->isa<tensor::COOTensor>()) {
975     auto input_tensor = input_arg->cast<tensor::COOTensorPtr>()->GetIndices();
976     UpdateNextEdge(fn, din, input_tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_tensor->ToAbstract()));
977   } else if (input_arg->isa<tensor::CSRTensor>()) {
978     auto input_tensor = input_arg->cast<tensor::CSRTensorPtr>()->GetIndices();
979     UpdateNextEdge(fn, din, input_tensor, PyNativeAlgo::Common::SetAbstractValueToAnyValue(input_tensor->ToAbstract()));
980   } else if (input_arg->isa<ValueDictionary>()) {
981     UpdateNextEdgeForDict(fn, din, input_arg, abs);
982   } else {
983     MS_LOG(DEBUG) << "It is not tensor, not need derivation " << input_arg->ToString();
984     return;
985   }
986 }
987 
TraceInput(const IrFunctionNodePtr & fn,const ValuePtr & out_value,const abstract::AbstractBasePtr & out_abs,const tensor::BaseTensorPtr & input_tensor,const AnfNodePtr & din)988 AnfNodePtr IrBprop::TraceInput(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
989                                const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor,
990                                const AnfNodePtr &din) {
991   MS_EXCEPTION_IF_NULL(out_value);
992   MS_EXCEPTION_IF_NULL(out_abs);
993   MS_EXCEPTION_IF_NULL(input_tensor);
994   MS_EXCEPTION_IF_NULL(din);
995 
996   // The node corresponding output tensor is the same as the currently used tensor
997   if (out_value->isa<tensor::BaseTensor>()) {
998     // out_value is be used, may be it is one of multiple output
999     auto out_tensor = out_value->cast<tensor::BaseTensorPtr>();
1000     if (input_tensor->id() == out_tensor->id()) {
1001       return din;
1002     }
1003     return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, out_abs, SpecialType::kZerosLikeType);
1004   } else if (out_value->isa<ValueSequence>()) {
1005     // The corresponding output of node is ValueSequence, but used one of it
1006     AnfNodePtrList inputs;
1007     (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1008     auto value_seq = out_value->cast<ValueSequencePtr>();
1009     auto abs_seq = out_abs->cast<abstract::AbstractSequencePtr>();
1010     if (abs_seq == nullptr) {
1011       MS_LOG(EXCEPTION) << "Get output abstract " << out_abs->ToString() << ", not abstract sequence";
1012     }
1013     int index = -1;
1014     for (size_t i = 0; i < value_seq->size(); ++i) {
1015       // Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
1016       // din is zero , which set by second branch condition above
1017       auto new_din = TraceInput(fn, value_seq->value()[i], abs_seq->elements()[i], input_tensor, din);
1018       (void)inputs.emplace_back(new_din);
1019 
1020       // if exist din == fake_dout, we record it in user vector
1021       if (din == fn->fake_dout() && new_din == din) {
1022         index = static_cast<int>(inputs.size()) - 1;
1023       }
1024     }
1025     auto new_din = ad_param_->tape_->FuncGraph::NewCNode(inputs);
1026     new_din->set_abstract(out_abs);
1027     if (index != -1) {
1028       LazyAddUser(fn->fake_dout(), new_din, index);
1029     }
1030     return new_din;
1031   } else if (out_value->isa<ValueDictionary>()) {
1032     return TraceInputForDict(fn, out_value, out_abs, input_tensor, din);
1033   }
1034   MS_LOG(DEBUG) << "Get non tensor input " << out_value->ToString();
1035   return PyNativeAlgo::AutoGrad::BuildSpecialNode(ad_param_->tape_, out_value, out_abs, SpecialType::kZerosLikeType);
1036 }
1037 
TraceInputForDict(const IrFunctionNodePtr & fn,const ValuePtr & out_value,const abstract::AbstractBasePtr & out_abs,const tensor::BaseTensorPtr & input_tensor,const AnfNodePtr & din)1038 AnfNodePtr IrBprop::TraceInputForDict(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
1039                                       const abstract::AbstractBasePtr &out_abs,
1040                                       const tensor::BaseTensorPtr &input_tensor, const AnfNodePtr &din) {
1041   // The corresponding output of node is ValueDictionary, but used one of it
1042   AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1043   AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
1044   abstract::AbstractBasePtrList local_key_abs_inputs;
1045   abstract::AbstractBasePtrList local_value_abs_inputs;
1046   auto value_dict = out_value->cast<ValueDictionaryPtr>();
1047   auto abs_dict = out_abs->cast<abstract::AbstractDictionaryPtr>();
1048   MS_EXCEPTION_IF_NULL(abs_dict);
1049   int index = -1;
1050   for (size_t i = 0; i < value_dict->size(); ++i) {
1051     // Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
1052     // din is zero, which set by second branch condition above
1053     (void)key_inputs.emplace_back(
1054       PyNativeAlgo::Common::CreateValueNodeByValue(value_dict->value()[i].first, abs_dict->elements()[i].first));
1055     (void)local_key_abs_inputs.emplace_back(abs_dict->elements()[i].first);
1056     auto new_din = TraceInput(fn, value_dict->value()[i].second, abs_dict->elements()[i].second, input_tensor, din);
1057     (void)value_inputs.emplace_back(new_din);
1058     (void)local_value_abs_inputs.emplace_back(abs_dict->elements()[i].second);
1059 
1060     // if exist din == fake_dout, we record it in user vector
1061     if (din == fn->fake_dout() && new_din == din) {
1062       index = static_cast<int>(value_inputs.size()) - 1;
1063     }
1064   }
1065   auto local_key_node = ad_param_->tape_->NewCNode(key_inputs);
1066   local_key_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_key_abs_inputs));
1067   auto local_value_node = ad_param_->tape_->NewCNode(value_inputs);
1068   local_value_node->set_abstract(std::make_shared<abstract::AbstractTuple>(local_value_abs_inputs));
1069   auto new_din = ad_param_->tape_->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
1070   new_din->set_abstract(abs_dict);
1071   if (index != -1) {
1072     LazyAddUser(fn->fake_dout(), new_din, index);
1073   }
1074   return new_din;
1075 }
1076 
AddTupleGetItemUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)1077 void IrBprop::AddTupleGetItemUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
1078   (void)ad_param_->users_.tuple_getitem_user_[node].emplace_back(user, index);
1079 }
1080 
UpdateLazyUser()1081 void IrBprop::UpdateLazyUser() {
1082   // For lazy add user data, we need emplace to user.
1083   for (const auto &user_data : ad_param_->lazy_user_data_) {
1084     AddUser(std::get<kIndex0>(user_data), std::get<kIndex1>(user_data), std::get<kIndex2>(user_data));
1085   }
1086 }
1087 
LazyAddUser(const AnfNodePtr & node,const CNodePtr & user,size_t index)1088 void IrBprop::LazyAddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index) {
1089   MS_EXCEPTION_IF_NULL(node);
1090   MS_EXCEPTION_IF_NULL(user);
1091   (void)ad_param_->lazy_user_data_.emplace_back(node, user, index);
1092 }
1093 }  // namespace mindspore::pynative::autograd
1094