• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include <algorithm>
27 #include "ir/anf.h"
28 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
29 #include "frontend/optimizer/ad/adjoint.h"
30 #include "frontend/optimizer/ad/dfunctor.h"
31 #include "frontend/optimizer/ad/kpynative.h"
32 #include "frontend/operator/ops.h"
33 #include "utils/info.h"
34 #include "debug/anf_ir_dump.h"
35 #include "debug/trace.h"
36 
37 namespace mindspore {
38 namespace ad {
39 using CacheKey = std::pair<std::string, size_t>;
40 
41 static KPrim g_k_prims_pynative;
42 static ValuePtr add_ops;
43 static ValuePtr ones_like_ops;
44 static ValuePtr zeros_like_ops;
45 static std::shared_ptr<const opt::irpass::OptimizeIRPassLib> irpass;
46 static std::map<CacheKey, FuncGraphPtr> bprop_func_graph_cache;
47 static std::unordered_map<abstract::AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
48                           abstract::AbstractBasePtrListEqual>
49   zeros_like_funcgraph_cache;
50 static std::unordered_map<abstract::AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
51                           abstract::AbstractBasePtrListEqual>
52   ones_like_funcgraph_cache;
53 
54 namespace {
ZerosLikePrimOptPass(const pipeline::ResourcePtr & res)55 FuncGraphPtr ZerosLikePrimOptPass(const pipeline::ResourcePtr &res) {
56   if (irpass == nullptr) {
57     irpass = std::make_shared<opt::irpass::OptimizeIRPassLib>();
58   }
59   opt::OptPassConfig eliminate_zeros_like_prim_pass = opt::OptPassConfig({
60     irpass->zero_like_fill_zero_,
61   });
62 
63   opt::OptPassGroupMap map({{"eliminate_zeros_like_prim_", eliminate_zeros_like_prim_pass}});
64 
65   auto eliminate_zeros_like_prim = opt::Optimizer::MakeOptimizer("eliminate_zeros_like_prim", res, map);
66   FuncGraphPtr func_graph = res->func_graph();
67   WITH(MsProfile::GetProfile()->Step("eliminate_zeros_like_prim"))[&eliminate_zeros_like_prim, &func_graph]() {
68     func_graph = eliminate_zeros_like_prim->step(func_graph, true);
69   };
70   return func_graph;
71 }
72 
GetZerosLike(const abstract::AbstractBasePtrList & args_spec)73 FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
74   if (zeros_like_ops == nullptr) {
75     zeros_like_ops = prim::GetPythonOps("zeros_like");
76   }
77   auto iter = zeros_like_funcgraph_cache.find(args_spec);
78   if (iter != zeros_like_funcgraph_cache.end()) {
79     MS_LOG(DEBUG) << "Cache hit for zeros_like: " << mindspore::ToString(args_spec);
80     return BasicClone(iter->second);
81   }
82   if (!zeros_like_ops->isa<MetaFuncGraph>()) {
83     MS_LOG(EXCEPTION) << "zeros_like is not a MetaFuncGraph";
84   }
85   auto zeros_like = zeros_like_ops->cast<MetaFuncGraphPtr>();
86   auto zeros_like_fg = zeros_like->GenerateFuncGraph(args_spec);
87   MS_EXCEPTION_IF_NULL(zeros_like_fg);
88   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
89   auto specialized_zeros_like_fg = pipeline::Renormalize(resource, zeros_like_fg, args_spec);
90   MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
91   auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
92   MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
93   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
94   if (enable_grad_cache) {
95     zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
96   }
97   return opted_zeros_like_fg;
98 }
99 
GetHyperAdd(const abstract::AbstractBasePtrList & args_spec)100 FuncGraphPtr GetHyperAdd(const abstract::AbstractBasePtrList &args_spec) {
101   if (add_ops == nullptr) {
102     add_ops = prim::GetPythonOps("hyper_add");
103   }
104   if (!add_ops->isa<MetaFuncGraph>()) {
105     MS_LOG(EXCEPTION) << "add is not a MetaFuncGraph";
106   }
107   auto add = add_ops->cast<MetaFuncGraphPtr>();
108   auto add_fg = add->GenerateFuncGraph(args_spec);
109   MS_EXCEPTION_IF_NULL(add_fg);
110   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
111   auto specialized_add_fg = pipeline::Renormalize(resource, add_fg, args_spec);
112   MS_EXCEPTION_IF_NULL(specialized_add_fg);
113   return specialized_add_fg;
114 }
115 
BuildZerosLikeNode(const FuncGraphPtr & tape,const AnfNodePtr & node)116 AnfNodePtr BuildZerosLikeNode(const FuncGraphPtr &tape, const AnfNodePtr &node) {
117   // Build zeros_like(node) as dout
118   abstract::AbstractBasePtrList args_spec{node->abstract()->Broaden()};
119   auto zeros_like_fg = GetZerosLike(args_spec);
120   auto zeros_like_node = tape->NewCNode({NewValueNode(zeros_like_fg), node});
121   zeros_like_node->set_abstract(zeros_like_fg->output()->abstract());
122   return zeros_like_node;
123 }
124 
BuildZerosLikeValue(const FuncGraphPtr & tape,const ValuePtr & out)125 AnfNodePtr BuildZerosLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
126   // Build zeros_like(out) as dout
127   abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()};
128   auto zeros_like_fg = GetZerosLike(args_spec);
129   auto zeros_like_value = tape->NewCNode({NewValueNode(zeros_like_fg), NewValueNode(out)});
130   zeros_like_value->set_abstract(zeros_like_fg->output()->abstract());
131   return zeros_like_value;
132 }
133 
GetOnesLike(const abstract::AbstractBasePtrList & args_spec)134 FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
135   if (ones_like_ops == nullptr) {
136     ones_like_ops = prim::GetPythonOps("ones_like");
137   }
138   auto iter = ones_like_funcgraph_cache.find(args_spec);
139   if (iter != ones_like_funcgraph_cache.end()) {
140     MS_LOG(DEBUG) << "Cache hit for ones_like: " << mindspore::ToString(args_spec);
141     return BasicClone(iter->second);
142   }
143   if (!ones_like_ops->isa<MetaFuncGraph>()) {
144     MS_LOG(EXCEPTION) << "ones_like is not a MetaFuncGraph";
145   }
146   auto ones_like = ones_like_ops->cast<MetaFuncGraphPtr>();
147   auto ones_like_fg = ones_like->GenerateFuncGraph(args_spec);
148   MS_EXCEPTION_IF_NULL(ones_like_fg);
149   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
150   auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
151   MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
152   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
153   if (enable_grad_cache) {
154     ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
155   }
156   return specialized_ones_like_fg;
157 }
158 
BuildOnesLikeValue(const FuncGraphPtr & tape,const ValuePtr & out)159 AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
160   // Build ones_like(out) as dout
161   abstract::AbstractBasePtrList args_spec{out->ToAbstract()->Broaden()};
162   auto ones_like_fg = GetOnesLike(args_spec);
163   auto ones_like_value = tape->NewCNode({NewValueNode(ones_like_fg), NewValueNode(out)});
164   ones_like_value->set_abstract(ones_like_fg->output()->abstract());
165   return ones_like_value;
166 }
167 
168 // This Faked BProp func_graph should not be present in the final top bprop func_graph.
BuildFakeBProp(const PrimitivePtr & prim,size_t inputs_num)169 FuncGraphPtr BuildFakeBProp(const PrimitivePtr &prim, size_t inputs_num) {
170   auto func_graph = std::make_shared<FuncGraph>();
171   std::vector<AnfNodePtr> outputs;
172   outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
173 
174   auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
175   (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
176   auto fake_input_sens = func_graph->NewCNode({NewValueNode(fake_bprop), NewValueNode(true)});
177 
178   for (size_t i = 0; i < inputs_num; ++i) {
179     // Mock params for inputs
180     auto param = func_graph->add_parameter();
181     MS_EXCEPTION_IF_NULL(param);
182     // Mock derivatives for each inputs
183     outputs.push_back(fake_input_sens);
184   }
185   // mock params for out and dout
186   (void)func_graph->add_parameter();
187   (void)func_graph->add_parameter();
188   func_graph->set_output(func_graph->NewCNode(outputs));
189   return func_graph;
190 }
191 }  // namespace
192 
193 class PynativeAdjoint {
194  public:
195   enum FuncGraphType { kForwardPropagate, kBackwardPropagate };
PynativeAdjoint(const FuncGraphPtr & tape,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fg,FuncGraphType fg_type=kBackwardPropagate)196   PynativeAdjoint(const FuncGraphPtr &tape, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &fg,
197                   FuncGraphType fg_type = kBackwardPropagate)
198       : tape_(tape), op_args_(op_args), out_(out), fg_(fg), fg_type_(fg_type) {}
199 
200   ~PynativeAdjoint() = default;
users()201   AnfNodePtrList &users() { return users_; }
op_args() const202   const ValuePtrList &op_args() const { return op_args_; }
out() const203   const ValuePtr &out() const { return out_; }
fg() const204   const FuncGraphPtr &fg() const { return fg_; }
fg_type() const205   const FuncGraphType &fg_type() const { return fg_type_; }
RealDout()206   AnfNodePtr RealDout() {
207     if (dout_ != nullptr) {
208       return dout_;
209     }
210     return BuildZerosLikeValue(tape_, out_);
211   }
212 
AccumulateDout(const AnfNodePtr & dout_factor)213   void AccumulateDout(const AnfNodePtr &dout_factor) {
214     if (dout_factor->abstract() == nullptr) {
215       MS_LOG(EXCEPTION) << "Abstract of dout_factor should not be null: " << dout_factor->ToString();
216     }
217     if (dout_ != nullptr) {
218       MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString();
219       auto arg = out_->ToAbstract()->Broaden();
220       abstract::AbstractBasePtrList args_spec{arg, arg};
221       auto add_fg = GetHyperAdd(args_spec);
222       MS_EXCEPTION_IF_NULL(add_fg);
223       dout_ = tape_->NewCNode({NewValueNode(add_fg), dout_, dout_factor});
224       dout_->set_abstract(add_fg->output()->abstract());
225       MS_LOG(DEBUG) << "New dout_ " << dout_->DebugString();
226       return;
227     }
228     dout_ = dout_factor;
229   }
230 
k_node() const231   AnfNodePtr k_node() const { return k_node_; }
set_k_node(const AnfNodePtr & k_node)232   void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
233 
234  private:
235   const FuncGraphPtr tape_;
236   AnfNodePtr dout_{nullptr};
237   // Used by whose
238   AnfNodePtrList users_;
239   // cache these arguments from ad caller.
240   const ValuePtrList op_args_;
241   // For CNode , it's output of cnode. For Parameter or ValueNode, it's its value.
242   const ValuePtr out_;
243   // fg_ is a bprop_fg generated from Primitive.
244   // or a fprop_fg passed from caller.
245   // FuncGraph to tape_;
246   const FuncGraphPtr fg_;
247   const FuncGraphType fg_type_;
248   // k mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
249   AnfNodePtr k_node_;
250 };
251 using PynativeAdjointPtr = std::shared_ptr<PynativeAdjoint>;
252 
253 class KPynativeCellImpl : public KPynativeCell {
254  public:
KPynativeCellImpl(const AnfNodePtrList & cell_inputs,const std::vector<ValuePtr> & input_param_values)255   KPynativeCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
256       : tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
257     tape_->debug_info()->set_name("grad_top");
258     for (size_t i = 0; i < cell_inputs.size(); ++i) {
259       TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
260       (void)tape_->add_parameter();
261       // Build adjoint for every input parameter
262       auto input_adjoint =
263         std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, input_param_values[i], FuncGraphPtr(nullptr));
264       (void)anfnode_to_adjoin_.insert(std::make_pair(cell_inputs[i], input_adjoint));
265     }
266   }
267   ~KPynativeCellImpl() override = default;
268   bool KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out);
269   bool KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
270                           const FuncGraphPtr &bprop_fg);
271   bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
272                           const FuncGraphPtr &fprop_fg) override;
273   void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) override;
274   // Build a back propagate funcgraph, each cnode in primal funcgraph is replaced by value node or formal cnode, so it
275   // can be grad again.
276   FuncGraphPtr Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg,
277                       bool build_formal_param);
278 
279  private:
280   bool need_propagate_stop_gradient_{false};
281   // Last cnode of this Cell, may be a primitive op or cell with user defined bprop.
282   AnfNodePtr last_node_{nullptr};
283   FuncGraphPtr tape_;
284   AnfNodePtrList cell_inputs_;
285   // These weights need to calculate gradient.
286   std::unordered_set<AnfNodePtr> need_grad_weights_;
287   OrderedMap<AnfNodePtr, PynativeAdjointPtr> anfnode_to_adjoin_;
288 
289   // For CNode like TupleGetItem, ListGetItem, MakeTuple, MakeList, it's bypassed by caller so
290   // no KPynativeOp is called for these CNode. Here we forge Adjoint for these CNode.
291   PynativeAdjointPtr ForgeCNodeAdjoint(const CNodePtr &cnode);
292   PynativeAdjointPtr ForgeGetItemAdjoint(const CNodePtr &cnode);
293   PynativeAdjointPtr ForgeMakeSequenceAdjoint(const CNodePtr &cnode);
294   bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
295                     const FuncGraphPtr &bprop_fg,
296                     PynativeAdjoint::FuncGraphType fg_type = PynativeAdjoint::kBackwardPropagate);
297   void BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args);
298   void PropagateStopGradient();
299   bool AllReferencesStopped(const CNodePtr &curr_cnode);
300   OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
301   // Back propagate for all node;
302   // if by_value is true, in bprop_app cnode, every input is value node;
303   // if by_value is false, in bprop_app cnode, input is the k mapped node, so it can be grad again.
304   bool BackPropagate(bool by_value);
305   bool BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint,
306                                                const FuncGraphPtr &bprop_fg, bool by_value);
307   bool BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint,
308                                                const FuncGraphPtr &fprop_fg, bool by_value);
309   bool BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app);
310   AnfNodePtr BuildKNodeForCNodeInput(const PynativeAdjointPtr &cnode_adjoint, const AnfNodePtr &input_node,
311                                      size_t input_index);
312   const AnfNodePtrList BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const PynativeAdjointPtr &adjoint);
313   FuncGraphPtr BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode);
314   // Back propagate for MakeList or MakeTuple is generated from MetaFuncGraph.
315   FuncGraphPtr BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode);
316   // Replace input or weights parameter from primal funcgraph to parameters of tape_;
317   void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
318   // Set sens and weights parameter nodes by user input info
319   void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
320   // Set return node according to grad flag
321   void SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights);
322 
323   // for higher order gradient;
324   // Build k mapped node owned by tape_ for each cnode in primal funcgraph, so these node can be
325   // used in tape_ to keep tracking the cnode dependency.
326   bool BuildKNode();
327   CNodePtr GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args);
328 };
329 using KPynativeCellImplPtr = std::shared_ptr<KPynativeCellImpl>;
330 
GradPynativeCellBegin(const AnfNodePtrList & cell_inputs,const std::vector<ValuePtr> & input_param_values)331 KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
332                                        const std::vector<ValuePtr> &input_param_values) {
333   auto abstract_are_set = std::all_of(cell_inputs.cbegin(), cell_inputs.cend(),
334                                       [](const AnfNodePtr &node) { return node->abstract() != nullptr; });
335   if (!abstract_are_set) {
336     MS_LOG(EXCEPTION) << "Not all abstract_value in cell_inputs are set";
337   }
338   if (cell_inputs.size() != input_param_values.size()) {
339     MS_LOG(EXCEPTION) << "The size of cell inputs " << cell_inputs.size()
340                       << " is not equal to the size of input parameter values " << input_param_values.size();
341   }
342   return std::make_shared<KPynativeCellImpl>(cell_inputs, input_param_values);
343 }
344 
GradPynativeCellEnd(const KPynativeCellPtr & k_cell,const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights,bool has_sens_arg,bool build_formal_param)345 FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
346                                  bool grad_weights, bool has_sens_arg, bool build_formal_param) {
347   auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
348   return k_cell_impl->Finish(weights, grad_inputs, grad_weights, has_sens_arg, build_formal_param);
349 }
350 
Finish(const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights,bool has_sens_arg,bool build_formal_param)351 FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights,
352                                        bool has_sens_arg, bool build_formal_param) {
353   // propagate stop_gradient flag to cnode before back propagate;
354   PropagateStopGradient();
355   // Set sens node and weights node
356   SetSensAndWeights(weights, has_sens_arg);
357   // Build forward CNode;
358   if (build_formal_param) {
359     (void)BuildKNode();
360   }
361   // BackPropagate sensitivity, except when the last node is a valuenode which may be obtained by constant folding;
362   if (!last_node_->isa<ValueNode>()) {
363     (void)BackPropagate(!build_formal_param);
364   }
365   // Return the gradient;
366   SetOutput(weights, grad_inputs, grad_weights);
367   // Replace Parameter of primal funcgraph  with parameter of tape_;
368   ReplacePrimalParameter(weights, has_sens_arg);
369 #ifdef ENABLE_DUMP_IR
370   auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
371   if (save_graphs_flg) {
372     DumpIR("before_final_opt.ir", tape_);
373   }
374 #endif
375   return tape_;
376 }
377 
GradPynativeOp(const KPynativeCellPtr & k_cell,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)378 bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
379                     const ValuePtr &out) {
380   auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
381   return k_cell_impl->KPynativeOp(cnode, op_args, out);
382 }
383 
KPynativeOp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)384 bool KPynativeCellImpl::KPynativeOp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out) {
385   MS_EXCEPTION_IF_NULL(cnode);
386   auto prim = GetCNodePrimitive(cnode);
387   if (prim == nullptr) {
388     MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
389   }
390   if (IsPrimitiveEquals(prim, prim::kPrimStopGradient) || IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
391     need_propagate_stop_gradient_ = true;
392   }
393 
394   FuncGraphPtr bprop_fg = nullptr;
395   if (IsPrimitiveEquals(prim, prim::kPrimHookBackward)) {
396     bprop_fg = BuildBPropCutFuncGraph(prim, cnode);
397   } else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
398     bprop_fg = BuildMakeSequenceBprop(prim, cnode);
399   } else {
400     bprop_fg = g_k_prims_pynative.GetPossibleBprop(prim);
401     if (bprop_fg == nullptr) {
402       MS_LOG(DEBUG) << "Cannot find defined bprop for cnode prim: " << cnode->DebugString();
403       bprop_fg = BuildFakeBProp(prim, cnode->size() - 1);
404     }
405   }
406   MS_EXCEPTION_IF_NULL(bprop_fg);
407   (void)BuildAdjoint(cnode, op_args, out, bprop_fg);
408 
409   return true;
410 }
411 
GradPynativeWithBProp(const KPynativeCellPtr & k_cell,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & bprop_fg)412 bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &cnode, const ValuePtrList &op_args,
413                            const ValuePtr &out, const FuncGraphPtr &bprop_fg) {
414   auto k_cell_impl = std::dynamic_pointer_cast<KPynativeCellImpl>(k_cell);
415   return k_cell_impl->KPynativeWithBProp(cnode, op_args, out, bprop_fg);
416 }
417 
KPynativeWithBProp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & bprop_fg)418 bool KPynativeCellImpl::KPynativeWithBProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
419                                            const FuncGraphPtr &bprop_fg) {
420   MS_EXCEPTION_IF_NULL(cnode);
421   auto primal_fg = GetCNodeFuncGraph(cnode);
422   if (primal_fg == nullptr) {
423     MS_LOG(EXCEPTION) << "Should be func graph, but: " << cnode->DebugString();
424   }
425   MS_EXCEPTION_IF_NULL(bprop_fg);
426   (void)BuildAdjoint(cnode, op_args, out, bprop_fg);
427 
428   return true;
429 }
430 
KPynativeWithFProp(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fprop_fg)431 bool KPynativeCellImpl::KPynativeWithFProp(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
432                                            const FuncGraphPtr &fprop_fg) {
433   MS_EXCEPTION_IF_NULL(cnode);
434   MS_EXCEPTION_IF_NULL(fprop_fg);
435 
436   (void)BuildAdjoint(cnode, op_args, out, fprop_fg, PynativeAdjoint::kForwardPropagate);
437 
438   return true;
439 }
440 
UpdateOutputNodeOfTopCell(const AnfNodePtr & output_node)441 void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) {
442   MS_EXCEPTION_IF_NULL(output_node);
443   MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
444   last_node_ = output_node;
445 
446   auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
447   if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
448     if (IsPrimitiveCNode(output_node, prim::kPrimTupleGetItem) ||
449         IsPrimitiveCNode(output_node, prim::kPrimListGetItem)) {
450       MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << output_node->DebugString();
451       auto cnode = output_node->cast<CNodePtr>();
452       (void)ForgeGetItemAdjoint(cnode);
453       return;
454     } else if (output_node->isa<ValueNode>()) {
455       auto v_node = output_node->cast<ValueNodePtr>();
456       MS_LOG(DEBUG) << "Build adjoint for valuenode: " << v_node->ToString();
457       auto v_node_pynative_adjoint =
458         std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, v_node->value(), FuncGraphPtr(nullptr));
459       (void)anfnode_to_adjoin_.insert(std::make_pair(output_node, v_node_pynative_adjoint));
460       return;
461     }
462     MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
463   }
464 }
465 
466 namespace {
ShallowCopyValue(const ValuePtr & value)467 ValuePtr ShallowCopyValue(const ValuePtr &value) {
468   MS_EXCEPTION_IF_NULL(value);
469   if (value->isa<mindspore::tensor::Tensor>()) {
470     auto tensor_value = value->cast<mindspore::tensor::TensorPtr>();
471     return std::make_shared<mindspore::tensor::Tensor>(*tensor_value);
472   } else if (value->isa<ValueTuple>()) {
473     std::vector<ValuePtr> values;
474     auto value_tuple = value->cast<ValueTuplePtr>();
475     (void)std::transform(value_tuple->value().begin(), value_tuple->value().end(), std::back_inserter(values),
476                          [](const ValuePtr &elem) { return ShallowCopyValue(elem); });
477     return std::make_shared<ValueTuple>(values);
478   } else {
479     return value;
480   }
481 }
482 }  // namespace
483 
ForgeGetItemAdjoint(const CNodePtr & cnode)484 PynativeAdjointPtr KPynativeCellImpl::ForgeGetItemAdjoint(const CNodePtr &cnode) {
485   if (cnode->size() != 3) {
486     MS_LOG(EXCEPTION) << "TupleGetItem/ListGetItem CNode should have 3 inputs, but CNode: " << cnode->DebugString();
487   }
488   // Input 1 of CNode;
489   PynativeAdjointPtr input_1_adjoint = nullptr;
490   auto input_1 = cnode->input(1);
491   auto input_1_adjoint_iter = anfnode_to_adjoin_.find(input_1);
492   if (input_1_adjoint_iter == anfnode_to_adjoin_.end()) {
493     if (!input_1->isa<CNode>()) {
494       MS_LOG(EXCEPTION) << "Input 1 of CNode should be a CNode, CNode: " << cnode->DebugString();
495     }
496     input_1_adjoint = ForgeCNodeAdjoint(input_1->cast<CNodePtr>());
497     if (input_1_adjoint == nullptr) {
498       MS_LOG(EXCEPTION) << "Build adjoint for input 1 of CNode failed, CNode: " << cnode->DebugString();
499     }
500     input_1_adjoint->users().push_back(cnode);
501   } else {
502     input_1_adjoint = input_1_adjoint_iter->second;
503   }
504   if (!input_1_adjoint->out()->isa<ValueSequeue>()) {
505     MS_LOG(EXCEPTION) << "Input of CNode should be evaluated to a ValueSequence. CNode: " << cnode->DebugString()
506                       << ", out of input1: " << input_1_adjoint->out()->ToString();
507   }
508   auto input_1_out = input_1_adjoint->out()->cast<ValueSequeuePtr>();
509 
510   // Input 2 of CNode;
511   auto index_value = GetValueNode<Int64ImmPtr>(cnode->input(2));
512   if (index_value == nullptr) {
513     MS_LOG(EXCEPTION) << "CNode input 2 should be a Int64Imm, CNode: " << cnode->DebugString();
514   }
515   if (index_value->value() < 0) {
516     MS_LOG(EXCEPTION) << "CNode input 2 should not less than 0, CNode: " << cnode->DebugString();
517   }
518   size_t index_value_imm = LongToSize(index_value->value());
519   if (index_value_imm >= input_1_out->size()) {
520     MS_LOG(EXCEPTION) << "CNode input 2 should be index between [0, " << input_1_out->size()
521                       << ", but: " << index_value->ToString();
522   }
523   auto cnode_out = (*input_1_out)[index_value_imm];
524   ValuePtrList op_args{input_1_out, index_value};
525   auto built = KPynativeOp(cnode, op_args, cnode_out);
526   if (!built) {
527     MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString();
528   }
529   auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
530   if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) {
531     MS_LOG(EXCEPTION) << "Build Adjoint for GetItem node failed, CNode: " << cnode->DebugString();
532   }
533   return cnode_adjoint_iter->second;
534 }
535 
ForgeMakeSequenceAdjoint(const CNodePtr & cnode)536 PynativeAdjointPtr KPynativeCellImpl::ForgeMakeSequenceAdjoint(const CNodePtr &cnode) {
537   // () or [] is not supported yet.
538   if (cnode->size() <= 1) {
539     MS_LOG(DEBUG) << "MakeTuple/MakeList CNode is empty Tuple/List, CNode: " << cnode->DebugString();
540     auto empty_tuple = MakeValue(std::vector<ValuePtr>{});
541     auto dummy_adjoint =
542       std::make_shared<PynativeAdjoint>(FuncGraphPtr(nullptr), ValuePtrList{}, empty_tuple, FuncGraphPtr(nullptr));
543     anfnode_to_adjoin_[cnode] = dummy_adjoint;
544     cnode->set_stop_gradient(true);
545     return dummy_adjoint;
546   }
547   ValuePtrList op_args;
548   for (size_t i = 1; i < cnode->size(); ++i) {
549     const auto &input = cnode->input(i);
550     auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
551     if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
552       MS_LOG(DEBUG) << "Item in CNode cannot found in cache. Input is: " << input->DebugString();
553       if (input->isa<CNode>()) {
554         const auto input_cnode = input->cast<CNodePtr>();
555         MS_EXCEPTION_IF_NULL(input_cnode);
556         auto forged_input_adjoint = ForgeCNodeAdjoint(input->cast<CNodePtr>());
557         op_args.push_back(forged_input_adjoint->out());
558       } else if (input->isa<ValueNode>()) {
559         const auto &input_value = GetValueNode(input);
560         op_args.push_back(input_value);
561       } else {
562         MS_LOG(EXCEPTION) << "Input of MakeTuple/MakeLis is not a CNode or ValueNode, but: " << input->DebugString();
563       }
564     } else {
565       op_args.push_back(input_adjoint_iter->second->out());
566     }
567   }
568   ValuePtr cnode_out = nullptr;
569   if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
570     cnode_out = MakeValue(op_args);
571   }
572   if (IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
573     cnode_out = std::make_shared<ValueList>(op_args);
574   }
575   // op_args is real inputs find by prev cnode outputs
576   auto built = KPynativeOp(cnode, op_args, cnode_out);
577   if (!built) {
578     MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString();
579   }
580   auto cnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
581   if (cnode_adjoint_iter == anfnode_to_adjoin_.end()) {
582     MS_LOG(EXCEPTION) << "Build Adjoint for MakeTuple/MakeList node failed, CNode: " << cnode->DebugString();
583   }
584   return cnode_adjoint_iter->second;
585 }
586 
ForgeCNodeAdjoint(const CNodePtr & cnode)587 PynativeAdjointPtr KPynativeCellImpl::ForgeCNodeAdjoint(const CNodePtr &cnode) {
588   if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimListGetItem)) {
589     MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString();
590     return ForgeGetItemAdjoint(cnode);
591   }
592 
593   if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
594     MS_LOG(DEBUG) << "Build cnode adjoint for anfnode: " << cnode->DebugString();
595     return ForgeMakeSequenceAdjoint(cnode);
596   }
597   MS_LOG(EXCEPTION) << "Unknown cnode: " << cnode->DebugString();
598 }
599 
BuildAdjointForInput(const CNodePtr & cnode,const ValuePtrList & op_args)600 void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args) {
601   auto anfnode_adjoint_iter = anfnode_to_adjoin_.find(cnode);
602   if (anfnode_adjoint_iter != anfnode_to_adjoin_.end()) {
603     MS_LOG(EXCEPTION) << "CNode should be unique, but: " << cnode->DebugString();
604   }
605   // Book-keeping last cnode, as dout of this node will be given from outside;
606   last_node_ = cnode;
607 
608   for (size_t i = 1; i < cnode->inputs().size(); ++i) {
609     auto input = cnode->input(i);
610     auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
611     if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
612       if (input->isa<CNode>()) {
613         auto cnode_input = input->cast<CNodePtr>();
614         auto forged_adjoint = ForgeCNodeAdjoint(cnode_input);
615         if (forged_adjoint == nullptr) {
616           MS_LOG(EXCEPTION) << "Cannot forge adjoint for anfnode: " << input->DebugString();
617         }
618         forged_adjoint->users().push_back(cnode);
619       } else {
620         MS_EXCEPTION_IF_NULL(op_args[i - 1]);
621         auto input_adjoint =
622           std::make_shared<PynativeAdjoint>(tape_, ValuePtrList{}, op_args[i - 1], FuncGraphPtr(nullptr));
623         (void)anfnode_to_adjoin_.insert(std::make_pair(input, input_adjoint));
624         input_adjoint->users().push_back(cnode);
625       }
626     } else {
627       input_adjoint_iter->second->users().push_back(cnode);
628     }
629   }
630 }
631 
BuildAdjoint(const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out,const FuncGraphPtr & fg,const PynativeAdjoint::FuncGraphType fg_type)632 bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
633                                      const FuncGraphPtr &fg, const PynativeAdjoint::FuncGraphType fg_type) {
634   // Optimize the bprop_fg based on value.
635   // Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor
636   // is not used in bprop_fg;
637   ValuePtrList cloned_op_args;
638   (void)std::transform(op_args.begin(), op_args.end(), std::back_inserter(cloned_op_args),
639                        [](const ValuePtr &value) { return ShallowCopyValue(value); });
640   ValuePtr cloned_out = ShallowCopyValue(out);
641   PynativeAdjointPtr cnode_adjoint;
642   if (fg_type == PynativeAdjoint::kBackwardPropagate) {
643     auto optimized_bprop_fg = OptimizeBPropFuncGraph(fg, cnode, cloned_op_args, cloned_out);
644     cnode_adjoint = std::make_shared<PynativeAdjoint>(tape_, cloned_op_args, cloned_out, optimized_bprop_fg);
645   } else {
646     cnode_adjoint = std::make_shared<PynativeAdjoint>(tape_, cloned_op_args, cloned_out, fg, fg_type);
647   }
648 
649   BuildAdjointForInput(cnode, op_args);
650 
651   (void)anfnode_to_adjoin_.insert(std::make_pair(cnode, cnode_adjoint));
652 
653   return true;
654 }
655 
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & cnode,const ValuePtrList & op_args,const ValuePtr & out)656 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
657                                     const ValuePtr &out) {
658   auto optimized_bprop_fg =
659     PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
660   return optimized_bprop_fg;
661 }
662 
BackPropagate(const CNodePtr & cnode_primal,const CNodePtr & bprop_app)663 bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodePtr &bprop_app) {
664   abstract::AbstractTuplePtr abstract_tuple = nullptr;
665   auto bprop_app_abstract = bprop_app->abstract();
666   // if input 0 of bprop_app is a CNode other than FuncGraph ValueNode, bprop_app_abstract is nullptr;
667   // After tape_ returned, caller should renormalize tape_ to set abstract of each AnfNode.
668   if (bprop_app_abstract != nullptr) {
669     abstract_tuple = bprop_app_abstract->cast<abstract::AbstractTuplePtr>();
670     if (abstract_tuple->size() != (cnode_primal->size() - 1)) {
671       MS_LOG(EXCEPTION) << "AbstractTuple size: " << abstract_tuple->ToString()
672                         << " not match primal cnode input size: " << cnode_primal->DebugString();
673     }
674   }
675   for (size_t i = 1; i < cnode_primal->size(); i++) {
676     auto input = cnode_primal->input(i);
677     // Useless to accumulate sens for ValueNode, the sens for ValueNode should be zeros_like;
678     if (input->isa<ValueNode>()) {
679       continue;
680     }
681     auto cnode_input = input->cast<CNodePtr>();
682     if (cnode_input != nullptr && cnode_input->stop_gradient()) {
683       MS_LOG(DEBUG) << "Bypass accumulate dout to cnode with stop_gradient flag, cnode: " << input->DebugString();
684       continue;
685     }
686     // Backprop sens wrt inputs.
687     auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
688     if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
689       MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->DebugString();
690     }
691     AnfNodePtr din;
692     if (abstract_tuple != nullptr) {
693       din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i - 1))});
694       din->set_abstract((*abstract_tuple)[i - 1]);
695     } else {
696       // bprop_app[0] is env;
697       din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
698       din->set_abstract(input_adjoint_iter->second->out()->ToAbstract()->Broaden());
699     }
700     input_adjoint_iter->second->AccumulateDout(din);
701   }
702   return true;
703 }
704 
BuildKNodeForCNodeInput(const PynativeAdjointPtr & cnode_adjoint,const AnfNodePtr & input_node,size_t input_index)705 AnfNodePtr KPynativeCellImpl::BuildKNodeForCNodeInput(const PynativeAdjointPtr &cnode_adjoint,
706                                                       const AnfNodePtr &input_node, size_t input_index) {
707   MS_EXCEPTION_IF_NULL(cnode_adjoint);
708   MS_EXCEPTION_IF_NULL(input_node);
709   if (input_node->isa<CNode>()) {
710     auto input_adjoint_iter = anfnode_to_adjoin_.find(input_node);
711     if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
712       MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
713     }
714     return input_adjoint_iter->second->k_node();
715   } else {
716     if (input_node->isa<Parameter>()) {
717       bool is_weight = input_node->cast<ParameterPtr>()->has_default();
718       // If weight does not need to calculate gradient, it will be converted to value node.
719       if (is_weight && need_grad_weights_.find(input_node) == need_grad_weights_.end()) {
720         return NewValueNode(cnode_adjoint->op_args()[input_index - 1]);
721       }
722     }
723     return input_node;
724   }
725 }
726 
BuildKNodeListFromPrimalCNode(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint)727 const AnfNodePtrList KPynativeCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode,
728                                                                       const PynativeAdjointPtr &adjoint) {
729   MS_EXCEPTION_IF_NULL(cnode);
730   MS_EXCEPTION_IF_NULL(adjoint);
731   AnfNodePtrList node_list;
732   for (size_t i = 1; i < cnode->inputs().size(); ++i) {
733     (void)node_list.emplace_back(BuildKNodeForCNodeInput(adjoint, cnode->input(i), i));
734   }
735   return node_list;
736 }
737 
BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint,const FuncGraphPtr & bprop_fg,bool by_value)738 bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode,
739                                                                 const PynativeAdjointPtr &adjoint,
740                                                                 const FuncGraphPtr &bprop_fg, bool by_value) {
741   AnfNodePtrList node_list;
742   abstract::AbstractBasePtr bprop_output_abs;
743 
744   bprop_output_abs = bprop_fg->output()->abstract();
745   if (bprop_output_abs == nullptr) {
746     MS_LOG(EXCEPTION) << "Abstract of bprop_output_abs is not AbstractTuple, but nullptr";
747   }
748   if (!bprop_output_abs->isa<abstract::AbstractTuple>()) {
749     MS_LOG(EXCEPTION) << "Abstract of bprop_output_abs is not AbstractTuple, but: " << bprop_output_abs->ToString();
750   }
751   node_list.push_back(NewValueNode(bprop_fg));
752 
753   if (by_value) {
754     for (size_t i = 0; i < adjoint->op_args().size(); ++i) {
755       auto input_node = cnode->input(i + 1);
756       if (input_node->isa<Parameter>()) {
757         bool is_weight = input_node->cast<ParameterPtr>()->has_default();
758         if (!is_weight || need_grad_weights_.find(input_node) != need_grad_weights_.end()) {
759           node_list.push_back(input_node);
760           continue;
761         }
762       }
763       auto v_node = NewValueNode(adjoint->op_args()[i]);
764       v_node->set_abstract(adjoint->op_args()[i]->ToAbstract()->Broaden());
765       node_list.push_back(v_node);
766     }
767     auto out_node = NewValueNode(adjoint->out());
768     out_node->set_abstract(adjoint->out()->ToAbstract()->Broaden());
769     node_list.push_back(out_node);
770     node_list.push_back(adjoint->RealDout());
771   } else {
772     const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint);
773     (void)node_list.insert(node_list.end(), k_node_list.begin(), k_node_list.end());
774     // out;
775     node_list.push_back(adjoint->k_node());
776     // dout
777     node_list.push_back(adjoint->RealDout());
778   }
779   // Back propagate process
780   auto bprop_app = tape_->NewCNode(node_list);
781   bprop_app->set_abstract(bprop_output_abs);
782   (void)BackPropagate(cnode, bprop_app);
783   return true;
784 }
785 
BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr & cnode,const PynativeAdjointPtr & adjoint,const FuncGraphPtr & fprop_fg,bool by_value)786 bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &cnode,
787                                                                 const PynativeAdjointPtr &adjoint,
788                                                                 const FuncGraphPtr &fprop_fg, bool by_value) {
789   MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
790 
791   AnfNodePtrList node_list;
792   CNodePtr bprop_cnode;
793   if (by_value) {
794     AnfNodePtrList args_node_list;
795     for (size_t i = 0; i < adjoint->op_args().size(); ++i) {
796       auto input_node = cnode->input(i + 1);
797       if (input_node->isa<Parameter>()) {
798         bool is_weight = input_node->cast<ParameterPtr>()->has_default();
799         if (!is_weight || need_grad_weights_.find(input_node) != need_grad_weights_.end()) {
800           args_node_list.push_back(input_node);
801           continue;
802         }
803       }
804       auto v_node = NewValueNode(adjoint->op_args()[i]);
805       v_node->set_abstract(adjoint->op_args()[i]->ToAbstract()->Broaden());
806       args_node_list.push_back(v_node);
807     }
808     bprop_cnode = GetBPropFromFProp(fprop_fg, args_node_list);
809   } else {
810     const auto &k_node_list = BuildKNodeListFromPrimalCNode(cnode, adjoint);
811     bprop_cnode = GetBPropFromFProp(fprop_fg, k_node_list);
812   }
813   node_list.push_back(bprop_cnode);
814   // dout;
815   node_list.push_back(adjoint->RealDout());
816   // Back propagate process
817   auto bprop_app = tape_->NewCNode(node_list);
818   (void)BackPropagate(cnode, bprop_app);
819   return true;
820 }
821 
GetLastNodeReverseIter()822 OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator KPynativeCellImpl::GetLastNodeReverseIter() {
823   for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
824     if (!iter->first->isa<CNode>()) {
825       continue;
826     }
827     if (iter->first->cast<CNodePtr>() == last_node_) {
828       return iter;
829     }
830   }
831   return anfnode_to_adjoin_.rend();
832 }
833 
BackPropagate(bool by_value)834 bool KPynativeCellImpl::BackPropagate(bool by_value) {
835   auto last_node_reverse_iter = GetLastNodeReverseIter();
836   for (auto iter = last_node_reverse_iter; iter != anfnode_to_adjoin_.rend(); ++iter) {
837     if (!iter->first->isa<CNode>()) {
838       continue;
839     }
840     auto cnode = iter->first->cast<CNodePtr>();
841     if (cnode->stop_gradient()) {
842       MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->DebugString();
843       continue;
844     }
845     MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
846     auto fg = iter->second->fg();
847     auto fg_type = iter->second->fg_type();
848     if (fg_type == PynativeAdjoint::kBackwardPropagate) {
849       (void)BackPropagateOneCNodeWithBPropFuncGraph(cnode, iter->second, fg, by_value);
850     } else {
851       (void)BackPropagateOneCNodeWithFPropFuncGraph(cnode, iter->second, fg, by_value);
852     }
853   }
854   return true;
855 }
856 
AllReferencesStopped(const CNodePtr & curr_cnode)857 bool KPynativeCellImpl::AllReferencesStopped(const CNodePtr &curr_cnode) {
858   // If all CNode use curr_cnode has stop_gradient_ flag, then curr_cnode also can set that flag.
859   auto iter = anfnode_to_adjoin_.find(curr_cnode);
860   if (iter == anfnode_to_adjoin_.end()) {
861     MS_LOG(EXCEPTION) << "Cannot find adjoint for cnode: " << curr_cnode->DebugString();
862   }
863   auto users = iter->second->users();
864   if (users.empty()) {
865     return false;
866   }
867   auto all_users_have_stopped = std::all_of(users.cbegin(), users.cend(), [](const AnfNodePtr &user) {
868     if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
869       return false;
870     }
871     return true;
872   });
873   return all_users_have_stopped;
874 }
875 
PropagateStopGradient()876 void KPynativeCellImpl::PropagateStopGradient() {
877   // propagate need_stop_gradient_ to cnode before back propagate;
878   if (need_propagate_stop_gradient_) {
879     for (auto iter = anfnode_to_adjoin_.rbegin(); iter != anfnode_to_adjoin_.rend(); ++iter) {
880       const auto &node = iter->first;
881       if (node->isa<CNode>()) {
882         auto cnode = node->cast<CNodePtr>();
883         if (!cnode->stop_gradient()) {
884           // Cut off the cnode only when it's not referred any more
885           if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
886               AllReferencesStopped(cnode)) {
887             MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->DebugString();
888             cnode->set_stop_gradient(true);
889           }
890         }
891       }
892     }
893   }
894 }
895 
BuildBPropCutFuncGraph(const PrimitivePtr & prim,const CNodePtr & cnode)896 FuncGraphPtr KPynativeCellImpl::BuildBPropCutFuncGraph(const PrimitivePtr &prim, const CNodePtr &cnode) {
897   auto inputs_num = cnode->size() - 1;
898 
899   auto func_graph = std::make_shared<FuncGraph>();
900   std::vector<AnfNodePtr> outputs;
901 
902   auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
903   bprop_cut->CopyHookFunction(prim);
904 
905   auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
906   if (cell_id != "") {
907     (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
908     (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
909   }
910 
911   outputs.push_back(NewValueNode(bprop_cut));
912   for (size_t i = 0; i < inputs_num; ++i) {
913     auto param = func_graph->add_parameter();
914     outputs.push_back(param);
915   }
916   // out, dout
917   auto p1 = func_graph->add_parameter();
918   auto p2 = func_graph->add_parameter();
919   outputs.push_back(p1);
920   outputs.push_back(p2);
921 
922   func_graph->set_output(func_graph->NewCNode(outputs));
923   return func_graph;
924 }
925 
BuildMakeSequenceBprop(const PrimitivePtr & prim,const CNodePtr & cnode)926 FuncGraphPtr KPynativeCellImpl::BuildMakeSequenceBprop(const PrimitivePtr &prim, const CNodePtr &cnode) {
927   auto inputs_num = cnode->size() - 1;
928   CacheKey key{prim->name(), inputs_num};
929   auto bprop_func_graph_iter = bprop_func_graph_cache.find(key);
930   if (bprop_func_graph_iter != bprop_func_graph_cache.end()) {
931     return bprop_func_graph_iter->second;
932   }
933 
934   FuncGraphPtr b = std::make_shared<FuncGraph>();
935 
936   std::ostringstream ss;
937   ss << "◀" << prim->ToString() << inputs_num;
938   b->debug_info()->set_name(ss.str());
939   for (size_t i = 0; i < inputs_num; ++i) {
940     auto param = b->add_parameter();
941     MS_EXCEPTION_IF_NULL(param);
942   }
943   // out, dout
944   auto p1 = b->add_parameter();
945   MS_EXCEPTION_IF_NULL(p1);
946   AnfNodePtr dout = b->add_parameter();
947 
948   std::vector<AnfNodePtr> grads;
949   PrimitivePtr getitem_prim;
950 
951   if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
952     getitem_prim = prim::kPrimTupleGetItem;
953   } else if (IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
954     getitem_prim = prim::kPrimListGetItem;
955   } else {
956     MS_LOG(EXCEPTION) << "Prim should be MakeTuple or MakeList, Invalid prim: " << prim->ToString();
957   }
958 
959   grads.push_back(NewValueNode(prim));
960   for (size_t i = 0; i < inputs_num; ++i) {
961     grads.push_back(b->NewCNode({NewValueNode(getitem_prim), dout, NewValueNode(SizeToLong(i))}));
962   }
963 
964   b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
965   b->set_output(b->NewCNode(grads));
966 
967   bprop_func_graph_cache[key] = b;
968   return b;
969 }
970 
SetSensAndWeights(const AnfNodePtrList & weights,bool has_sens_arg)971 void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
972   MS_EXCEPTION_IF_NULL(last_node_);
973   MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
974   auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
975   if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
976     MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
977   }
978   // Add sens parameter
979   if (has_sens_arg) {
980     auto sens_param = tape_->add_parameter();
981     sens_param->debug_info()->set_name("sens");
982     sens_param->set_abstract(last_node_adjoint_iter->second->out()->ToAbstract()->Broaden());
983     // Set dout of last node to sens;
984     last_node_adjoint_iter->second->AccumulateDout(sens_param);
985   } else {
986     auto sens_node = BuildOnesLikeValue(tape_, last_node_adjoint_iter->second->out());
987     last_node_adjoint_iter->second->AccumulateDout(sens_node);
988   }
989   // Add weights parameter
990   need_grad_weights_.clear();
991   for (const auto &weight : weights) {
992     TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
993     auto p = tape_->add_parameter();
994     (void)need_grad_weights_.emplace(weight);
995     auto input_w = weight->cast<ParameterPtr>();
996     MS_EXCEPTION_IF_NULL(input_w);
997     // Use name to match weight parameter in high order
998     p->set_name(input_w->name());
999     p->set_default_param(input_w->default_param());
1000   }
1001 }
1002 
SetOutput(const AnfNodePtrList & weights,bool grad_inputs,bool grad_weights)1003 void KPynativeCellImpl::SetOutput(const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights) {
1004   AnfNodePtrList grad_inputs_list{NewValueNode(prim::kPrimMakeTuple)};
1005   AbstractBasePtr grad_inputs_spec;
1006   if (grad_inputs) {
1007     AbstractBasePtrList grad_inputs_abs_list;
1008     for (const auto &input : cell_inputs_) {
1009       MS_EXCEPTION_IF_NULL(input);
1010       auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
1011       if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1012         // If input is not used in the network, just return zeros_like() as dout;
1013         MS_LOG(WARNING) << "Input is not used in network, input: " << input->ToString();
1014         auto dout = BuildZerosLikeNode(tape_, input);
1015         grad_inputs_list.push_back(dout);
1016       } else {
1017         grad_inputs_list.push_back(input_adjoint_iter->second->RealDout());
1018       }
1019       grad_inputs_abs_list.push_back(grad_inputs_list.back()->abstract());
1020     }
1021     grad_inputs_spec = std::make_shared<abstract::AbstractTuple>(grad_inputs_abs_list);
1022   }
1023 
1024   AnfNodePtrList grad_weights_list{NewValueNode(prim::kPrimMakeTuple)};
1025   AbstractBasePtr grad_weights_spec;
1026   if (grad_weights) {
1027     AbstractBasePtrList grad_weights_abs_list;
1028     for (const auto &weight : weights) {
1029       MS_EXCEPTION_IF_NULL(weight);
1030       auto input_adjoint_iter = anfnode_to_adjoin_.find(weight);
1031       if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1032         // If weight is not used in the network, just return zeros_like() as dout;
1033         MS_LOG(WARNING) << "Weight is not used in network, weight: " << weight->ToString();
1034         auto input_w = weight->cast<ParameterPtr>();
1035         MS_EXCEPTION_IF_NULL(input_w);
1036         auto default_param = input_w->default_param();
1037         MS_EXCEPTION_IF_NULL(default_param);
1038         auto dout = BuildZerosLikeValue(tape_, default_param);
1039         grad_weights_list.push_back(dout);
1040       } else {
1041         grad_weights_list.push_back(input_adjoint_iter->second->RealDout());
1042       }
1043       grad_weights_abs_list.push_back(grad_weights_list.back()->abstract());
1044     }
1045     grad_weights_spec = std::make_shared<abstract::AbstractTuple>(grad_weights_abs_list);
1046   }
1047 
1048   AnfNodePtr tape_output;
1049   if (grad_inputs && grad_weights) {
1050     tape_output = tape_->NewCNode(
1051       {NewValueNode(prim::kPrimMakeTuple), tape_->NewCNode(grad_inputs_list), tape_->NewCNode(grad_weights_list)});
1052     tape_output->set_abstract(
1053       std::make_shared<abstract::AbstractTuple>(abstract::AbstractBasePtrList{grad_inputs_spec, grad_weights_spec}));
1054   } else if (grad_inputs) {
1055     tape_output = tape_->NewCNode(grad_inputs_list);
1056     tape_output->set_abstract(grad_inputs_spec);
1057   } else if (grad_weights) {
1058     tape_output = tape_->NewCNode(grad_weights_list);
1059     tape_output->set_abstract(grad_weights_spec);
1060   } else if (cell_inputs_.empty()) {
1061     tape_output = tape_->NewCNode(grad_inputs_list);
1062     tape_output->set_abstract(grad_inputs_spec);
1063   } else {
1064     auto input_adjoint_iter = anfnode_to_adjoin_.find(cell_inputs_[0]);
1065     if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
1066       // If input is not used in the network, just return zeros_like() as dout;
1067       MS_LOG(WARNING) << "Input is not used in network, input: " << cell_inputs_[0]->ToString();
1068       tape_output = BuildZerosLikeNode(tape_, cell_inputs_[0]);
1069     } else {
1070       tape_output = input_adjoint_iter->second->RealDout();
1071     }
1072   }
1073   tape_->set_output(tape_output);
1074 }
1075 
BuildKNode()1076 bool KPynativeCellImpl::BuildKNode() {
1077   for (auto iter = anfnode_to_adjoin_.cbegin(); iter != anfnode_to_adjoin_.cend(); ++iter) {
1078     if (!iter->first->isa<CNode>()) {
1079       continue;
1080     }
1081 
1082     AnfNodePtrList node_list;
1083     auto cnode = iter->first->cast<CNodePtr>();
1084     MS_EXCEPTION_IF_NULL(cnode);
1085     for (size_t i = 0; i < cnode->inputs().size(); ++i) {
1086       (void)node_list.emplace_back(BuildKNodeForCNodeInput(iter->second, cnode->input(i), i));
1087     }
1088     auto k_node = tape_->NewCNode(node_list);
1089     k_node->set_abstract(iter->second->out()->ToAbstract()->Broaden());
1090     iter->second->set_k_node(k_node);
1091   }
1092   return true;
1093 }
1094 
GetBPropFromFProp(const FuncGraphPtr & fprop_fg,const AnfNodePtrList & args)1095 CNodePtr KPynativeCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const AnfNodePtrList &args) {
1096   // Wrap tuple_getitem(fprop_app, 1) in a FuncGraph and optimize it;
1097   auto bprop_builder = std::make_shared<FuncGraph>();
1098   bprop_builder->debug_info()->set_name("bprop_builder");
1099 
1100   AnfNodePtrList fprop_app_inputs{NewValueNode(fprop_fg)};
1101   AnfNodePtrList bprop_builder_inputs;
1102   for (const auto &arg : args) {
1103     auto param = bprop_builder->add_parameter();
1104     fprop_app_inputs.push_back(param);
1105     bprop_builder_inputs.push_back(arg);
1106   }
1107   auto fprop_app = bprop_builder->NewCNode(fprop_app_inputs);
1108   auto get_bprop =
1109     bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
1110   bprop_builder->set_output(get_bprop);
1111   (void)bprop_builder_inputs.insert(bprop_builder_inputs.begin(), NewValueNode(bprop_builder));
1112   get_bprop = tape_->NewCNode(bprop_builder_inputs);
1113 
1114   return get_bprop;
1115 }
1116 
ReplacePrimalParameter(const AnfNodePtrList & weights,bool has_sens_arg)1117 void KPynativeCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
1118   auto mng = MakeManager({tape_}, false);
1119   auto tr = mng->Transact();
1120   const auto &parameters = tape_->parameters();
1121   auto cell_inputs_size = cell_inputs_.size();
1122   for (size_t i = 0; i < cell_inputs_size; ++i) {
1123     (void)tr.Replace(cell_inputs_[i], parameters[i]);
1124   }
1125   // (Inputs, sens, weights) or (Inputs, weights)
1126   size_t weight_offset = cell_inputs_size;
1127   if (has_sens_arg) {
1128     weight_offset = weight_offset + 1;
1129   }
1130   for (size_t i = 0; i < weights.size(); ++i) {
1131     (void)tr.Replace(weights[i], parameters[weight_offset + i]);
1132   }
1133   tr.Commit();
1134 }
1135 
ClearKPynativeCellStaticRes()1136 void ClearKPynativeCellStaticRes() {
1137   irpass = nullptr;
1138   add_ops = nullptr;
1139   ones_like_ops = nullptr;
1140   zeros_like_ops = nullptr;
1141   g_k_prims_pynative.clear();
1142   bprop_func_graph_cache.clear();
1143   zeros_like_funcgraph_cache.clear();
1144   ones_like_funcgraph_cache.clear();
1145 }
1146 }  // namespace ad
1147 }  // namespace mindspore
1148